From 5b06d5b9d05d8d93bad024f0faff9e4c8acaf57a Mon Sep 17 00:00:00 2001 From: Faraz Shahsavan Date: Sat, 7 Dec 2024 18:55:37 +0000 Subject: [PATCH] Add comprehensive heuristics to c3x kernels for h100 dense fp8 --- CMakeLists.txt | 6 +- .../cutlass_benchmarks/dense_mm/bench_v1.py | 184 ++++ .../cutlass_benchmarks/dense_mm/bench_v2.py | 608 ++++++++++++ .../dense_mm/mm_benchmarks.py | 224 +++++ .../dense_mm/stable_kernels_fp8.json | 1 + .../cutlass_benchmarks/dense_mm/utils.py | 55 ++ .../dense_mm/weight_shapes.py | 75 ++ .../cutlass_benchmarks/w8a8_benchmarks.py | 389 -------- .../cutlass_benchmarks/weight_shapes.py | 43 - .../epilogue/scaled_mm_epilogues_c3x.hpp | 4 +- .../broadcast_load_epilogue_c2x.hpp | 496 ++++++++++ .../broadcast_load_epilogue_c3x.hpp | 447 +++++++++ .../cutlass_w8a8/generator/README.md | 143 +++ .../generator/autogen_manifest.py | 167 ++++ .../cutlass_w8a8/generator/generator.py | 155 +++ .../cutlass_w8a8/generator/generator_types.py | 77 ++ .../cutlass_w8a8/generator/kernel_compiler.py | 131 +++ .../generator/kernel_generator.py | 251 +++++ .../generator/scaled_mm_c3x.jinja | 56 ++ .../generator/scaled_mm_c3x_fnprototype.jinja | 6 + .../scaled_mm_c3x_struct_prototype.jinja | 6 + .../cutlass_w8a8/generator/tools/heatmap.py | 242 +++++ .../generator/tools/select_kernels.py | 324 +++++++ .../generator/tools/test_kernel.py | 119 +++ .../generator/tools/test_utils.py | 114 +++ .../cutlass_w8a8/generator/tools/utils.py | 75 ++ .../cutlass_w8a8/generator/utils.py | 56 ++ .../cutlass_w8a8/scaled_mm_c2x.cu | 53 +- .../cutlass_w8a8/scaled_mm_c2x.cuh | 302 ++++++ .../cutlass_w8a8/scaled_mm_c3x.cu | 918 ++++++++++++------ .../cutlass_w8a8/scaled_mm_c3x.cuh | 779 +++++++++++++++ .../cutlass_w8a8/scaled_mm_c3x_configs.cuh | 223 +++++ .../cutlass_w8a8/scaled_mm_entry.cu | 8 +- 33 files changed, 5976 insertions(+), 761 deletions(-) create mode 100644 benchmarks/cutlass_benchmarks/dense_mm/bench_v1.py create mode 100644 benchmarks/cutlass_benchmarks/dense_mm/bench_v2.py create mode 100644 benchmarks/cutlass_benchmarks/dense_mm/mm_benchmarks.py create mode 100644 benchmarks/cutlass_benchmarks/dense_mm/stable_kernels_fp8.json create mode 100644 benchmarks/cutlass_benchmarks/dense_mm/utils.py create mode 100644 benchmarks/cutlass_benchmarks/dense_mm/weight_shapes.py delete mode 100644 benchmarks/cutlass_benchmarks/w8a8_benchmarks.py delete mode 100644 benchmarks/cutlass_benchmarks/weight_shapes.py create mode 100644 csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp create mode 100644 csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp create mode 100644 csrc/quantization/cutlass_w8a8/generator/README.md create mode 100644 csrc/quantization/cutlass_w8a8/generator/autogen_manifest.py create mode 100644 csrc/quantization/cutlass_w8a8/generator/generator.py create mode 100644 csrc/quantization/cutlass_w8a8/generator/generator_types.py create mode 100644 csrc/quantization/cutlass_w8a8/generator/kernel_compiler.py create mode 100644 csrc/quantization/cutlass_w8a8/generator/kernel_generator.py create mode 100644 csrc/quantization/cutlass_w8a8/generator/scaled_mm_c3x.jinja create mode 100644 csrc/quantization/cutlass_w8a8/generator/scaled_mm_c3x_fnprototype.jinja create mode 100644 csrc/quantization/cutlass_w8a8/generator/scaled_mm_c3x_struct_prototype.jinja create mode 100644 csrc/quantization/cutlass_w8a8/generator/tools/heatmap.py create mode 100644 csrc/quantization/cutlass_w8a8/generator/tools/select_kernels.py create mode 100644 csrc/quantization/cutlass_w8a8/generator/tools/test_kernel.py create mode 100644 csrc/quantization/cutlass_w8a8/generator/tools/test_utils.py create mode 100644 csrc/quantization/cutlass_w8a8/generator/tools/utils.py create mode 100644 csrc/quantization/cutlass_w8a8/generator/utils.py create mode 100644 csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cuh create mode 100644 csrc/quantization/cutlass_w8a8/scaled_mm_c3x_configs.cuh diff --git a/CMakeLists.txt b/CMakeLists.txt index c78cdc77a7e42..191c8c9923b33 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -205,7 +205,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") # Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case. - set(CUTLASS_REVISION "v3.5.1" CACHE STRING "CUTLASS revision to use") + set(CUTLASS_REVISION "v3.6.0" CACHE STRING "CUTLASS revision to use") # Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR}) @@ -222,13 +222,13 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") FetchContent_Declare( cutlass GIT_REPOSITORY https://github.com/nvidia/cutlass.git - GIT_TAG v3.5.1 + GIT_TAG 8aa95dbb888be6d81c6fbf7169718c5244b53227 GIT_PROGRESS TRUE # Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history. # Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags. # So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE - GIT_SHALLOW TRUE + # GIT_SHALLOW TRUE ) endif() FetchContent_MakeAvailable(cutlass) diff --git a/benchmarks/cutlass_benchmarks/dense_mm/bench_v1.py b/benchmarks/cutlass_benchmarks/dense_mm/bench_v1.py new file mode 100644 index 0000000000000..4373036c4518b --- /dev/null +++ b/benchmarks/cutlass_benchmarks/dense_mm/bench_v1.py @@ -0,0 +1,184 @@ +## Cutlass benchmark V1 + +from typing import Callable, Iterable + +import torch +import torch.utils.benchmark as TBenchmark +from torch.utils.benchmark import Measurement as TMeasurement +from utils import make_rand_tensors, to_fp16, to_bf16 + +import vllm._custom_ops as ops + + +# bench +def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args, + **kwargs) -> TMeasurement: + min_run_time = 1 + + globals = { + "args": args, + "kwargs": kwargs, + "fn": fn, + } + return TBenchmark.Timer( + stmt="fn(*args, **kwargs)", + globals=globals, + label=label, + sub_label=sub_label, + description=description, + ).blocked_autorange(min_run_time=min_run_time) + + +def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str, + sub_label: str) -> Iterable[TMeasurement]: + assert dtype == torch.int8 + + # Create tensors + a, b = make_rand_tensors(torch.int8, m, n, k) + scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) + scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) + bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) + + out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16) + out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16) + + if not torch.allclose(out, out_ref): + print("Incorrect result") + exit() + + timers = [] + + # pytorch impl - bfloat16 + timers.append( + bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", + torch.mm, a.to(dtype=torch.bfloat16), + b.to(dtype=torch.bfloat16))) + + # pytorch impl - float16 + timers.append( + bench_fn(label, sub_label, + "pytorch_fp16_fp16_fp16_matmul-no-scales", torch.mm, + a.to(dtype=torch.float16), b.to(dtype=torch.float16))) + + # cutlass impl: bf16 output + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm", + ops.cutlass_scaled_mm, a, b, scale_a, scale_b, + torch.bfloat16)) + + # cutlass with bias: bf16 output + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_bias", + ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16, + bias)) + + # cutlass impl: fp16 output + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_fp16_scaled_mm", + ops.cutlass_scaled_mm, a, b, scale_a, scale_b, + torch.float16)) + + # cutlass with bias: fp16 output + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_fp16_scaled_mm_bias", + ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.float16, + bias.to(dtype=torch.float16))) + + return timers + + +def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str, + sub_label: str) -> Iterable[TMeasurement]: + assert dtype == torch.float8_e4m3fn + + # Create tensors + a, b = make_rand_tensors(torch.int8, m, n, k) + scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) + scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) + bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) + + out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16) + out_ref = torch._scaled_mm(a, b, scale_a=scale_a, scale_b=scale_b, out_dtype=torch.bfloat16) + + if not torch.allclose(out, out_ref): + print("Incorrect result") + exit() + + timers = [] + + # pytorch impl w. bf16 + timers.append( + bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", + torch.mm, a.to(dtype=torch.bfloat16, device="cuda"), + b.to(dtype=torch.bfloat16, device="cuda"))) + + # pytorch impl: bf16 output, without fp8 fast accum + timers.append( + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_bf16_scaled_mm", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.bfloat16)) + + # pytorch impl: bf16 output, with fp8 fast accum + timers.append( + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.bfloat16, + use_fast_accum=True)) + + # pytorch impl: fp16 output, without fp8 fast accum + timers.append( + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_fp16_scaled_mm", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.float16)) + + # pytorch impl: fp16 output, with fp8 fast accum + timers.append( + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.float16, + use_fast_accum=True)) + + # cutlass impl: bf16 output + timers.append( + bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm", + ops.cutlass_scaled_mm, a, b, scale_a, scale_b, + torch.bfloat16)) + # cutlass impl: fp16 output + timers.append( + bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_mm", + ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.float16)) + + return timers + + +def bench_v1(dtype: torch.dtype, m: int, k: int, n: int, label: str, + sub_label: str) -> Iterable[TMeasurement]: + if dtype == torch.int8: + return bench_int8(dtype, m, k, n, label, sub_label) + if dtype == torch.float8_e4m3fn: + return bench_fp8(dtype, m, k, n, label, sub_label) + raise ValueError("unsupported type") diff --git a/benchmarks/cutlass_benchmarks/dense_mm/bench_v2.py b/benchmarks/cutlass_benchmarks/dense_mm/bench_v2.py new file mode 100644 index 0000000000000..ab29bf26d40f4 --- /dev/null +++ b/benchmarks/cutlass_benchmarks/dense_mm/bench_v2.py @@ -0,0 +1,608 @@ +import dataclasses +import random +from typing import Any, Callable, Iterable, Optional, Tuple, Dict, List + +import multiprocessing as mp +from multiprocessing import Process, Queue +from queue import Empty + +import torch +import torch.utils.benchmark as TBenchmark +from torch.utils.benchmark import Measurement as TMeasurement +from utils import make_n_rand_tensors, to_fp8 + +import vllm._custom_ops as ops +import traceback + +import json +import os +import hashlib +from datetime import datetime +from pathlib import Path + + +@dataclasses.dataclass +class CudaGraphBenchParams: + num_ops_in_cuda_graph: int + + +@dataclasses.dataclass +class ArgPool: + ''' + When some argument of the benchmarking function is annotated with this type, + the benchmarking class (BenchMM) will collapse the argument to a pick a + single value from the given list of values, during function invocation. + + For every invocation during a benchmarking run, it will choose a + different value from the list. + ''' + values: Iterable[Any] + + +class BenchMM: + + class ArgsIterator: + + def __init__(self, args_list, kwargs_list): + assert len(args_list) == len(kwargs_list) + self.args_list = args_list + self.kwargs_list = kwargs_list + self.n = len(self.args_list) + self.idx = 0 + + def __next__(self): + while True: + yield (self.args_list[self.idx], self.kwargs_list[self.idx]) + self.idx += 1 + self.idx = self.idx % self.n + + def reset(self): + self.idx = 0 + + @property + def n_args(self): + return self.n + + def __init__(self, cuda_graph_params: Optional[CudaGraphBenchParams], + label: str, sub_label: str, description: str, fn: Callable, + *args, **kwargs): + + self.cuda_graph_params = cuda_graph_params + self.use_cuda_graph = self.cuda_graph_params is not None + self.label = label + self.sub_label = sub_label + self.description = description + self.fn = fn + + # Process args + self._args = args + self._kwargs = kwargs + self.args_list, self.kwargs_list = self.collapse_argpool( + *args, **kwargs) + self.args_iterator = self.ArgsIterator(self.args_list, + self.kwargs_list) + + # Cudagraph runner + self.g = None + if self.use_cuda_graph: + self.g = self.get_cuda_graph_runner() + + # benchmark run params + self.min_run_time = 1 + + def collapse_argpool(self, *args, **kwargs): + kwargs = kwargs if kwargs is not None else {} + assert kwargs is None or all([ + not isinstance(v, ArgPool) for k, v in kwargs.items() + ]), 'ArgPools in kwargs are not supported yet' + + arg_pool_indices = [ + i for i, x in enumerate(args) if isinstance(x, ArgPool) + ] + if len(arg_pool_indices) == 0: + return [args], [kwargs] + + # make sure all the Arg pools have the same number of choices + arg_pool_size = len(args[arg_pool_indices[0]].values) + assert all( + [len(args[i].values) == arg_pool_size for i in arg_pool_indices]) + + # create copies of the args + args_list = [] + kwargs_list = [] + for _ in range(arg_pool_size): + args_list.append(args) + kwargs_list.append(kwargs.copy()) + + # collapse the arg pools by simply choosing the ith value + for i in range(arg_pool_size): + assert isinstance(args_list[i], tuple) + # get as list + args_i = list(args_list[i]) + # collapse - make replacements + for arg_pool_idx in arg_pool_indices: + val_from_pool = args_i[arg_pool_idx].values[i] + args_i[arg_pool_idx] = val_from_pool + # store back as tuple + args_list[i] = tuple(args_i) + + return args_list, kwargs_list + + def get_cuda_graph_runner(self): + assert self.use_cuda_graph + assert self.args_iterator is not None + + num_graph_ops = self.cuda_graph_params.num_ops_in_cuda_graph + + # warmup + args_it = self.args_iterator.__next__() + for _ in range(5): + args, kwargs = next(args_it) + self.fn(*args, **kwargs) + + self.args_iterator.reset() + args_it = self.args_iterator.__next__() + + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + for _ in range(num_graph_ops): + args, kwargs = next(args_it) + self.fn(*args, **kwargs) + return g + + def run_cudagrah(self) -> TMeasurement: + assert self.use_cuda_graph + globals = {'g': self.g} + + return TBenchmark.Timer( + stmt="g.replay()", + globals=globals, + label=self.label, + sub_label=self.sub_label, + description=self.description, + ).blocked_autorange(min_run_time=self.min_run_time) + + def run_eager(self) -> TMeasurement: + setup = None + stmt = None + globals = None + + has_arg_pool = self.args_iterator.n_args > 1 + if has_arg_pool: + setup = ''' + args_iterator.reset() + args_it = args_iterator.__next__() + ''' + stmt = ''' + args, kwargs = next(args_it) + fn(*args, **kwargs) + ''' + globals = {'fn': self.fn, 'args_iterator': self.args_iterator} + else: + # no arg pool. Just use the args and kwargs directly + self.args_iterator.reset() + args_it = self.args_iterator.__next__() + args, kwargs = next(args_it) + + setup = "" + stmt = ''' + fn(*args, **kwargs) + ''' + globals = {'fn': self.fn, 'args': args, 'kwargs': kwargs} + + return TBenchmark.Timer( + stmt=stmt, + setup=setup, + globals=globals, + label=self.label, + sub_label=self.sub_label, + description=self.description, + ).blocked_autorange(min_run_time=self.min_run_time) + + def run(self) -> TMeasurement: + timer = None + if self.use_cuda_graph: # noqa SIM108 + timer = self.run_cudagrah() + else: + timer = self.run_eager() + #assert timer.meets_confidence() + #assert not timer.has_warnings, f"Warnings {timer._warnings}" + if not timer.meets_confidence() or timer.has_warnings: + print("Doesn't meet confidence - re-running bench ...") + return self.run() + return timer + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + if exc_type: + print(f"exc type {exc_type}") + print(f"exc value {exc_value}") + print(f"exc traceback {traceback}") + + +def get_autogen_functions(dtype_str): + import importlib + from importlib.util import find_spec + import re + import sys + + # # Find all loaded modules matching the pattern + # pattern = re.compile(r'vllm\._nm_cutlass_\d+_C') + # matching_modules = [name for name in sys.modules.keys() if pattern.match(name)] + + # # Import or reimport all matching modules + # for module_name in matching_modules: + # try: + # importlib.import_module(module_name) + # except ImportError as e: + # print(f"Warning: Could not import {module_name}: {e}") + + # import vllm nm_cutlass modules so torch._C can find it + m_idx = 0 + m_name = f'vllm._nm_cutlass_{dtype_str}_{m_idx}_C' + while find_spec(m_name): + # print(f"attempting import {m_name}") + importlib.import_module(m_name) + m_idx += 1 + m_name = f'vllm._nm_cutlass_{dtype_str}_{m_idx}_C' + + dispatch_names = torch._C._dispatch_get_all_op_names() + autogen_dispatch_names = [x for x in dispatch_names if 'autogen' in x] + assert all([x.startswith('_nm_cutlass') for x in autogen_dispatch_names]) + autogen_dispatch_modules_names = [(getattr(torch.ops, + x.split('::')[0]), + x.split('::')[1]) + for x in autogen_dispatch_names] + name_fn = [(name, getattr(m, name)) + for m, name in autogen_dispatch_modules_names] + # print(f"#autogen functions found {len(name_fn)}") + return name_fn + + +def run_single_benchmark_process(kernel_config: Dict, gpu_id: int, queue: Queue, dtype_str: str): + """ + Run a single kernel benchmark in an isolated process. + Puts (success, result, config) tuple in the queue. + """ + try: + # torch.cuda.set_device(gpu_id) + + # Initialize CUDA tensors + m, k, n = kernel_config['m'], kernel_config['k'], kernel_config['n'] + dtype = kernel_config['dtype'] + + # Create tensors + As, Bs = make_n_rand_tensors( + kernel_config.get('arg_pool_size', 1), + dtype, m, n, k + ) + bf16_As = [x.to(dtype=torch.bfloat16) for x in As] + bf16_Bs = [x.to(dtype=torch.bfloat16) for x in Bs] + scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) + scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) + # Because the transposed output will be computed + out = torch.empty((m, n), dtype=torch.bfloat16, device="cuda") + + # Setup benchmark params + cuda_graph_params = None + if cgops := kernel_config.get('cuda_graph_ops'): + cuda_graph_params = CudaGraphBenchParams(cgops) + + label = kernel_config['label'] + sub_label = kernel_config['sub_label'] + + # Initialize benchmark based on kernel type + bench = None + kernel_type = kernel_config['kernel_type'] + + if kernel_type == 'pytorch_mm': + bench = BenchMM(cuda_graph_params, label, sub_label, + "pytorch_bf16_bf16_bf16_matmul-no-scales", + torch.mm, + ArgPool(bf16_As), ArgPool(bf16_Bs)) + + elif kernel_type == 'pytorch_scaled_mm': + bench = BenchMM(cuda_graph_params, label, sub_label, + "pytorch_fp8_fp8_bf16_scaled_mm", + torch._scaled_mm, + ArgPool(As), ArgPool(Bs), + scale_a=scale_a, scale_b=scale_b, + out_dtype=torch.bfloat16) + + elif kernel_type == 'pytorch_scaled_mm_fast': + bench = BenchMM(cuda_graph_params, label, sub_label, + "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum", + torch._scaled_mm, + ArgPool(As), ArgPool(Bs), + scale_a=scale_a, scale_b=scale_b, + out_dtype=torch.bfloat16, + use_fast_accum=True) + + elif kernel_type == 'cutlass_scaled_mm': + bench = BenchMM(cuda_graph_params, label, sub_label, + "cutlass_fp8_fp8_bf16_scaled_mm_default", + ops.cutlass_scaled_mm, + ArgPool(As), ArgPool(Bs), scale_a, scale_b, + torch.bfloat16) + + elif kernel_type == 'autogen_kernel': + # Get the autogen kernel + kernel_num = kernel_config['kernel_num'] + kernel_name, autogen_fn = get_autogen_functions(dtype_str)[kernel_num] + + # Create appropriate benchmark based on kernel type + bench = BenchMM(cuda_graph_params, label, sub_label, + kernel_name, autogen_fn, out, + ArgPool(As), ArgPool(Bs), scale_a, scale_b) + + # Run the benchmark + result = bench.run() + queue.put((True, result, kernel_config)) + + except Exception as e: + print(f"Error in benchmark process: {str(e)}") + print(traceback.format_exc()) + queue.put((False, None, kernel_config)) + finally: + # Explicit cleanup + torch.cuda.empty_cache() + +def benchmark_gpu_worker(gpu_id: int, task_queue: Queue, result_queue: Queue, dtype_str: str): + """Worker process that spawns individual benchmark processes for each kernel.""" + try: + while True: + try: + kernel_config = task_queue.get_nowait() + if kernel_config is None: # Poison pill + break + + # Create a new process queue for this specific benchmark + process_queue = Queue() + + # Create and start a new process for this kernel benchmark + p = Process(target=run_single_benchmark_process, + args=(kernel_config, gpu_id, process_queue, dtype_str)) + p.start() + + # Wait for result with timeout (5 minutes for benchmarking) + try: + success, result, config = process_queue.get(timeout=300) + result_queue.put((success, result, config)) + except Empty: + print(f"Kernel {kernel_config.get('kernel_type')} benchmark timed out") + result_queue.put((False, None, kernel_config)) + + # Cleanup + p.join(timeout=1) # Give it 1 second to join + if p.is_alive(): + p.terminate() + p.join() + + except Empty: + break + except Exception as e: + print(f"Error in GPU {gpu_id} worker: {str(e)}") + print(traceback.format_exc()) + if 'kernel_config' in locals(): + result_queue.put((False, None, kernel_config)) + + finally: + print(f"GPU {gpu_id} worker finished") + +def run_kernels_on_gpus(configs: List[Dict], dtype_str: str) \ + -> List[Tuple[bool, Optional[TMeasurement], Dict]]: + MULTI_GPU_MULTI_PROCESS = False # Set to False for single GPU testing + if MULTI_GPU_MULTI_PROCESS: + gpus_list = [5] + task_queue = Queue() + result_queue = Queue() + + configs = configs[:10] + + # Fill task queue + for config in configs: + task_queue.put(config) + for _ in gpus_list: # Add poison pills + task_queue.put(None) + + # Start GPU workers + workers = [] + for gpu_id in gpus_list: + p = Process(target=benchmark_gpu_worker, args=(gpu_id, task_queue, result_queue, dtype_str)) + p.start() + workers.append(p) + + # Collect results + results = [] + completed = 0 + total_tasks = len(configs) + + while completed < total_tasks: + success, result, config = result_queue.get() + results.append((success, result, config)) + completed += 1 + + # Print progress + if config['kernel_type'] == 'autogen_kernel': + kernel_num = config['kernel_num'] + kernel_name = get_autogen_functions(dtype_str)[kernel_num][0] + status = "Success" if success else "Failed" + print(f"{status}: autogen {kernel_num} {kernel_name}") + else: + status = "Success" if success else "Failed" + print(f"{status}: {config['kernel_type']}") + + # Cleanup workers + for worker in workers: + worker.join(timeout=1) + if worker.is_alive(): + worker.terminate() + worker.join() + + return results + else: + """Run kernel benchmarks in a single process.""" + results = [] + # configs = configs[:10] # Keep the original slice + + for config in configs: + try: + # Initialize CUDA tensors + m, k, n = config['m'], config['k'], config['n'] + dtype = config['dtype'] + + As, Bs = make_n_rand_tensors( + config.get('arg_pool_size', 1), + dtype, m, n, k + ) + bf16_As = [x.to(dtype=torch.bfloat16) for x in As] + bf16_Bs = [x.to(dtype=torch.bfloat16) for x in Bs] + scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) + scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) + # Because the transposed output will be computed + out = torch.empty((m, n), dtype=torch.bfloat16, device="cuda") + + # Setup benchmark params + cuda_graph_params = None + if cgops := config.get('cuda_graph_ops'): + cuda_graph_params = CudaGraphBenchParams(cgops) + + label = config['label'] + sub_label = config['sub_label'] + + # Initialize benchmark based on kernel type + bench = None + kernel_type = config['kernel_type'] + + if kernel_type == 'pytorch_mm': + bench = BenchMM(cuda_graph_params, label, sub_label, + "pytorch_bf16_bf16_bf16_matmul-no-scales", + torch.mm, + ArgPool(bf16_As), ArgPool(bf16_Bs)) + + elif kernel_type == 'pytorch_scaled_mm': + bench = BenchMM(cuda_graph_params, label, sub_label, + "pytorch_fp8_fp8_bf16_scaled_mm", + torch._scaled_mm, + ArgPool(As), ArgPool(Bs), + scale_a=scale_a, scale_b=scale_b, + out_dtype=torch.bfloat16) + + elif kernel_type == 'pytorch_scaled_mm_fast': + bench = BenchMM(cuda_graph_params, label, sub_label, + "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum", + torch._scaled_mm, + ArgPool(As), ArgPool(Bs), + scale_a=scale_a, scale_b=scale_b, + out_dtype=torch.bfloat16, + use_fast_accum=True) + + elif kernel_type == 'cutlass_scaled_mm': + bench = BenchMM(cuda_graph_params, label, sub_label, + "cutlass_fp8_fp8_bf16_scaled_mm_default", + ops.cutlass_scaled_mm, + ArgPool(As), ArgPool(Bs), scale_a, scale_b, + torch.bfloat16) + + elif kernel_type == 'autogen_kernel': + # Get the autogen kernel + kernel_num = config['kernel_num'] + kernel_name, autogen_fn = get_autogen_functions(dtype_str)[kernel_num] + + # Create appropriate benchmark based on kernel type + bench = BenchMM(cuda_graph_params, label, sub_label, + kernel_name, autogen_fn, out, + ArgPool(As), ArgPool(Bs), + scale_b, scale_a) + + # Run the benchmark + result = bench.run() + + # Print progress + if kernel_type == 'autogen_kernel': + kernel_num = config['kernel_num'] + kernel_name = get_autogen_functions(dtype_str)[kernel_num][0] + print(f"Success: autogen {kernel_num} {kernel_name}") + else: + print(f"Success: {kernel_type}") + + results.append((True, result, config)) + + # Cleanup + torch.cuda.empty_cache() + + except Exception as e: + print(f"Error in benchmark: {str(e)}") + print(traceback.format_exc()) + results.append((False, None, config)) + torch.cuda.empty_cache() + + return results + + +def bench_dtype(dtype: torch.dtype, with_cuda_graph: Optional[int], + with_arg_pool: Optional[int], m: int, k: int, n: int, label: str, + sub_label: str) -> Iterable[TMeasurement]: + + # Check if context is not set + try: + mp.set_start_method('spawn', force=True) + except RuntimeError: + pass + + timers = [] + gpus_list = [0] # Using the same GPU list as original code + + # Base configuration for all kernels + base_config = { + 'm': m, + 'k': k, + 'n': n, + 'dtype': dtype, + 'cuda_graph_ops': with_cuda_graph, + 'arg_pool_size': with_arg_pool if with_arg_pool else 1, + 'label': label, + 'sub_label': sub_label + } + + # Prepare configs for all kernels + standard_kernels = [ + {'kernel_type': 'pytorch_mm'}, + # {'kernel_type': 'pytorch_scaled_mm'}, + {'kernel_type': 'pytorch_scaled_mm_fast'}, + {'kernel_type': 'cutlass_scaled_mm'} + ] + + # Create configs for standard kernels + standard_configs = [{**base_config, **kernel} for kernel in standard_kernels] + + # Get stable kernels (from cache or by testing) + dtype_str = 'fp8' if dtype == torch.float8_e4m3fn else \ + 'bf16' if dtype == torch.bfloat16 else \ + 'fp16' if dtype == torch.float16 else 'int8' + + # Combine all configs + all_configs = standard_configs + + # Run all kernels distributed across GPUs + print(f"Running {len(all_configs)} benchmarks across {len(gpus_list)} GPUs...") + results = run_kernels_on_gpus(all_configs, dtype_str) + + # Process results + for success, result, _ in results: + if success and result is not None: + timers.append(result) + + return timers + + +def bench_v2(dtype: torch.dtype, with_cuda_graph: Optional[int], + with_arg_pool: Optional[int], m: int, k: int, n: int, label: str, + sub_label: str) -> Iterable[TMeasurement]: + if dtype == torch.float8_e4m3fn or dtype == torch.int8: + return bench_dtype(dtype, with_cuda_graph, with_arg_pool, m, k, n, label, sub_label) + raise ValueError("unsupported type") diff --git a/benchmarks/cutlass_benchmarks/dense_mm/mm_benchmarks.py b/benchmarks/cutlass_benchmarks/dense_mm/mm_benchmarks.py new file mode 100644 index 0000000000000..16cb68aa640de --- /dev/null +++ b/benchmarks/cutlass_benchmarks/dense_mm/mm_benchmarks.py @@ -0,0 +1,224 @@ +import argparse +import copy +import itertools +import pickle as pkl +import time +from typing import Iterable, List, Tuple + +import torch +import torch.utils.benchmark as TBenchmark +from bench_v1 import bench_v1 +from bench_v2 import bench_v2 +from torch.utils.benchmark import Measurement as TMeasurement +from weight_shapes import WEIGHT_SHAPES + +from vllm.utils import FlexibleArgumentParser + +DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) +DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] +DEFAULT_TP_SIZES = [1] + + +# runner +def print_timers(timers: Iterable[TMeasurement]): + compare = TBenchmark.Compare(timers) + compare.print() + + +def run(args, MKNs: Iterable[Tuple[int, int, int]], tp_size=None) -> Iterable[TMeasurement]: + results = [] + dtype = args.dtype + + use_bench_v2 = args.with_cuda_graph or args.with_arg_pool + for m, k, n in MKNs: + if use_bench_v2: + label = f"scaled-{dtype}-gemm" + label = f"{label}-cugraph_{args.with_cuda_graph}" \ + if args.with_cuda_graph else label + label = f"{label}-argpool_{args.with_arg_pool}" \ + if args.with_arg_pool else label + timers = bench_v2(args.dtype, args.with_cuda_graph, + args.with_arg_pool, m, k, n, label, + f"MKN=({m}x{k}x{n})") + else: + timers = bench_v1(args.dtype, m, k, n, f"scaled-{dtype}-gemm", + f"MKN=({m}x{k}x{n})") + + print_timers(timers) + results.extend(timers) + + if tp_size is not None: + with open(f"chunk_bench-{m}_{k}_{n}-{args.dtype}_{tp_size}.pkl", "wb") as f: + pkl.dump(timers, f) + else: + with open(f"chunk_bench-{m}_{k}_{n}-{args.dtype}.pkl", "wb") as f: + pkl.dump(timers, f) + + return results + + +# output makers +def make_output(data: Iterable[TMeasurement], + MKNs: Iterable[Tuple[int, int, int]], + base_description: str, + timestamp=None): + print(f"== All Results {base_description} ====") + print_timers(data) + + # pickle all the results + timestamp = int(time.time()) if timestamp is None else timestamp + with open(f"{base_description}-{timestamp}.pkl", "wb") as f: + pkl.dump(data, f) + + +# argparse runners + + +def run_square_bench(args): + dim_sizes = list( + range(args.dim_start, args.dim_end + 1, args.dim_increment)) + MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) + data = run(args, MKNs) + + make_output(data, MKNs, f"square_bench-{args.dtype}") + + +def run_range_bench(args): + # dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment)) + dim_sizes = [1, 16, 32, 64, 128, 256, 512, 1024] + + n = len(dim_sizes) + Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes + Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes + Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes + MKNs = list(zip(Ms, Ks, Ns)) + data = run(args, MKNs) + + make_output(data, MKNs, f"range_bench-{args.dtype}") + + +def run_model_bench(args): + print("Benchmarking models:") + for i, model in enumerate(args.models): + print(f"[{i}] {model}") + + def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]: + KNs = [] + for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]): + if tp_split_dim is not None: + KN[tp_split_dim] = KN[tp_split_dim] // tp_size + KNs.append(KN) + return KNs + + model_bench_data = [] + models_tps = list(itertools.product(args.models, args.tp_sizes)) + for model, tp_size in models_tps: + Ms = args.batch_sizes + KNs = model_shapes(model, tp_size) + MKNs = [] + for m in Ms: + for k, n in KNs: + MKNs.append((m, k, n)) + + data = run(args, MKNs, tp_size) + model_bench_data.append(data) + + # Print all results + for data, model_tp in zip(model_bench_data, models_tps): + model, tp_size = model_tp + print(f"== Results {args.dtype} {model}-TP{tp_size} ====") + print_timers(data) + + timestamp = int(time.time()) + + all_data = [] + for d in model_bench_data: + all_data.extend(d) + # pickle all data + with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f: + pkl.dump(all_data, f) + + +if __name__ == '__main__': + + def to_torch_dtype(dt): + if dt == "int8": + return torch.int8 + if dt == "fp8": + return torch.float8_e4m3fn + if dt == "fp16": + return torch.float16 + if dt == "bf16": + return torch.bfloat16 + raise ValueError("unsupported dtype") + + parser = FlexibleArgumentParser( + description=""" +Benchmark Cutlass GEMM. + + To run square GEMMs: + python3 ./benchmarks/cutlass_benchmarks/dense_mm/mm_benchmarks.py --dtype fp8 square_bench --dim-start 128 --dim-end 512 --dim-increment 64 + + To run constant N and K and sweep M: + python3 ./benchmarks/cutlass_benchmarks/dense_mm/mm_benchmarks.py --dtype fp8 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384 + + To run dimensions from a model: + python3 ./benchmarks/cutlass_benchmarks/dense_mm/mm_benchmarks.py --dtype fp8 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1 + + Output: + - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs. + """, # noqa: E501 + formatter_class=argparse.RawTextHelpFormatter) + + parser.add_argument("--dtype", + type=to_torch_dtype, + required=True, + help="Available options are ['int8', 'fp8', 'fp16', 'bf16']") + parser.add_argument( + '--with-cuda-graph', + type=int, + default=None, + help="Number of ops/matmuls in a cudagraph execution. When set" + "cuda-graphs is enabled") + parser.add_argument( + '--with-arg-pool', + type=int, + default=None, + help="Number of A and B tensors to use as arg-pool. When not set," + "it defaults to 1") + + subparsers = parser.add_subparsers(dest="cmd") + + square_parser = subparsers.add_parser("square_bench") + square_parser.add_argument("--dim-start", type=int, required=True) + square_parser.add_argument("--dim-end", type=int, required=True) + square_parser.add_argument("--dim-increment", type=int, required=True) + square_parser.set_defaults(func=run_square_bench) + + range_parser = subparsers.add_parser("range_bench") + range_parser.add_argument("--dim-start", type=int, required=True) + range_parser.add_argument("--dim-end", type=int, required=True) + range_parser.add_argument("--dim-increment", type=int, required=True) + range_parser.add_argument("--m-constant", type=int, default=None) + range_parser.add_argument("--n-constant", type=int, default=None) + range_parser.add_argument("--k-constant", type=int, default=None) + range_parser.set_defaults(func=run_range_bench) + + model_parser = subparsers.add_parser("model_bench") + model_parser.add_argument("--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES.keys()) + model_parser.add_argument("--tp-sizes", + nargs="+", + type=int, + default=DEFAULT_TP_SIZES) + model_parser.add_argument("--batch-sizes", + nargs="+", + type=int, + default=DEFAULT_BATCH_SIZES) + model_parser.set_defaults(func=run_model_bench) + + args = parser.parse_args() + args.func(args) diff --git a/benchmarks/cutlass_benchmarks/dense_mm/stable_kernels_fp8.json b/benchmarks/cutlass_benchmarks/dense_mm/stable_kernels_fp8.json new file mode 100644 index 0000000000000..eb69d25a96745 --- /dev/null +++ b/benchmarks/cutlass_benchmarks/dense_mm/stable_kernels_fp8.json @@ -0,0 +1 @@ +{"date": "2024-12-05T05:41:42.719869", "stable_kernels": [0, 1, 2, 4, 6, 7, 8, 9, 10, 13, 14, 16, 17, 18, 19, 20, 21, 23, 24, 25, 26, 27, 28, 31, 33, 34, 37, 38, 40, 41, 42, 43, 44, 45, 46, 47, 48, 50, 51, 53, 55, 57, 58, 59, 60, 62, 64, 65, 66, 68, 71, 72, 76, 77, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 101, 102, 103, 104, 105, 109, 110, 112, 114, 115, 117, 118, 119, 120, 121, 123, 124, 125, 127, 129, 131, 132, 133, 134, 135, 136, 137, 138, 140, 141, 142, 143, 144, 147, 149, 151, 152, 154, 155, 156, 157, 159, 160, 161, 162, 163, 164, 165, 166, 167, 169, 170, 171, 173, 174, 176, 177, 182, 183, 184, 186, 187, 189, 191, 192, 193, 194, 197, 198, 199, 200, 201, 202, 203, 204, 205, 207, 209, 210, 211, 212, 213, 214, 215, 217, 218, 219, 221, 222, 223, 225, 226, 227, 228, 229, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 246, 247, 248, 249, 251, 252, 253, 254, 255, 256, 258, 259, 260, 262, 263, 264, 266, 267, 269, 270, 272, 273, 275, 278, 279, 280, 281, 284, 285, 287, 288, 290, 292, 293, 294, 295, 296, 297, 298, 299, 302, 304, 305, 310, 311, 312, 313, 314, 316, 317, 318, 319, 320, 321, 322, 324, 326, 328, 330, 331, 333, 334, 335, 336, 337, 339, 341, 342, 343, 344, 345, 346, 347, 349, 350, 352, 353, 354, 356, 357, 358, 359, 360, 362, 363, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 377, 378, 379, 380, 381, 382, 383, 385, 387, 388, 390, 391, 393, 394, 396, 397, 399, 400, 401, 402, 403, 404, 405, 406, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 434, 436, 437, 438, 439, 440, 441, 442, 443, 444, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 460, 462, 463, 464, 466, 467, 470, 472, 473, 474, 475, 476, 477, 478, 480, 484, 486, 487, 488, 489, 491, 492, 493, 494, 495, 497, 498, 499, 500, 501, 502, 504, 506, 507, 509, 510, 514, 515, 516, 520, 521, 522, 523, 524, 525, 526, 528, 529, 530, 532, 533, 534, 535, 536, 537, 538, 540, 541, 542, 544, 545, 547, 548, 549, 550, 551, 552, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 565, 566, 567, 568, 569, 570, 572, 573, 574, 575, 576, 578, 579, 580, 581, 584, 585, 587, 588, 594, 595, 596, 597, 598, 601, 602, 603, 605, 606, 607, 608, 609, 610, 611, 614, 616, 617, 618, 619, 621, 623, 624, 625, 626, 627, 628, 630, 631, 632, 633, 634, 635, 636, 637, 638, 640, 641, 642, 644, 645, 647, 651, 652, 653, 655, 658, 659, 660, 661, 662, 663, 664, 666, 670, 671, 672, 674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685, 686, 689, 690, 692, 693, 694, 695, 696, 697, 698, 699, 701, 703, 705, 706, 707, 708, 709, 710, 711, 712, 714, 716, 717, 718, 719, 722, 724, 725, 726, 727, 730, 732, 733, 734, 735, 736, 738, 740, 742, 744, 746, 748, 749, 750, 751, 752, 754, 755, 756, 757, 758, 759, 760, 763, 764, 765, 766, 767, 768, 770, 772, 773, 774, 775, 776, 777, 778, 779, 780, 781, 782, 783, 784, 786, 787, 788, 790, 792, 793, 794, 796, 797, 798, 799, 800, 801, 803, 804, 805, 807, 808, 809, 810, 811, 812, 813, 814, 815, 816, 817, 818, 819, 820, 821, 822, 823, 824, 826, 827, 828, 829, 830, 831, 832, 833, 834, 835, 836, 837, 838, 839, 840, 841, 842, 843, 844, 845, 846, 847, 848, 850, 851, 852, 854, 855, 856, 857, 858, 859, 860, 861, 862, 863, 865, 866, 867, 868, 870, 871, 872, 874, 875, 879, 880, 881, 882, 883, 884, 885, 886, 887, 888, 889, 890, 891, 892, 893, 894, 895, 898, 899, 900, 901, 902, 903, 904, 907, 909, 910, 912, 915, 916, 917, 918, 920, 921, 922, 923, 925, 926, 928, 929, 933, 934, 936, 937, 941, 942, 943, 945, 946, 947, 948, 949, 950, 951, 952, 953, 954, 955, 956, 958, 960, 961, 962, 963, 964, 966, 967, 968, 969, 970, 971, 972, 973, 974, 975, 976, 979, 980, 981, 983, 985, 986, 987, 988, 989, 990, 991, 993, 994, 995, 996, 997, 998, 1001, 1002, 1003, 1004, 1006, 1007, 1009, 1010, 1012, 1014, 1016, 1017, 1018, 1020, 1021, 1022, 1024, 1027, 1028, 1030, 1031, 1034, 1035, 1037, 1039, 1041, 1042, 1043, 1044, 1045, 1047, 1048, 1049, 1050, 1052, 1053, 1054, 1055, 1057, 1058, 1059, 1061, 1063, 1064, 1065, 1067, 1068, 1070, 1071, 1072, 1073, 1074, 1075, 1076, 1078, 1079, 1080, 1081, 1082, 1084, 1085, 1087, 1088, 1089, 1090, 1091, 1092, 1094, 1096, 1098, 1099, 1102, 1104, 1105, 1106, 1107, 1108, 1110, 1111, 1112, 1116, 1117, 1118, 1119, 1120, 1122, 1123, 1124, 1127, 1129, 1131, 1132, 1133, 1134, 1135, 1136, 1137, 1138, 1140, 1141, 1144, 1145, 1146, 1147, 1148, 1150, 1151, 1154, 1157, 1159, 1161, 1163, 1164, 1165, 1166, 1167, 1169, 1171, 1172, 1173, 1174, 1175, 1177, 1178, 1180, 1181, 1183, 1184, 1185, 1186, 1187, 1188, 1189, 1190, 1191, 1192, 1194, 1195, 1197, 1198, 1200, 1201, 1202, 1203, 1204, 1205, 1206, 1207, 1208, 1209, 1210, 1212, 1213, 1214, 1216, 1217, 1218, 1219, 1221, 1222, 1223, 1224, 1227, 1228, 1229, 1230, 1231, 1232, 1234, 1235, 1236, 1237, 1238, 1240, 1242, 1243, 1244, 1245, 1246, 1247, 1248, 1249, 1250, 1251, 1252, 1253, 1254, 1255, 1256, 1257, 1258, 1259, 1260, 1261, 1262, 1263, 1264, 1265, 1267, 1268, 1269, 1270, 1271, 1272, 1273, 1274, 1275, 1277, 1278, 1279, 1281, 1283, 1284, 1285, 1286, 1288, 1290, 1291, 1293, 1294, 1297, 1298, 1299, 1300, 1301, 1302, 1303, 1304, 1306, 1307, 1308, 1309, 1310, 1311, 1316, 1317, 1318, 1319, 1320, 1321, 1323, 1324, 1325, 1326, 1329, 1330, 1335, 1336, 1337, 1338, 1339, 1340, 1341, 1342, 1343, 1345, 1346, 1349, 1351, 1352, 1353, 1354, 1355, 1356, 1357, 1358, 1359, 1361, 1362, 1363, 1364, 1365, 1367, 1369, 1370, 1372, 1373, 1374, 1375, 1376, 1379, 1380, 1382, 1383, 1384, 1385, 1386, 1388, 1389, 1392, 1394, 1395, 1396, 1397, 1399, 1401, 1405, 1407, 1408, 1410, 1411, 1412, 1413, 1414, 1415, 1416, 1418, 1419, 1420, 1422, 1423, 1424, 1426, 1427, 1429, 1430, 1431, 1432, 1433, 1434, 1436, 1437, 1438, 1439, 1441, 1442, 1443, 1445, 1447, 1448, 1452, 1453, 1455, 1457, 1459, 1460, 1461, 1462, 1463, 1464, 1465, 1466, 1467, 1468, 1471, 1472, 1473, 1474, 1475, 1476, 1477, 1478, 1479, 1482, 1483, 1485, 1487, 1488, 1489, 1490, 1491, 1492, 1493, 1495, 1496, 1497, 1498, 1499, 1500, 1501, 1503, 1504, 1505, 1508, 1509, 1510, 1512, 1513, 1514, 1516, 1517, 1518, 1519, 1520, 1523, 1525, 1527, 1528, 1529, 1531, 1532, 1533, 1534, 1535, 1537, 1538, 1540, 1541, 1542, 1543, 1545, 1546, 1547, 1548, 1549, 1550, 1551, 1552, 1556, 1557, 1558, 1559, 1560, 1562, 1566, 1567, 1568, 1569, 1570, 1571, 1572, 1574, 1576, 1577, 1579, 1580, 1581, 1583, 1584, 1585, 1586, 1587, 1588, 1589, 1590, 1592, 1594, 1595, 1597, 1598, 1599, 1600, 1601, 1602, 1603, 1605, 1606, 1607, 1608, 1609, 1610, 1611, 1612, 1613, 1614, 1615, 1618, 1620, 1621, 1622, 1623, 1624, 1625, 1626, 1627, 1629, 1630, 1631, 1632, 1633, 1634, 1635, 1636, 1637, 1641, 1642, 1643, 1644, 1645, 1646, 1648, 1650, 1651, 1652, 1653, 1654, 1655, 1656, 1657, 1658, 1660, 1661, 1662, 1664, 1665, 1667, 1668, 1669, 1670, 1671, 1672, 1673]} \ No newline at end of file diff --git a/benchmarks/cutlass_benchmarks/dense_mm/utils.py b/benchmarks/cutlass_benchmarks/dense_mm/utils.py new file mode 100644 index 0000000000000..d496c2c5b5d18 --- /dev/null +++ b/benchmarks/cutlass_benchmarks/dense_mm/utils.py @@ -0,0 +1,55 @@ +# Cutlass bench utils +from typing import Iterable, Tuple + +import torch + +import vllm._custom_ops as ops + + +def to_fp8(tensor: torch.Tensor) -> torch.Tensor: + finfo = torch.finfo(torch.float8_e4m3fn) + return torch.round(tensor.clamp( + min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) + + +def to_int8(tensor: torch.Tensor) -> torch.Tensor: + return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) + + +def to_bf16(tensor: torch.Tensor) -> torch.Tensor: + return tensor.to(dtype=torch.bfloat16) + + +def to_fp16(tensor: torch.Tensor) -> torch.Tensor: + return tensor.to(dtype=torch.float16) + + +def make_rand_tensors(dtype: torch.dtype, m: int, n: int, + k: int) -> Tuple[torch.Tensor, torch.Tensor]: + a = torch.randn((m, k), device='cuda') * 5 + b = torch.randn((n, k), device='cuda').t() * 5 + + if dtype == torch.int8: + a, b = to_int8(a), to_int8(b) + elif dtype == torch.float8_e4m3fn: + a, b = to_fp8(a), to_fp8(b) + elif dtype == torch.float16: + a, b = to_fp16(a), to_fp16(b) + elif dtype == torch.bfloat16: + a, b = to_bf16(a), to_bf16(b) + else: + raise ValueError("unsupported dtype") + + return a, b + + +def make_n_rand_tensors(num_tensors: int, dtype: torch.dtype, + m: int, n: int, k: int) -> \ + Tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]: + ABs = [] + for _ in range(num_tensors): + a, b = make_rand_tensors(dtype, m, n, k) + if a is not None: + ABs.append(make_rand_tensors(dtype, m, n, k)) + As, Bs = zip(*ABs) + return list(As), list(Bs) diff --git a/benchmarks/cutlass_benchmarks/dense_mm/weight_shapes.py b/benchmarks/cutlass_benchmarks/dense_mm/weight_shapes.py new file mode 100644 index 0000000000000..2999244bf9b95 --- /dev/null +++ b/benchmarks/cutlass_benchmarks/dense_mm/weight_shapes.py @@ -0,0 +1,75 @@ +# Weight Shapes are in the format +# ([K, N], TP_SPLIT_DIM) +# Example: +# A shape of ([14336, 4096], 0) indicates the following GEMM shape, +# - TP1 : K = 14336, N = 4096 +# - TP2 : K = 7168, N = 4096 +# A shape of ([4096, 6144], 1) indicates the following GEMM shape, +# - TP1 : K = 4096, N = 6144 +# - TP4 : K = 4096, N = 1536 + +# TP1 shapes +WEIGHT_SHAPES = { + "mistralai/Mistral-7B-v0.1": [ + ([4096, 6144], 1), + ([4096, 4096], 0), + ([4096, 28672], 1), + ([14336, 4096], 0), + ], + "meta-llama/Llama-2-7b-hf": [ + ([4096, 12288], 1), + ([4096, 4096], 0), + ([4096, 22016], 1), + ([11008, 4096], 0), + ], + "meta-llama/Llama-3-8b": [ + ([4096, 6144], 1), + ([4096, 4096], 0), + ([4096, 28672], 1), + ([14336, 4096], 0), + ], + "meta-llama/Llama-2-13b-hf": [ + ([5120, 15360], 1), + ([5120, 5120], 0), + ([5120, 27648], 1), + ([13824, 5120], 0), + ], + "meta-llama/Llama-2-70b-hf": [ + ([8192, 10240], 1), + ([8192, 8192], 0), + ([8192, 57344], 1), + ([28672, 8192], 0), + ], + "meta-llama/Llama-2-70b-tp4-hf": [([8192, 2560], None), ([2048, + 8192], None), + ([8192, 14336], None), + ([7168, 8192], None)], + # The shape space is very big when benchmarking a large set of kernels. + # For example: Let, + # - #kernels to benchmark be 1700 + # - #models to benchmark be 4 (each model has 4 shapes) + # - #batch sizes be 6 (16, 32, 64, 128, 256, 512) + # For 1 kernel, 1 shape and 1 batch-size, H100 takes 1 second (approx.) + # to run, then the benchmark suite would take, + # 1700 * (4 * 4) * 6 = 163200 seconds => 46 hrs. + # Below, we exploit some observation on the benchmark shapes to create a + # representative set. + # + # From previous benchmarking runs, we observe that perf if stratified as, + # N - small, medium, large and K - small and large. We also observe that + # in the model shapes, when K is small, we have small, medium and large Ns. + # when K is large, we only have small Ns. + # + # models : ['meta-llama/Llama-2-7b-hf', 'meta-llama/Llama-3-8b', + # 'meta-llama/Llama-2-13b-hf', 'meta-llama/Llama-2-70b-tp4-hf'] + # Ks : [2048, 4096, 5120, 7168, 8192, 11008, 13824, 14336] + # Ns : [2560, 4096, 5120, 6144, 8192, 12288, 14336, 15360, + # 22016, 27648, 28672] + "llama-representative-set": [ + # ([4096, 4096], None), # small K, small N + ([4096, 8192], None), # small K, medium N + ([4096, 22016], None), # small K, large N + ([14336, 4096], None), # large K, small N + ([8192, 14336], None), # medium K, large N (from llama-2-70b-tp4-hf + ], +} diff --git a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py deleted file mode 100644 index 63cf5d50cac75..0000000000000 --- a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py +++ /dev/null @@ -1,389 +0,0 @@ -import argparse -import copy -import itertools -import pickle as pkl -import time -from typing import Callable, Iterable, List, Tuple - -import torch -import torch.utils.benchmark as TBenchmark -from torch.utils.benchmark import Measurement as TMeasurement -from weight_shapes import WEIGHT_SHAPES - -from vllm import _custom_ops as ops -from vllm.utils import FlexibleArgumentParser - -DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) -DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] -DEFAULT_TP_SIZES = [1] - -# helpers - - -def to_fp8(tensor: torch.Tensor) -> torch.Tensor: - finfo = torch.finfo(torch.float8_e4m3fn) - return torch.round(tensor.clamp( - min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) - - -def to_int8(tensor: torch.Tensor) -> torch.Tensor: - return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) - - -def make_rand_tensors(dtype: torch.dtype, m: int, n: int, - k: int) -> Tuple[torch.Tensor, torch.Tensor]: - a = torch.randn((m, k), device='cuda') * 5 - b = torch.randn((n, k), device='cuda').t() * 5 - - if dtype == torch.int8: - return to_int8(a), to_int8(b) - if dtype == torch.float8_e4m3fn: - return to_fp8(a), to_fp8(b) - - raise ValueError("unsupported dtype") - - -# bench -def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args, - **kwargs) -> TMeasurement: - min_run_time = 1 - - globals = { - "args": args, - "kwargs": kwargs, - "fn": fn, - } - return TBenchmark.Timer( - stmt="fn(*args, **kwargs)", - globals=globals, - label=label, - sub_label=sub_label, - description=description, - ).blocked_autorange(min_run_time=min_run_time) - - -def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str, - sub_label: str) -> Iterable[TMeasurement]: - assert dtype == torch.int8 - a, b = make_rand_tensors(torch.int8, m, n, k) - scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) - scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) - bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) - azp = torch.zeros((m, ), device="cuda", dtype=torch.int32) - azp_adj = torch.zeros((n, ), device="cuda", dtype=torch.int32) - - timers = [] - # pytorch impl - bfloat16 - timers.append( - bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", - torch.mm, a.to(dtype=torch.bfloat16), - b.to(dtype=torch.bfloat16))) - - # pytorch impl - float16 - timers.append( - bench_fn(label, sub_label, - "pytorch_fp16_fp16_fp16_matmul-no-scales", torch.mm, - a.to(dtype=torch.float16), b.to(dtype=torch.float16))) - - # cutlass impl - timers.append( - bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm", - ops.cutlass_scaled_mm, a, b, scale_a, scale_b, - torch.bfloat16)) - - # cutlass with bias - timers.append( - bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_bias", - ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16, - bias)) - - # cutlass with azp per-tensor - timers.append( - bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp", - ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b, - torch.bfloat16, azp_adj)) - - # cutlass with azp per-tensor + bias - timers.append( - bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp_bias", - ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b, - torch.bfloat16, azp_adj, None, bias)) - - # cutlass with azp per-token - timers.append( - bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp_pt", - ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b, - torch.bfloat16, azp_adj, azp)) - - # cutlass with azp per-token + bias - timers.append( - bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias", - ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b, - torch.bfloat16, azp_adj, azp, bias)) - - return timers - - -def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str, - sub_label: str) -> Iterable[TMeasurement]: - assert dtype == torch.float8_e4m3fn - a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k) - scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) - scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) - bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) - - timers = [] - - # pytorch impl w. bf16 - timers.append( - bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", - torch.mm, a.to(dtype=torch.bfloat16, device="cuda"), - b.to(dtype=torch.bfloat16, device="cuda"))) - - # pytorch impl: bf16 output, without fp8 fast accum - timers.append( - bench_fn(label, - sub_label, - "pytorch_fp8_fp8_bf16_scaled_mm", - torch._scaled_mm, - a, - b, - scale_a=scale_a, - scale_b=scale_b, - out_dtype=torch.bfloat16)) - - # pytorch impl: bf16 output, with fp8 fast accum - timers.append( - bench_fn(label, - sub_label, - "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum", - torch._scaled_mm, - a, - b, - scale_a=scale_a, - scale_b=scale_b, - out_dtype=torch.bfloat16, - use_fast_accum=True)) - - # pytorch impl: fp16 output, without fp8 fast accum - timers.append( - bench_fn(label, - sub_label, - "pytorch_fp8_fp8_fp16_scaled_mm", - torch._scaled_mm, - a, - b, - scale_a=scale_a, - scale_b=scale_b, - out_dtype=torch.float16)) - - # pytorch impl: fp16 output, with fp8 fast accum - timers.append( - bench_fn(label, - sub_label, - "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum", - torch._scaled_mm, - a, - b, - scale_a=scale_a, - scale_b=scale_b, - out_dtype=torch.float16, - use_fast_accum=True)) - - # cutlass impl: bf16 output - timers.append( - bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm", - ops.cutlass_scaled_mm, a, b, scale_a, scale_b, - torch.bfloat16)) - # cutlass impl: fp16 output - timers.append( - bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_mm", - ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.float16)) - - # cutlass impl: bf16 output, with bias - timers.append( - bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm_bias", - ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16, - bias)) - - # cutlass impl: fp16 output, with bias - timers.append( - bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_mm_bias", - ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.float16, - bias.to(dtype=torch.float16))) - - return timers - - -def bench(dtype: torch.dtype, m: int, k: int, n: int, label: str, - sub_label: str) -> Iterable[TMeasurement]: - if dtype == torch.int8: - return bench_int8(dtype, m, k, n, label, sub_label) - if dtype == torch.float8_e4m3fn: - return bench_fp8(dtype, m, k, n, label, sub_label) - raise ValueError("unsupported type") - - -# runner -def print_timers(timers: Iterable[TMeasurement]): - compare = TBenchmark.Compare(timers) - compare.print() - - -def run(dtype: torch.dtype, - MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]: - results = [] - for m, k, n in MKNs: - timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm", - f"MKN=({m}x{k}x{n})") - print_timers(timers) - results.extend(timers) - - return results - - -# output makers -def make_output(data: Iterable[TMeasurement], - MKNs: Iterable[Tuple[int, int, int]], - base_description: str, - timestamp=None): - print(f"== All Results {base_description} ====") - print_timers(data) - - # pickle all the results - timestamp = int(time.time()) if timestamp is None else timestamp - with open(f"{base_description}-{timestamp}.pkl", "wb") as f: - pkl.dump(data, f) - - -# argparse runners - - -def run_square_bench(args): - dim_sizes = list( - range(args.dim_start, args.dim_end + 1, args.dim_increment)) - MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) - data = run(args.dtype, MKNs) - - make_output(data, MKNs, f"square_bench-{args.dtype}") - - -def run_range_bench(args): - dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment)) - n = len(dim_sizes) - Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes - Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes - Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes - MKNs = list(zip(Ms, Ks, Ns)) - data = run(args.dtype, MKNs) - - make_output(data, MKNs, f"range_bench-{args.dtype}") - - -def run_model_bench(args): - print("Benchmarking models:") - for i, model in enumerate(args.models): - print(f"[{i}] {model}") - - def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]: - KNs = [] - for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]): - KN[tp_split_dim] = KN[tp_split_dim] // tp_size - KNs.append(KN) - return KNs - - model_bench_data = [] - models_tps = list(itertools.product(args.models, args.tp_sizes)) - for model, tp_size in models_tps: - Ms = args.batch_sizes - KNs = model_shapes(model, tp_size) - MKNs = [] - for m in Ms: - for k, n in KNs: - MKNs.append((m, k, n)) - - data = run(args.dtype, MKNs) - model_bench_data.append(data) - - # Print all results - for data, model_tp in zip(model_bench_data, models_tps): - model, tp_size = model_tp - print(f"== Results {args.dtype} {model}-TP{tp_size} ====") - print_timers(data) - - timestamp = int(time.time()) - - all_data = [] - for d in model_bench_data: - all_data.extend(d) - # pickle all data - with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f: - pkl.dump(all_data, f) - - -if __name__ == '__main__': - - def to_torch_dtype(dt): - if dt == "int8": - return torch.int8 - if dt == "fp8": - return torch.float8_e4m3fn - raise ValueError("unsupported dtype") - - parser = FlexibleArgumentParser( - description=""" -Benchmark Cutlass GEMM. - - To run square GEMMs: - python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 square_bench --dim-start 128 --dim-end 512 --dim-increment 64 - - To run constant N and K and sweep M: - python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384 - - To run dimensions from a model: - python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1 - - Output: - - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs. - """, # noqa: E501 - formatter_class=argparse.RawTextHelpFormatter) - - parser.add_argument("--dtype", - type=to_torch_dtype, - required=True, - help="Available options are ['int8', 'fp8']") - subparsers = parser.add_subparsers(dest="cmd") - - square_parser = subparsers.add_parser("square_bench") - square_parser.add_argument("--dim-start", type=int, required=True) - square_parser.add_argument("--dim-end", type=int, required=True) - square_parser.add_argument("--dim-increment", type=int, required=True) - square_parser.set_defaults(func=run_square_bench) - - range_parser = subparsers.add_parser("range_bench") - range_parser.add_argument("--dim-start", type=int, required=True) - range_parser.add_argument("--dim-end", type=int, required=True) - range_parser.add_argument("--dim-increment", type=int, required=True) - range_parser.add_argument("--m-constant", type=int, default=None) - range_parser.add_argument("--n-constant", type=int, default=None) - range_parser.add_argument("--k-constant", type=int, default=None) - range_parser.set_defaults(func=run_range_bench) - - model_parser = subparsers.add_parser("model_bench") - model_parser.add_argument("--models", - nargs="+", - type=str, - default=DEFAULT_MODELS, - choices=WEIGHT_SHAPES.keys()) - model_parser.add_argument("--tp-sizes", - nargs="+", - type=int, - default=DEFAULT_TP_SIZES) - model_parser.add_argument("--batch-sizes", - nargs="+", - type=int, - default=DEFAULT_BATCH_SIZES) - model_parser.set_defaults(func=run_model_bench) - - args = parser.parse_args() - args.func(args) diff --git a/benchmarks/cutlass_benchmarks/weight_shapes.py b/benchmarks/cutlass_benchmarks/weight_shapes.py deleted file mode 100644 index 25ec9d6028627..0000000000000 --- a/benchmarks/cutlass_benchmarks/weight_shapes.py +++ /dev/null @@ -1,43 +0,0 @@ -# Weight Shapes are in the format -# ([K, N], TP_SPLIT_DIM) -# Example: -# A shape of ([14336, 4096], 0) indicates the following GEMM shape, -# - TP1 : K = 14336, N = 4096 -# - TP2 : K = 7168, N = 4096 -# A shape of ([4096, 6144], 1) indicates the following GEMM shape, -# - TP1 : K = 4096, N = 6144 -# - TP4 : K = 4096, N = 1536 - -# TP1 shapes -WEIGHT_SHAPES = { - "mistralai/Mistral-7B-v0.1": [ - ([4096, 6144], 1), - ([4096, 4096], 0), - ([4096, 28672], 1), - ([14336, 4096], 0), - ], - "meta-llama/Llama-2-7b-hf": [ - ([4096, 12288], 1), - ([4096, 4096], 0), - ([4096, 22016], 1), - ([11008, 4096], 0), - ], - "meta-llama/Llama-3-8b": [ - ([4096, 6144], 1), - ([4096, 4096], 0), - ([4096, 28672], 1), - ([14336, 4096], 0), - ], - "meta-llama/Llama-2-13b-hf": [ - ([5120, 15360], 1), - ([5120, 5120], 0), - ([5120, 27648], 1), - ([13824, 5120], 0), - ], - "meta-llama/Llama-2-70b-hf": [ - ([8192, 10240], 1), - ([8192, 8192], 0), - ([8192, 57344], 1), - ([28672, 8192], 0), - ], -} diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp index 95764ecddc79f..fcc17c7727f94 100644 --- a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp +++ b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp @@ -36,13 +36,13 @@ struct ScaledEpilogueBase { // Don't want to support nullptr by default template using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T, Stride, Int<0>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; // Don't want to support nullptr by default template using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T, Stride, Int<1>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; // This utility function constructs the arguments for the load descriptors diff --git a/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp b/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp new file mode 100644 index 0000000000000..d407d66ab2aa6 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp @@ -0,0 +1,496 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// +// This file is a modified excerpt of +// include/cutlass/epilogue/fusion/visitor_load.hpp from +// https://github.com/NVIDIA/cutlass v3.5.0 +// It has been modified to support either +// row/column or scalar broadcasting where the tensor being loaded from is +// always passed in via a device pointer. This lets one compiled kernel handle +// all cases of per-tensor or per-channel/per-token quantization. +// +// This interface also allows the scales to be passed in as tensors that +// consistently reside on the device, which avoids an issue with a previous +// implementation where scalars needed to be on the CPU since they +// were passed in via float values. This created a potential performance hazard +// if scales were initially on the device, and caused torch.compile graph +// breaks when moving scales to the CPU. +// +#pragma once + +// Turn off clang-format for the entire file to keep it close to upstream +// clang-format off + +#include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp" +#include "cute/tensor.hpp" + +namespace cutlass::epilogue::threadblock { + +using namespace cute; +using namespace detail; + +template< + class ThreadMap, + class Element, + class StrideMNL +> +struct VisitorRowOrScalarBroadcast { + + // This struct has been modified to have a bool indicating that ptr_row is a + // scalar that must be broadcast. + struct Arguments { + Element const* ptr_row = nullptr; + bool row_broadcast = true; + StrideMNL dRow = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + struct SharedStorage {}; + + // Global load type + static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits::value; + using VecType = uint_bit_t; + static int constexpr VecLength = sizeof(VecType) / sizeof(Element); + + CUTLASS_HOST_DEVICE + VisitorRowOrScalarBroadcast() { } + + CUTLASS_HOST_DEVICE + VisitorRowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) { } + + Params const* params_ptr; + + template + struct Callbacks : EmptyCallbacks { + CUTLASS_DEVICE + Callbacks( + GTensor&& tC_gRow, + RTensor&& tC_rRow, + CTensor&& tC_cRow, + ProblemShape problem_shape, + Params const* params_ptr + ): + tC_gRow(cute::forward(tC_gRow)), + tC_rRow(cute::forward(tC_rRow)), + tC_cRow(cute::forward(tC_cRow)), + n(get<1>(problem_shape)), + params_ptr(params_ptr) { } + + GTensor tC_gRow; + RTensor tC_rRow; + CTensor tC_cRow; + Params const* params_ptr; + int n; + + // This function is modified from VisitorRowBroadcast + CUTLASS_DEVICE void + begin_epilogue() { + clear(tC_rRow); + auto src_v = filter(tC_gRow); + auto coord_v = filter(tC_cRow); + auto dst_v = filter(tC_rRow); + + if (params_ptr->row_broadcast) { + // In this case we are loading from a row vector and broadcasting + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src_v); ++i) { + bool guard = get<1>(coord_v(i)) < n; + cutlass::arch::global_load( + dst_v(i), (void const*)&src_v(i), guard); + } + } else { + // In this case we are loading from a scalar and broadcasting + VecType filled_vec; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < VecLength; i++) { + reinterpret_cast(&filled_vec)[i] = *(params_ptr->ptr_row); + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src_v); ++i) { + if (get<1>(coord_v(i)) < n) { + dst_v(i) = filled_vec; + } + } + } + } + + template + CUTLASS_DEVICE auto // returns an Array + visit(int iter_idx, int row_idx, int column_idx, int frg_idx, + Array const& frg_acc) { + Tensor rRow_frg = recast>(coalesce(tC_rRow)); + return rRow_frg(column_idx); + } + }; + + template + CUTLASS_DEVICE auto + get_callbacks( + gemm::GemmCoord threadblock_tile_offset, + int thread_idx, + ProblemShape problem_shape + ) { + Tensor mRow = make_tensor( + make_gmem_ptr(params_ptr->ptr_row), + problem_shape, + params_ptr->dRow); + + // VECTOR, FRAGMENT_COLUMN + Tensor tC_gRow = recast( + ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset) + )(_,_,_0{},_0{},_0{},_0{}); + Tensor tC_rRow = make_tensor_like(tC_gRow); + + // Generate the pred tensor + Tensor cRow = make_identity_tensor(mRow.shape()); + Tensor tC_cRow = outer_partition( + ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}), + Shape>{}, + (_0{}) + ); + + return Callbacks< + decltype(tC_gRow), decltype(tC_rRow), + decltype(tC_cRow), ProblemShape>( + cute::move(tC_gRow), + cute::move(tC_rRow), + cute::move(tC_cRow), + problem_shape, + params_ptr + ); + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// This is a modified RowBroadcast that will broadcast 0 if ptr_row is null +template< + class ThreadMap, + class Element, + class StrideMNL +> +struct VisitorRowOrZeroBroadcast { + + // This struct has been modified to remove null_default (because it's always 0) + struct Arguments { + Element const* ptr_row = nullptr; + StrideMNL dRow = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + struct SharedStorage {}; + + // Global load type + static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits::value; + using VecType = uint_bit_t; + static int constexpr VecLength = sizeof(VecType) / sizeof(Element); + + CUTLASS_HOST_DEVICE + VisitorRowOrZeroBroadcast() { } + + CUTLASS_HOST_DEVICE + VisitorRowOrZeroBroadcast(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) { } + + Params const* params_ptr; + + template + struct Callbacks : EmptyCallbacks { + CUTLASS_DEVICE + Callbacks( + GTensor&& tC_gRow, + RTensor&& tC_rRow, + CTensor&& tC_cRow, + ProblemShape problem_shape, + Params const* params_ptr + ): + tC_gRow(cute::forward(tC_gRow)), + tC_rRow(cute::forward(tC_rRow)), + tC_cRow(cute::forward(tC_cRow)), + n(get<1>(problem_shape)), + params_ptr(params_ptr) { } + + GTensor tC_gRow; + RTensor tC_rRow; + CTensor tC_cRow; + Params const* params_ptr; + int n; + + // This function is modified from VisitorRowBroadcast + CUTLASS_DEVICE void + begin_epilogue() { + clear(tC_rRow); + auto src_v = filter(tC_gRow); + auto coord_v = filter(tC_cRow); + auto dst_v = filter(tC_rRow); + + if (params_ptr->ptr_row != nullptr) { + // In this case we are loading from a row vector and broadcasting + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src_v); ++i) { + bool guard = get<1>(coord_v(i)) < n; + cutlass::arch::global_load( + dst_v(i), (void const*)&src_v(i), guard); + } + } else { + // In this case we are broadcasting 0 + VecType filled_vec; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < VecLength; i++) { + reinterpret_cast(&filled_vec)[i] = Element{0}; + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src_v); ++i) { + if (get<1>(coord_v(i)) < n) { + dst_v(i) = filled_vec; + } + } + } + } + + template + CUTLASS_DEVICE auto // returns an Array + visit(int iter_idx, int row_idx, int column_idx, int frg_idx, + Array const& frg_acc) { + Tensor rRow_frg = recast>(coalesce(tC_rRow)); + return rRow_frg(column_idx); + } + }; + + template + CUTLASS_DEVICE auto + get_callbacks( + gemm::GemmCoord threadblock_tile_offset, + int thread_idx, + ProblemShape problem_shape + ) { + Tensor mRow = make_tensor( + make_gmem_ptr(params_ptr->ptr_row), + problem_shape, + params_ptr->dRow); + + // VECTOR, FRAGMENT_COLUMN + Tensor tC_gRow = recast( + ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset) + )(_,_,_0{},_0{},_0{},_0{}); + Tensor tC_rRow = make_tensor_like(tC_gRow); + + // Generate the pred tensor + Tensor cRow = make_identity_tensor(mRow.shape()); + Tensor tC_cRow = outer_partition( + ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}), + Shape>{}, + (_0{}) + ); + + return Callbacks< + decltype(tC_gRow), decltype(tC_rRow), + decltype(tC_cRow), ProblemShape>( + cute::move(tC_gRow), + cute::move(tC_rRow), + cute::move(tC_cRow), + problem_shape, + params_ptr + ); + } + +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Column vector broadcast +template< + class ThreadMap, + class Element, + class StrideMNL = Stride<_1,_0,_0> +> +struct VisitorColOrScalarBroadcast { + + // This struct has been modified to have a bool indicating that ptr_col is a + // scalar that must be broadcast. + struct Arguments { + Element const* ptr_col = nullptr; + bool col_broadcast = true; + StrideMNL dCol = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + struct SharedStorage { }; + + CUTLASS_HOST_DEVICE + VisitorColOrScalarBroadcast() { } + + CUTLASS_HOST_DEVICE + VisitorColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) { } + + Params const* params_ptr; + + template + struct Callbacks : EmptyCallbacks { + CUTLASS_DEVICE + Callbacks( + GTensor&& tC_gCol, + RTensor&& tC_rCol, + CTensor&& tC_cCol, + ProblemShape problem_shape, + Params const* params_ptr + ): + tC_gCol(cute::forward(tC_gCol)), + tC_rCol(cute::forward(tC_rCol)), + tC_cCol(cute::forward(tC_cCol)), + m(get<0>(problem_shape)), + params_ptr(params_ptr) { } + + GTensor tC_gCol; + RTensor tC_rCol; + CTensor tC_cCol; + Params const* params_ptr; + int m; + + // This function is modified from VisitorColBroadcast + CUTLASS_DEVICE void + begin_epilogue() { + clear(tC_rCol); + + Tensor pred = make_tensor(shape(tC_gCol)); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(pred); ++i) { + pred(i) = get<0>(tC_cCol(i)) < m; + } + + if (params_ptr->col_broadcast) { + // In this case we are loading from a column vector and broadcasting + copy_if(pred, tC_gCol, tC_rCol); + } else { + // In this case we are loading from a scalar and broadcasting + auto dst_v = filter(tC_rCol); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(dst_v); ++i) { + if (pred(i)) { + dst_v(i) = *(params_ptr->ptr_col); + } + } + } + } + + template + CUTLASS_DEVICE auto // returns an Array + visit(int iter_idx, int row_idx, int column_idx, int frg_idx, + Array const& frg_acc) { + Array frg_col; + frg_col.fill(tC_rCol(row_idx,iter_idx)); + return frg_col; + } + }; + + template + CUTLASS_DEVICE auto + get_callbacks( + gemm::GemmCoord threadblock_tile_offset, + int thread_idx, + ProblemShape problem_shape + ) { + Tensor mCol = make_tensor( + make_gmem_ptr(params_ptr->ptr_col), + problem_shape, + params_ptr->dCol); + + // VECTOR, FRAGMENT_COLUMN, FRAGMENT_ROW, ITERATION_ROW, ITERATION_GROUP, ITERATION_CLUSTER + Tensor tC_gCol = group_modes<1,4>( + ThreadMap::partition(mCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_)); + Tensor tC_rCol = make_tensor_like(tC_gCol); + + // Generate the pred tensor + Tensor cCol = make_identity_tensor(mCol.shape()); + Tensor tC_cCol = group_modes<1,4>( + ThreadMap::partition(cCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_)); + + return Callbacks< + decltype(tC_gCol), decltype(tC_rCol), + decltype(tC_cCol), ProblemShape>( + cute::move(tC_gCol), + cute::move(tC_rCol), + cute::move(tC_cCol), + problem_shape, + params_ptr + ); + } +}; + +} diff --git a/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp b/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp new file mode 100644 index 0000000000000..58b1e8ff159fb --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp @@ -0,0 +1,447 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// +// This file is a modified excerpt of +// include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp +// from https://github.com/NVIDIA/cutlass v3.5.0 +// It has been modified to support either row/column or scalar broadcasting +// where the tensor being loaded from is always passed in via a device pointer. +// This lets one compiled kernel handle all cases of per-tensor or +// per-channel/per-token quantization. +// +// This interface also allows the scales to be passed in as tensors that +// consistently reside on the device, which avoids an issue with a previous +// implementation where scalars needed to be on the CPU since they +// were passed in via float values. This created a potential performance hazard +// if scales were initially on the device, and caused torch.compile graphs +// breaks when moving scales to the CPU. +// +#pragma once + +// Turn off clang-format for the entire file to keep it close to upstream +// clang-format off + +#include "cutlass/cutlass.h" +#include "cutlass/arch/barrier.h" + +#include "cute/tensor.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" + +namespace cutlass::epilogue::fusion { + +using namespace cute; +using namespace detail; + +// Row vector broadcast +template< + int Stages, + class CtaTileShapeMNK, + class Element, + class StrideMNL = Stride<_0,_1,_0>, + int Alignment = 128 / sizeof_bits_v +> +struct Sm90RowOrScalarBroadcast { + static_assert(Stages == 0, "Row broadcast doesn't support smem usage"); + static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static + static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{}); + + struct SharedStorage { + array_aligned(CtaTileShapeMNK{})> smem; + }; + + // This struct has been modified to have a bool indicating that ptr_row is a + // scalar that must be broadcast, instead of containing a scalar that is + // valid if ptr_row is null. + struct Arguments { + Element const* ptr_row = nullptr; + bool row_broadcast = true; + StrideMNL dRow = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Sm90RowOrScalarBroadcast() { } + + CUTLASS_HOST_DEVICE + Sm90RowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) + : params(params) + , smem(const_cast(shared_storage.smem.data())) { } + + Params params; + Element *smem = nullptr; + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_zero() const { + return (!params.row_broadcast && *(params.ptr_row) == Element(0)); + } + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_, + GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_, + SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_, + CTensor tCcRow_, ThrResidue residue_tCcRow_, ThrNum thr_num_, Params const& params_) + : tGS_gRow(tGS_gRow_) + , tGS_sRow(tGS_sRow_) + , tGS_cRow(tGS_cRow_) + , tiled_G2S(tiled_g2s_) + , tSR_sRow(tSR_sRow_) + , tSR_rRow(tSR_rRow_) + , tCcRow(tCcRow_) + , residue_tCcRow(residue_tCcRow_) + , params(params_) {} + + GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N) + GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N) + GS_CTensor tGS_cRow; // (CPY,CPY_M,CPY_N) + Tiled_G2S tiled_G2S; + + SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + CTensor tCcRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + ThrResidue residue_tCcRow; // (m, n) + ThrNum thr_num; + Params const& params; + + CUTLASS_DEVICE void + begin() { + if (!params.row_broadcast) { + fill(tSR_rRow, *(params.ptr_row)); + return; + } + + auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; + Tensor tGS_gRow_flt = filter_zeros(tGS_gRow); + Tensor tGS_sRow_flt = filter_zeros(tGS_sRow); + Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride())); + + for (int i = 0; i < size(tGS_gRow_flt); ++i) { + if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) { + continue; // OOB of SMEM, + } + if (elem_less(tGS_cRow_flt(i), make_coord(get<0>(residue_tCcRow), get<1>(residue_tCcRow)))) { + tGS_sRow_flt(i) = tGS_gRow_flt(i); + } + else { + tGS_sRow_flt(i) = Element(0); // Set to Zero when OOB so LDS could be issue without any preds. + } + } + synchronize(); + } + + CUTLASS_DEVICE void + begin_loop(int epi_m, int epi_n) { + if (epi_m == 0) { // Assumes M-major subtile loop + if (!params.row_broadcast) return; // Do not issue LDS when row is scalar + Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n)); + Tensor tSR_rRow_flt = filter_zeros(tSR_rRow); + copy(tSR_sRow_flt, tSR_rRow_flt); + } + } + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_row; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + frg_row[i] = tSR_rRow(epi_v * FragmentSize + i); + } + + return frg_row; + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + using ThreadCount = decltype(size(args.tiled_copy)); + + Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow); + Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N) + Tensor sRow = make_tensor(make_smem_ptr(smem), + make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N) + //// G2S: Gmem to Smem + auto tiled_g2s = make_tiled_copy(Copy_Atom{}, + Layout< Shape<_1, ThreadCount>, + Stride<_0, _1>>{}, + Layout<_1>{}); + auto thr_g2s = tiled_g2s.get_slice(args.thread_idx); + Tensor tGS_gRow = thr_g2s.partition_S(gRow); + Tensor tGS_sRow = thr_g2s.partition_D(sRow); + + //// G2S: Coord + auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}))); + Tensor tGS_cRow = thr_g2s.partition_S(cRow); + + //// S2R: Smem to Reg + Tensor tSR_sRow = sm90_partition_for_epilogue(sRow, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N) + + return ConsumerStoreCallbacks( + tGS_gRow, + tGS_sRow, + tGS_cRow, tiled_g2s, + tSR_sRow, + tSR_rRow, + args.tCcD, + args.residue_cD, + ThreadCount{}, + params); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Column vector broadcast +template< + int Stages, + class CtaTileShapeMNK, + class Element, + class StrideMNL = Stride<_1,_0,_0>, + int Alignment = 128 / sizeof_bits_v +> +struct Sm90ColOrScalarBroadcast { + static_assert(Stages == 0, "Column broadcast doesn't support smem usage yet"); + static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); + static_assert( + (cute::is_same_v>) || // col vector broadcast, e.g. per-row alpha/bias + (cute::is_same_v>)); // batched col vector broadcast, e.g. batched per-row bias + + // Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem + struct SharedStorage { }; + + // This struct has been modified to have a bool indicating that ptr_col is a + // scalar that must be broadcast, instead of containing a scalar that is + // valid if ptr_col is null. + struct Arguments { + Element const* ptr_col = nullptr; + bool col_broadcast = true; + StrideMNL dCol = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_zero() const { + return (!params.col_broadcast && *(params.ptr_col) == Element(0)); + } + + CUTLASS_HOST_DEVICE + Sm90ColOrScalarBroadcast() { } + + CUTLASS_HOST_DEVICE + Sm90ColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage) + : params(params) { } + + Params params; + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + GTensor&& tCgCol, + RTensor&& tCrCol, + CTensor&& tCcCol, + ProblemShape problem_shape, + Params const& params + ): + tCgCol(cute::forward(tCgCol)), + tCrCol(cute::forward(tCrCol)), + tCcCol(cute::forward(tCcCol)), + m(get<0>(problem_shape)), + params(params) {} + + GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + RTensor tCrCol; + CTensor tCcCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + Params const& params; + int m; + + CUTLASS_DEVICE void + begin() { + Tensor pred = make_tensor(shape(tCgCol)); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(pred); ++i) { + pred(i) = get<0>(tCcCol(i)) < m; + } + + if (!params.col_broadcast) { + fill(tCrCol, *(params.ptr_col)); + return; + } + + // Filter so we don't issue redundant copies over stride-0 modes + // (only works if 0-strides are in same location, which is by construction) + copy_if(pred, filter(tCgCol), filter(tCrCol)); + } + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_col; + Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i); + } + + return frg_col; + } + + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + + auto [M, N, K, L] = args.problem_shape_mnkl; + Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol); + Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + // Generate an identity tensor matching the shape of the global tensor and + // partition the same way, this will be used to generate the predicate + // tensor for loading + Tensor cCol = make_identity_tensor(mCol.shape()); + Tensor tCcCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + + return ConsumerStoreCallbacks( + cute::move(tCgCol), + cute::move(tCrCol), + cute::move(tCcCol), + args.problem_shape_mnkl, + params + ); + } +}; + +} diff --git a/csrc/quantization/cutlass_w8a8/generator/README.md b/csrc/quantization/cutlass_w8a8/generator/README.md new file mode 100644 index 0000000000000..523d767074820 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/generator/README.md @@ -0,0 +1,143 @@ +## Cutlass Kernel Generator and Benchmark Sweeps + +#### Basic Idea + - Expose a C++ interface for the function to benchmark. The interface must be + templated with the hyper-parameters we desire to sweep over. + - Generate .cu files using jinja templates that use the exposed interface. + Look at `scaled_mm_c3x.jinja` + - Generate torch bindings for the functions in the .cu files. + - Build vllm to include all the generated .cu files. Look at `nm_cutlass_c.cmake` + - Run the benchmarking script to sweep over problem shapes and all the generated + cutlass kernels. Look at `benchmarks/cutlass_benchmarks/bench_v2.py` + +#### Important Files + - scaled_mm_c3x.jinja / simple_gemm_c3x.jinja : Jinja templated files for functions to generate. + - scaled_mm_c3x_fnprototype.jinja / simple_gemm_c3x_fnprototype.jinja : Jinja templated files for the C++ function declarations. + - generator_types.py : This file contains all the information regarding the function type we intend to generate. + For example, at the time of writing, we have ScaledMMGeneratorType and SimpleGemmGeneratorType. + The ScaledMMGeneratorType points to the correct jinja templates to use and also defines the + correct torch biniding `ops.impl` and `ops.def` string. This is where we register new GeneratorTypes + if we add more function-generators in the future. + - autogen_manifest.py : Defines hyper-parameter sets. + - kernel_generator.py : All utilities that are responsible for filling out the jinja templates + based on the given set of hyper-parameter args. + - generator.py : Bridges autogen_manifest.py and kernel_generator.py. This is the `main` driver + scripts that we use to generate kernels. + - kernel_compiler.py : Not all sets of hyperparameters are valid. The KernelCompiler, attempts an + nvcc compile on the generated kernel file and kernel_generator/generator accepts/rejects + the generated kernel based this compilation status. + +#### Adding a new function to generate + +##### Step 1 + - Like mentioned before, expose a C++ interface for the function to generate. The interface + must be templated with the hyper-parameters we desire to sweep over. + +##### Step 2 + - Create jinja templates. + 1. Create a jinja template file that is representative of the kernel we wish to generate. + 2. Create a separate jinja template file that has the function declaration. + - Refer to `scaled_mm_c3x.jinja` and `scaled_mm_c3x_fnprototype.jinja` + +##### Step 3 + - Create a GeneratorType in generator_types.py + - The GeneratorType is the datastructure that communicates, + 1. What jinja template files to use + 2. What is the torch_bindings `ops.def` and `ops.impl` arguments + - Refer to ScaledMMGeneratorType + +##### Step 4 + - In autogen_manifest, create a list of hyper-parameter sets that are to be translated into kernel files. + - Look at the construction of Cutlass3xArgsTest in autogen_manifest.py + +##### Commands to generate kernels: + - Example command: + python3 csrc/quantization/cutlass_w8a8/generator/generator.py --generator-type scaled_mm --vllm-root-dir ${HOME}/code/nm-vllm-ent/nm-vllm-ent/ --py-venv-dir ${HOME}/code/nm-vllm-ent/nm-vllm-ent/vllm-test --cuda-dir /usr/local/cuda-12.5 --cutlass-args-list Cutlass3xArgsTest + + Here: + - --generator-type : The description of the desired GeneratorType in generator_types.py + - --vllm-root-dir : The root-dir of your vllm project + - --py-venv-dir : The root-dir of your python environment + - --cuda-dir : cuda dir to use + - --cutlass-args-list : the name of the list of hyper-parameter sets that you created in autogen_manifest.py + + Expectations: + The generator attempts to generate one kernel for every hyper-parameter set. + - The generator looks generates the kernel file + - The generator attempts to compile the generated kernel file + - If compilation succeeds, it keeps the generated kernel file. Deletes it otherwise. + + The generator records the status of the compilation for each kernel it tries to compile. If some kernel is known to + have succeeded in a previous run, it simply generates it and doesnot attempt a re-compile. + +##### Commands to build + - The normal vllm build command should work. + - i.e. either `pip3 install -e .` or `python3 setup.py --build_ext --inplace` + Expectation: + Compilation should be successful and you should see .so files like, `_nm_cutlass_*_C.so` in the vllm folder + +##### How to benchmark +The benchmarking scripts have been updated to grab all the auto-generated cutlass kernels. Look at +`get_autogen_functions` in `benchmarks/cutlass_benchmarks/bench_v2.py`. + +Example command: +python3 benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 --with-arg-pool 32 --with-cuda-graph 32 square_bench --dim-start 128 --dim-end 256 --dim-increment 128 + +Expectations: + You should see output similar to, + ``` + attempting import vllm._nm_cutlass_0_C + #autogen functions found 3 + Bench autogen autogen_scaled_mm_90_64x64x32_1x1x1_KernelTmaWarpSpecializedFP8FastAccum_TmaWarpSpecializedCooperative_PersistentScheduler_kGemm_float_fp8 + Bench autogen autogen_scaled_mm_90_64x64x32_1x1x1_KernelTmaWarpSpecializedFP8FastAccum_TmaWarpSpecialized_PersistentScheduler_kGemm_float_fp8 + Bench autogen autogen_scaled_mm_90_64x64x32_1x1x1_KernelTmaWarpSpecializedPingpongFP8FastAccum_TmaWarpSpecialized_PersistentScheduler_kGemm_float_fp8 + ``` + +##### Benchmark Heatmaps and Optimal Kernel Set Selection +Typically a hyper-parameter sweep produces 100s of kernels. It could be hard to read the terminal outputs +of benchmarking scripts. The w8a8_benchmarks.py script when used with the model_bench command, produces +a pickle file that contains the benchmark information for all the {kernel, gemm-shape} pairs benchmarked. + +###### Kernel Selection Problem +When we run a hyper-parameter sweep, we are interested in finding a minimal a set of kernels that is the +optimal for the gemm-shapes benchmarked. `tools/select_kernels.py` solves this optimization problem. + +Example: + python3 select_kernels.py --input-pkl ./model_bench-torch.float8_e4m3fn-1729989172.pkl --min-gemm-efficiency 0.98 + + This example invocation of the select_kernels.py script, + - Reads the input pickle file and gathers the benchmark information of all the {kernel, gemm-shape} pairs. + - Normalizes the benchmark information with respect to gemm shapes. i.e. the best performing + kernel for some gemm-shape is given a value of 1.0. A kernel with a value of `x` ( `x` < 1.0) + indicates that that kernel's performance is `x` times that of the optimal kernel. + - The script ignores all the {kernel, gemm-shape} pairs where the kernel efficiency is < min_gemm_efficiency. + In this case the script only considers the {kernel, gemm-shape} pairs where the normalized value + is in range [0.98, 1.0] + - The script then determines the optimal and minimal kernel set. + +###### Visualization problem +Reading the w8a8_benchmarks.py terminal output can get overwhelming. The script `tools/heatmap.py` +consumes a model_bench pickle file and produces a heatmap for better consumption of the results. + +Example: + python3 heatmap.py --input-pkl ./model_bench-torch.float8_e4m3fn-1730295961-selected.pkl --plot-all-ops + + Normalizes all the {kernel, gemm-shape} information in the model_bench pickle file (refer to "Kernel Selection Problem" + for how the data is normalized). and renders the normalized benchmark information as a heatmap. + +Example: + python3 heatmap.py --input-pkl ./model_bench-torch.float8_e4m3fn-1730295961-selected.pkl --select-kernels + + Effectively runs select_kernel.py on the input pkl file and renders the selected kernels as heatmap. + + + + + + + +tools/select_kernel.py : + + + + diff --git a/csrc/quantization/cutlass_w8a8/generator/autogen_manifest.py b/csrc/quantization/cutlass_w8a8/generator/autogen_manifest.py new file mode 100644 index 0000000000000..8592fa3f093cc --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/generator/autogen_manifest.py @@ -0,0 +1,167 @@ +import copy +from dataclasses import dataclass +from itertools import product +from typing import Tuple + + +@dataclass +class Cutlass3xArgs: + dtype_str: str + arch: int + tile_shape: Tuple[int, int, int] + cluster_shape: Tuple[int, int, int] + kernel_schedule: str + epilogue_schedule: str + tile_schedule: str + gemm_mode: str + acc_type: str + + def with_tile_shape(self, ts): + clone = copy.deepcopy(self) + clone.tile_shape = ts + return clone + + def with_cluster_shape(self, cs): + clone = copy.deepcopy(self) + clone.cluster_shape = cs + return clone + + def with_tile_schedule(self, ts): + clone = copy.deepcopy(self) + clone.tile_schedule = ts + return clone + + def with_kernel_schedule(self, ks): + clone = copy.deepcopy(self) + clone.kernel_schedule = ks + return clone + + def with_epilogue_schedule(self, es): + clone = copy.deepcopy(self) + clone.epilogue_schedule = es + return clone + + def with_gemm_mode(self, gm): + clone = copy.deepcopy(self) + clone.gemm_mode = gm + return clone + + def with_acc_type(self, acc): + clone = copy.deepcopy(self) + clone.acc_type = acc + return clone + + def with_dtype_str(self, dtype_str): + clone = copy.deepcopy(self) + clone.dtype_str = dtype_str + return clone + + +DefaultCutlass3xArgsFP8 = Cutlass3xArgs( + "fp8", 90, (128, 128, 128), (1, 2, 1), + "cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative", + "cutlass::epilogue::TmaWarpSpecializedCooperative", + "cutlass::gemm::PersistentScheduler", + "cutlass::gemm::GemmUniversalMode::kGemm", "float") + +DefaultCutlass3xArgsINT8 = Cutlass3xArgs( + "int8", 90, (128, 128, 128), (2, 1, 1), + "cutlass::gemm::KernelTmaWarpSpecializedPingpong", + "cutlass::epilogue::TmaWarpSpecialized", + "cutlass::gemm::PersistentScheduler", + "cutlass::gemm::GemmUniversalMode::kGemm", "int32_t") + +## Kernel Schedules +## All +# struct KernelMultistage { }; +# struct KernelCpAsyncWarpSpecialized { }; +# struct KernelCpAsyncWarpSpecializedPingpong { }; +# struct KernelCpAsyncWarpSpecializedCooperative { }; +# struct KernelTma { }; +# struct KernelTmaWarpSpecialized { }; +# struct KernelTmaWarpSpecializedPingpong { }; +# struct KernelTmaWarpSpecializedCooperative { }; +# struct KernelPtrArrayTmaWarpSpecializedCooperative { }; +## FP8 +# struct KernelTmaWarpSpecializedFP8FastAccum : KernelTmaWarpSpecialized { }; +# struct KernelTmaWarpSpecializedPingpongFP8FastAccum : KernelTmaWarpSpecializedPingpong { }; # noqa +# struct KernelTmaWarpSpecializedCooperativeFP8FastAccum: KernelTmaWarpSpecializedCooperative { }; #noqa +# struct KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum : KernelPtrArrayTmaWarpSpecializedCooperative { }; #noqa + +## Epilogue policies +# struct NoSmemWarpSpecialized {}; +# struct PtrArrayNoSmemWarpSpecialized {}; +# struct TmaWarpSpecialized {}; +# struct TmaWarpSpecializedCooperative {}; + +## Tile scheduler +# struct PersistentScheduler { }; +# struct StreamKScheduler { }; + +## Kgemms +# kGemm +# kGemmSplitKParallel, +# kBatched, +# kArray, +# kGrouped, +# kInvalid + +cluster_shapes = [(1, 1, 1), (2, 1, 1), (1, 2, 1), (2, 2, 1), (4, 1, 1), + (1, 4, 1), (8, 1, 1), (1, 8, 1), (4, 4, 1)] +tile_shapes_m = [64, 128, 256] +tile_shapes_n = [64, 128, 256] +tile_shapes_k = [32, 64, 128, 256] +tile_shapes = list(product(tile_shapes_m, tile_shapes_n, tile_shapes_k)) + +kernel_schedules_fp8 = [ + "cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum", + "cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum", + "cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum" +] + +kernel_schedules = [ + "cutlass::gemm::KernelTmaWarpSpecialized", + "cutlass::gemm::KernelTmaWarpSpecializedPingpong", + "cutlass::gemm::KernelTmaWarpSpecializedCooperative" +] + +epilogue_schedules = [ + "cutlass::epilogue::TmaWarpSpecialized", + "cutlass::epilogue::TmaWarpSpecializedCooperative" +] + +tile_schedules = [ + "cutlass::gemm::PersistentScheduler", "cutlass::gemm::StreamKScheduler" +] + +gemm_modes = ["cutlass::gemm::GemmUniversalMode::kGemm"] + +acc_types = ["float"] +acc_types_int = ["int32_t"] + +## Make Cutlass3xArgsTest + +Cutlass3xArgsTestFP8 = [] +Cutlass3xArgsTestINT8 = [] + +for ts, cs, ks, es, tile_schedule, gm, at in product( + tile_shapes, cluster_shapes, kernel_schedules_fp8, epilogue_schedules, + tile_schedules, gemm_modes, acc_types): + + Cutlass3xArgsTestFP8.append( + DefaultCutlass3xArgsFP8.with_tile_shape(ts).with_cluster_shape(cs). + with_kernel_schedule(ks).with_epilogue_schedule(es).with_tile_schedule( + tile_schedule).with_gemm_mode(gm).with_acc_type(at)) + +for ts, cs, ks, es, tile_schedule, gm, at in product( + tile_shapes, cluster_shapes, kernel_schedules, epilogue_schedules, + tile_schedules, gemm_modes, acc_types_int): + + Cutlass3xArgsTestINT8.append( + DefaultCutlass3xArgsINT8.with_tile_shape(ts).with_cluster_shape(cs). + with_kernel_schedule(ks).with_epilogue_schedule(es).with_tile_schedule( + tile_schedule).with_gemm_mode(gm).with_acc_type(at)) + +# Cutlass3xArgsTestFP8 = Cutlass3xArgsTestFP8[:5] +# Cutlass3xArgsTestFP16.append(DefaultCutlass3xArgsFP16) +# Cutlass3xArgsTestBF16.append(DefaultCutlass3xArgsBF16) \ No newline at end of file diff --git a/csrc/quantization/cutlass_w8a8/generator/generator.py b/csrc/quantization/cutlass_w8a8/generator/generator.py new file mode 100644 index 0000000000000..eb56bdf8e0576 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/generator/generator.py @@ -0,0 +1,155 @@ +import pprint +from dataclasses import dataclass +from multiprocessing.pool import ThreadPool +from typing import List, Optional + +import autogen_manifest +from autogen_manifest import Cutlass3xArgs +from generator_types import GeneratorType, GeneratorTypes +from kernel_compiler import KernelCompiler +from kernel_generator import GeneratorOutput, KernelGenerator +from tqdm import tqdm + + +@dataclass +class GenerateFromArgInput: + generator_type: Optional[GeneratorType] = None + args: Optional[Cutlass3xArgs] = None + kernel_compiler: Optional[KernelCompiler] = None + + +def generate_from_arg(input: GenerateFromArgInput) -> GeneratorOutput: + """ + Kernel generation for a single Cutlass3xArg + """ + generator_type, args, kernel_compiler = (input.generator_type, input.args, + input.kernel_compiler) + return KernelGenerator.generate(generator_type, args, kernel_compiler) + + +def generate_from_args_mt(generator_type: GeneratorType, + args: List[Cutlass3xArgs], + kernel_compiler: KernelCompiler, + num_threads: int = 32) -> GeneratorOutput: + """ + Kernel generator for a list of Cutlass3xArgs with multi-threading. + """ + generator_outputs = GeneratorOutput() + # create thread pool with {num_threads} threads + pool = ThreadPool(processes=num_threads) + inputs = [ + GenerateFromArgInput(generator_type, x, kernel_compiler) for x in args + ] + result = pool.map_async(generate_from_arg, inputs) + for r in result.get(): + generator_outputs.merge(r) + return generator_outputs + + +def main(args): + pprint.pprint(args) + + cutlass_args_list = getattr(autogen_manifest, args.cutlass_args_list) + print(f"Generating {len(cutlass_args_list)} cuda files ...") + + generator_type: GeneratorType = GeneratorType.from_str(args.generator_type) + + additional_compile_args = [x.strip() for x in args.additional_compile_args] + kernel_compiler: KernelCompiler = KernelCompiler( + vllm_root_dir=args.vllm_root_dir, + py_venv_dir=args.py_venv_dir, + cuda_dir=args.cuda_dir, + py_version=args.py_version, + additional_args=additional_compile_args, + test_compile=args.test_compile, + cache_write_only=args.cache_write_only) + kernel_compiler.init_compile_cache() + + generator_outputs = GeneratorOutput() + batch_size = 100 # Compile-and-Generate batch_size items at a time + for idx in tqdm(range(0, len(cutlass_args_list), batch_size)): + print(f"Total {len(cutlass_args_list)}" + f" | Success {len(generator_outputs.success_file_names)}" + f"| Fail {len(generator_outputs.failed_file_names)}") + + chunk_generator_output = generate_from_args_mt( + generator_type, cutlass_args_list[idx:idx + batch_size], + kernel_compiler) + generator_outputs.merge(chunk_generator_output) + + # Store intermediate results + # fill-out ops.h + KernelGenerator.write_ops(generator_type, args, + generator_outputs.file_paths, + generator_outputs.fn_names, + generator_outputs.fn_decls) + # store result batch + kernel_compiler.cache.add(generator_outputs.success_file_names, + generator_outputs.failed_file_names) + kernel_compiler.cache.store() + + +if __name__ == "__main__": + + import argparse + parser = argparse.ArgumentParser(description=''' + Autogen cutlass kernels + Example: + python3 csrc/quantization/cutlass_w8a8/generator/generator.py \ + --generator-type scaled_mm \ + --vllm-root-dir ${HOME}/code/nm-vllm-ent/nm-vllm-ent/ \ + --py-venv-dir ${HOME}/code/nm-vllm-ent/nm-vllm-ent/vllm-test \ + --cuda-dir /usr/local/cuda-12.5 + ''') + + parser.add_argument("--generator-type", + required=True, + choices=[x.description() for x in GeneratorTypes]) + parser.add_argument("--cutlass-args-list", + required=True, + type=str, + default=None, + help=''' + The cutlass args list variable name constructed in + autogen_manifest.py. The variable name is imported + as, + getattr(autogen_manifest, args.cutlass_args_list) + ''') + parser.add_argument('--test-compile', + action='store_true', + help=''' + Runs as usual but, + - Prints compiler errors + - Doesn't update the kernel compiler cache. + ''') + parser.add_argument("--vllm-root-dir", + required=True, + type=str, + default=None, + help="Root directory of vllm source code") + parser.add_argument("--py-venv-dir", + required=True, + type=str, + default=None, + help="py venv root directory") + parser.add_argument("--cuda-dir", + type=str, + default=None, + help="CUDA dir example: /usr/local/cuda-12.5") + parser.add_argument("--dtype-str", + required=True, + type=str, + choices=["int8", "fp8", "fp16", "bf16"], + help="Data type string. Example: fp8") + parser.add_argument("--cache-write-only", + action='store_true', + help="Don't read from cache, only write to cache") + parser.add_argument("--additional-compile-args", nargs='*', default=[]) + parser.add_argument( + "--py-version", + type=str, + default="3.10", + help="Python version to use. Used in fetching the python includes") + + args = parser.parse_args() + main(args) diff --git a/csrc/quantization/cutlass_w8a8/generator/generator_types.py b/csrc/quantization/cutlass_w8a8/generator/generator_types.py new file mode 100644 index 0000000000000..76fc4f6aefa28 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/generator/generator_types.py @@ -0,0 +1,77 @@ +""" +Generator function types. + +Defines necessary information about each function type to generate. +""" + +from abc import ABC, abstractmethod +from pathlib import Path +from typing import List + +from utils import get_script_dir + + +class GeneratorType(ABC): + SCRIPT_DIR = get_script_dir() + + @staticmethod + def description() -> str: + raise NotImplementedError + + @abstractmethod + def fn_defn_jinja_filepath(self) -> Path: + # Function definition jinja - the entrypoint to the function to + # generate. + # Refer to csrc/quantization/cutlass_w8a8/scaled_mm_c3x.jinja for + # an example. + raise NotImplementedError + + @abstractmethod + def fn_decl_jinja_filepath(self) -> Path: + # Function decl jinja - the c++ function declaration of the function + # to generate. + # Refer to csrc/quantization/cutlass_w8a8/scaled_mm_c3x_fnprototype.jinja #noqa + # for an example. + + raise NotImplementedError + + @abstractmethod + def ops_def(self, fn_name: str) -> str: + # torch binding ops.def template. + raise NotImplementedError + + @abstractmethod + def ops_impl(self, fn_name: str) -> str: + # torch binding ops.impl template. + raise NotImplementedError + + @staticmethod + def from_str(s: str) -> "GeneratorType": + if ScaledMMGenerator.description() == s: + return ScaledMMGenerator() + raise ValueError("Unknown generator type string {s}") + + +class ScaledMMGenerator(GeneratorType): + + def __init__(self): + super().__init__() + + @staticmethod + def description(): + return "scaled_mm" + + def fn_defn_jinja_filepath(self): + return GeneratorType.SCRIPT_DIR / "scaled_mm_c3x.jinja" + + def fn_decl_jinja_filepath(self): + return GeneratorType.SCRIPT_DIR / "scaled_mm_c3x_fnprototype.jinja" + + def ops_def(self, fn_name: str) -> str: + return f'ops.def("{fn_name}(Tensor! out, Tensor a, Tensor b, Tensor a_scales, Tensor b_scales) -> ()");' #noqa + + def ops_impl(self, fn_name: str) -> str: + return f'ops.impl("{fn_name}", torch::kCUDA, &{fn_name});' + + +GeneratorTypes: List[GeneratorType] = [ScaledMMGenerator] diff --git a/csrc/quantization/cutlass_w8a8/generator/kernel_compiler.py b/csrc/quantization/cutlass_w8a8/generator/kernel_compiler.py new file mode 100644 index 0000000000000..9531b4328ac8f --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/generator/kernel_compiler.py @@ -0,0 +1,131 @@ +""" +Utilities to invoke the kernel compiler. +When generating cutlass kernels, we attempt an nvcc compile to make sure that +there won't be any issues at vllm build time. +""" + +import pickle as pkl +import subprocess +from dataclasses import dataclass, field +from pathlib import Path +from typing import List, Optional + +# Global compile cache path that stores information about which kernels +# compiled successfully and which failed. +CACHE_FILE_PATH = Path('./kernels_compile_cache.pkl') + + +class KernelCompileCache: + + def __init__(self, test_compile=False, write_only=False): + # If test_compile is true, we override the cache operations so it + # is a no-op. + self.test_compile = test_compile + self.write_only = write_only + + # self.bad_kernels are kernels that failed compilation + # self.good_kernels are kernels that succeeded compilation + if not CACHE_FILE_PATH.exists() or self.test_compile: + self.bad_kernels = [] + self.good_kernels = [] + else: + # Load from cache + data = None + with open(str(CACHE_FILE_PATH), 'rb') as f: + data = pkl.load(f) + self.bad_kernels, self.good_kernels = data + print(f"#bad kernels {len(self.bad_kernels)}," + f"#good kernels {len(self.good_kernels)} loaded from cache ...") + + def is_bad_kernel(self, kernel_file_name: str): + if self.test_compile or self.write_only: + return False + return kernel_file_name in self.bad_kernels + + def is_good_kernel(self, kernel_file_name: str): + if self.test_compile or self.write_only: + return False + return kernel_file_name in self.good_kernels + + def add(self, success: List[str], fail: List[str]): + self.good_kernels.extend(success) + self.bad_kernels.extend(fail) + # Remove duplicates + self.good_kernels = list(set(self.good_kernels)) + # Remove good kernels from bad kernels + self.bad_kernels = list(set(self.bad_kernels).difference(set(self.good_kernels))) + + def store(self): + if self.test_compile: + return + print(f"Storing #badkernels {len(self.bad_kernels)}, " + f"#goodkernels {len(self.good_kernels)}") + with open(str(CACHE_FILE_PATH), 'wb+') as f: + pkl.dump((self.bad_kernels, self.good_kernels), f) + + +@dataclass +class KernelCompiler: + # vllm source code directory path + vllm_root_dir: Optional[str] = None + # python venv directory path + py_venv_dir: Optional[str] = None + # cuda directory path. example : /usr/local/cuda-12.5 + cuda_dir: Optional[str] = None + #python version + py_version: str = '3.10' + # any additional flags + additional_args: List[str] = field(default_factory=lambda: []) + # kernel compile cache. Cache that holds history of which kernels + # succeeded and failed compilation. + cache: Optional[KernelCompileCache] = None + # Print nvcc compile information and override cache updates. + test_compile: bool = False + cache_write_only: bool = False + + def init_compile_cache(self): + self.cache = KernelCompileCache(self.test_compile, self.cache_write_only) + + def compile(self, cu_file: str, gencode_arch: str) -> bool: + compile_command_base = [ + 'nvcc', + '-DCUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1', + f'-I{self.vllm_root_dir}/csrc', + f'-I{self.vllm_root_dir}/.deps/cutlass-src/include', #noqa + '-isystem', + f'/usr/include/python{self.py_version}', + '-isystem', + f'{self.py_venv_dir}/lib/python3.10/site-packages/torch/include', + '-isystem', + f'{self.py_venv_dir}/lib/python3.10/site-packages/torch/include/torch/csrc/api/include', #noqa + '-isystem', + f'{self.cuda_dir}/include', + '-gencode', + f'arch=compute_{gencode_arch},code=sm_{gencode_arch}', + '-DONNX_NAMESPACE=onnx_c2', + '-Xcudafe', + '-DNDEBUG', + '-std=c++17', + '-Xcompiler=-fPIC', + '--expt-relaxed-constexpr', + '--threads=1', + '-D_GLIBCXX_USE_CXX11_ABI=0'] + self.additional_args + if gencode_arch == 90: + compile_command_base += ['-gencode', 'arch=compute_90a,code=sm_90a'] + + result = subprocess.run(compile_command_base + ['-c', cu_file], + capture_output=True) + + if self.test_compile: + print(f"Compiling {cu_file} : \n" + f" Successful compilation: {result.returncode == 0}\n" + f" stdout : {result.stdout}\n" + f" stderr : {result.stderr}\n") + + if result.returncode == 0: + # Cleanup generated object code on successful compile. + object_file_path = Path("./" + Path(cu_file).stem + '.o') + assert object_file_path.exists(), object_file_path + object_file_path.unlink() + + return result.returncode == 0 diff --git a/csrc/quantization/cutlass_w8a8/generator/kernel_generator.py b/csrc/quantization/cutlass_w8a8/generator/kernel_generator.py new file mode 100644 index 0000000000000..ad03b0075b1eb --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/generator/kernel_generator.py @@ -0,0 +1,251 @@ +""" +Kernel Generator classes / functions. +""" + +import shutil +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from pathlib import Path +from typing import List, Tuple + +import jinja2 +import utils +from autogen_manifest import Cutlass3xArgs +from generator_types import GeneratorType +from kernel_compiler import KernelCompiler + + +@dataclass +class GeneratorOutput: + # Used in torch_bindings generation + file_paths: List[str] = field(default_factory=lambda: []) + fn_names: List[str] = field(default_factory=lambda: []) + fn_decls: List[str] = field(default_factory=lambda: []) + # Used in cache update + failed_file_names: List[str] = field(default_factory=lambda: []) + success_file_names: List[str] = field(default_factory=lambda: []) + + def merge(self, output: "GeneratorOutput"): + self.file_paths.extend(output.file_paths) + self.fn_names.extend(output.fn_names) + self.fn_decls.extend(output.fn_decls) + self.failed_file_names.extend(output.failed_file_names) + self.success_file_names.extend(output.success_file_names) + + +## Abstract generator + + +class KernelGenerator_(ABC): + SCRIPT_DIR = utils.get_script_dir() + GENERATE_DIR = SCRIPT_DIR / "generated" + + @staticmethod + def write_torch_bindings(generator_type: GeneratorType, + fn_names: List[str], fn_decls: List[str], + ops_macro: str, dir_path: str): + s = "#pragma once\n" + s += "#include\n" + s += f"#define {ops_macro} \\\n" + for fn_name in fn_names: + s += generator_type.ops_def(fn_name) + '\\\n' + s += generator_type.ops_impl(fn_name) + '\\\n' + s += "\n" + + for fn_decl in fn_decls: + s += f'{fn_decl}\n' + + # write ops.h + file_path = Path(dir_path) / "ops.h" + with open(str(file_path), 'w+') as f: + f.write(s) + + # write torch_bindings.cpp + s = "" + s += '\n#include "core/registration.h"' + s += '\n#include ' + s += '\n#include "ops.h"' + s += '\nTORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {' + s += f'\n {ops_macro}' + s += '\n}' + s += '\nREGISTER_EXTENSION(TORCH_EXTENSION_NAME)' + s += '\n' + + tb_path = Path(dir_path) / "torch_bindings.cpp" + with open(str(tb_path), 'w+') as f: + f.write(s) + + @staticmethod + def write_ops(generator_type: GeneratorType, + args: Cutlass3xArgs, + file_paths: List[str], + fn_names: List[str], + fn_decls: List[str], + ops_macro: str, + batch_size: int = 100): + """ + batch_size defines the number of files per .so. + If there are a 1000 filenames, then with batch_size 100, we generate + 10 directories, each directory containing 100 kernels. Each directory + is converted into a .so during vllm compile. + """ + + assert len(file_paths) == len(fn_names) + assert len(file_paths) == len(fn_decls) + + dir_name = 0 + for i in range(0, len(file_paths), batch_size): + + dir_path: Path = KernelGenerator_.GENERATE_DIR / f'{args.dtype_str}_{dir_name}' + dir_path.mkdir(exist_ok=True) + + # Move files to dir + for file_path in file_paths[i:i + batch_size]: + if Path(file_path).exists(): + try: + shutil.move(file_path, str(dir_path)) + except shutil.Error: + # File already exists + pass + + KernelGenerator_.write_torch_bindings(generator_type, + fn_names[i:i + batch_size], + fn_decls[i:i + batch_size], + ops_macro, dir_path) + + dir_name += 1 #noqa + + @staticmethod + def last_namespace(s): + return s.split('::')[-1] + + @staticmethod + @abstractmethod + def generate(generator_type: GeneratorType, args: Cutlass3xArgs, + kernel_compiler: KernelCompiler) -> GeneratorOutput: + ... + + +class KernelGenerator(KernelGenerator_): + OPS_MACRO = "CUTLASS_DEFS" + + @staticmethod + def generate_name(description: str, args: Cutlass3xArgs): + + return 'autogen_{}_{}_{}x{}x{}_{}x{}x{}_{}_{}_{}_{}_{}_{}'.format( + description, args.arch, args.tile_shape[0], args.tile_shape[1], + args.tile_shape[2], args.cluster_shape[0], args.cluster_shape[1], + args.cluster_shape[2], + KernelGenerator_.last_namespace(args.kernel_schedule), + KernelGenerator_.last_namespace(args.epilogue_schedule), + KernelGenerator_.last_namespace(args.tile_schedule), + KernelGenerator_.last_namespace(args.gemm_mode), + KernelGenerator_.last_namespace(args.acc_type), args.dtype_str) + + @staticmethod + def generate_filename(description: str, args: Cutlass3xArgs): + + f = '{}/autogen_{}_{}x{}x{}_{}x{}x{}_{}_{}_{}_{}_{}_{}_{}'.format( + KernelGenerator_.GENERATE_DIR, description, args.tile_shape[0], + args.tile_shape[1], args.tile_shape[2], args.cluster_shape[0], + args.cluster_shape[1], args.cluster_shape[2], + KernelGenerator_.last_namespace(args.kernel_schedule), + KernelGenerator_.last_namespace(args.epilogue_schedule), + KernelGenerator_.last_namespace(args.tile_schedule), + KernelGenerator_.last_namespace(args.gemm_mode), + KernelGenerator_.last_namespace(args.acc_type), args.dtype_str, + args.arch) + + f = f + ".cu" + return f + + @staticmethod + def generate_kernel_file(generator_type: GeneratorType, + args: Cutlass3xArgs) -> Tuple[str, str]: + """ + Generate a .cu file that respects args and return, + - The function name of the generated function. + - The c++ function declaration of the generated function. + The return values are used in generating the torch bindings. + """ + + # Make the generate dir + KernelGenerator_.GENERATE_DIR.mkdir(exist_ok=True) + + # Get jinja templates + jenv = jinja2.Environment(loader=jinja2.FileSystemLoader("/")) + fn_defn_template = jenv.get_template( + str(generator_type.fn_defn_jinja_filepath())) + fn_decl_template = jenv.get_template( + str(generator_type.fn_decl_jinja_filepath())) + + # Generate code + fn_name = KernelGenerator.generate_name(generator_type.description(), + args) + fn_decl = fn_decl_template.render(_name=fn_name) + code: str = fn_defn_template.render( + _name=fn_name, + _torch_input_dtype=utils.to_torch_dtype_str(args.dtype_str), + _cutlass_input_dtype=utils.to_cutlass_dtype_str(args.dtype_str), + _tile_shape=utils.get_as_cutlass3x_gemm_shape(args.tile_shape), + _cluster_shape=utils.get_as_cutlass3x_gemm_shape( + args.cluster_shape), + _kernel_schedule=args.kernel_schedule, + _epilogue_schedule=args.epilogue_schedule, + _tile_schedule=args.tile_schedule, + _gemm_mode=args.gemm_mode, + _acc_type=args.acc_type) + + filename = KernelGenerator.generate_filename( + generator_type.description(), args) + if utils.file_contents_same(filename, code): + return (fn_name, fn_decl) + + # write code + with open(filename, "w+") as f: + f.write(code) + + return (fn_name, fn_decl) + + @staticmethod + def generate(generator_type: GeneratorType, args: Cutlass3xArgs, + kernel_compiler: KernelCompiler) -> GeneratorOutput: + generator_output = GeneratorOutput() + + filepath = KernelGenerator.generate_filename( + generator_type.description(), args) + filename = Path(filepath).name + + if kernel_compiler.cache.is_bad_kernel(filename): + # We know that this kernel wouldn't compile. Abort + return generator_output + + fn_name, fn_decl = KernelGenerator.generate_kernel_file( + generator_type, args) + + if not kernel_compiler.cache.is_good_kernel(filename): + # We dont have any information about this kernel in the cache. + # try compiling + compile_success = kernel_compiler.compile(filepath, + gencode_arch=args.arch) + if compile_success: + generator_output.success_file_names.append(filename) + else: + generator_output.failed_file_names.append(filename) + if not kernel_compiler.test_compile: + # Remove generated file + Path(filepath).unlink() + return generator_output + + generator_output.file_paths.append(filepath) + generator_output.fn_names.append(fn_name) + generator_output.fn_decls.append(fn_decl) + + return generator_output + + @staticmethod + def write_ops(generator_type: GeneratorType, args: Cutlass3xArgs, + file_paths: List[str], fn_names: List[str], + fn_decls: List[str]): + return KernelGenerator_.write_ops(generator_type, args, file_paths, fn_names, + fn_decls, KernelGenerator.OPS_MACRO) diff --git a/csrc/quantization/cutlass_w8a8/generator/scaled_mm_c3x.jinja b/csrc/quantization/cutlass_w8a8/generator/scaled_mm_c3x.jinja new file mode 100644 index 0000000000000..5b203ff3bb812 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/generator/scaled_mm_c3x.jinja @@ -0,0 +1,56 @@ +#include +#include +#include "cutlass/cutlass.h" +#include "quantization/cutlass_w8a8/scaled_mm_c3x.cuh" + +void {{ _name }}(torch::Tensor &out, torch::Tensor const &a, + torch::Tensor const &b, + torch::Tensor const &a_scales, + torch::Tensor const &b_scales) { + + using TileShape = {{ _tile_shape }}; + using ClusterShape = {{ _cluster_shape }}; + using KernelSchedule = typename {{ _kernel_schedule }}; + using EpilogueSchedule = typename {{ _epilogue_schedule }}; + using TileSchedule = typename {{ _tile_schedule }}; + using AccType = {{ _acc_type }}; + static constexpr cutlass::gemm::GemmUniversalMode Mode = {{ _gemm_mode }}; + + TORCH_CHECK(a.dtype() == {{ _torch_input_dtype }}); + TORCH_CHECK(b.dtype() == {{ _torch_input_dtype}}); + TORCH_CHECK(a_scales.dtype() == torch::kFloat32); + TORCH_CHECK(b_scales.dtype() == torch::kFloat32); + + if (out.dtype() == torch::kBFloat16) { + using Cutlass3xGemm = + cutlass_3x_gemm<{{ _cutlass_input_dtype }}, + cutlass::bfloat16_t, + ScaledEpilogue, + TileShape, + ClusterShape, + KernelSchedule, + EpilogueSchedule, + AccType, + TileSchedule, + Mode>; + + return cutlass_gemm_caller( + out, a, b, a_scales, b_scales); + } else { + TORCH_CHECK(out.dtype() == torch::kFloat16); + using Cutlass3xGemm = + cutlass_3x_gemm<{{ _cutlass_input_dtype }}, + cutlass::half_t, + ScaledEpilogue, + TileShape, + ClusterShape, + KernelSchedule, + EpilogueSchedule, + AccType, + TileSchedule, + Mode>; + + return cutlass_gemm_caller( + out, a, b, a_scales, b_scales); + } +} diff --git a/csrc/quantization/cutlass_w8a8/generator/scaled_mm_c3x_fnprototype.jinja b/csrc/quantization/cutlass_w8a8/generator/scaled_mm_c3x_fnprototype.jinja new file mode 100644 index 0000000000000..c671bfc155c09 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/generator/scaled_mm_c3x_fnprototype.jinja @@ -0,0 +1,6 @@ + + +void {{ _name }}(torch::Tensor &out, torch::Tensor const &a, + torch::Tensor const &b, + torch::Tensor const &a_scales, + torch::Tensor const &b_scales); diff --git a/csrc/quantization/cutlass_w8a8/generator/scaled_mm_c3x_struct_prototype.jinja b/csrc/quantization/cutlass_w8a8/generator/scaled_mm_c3x_struct_prototype.jinja new file mode 100644 index 0000000000000..c671bfc155c09 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/generator/scaled_mm_c3x_struct_prototype.jinja @@ -0,0 +1,6 @@ + + +void {{ _name }}(torch::Tensor &out, torch::Tensor const &a, + torch::Tensor const &b, + torch::Tensor const &a_scales, + torch::Tensor const &b_scales); diff --git a/csrc/quantization/cutlass_w8a8/generator/tools/heatmap.py b/csrc/quantization/cutlass_w8a8/generator/tools/heatmap.py new file mode 100644 index 0000000000000..2abe98934cb33 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/generator/tools/heatmap.py @@ -0,0 +1,242 @@ +import pickle as pkl +from pathlib import Path +from typing import List, Optional + +import matplotlib.pyplot as plt +import numpy as np +from select_kernels import select_kernels +from utils import Data, make_heatmap_data, measurement_to_data + + +def plot_heatmap(data: np.array, + y_labels: List[str], + x_labels: List[str], + save_filename='heatmap.png'): + # min because of some matplotlib render restrictions. + fig_size_x = min(len(x_labels), 320) + fig_size_y = len(y_labels) + fig, ax = plt.subplots(figsize=(fig_size_x, fig_size_y)) + im = ax.imshow(data, cmap="Reds", vmin=0.0, vmax=1.0, interpolation=None) + + cbar = ax.figure.colorbar(im, ax=ax, cmap="Reds") + cbar.ax.set_ylabel("Hot == Closer to peak perf.", rotation=90, va="top") + + # Show all ticks and label them with the respective list entries + ax.set_xticks(np.arange(len(x_labels)), labels=x_labels) + ax.set_yticks(np.arange(len(y_labels)), labels=y_labels) + + # Rotate the tick labels and set their alignment. + plt.setp(ax.get_xticklabels(), rotation=90) + + # Loop over data dimensions and create text annotations. + for i in range(len(y_labels)): + for j in range(len(x_labels)): + ax.text(j, + i, + data[i, j], + ha="center", + va="center", + color="w", + fontsize=6.0) + + #ax.colorbar() + + ax.set_title("GEMM shape vs Best cutlass op") + #ax.set_aspect('equal') + fig.tight_layout() + + #fig.set_dpi(300) + #plt.show() + print(f"Save location : {save_filename}") + fig.savefig(save_filename, dpi=100) + #fig.savefig(save_filename, dpi=10) + + +def select_top_k_kernels(gemm_ops: np.array, + gemm_problems: List[str], + ops: List[str], + k: int = 100) -> List[str]: + """ + Simple top_k kernel selection. + Gather the top-k best performing kernels for each gemm problem and + return the union. + """ + n_rows = len(gemm_problems) + + max_kernels_per_gemm_shape = 100 # k-value + gemm_efficiency_threshold = 0.90 + + selected_ops = [] + for r in range(n_rows): + gemm_ops_list = np.copy(gemm_ops[r]) + sorted_indices = list(reversed(np.argsort(gemm_ops_list).tolist())) + + selected_shape_ops = [] + for x in sorted_indices: + if 'autogen' not in ops[x]: + # select only autogen kernels/ops + continue + if len(selected_shape_ops) >= max_kernels_per_gemm_shape: + break + # we have reached the min requirement. Decide to break based on + # the gemm_efficiency threshold. + if gemm_ops_list[x] < gemm_efficiency_threshold: + break + else: + selected_shape_ops.append(ops[x]) + + selected_ops.append(selected_shape_ops) + + op_scores = [] + for idx in range(len(selected_shape_ops)): + if 'autogen' not in ops[sorted_indices[idx]]: + continue + op_scores.append(gemm_ops_list[sorted_indices[idx]]) + print(f"Gemm problem {gemm_problems[r]} " + f"- #kernels {len(selected_shape_ops)} " + f"- selected kernel range [ {min(op_scores)} , " + f"{max(op_scores)} ] ") + + # Merge all ops to create a final list + selected_ops = [set(x) for x in selected_ops] + selected_ops_set = set() + for x in selected_ops: + selected_ops_set = selected_ops_set.union(x) + + print(f"#Selected ops set {len(selected_ops_set)}") + for x in selected_ops_set: + print(x) + return list(selected_ops_set) + + +def remove_less_performant_kernels(gemm_ops: np.array, ops: List[str]): + """ + Removes kernel that are relatively less performant from gemm_ops. + """ + n_ops = gemm_ops.shape[1] + assert n_ops == len(ops) + + gemm_ops_predicated = gemm_ops < 0.75 + ops_predicated = np.all(gemm_ops_predicated, axis=0) + + bad_cols = list(range(n_ops)) + bad_cols = list(filter(lambda x: ops_predicated[x], bad_cols)) + bad_cols = sorted(list(set(bad_cols)), reverse=True) + for bc in bad_cols: + ops.pop(bc) + gemm_ops = np.delete(gemm_ops, bc, 1) + + return gemm_ops, ops + + +def plot(gemm_ops: np.array, + gemm_problems: List[str], + ops: List[str], + save_filename: str, + prune_ops: bool = False): + if prune_ops: + gemm_ops, ops = remove_less_performant_kernels(gemm_ops, ops) + print(f"Pruned gemm_ops {gemm_ops.shape}") + + plot_heatmap(gemm_ops, gemm_problems, ops, save_filename) + + +def select_kernels_and_plot(gemm_problems: List[str], ops: List[str], + data: List[str], save_filename: str): + + autogen_ops = list(filter(lambda x: x.startswith('autogen'), ops)) + cutlass_ops = list(filter(lambda x: x.startswith('cutlass'), ops)) + pytorch_ops = list(filter(lambda x: x.startswith('pytorch'), ops)) + assert len(autogen_ops) + len(cutlass_ops) + len(pytorch_ops) == len(ops) + + print("Selecting the autogen kernels ..") + # select the best autogen kernels + gemm_autogenops = make_heatmap_data(gemm_problems, autogen_ops, data) + selected_autogen_ops = select_kernels(gemm_autogenops, + gemm_problems, + autogen_ops, + min_gemm_efficiency=0.95) + + # prepare plot data + selected_ops = selected_autogen_ops + cutlass_ops + pytorch_ops + gemm_ops = make_heatmap_data(gemm_problems, selected_ops, data) + print("Plotting autogen kernels ...") + plot(gemm_ops, gemm_problems, selected_ops, save_filename) + + +def from_measurements(args): + pkl_files: List[str] = args.input_pkl + save_file: Optional[str] = args.save_file + data: List[Data] = [] + + for pkl_file in pkl_files: + with open(pkl_file, 'rb') as f: + pkl_data = pkl.load(f) + data.extend(list(map(lambda x: measurement_to_data(x), pkl_data))) + + ops: List[str] = list(map(lambda x: x.description, data)) + ops = sorted(list(set(ops))) + + gemm_problems: List[str] = list(map(lambda x: (x.m, x.n, x.k), data)) + gemm_problems = sorted(list(set(gemm_problems))) + + print(f"#gemm_problems {len(gemm_problems)}") + print(f"#gemm_ops {len(ops)}") + + # plot all data as heat map + if args.plot_all_ops: + gemm_ops: np.array = make_heatmap_data(gemm_problems, ops, data) + out_file: str = pkl_file.replace( + '.pkl', '_heatmap.png') if save_file is None else save_file + plot(gemm_ops, gemm_problems, ops, save_filename=out_file) + + if args.select_kernels: + out_file = None + if save_file: + out_file = Path(save_file).with_suffix("_selected.png") + else: + out_file = pkl_file.replace('.pkl', 'selected_heatmap.png') + select_kernels_and_plot(gemm_problems, ops, data, out_file) + + +def main(args): + from_measurements(args) + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser(description=''' + Plot bench measurements pkl. + Example invocation: + Plot all the ops in model bench pickle file: + python3 heatmap.py \ + --input-pkl ./model_bench-torch.float8_e4m3fn-1730295961.pkl \ + --plot-all-ops + Run select kernel on the input-pkl and plot the selected ops. + python3 heatmap.py \ + --input-pkl ./model_bench-torch.float8_e4m3fn-1730295961.pkl \ + --select-kernels + ''') + + parser.add_argument("--input-pkl", + "-i", + nargs="+", + required=True, + type=str, + help=("This is typically the pickle file output by " + "mm_benchmarks.py 's model_bench command")) + parser.add_argument("--save-file", "-o", required=False, type=str) + parser.add_argument("--select-kernels", + action='store_true', + help="Run kernel selection and plot the heatmap " + "for the selected kernels") + parser.add_argument("--plot-all-ops", + action='store_true', + help="plot heatmap for all ops") + args = parser.parse_args() + + if not args.plot_all_ops and not args.select_kernels: + print("Argument error : Please provide at least one argument among" + "[--plot-all-ops, --select-kernels]") + + main(args) diff --git a/csrc/quantization/cutlass_w8a8/generator/tools/select_kernels.py b/csrc/quantization/cutlass_w8a8/generator/tools/select_kernels.py new file mode 100644 index 0000000000000..55fb50342dbf6 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/generator/tools/select_kernels.py @@ -0,0 +1,324 @@ +import pickle as pkl +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple + +import numpy as np +from utils import Data, make_heatmap_data, measurement_to_data, to_cutlass_dtype_str + + +@dataclass +class Interval: + s: int # start of interval + e: int # end of interval + eff: float # efficiency of the kernel in that range. + + def x_in_interval(self, x: int) -> bool: + return self.s <= x and x <= self.e + + def is_overlap(self, s, e): + return s <= self.e and self.s <= e + + +@dataclass +class KernelIntervals: + name: str + intervals: List[Interval] + + def spanning_interval(self, pi: int) -> Optional[Interval]: + for i in self.intervals: + if i.x_in_interval(pi): + return i + return None + + +class SelectKernelMeta: + + def __init__(self, gemm_ops: np.array, gemm_problems: List[str], + ops: List[str], min_gemm_efficiency: float): + self.gemm_ops = np.copy(gemm_ops) + self.gemm_problems = gemm_problems + self.ops = ops + self.min_gemm_efficiency = min_gemm_efficiency + + self.n_problems = len(self.gemm_problems) + self.n_kernels = len(self.ops) + + # Convert to kernel ranges + self.problem_indices = {x: idx for idx, x in enumerate(gemm_problems)} + self.kernel_indices = {x: idx for idx, x in enumerate(ops)} + + self.kernel_intervals: List[KernelIntervals] = [] + for ki in range(self.n_kernels): + self.kernel_intervals.append(self.make_kernel_intervals(ki)) + + def avg_efficiency(self, p_s: int, p_e: int, ki: int) -> float: + """ + Average efficiency of the ki kernel for the gemm shapes in + range [p_s, p_e] + """ + vals = self.gemm_ops[:, ki].tolist()[p_s:p_e + 1] + return sum(vals) / len(vals) + + # TODO (varun) : Revisit kernel scores to use only the intervals we actually + # use for specific kernels. + def kernel_set_score(self, p_s: int, p_e: int, kernel_indices: set[int]): + """ + Compute a score for a set of kernels for the gemm shape indices in + range [p_s, p_e] + """ + if len(kernel_indices) == 0: + return 0.0 + ki_scores = [] + for ki in kernel_indices: + interval_scores = [] + for i in self.kernel_intervals[ki].intervals: + if i.is_overlap(p_s, p_e): + interval_scores.append(i.eff) + assert len(interval_scores) > 0 + ki_scores.append(sum(interval_scores) / len(interval_scores)) + assert len(ki_scores) > 0 + return sum(ki_scores) / len(ki_scores) + + def make_kernel_intervals(self, ki: int) -> KernelIntervals: + s = None + e = None + kernel_intervals: KernelIntervals = KernelIntervals(self.ops[ki], []) + for pi in range(self.n_problems): + if self.gemm_ops[pi][ki] < self.min_gemm_efficiency: + # record range + if e: + assert s is not None + kernel_intervals.intervals.append( + Interval(s, e, eff=self.avg_efficiency(s, e, ki))) + s, e = None, None + else: + s = pi if s is None else s + e = pi + if e: + assert s is not None + kernel_intervals.intervals.append( + Interval(s, e, eff=self.avg_efficiency(s, e, ki))) + # sort intervals in the kernel + kernel_intervals.intervals = sorted(kernel_intervals.intervals, + key=lambda x: x.s) + return kernel_intervals + + +def map_gemm_to_kernel(kernel_indices: List[int], + meta: SelectKernelMeta) -> Dict[int, int]: + """ + For every gemm problem in meta.gemm_problems, select a kernel from + kernel_indices and return as a dict. + """ + gemm_to_kernel_map = {} + + for pi in range(meta.n_problems): + kernels_for_pi = [] + for ki in kernel_indices: + if meta.kernel_intervals[ki].spanning_interval(pi): + kernels_for_pi.append(ki) + assert len(kernels_for_pi) != 0 + + # select the kernel with max efficiency + eff_ki = [(meta.gemm_ops[pi][ki], ki) for ki in kernels_for_pi] + max_eff_ki = max(eff_ki, key=lambda x: x[0])[1] + gemm_to_kernel_map[pi] = max_eff_ki + + return gemm_to_kernel_map + + +def select_kernels_dp( + p_s: int, + p_e: int, # Problem start index and problem end index + meta: SelectKernelMeta, + solution_cache: Dict[Tuple[int, int], set]) -> set[int]: + """ + Compute the best set of kernels for the gemm problem shapes, + meta.gemm_problems[p_s:p_e]. + """ + if p_s > p_e: + return set([]) + assert p_s <= p_e + assert p_s >= 0 and p_e >= 0 + assert p_s < meta.n_problems and p_e < meta.n_problems + + if solution_cache.get((p_s, p_e), None) is not None: + return solution_cache.get((p_s, p_e)) + + spanning_kernels: List[Tuple[int, Interval]] = [] + for ki in range(meta.n_kernels): + span_i = meta.kernel_intervals[ki].spanning_interval(p_s) + assert span_i is None or (span_i.s <= p_s and span_i.e >= p_s) + if span_i is not None: + spanning_kernels.append((ki, span_i)) + + assert len(spanning_kernels) != 0, \ + (f"Cannot find a spanning kernel in range ({p_s}, {p_e})" + f"- gemm {meta.gemm_problems[p_s]} to {meta.gemm_problems[p_e]}" + f". Try reducing the min_gemm_efficiency") + ki_solutions: List[set[int]] = [] + for ki, span in spanning_kernels: + ki_solutions.append( + set([ki]).union( + select_kernels_dp(span.e + 1, p_e, meta, solution_cache))) + + # find the solution with minimum number of kernels. + sol = min(ki_solutions, key=lambda x: len(x)) + solution_cache[(p_s, p_e)] = sol + return sol + + +def generate_struct_from_kernel_name(kernel_name: str, kernels_dict: dict) -> str: + # Sample kernel name: + # autogen_scaled_mm_90_64x64x256_8x1x1_KernelTmaWarpSpecializedPingpong_\ + # TmaWarpSpecializedCooperative_PersistentScheduler_kGemm_int32_t_int8 + import re + pattern = r'(\d+)x(\d+)x(\d+)_(\d+)x(\d+)x(\d+)_([A-Za-z0-9]+)_([A-Za-z0-9]+)_([A-Za-z0-9]+)_([A-Za-z0-9]+)_([a-z][a-z0-9_]+)_([a-z][a-z0-9_]+)' + match = re.search(pattern, kernel_name) + if not match: + raise ValueError(f"Cannot parse kernel name {kernel_name}") + + tile_dims = match.group(1, 2, 3) + cluster_dims = match.group(4, 5, 6) + kernel_schedule = match.group(7) + epilogue_schedule = match.group(8) + tile_schedule = match.group(9) + mode = match.group(10) + acc_type = match.group(11) + input_type = match.group(12) + + # Check if the kernel is already in the kernels_dict + if (tile_dims, cluster_dims, kernel_schedule, epilogue_schedule, tile_schedule, mode, acc_type, input_type) in kernels_dict: + kernel_idx = kernels_dict[(tile_dims, cluster_dims, kernel_schedule, epilogue_schedule, tile_schedule, mode, acc_type, input_type)] + return f'sm90_{input_type}_config_{kernel_idx}', "" + + # Create new kernel + kernel_idx = len(kernels_dict) + kernels_dict[(tile_dims, cluster_dims, kernel_schedule, epilogue_schedule, tile_schedule, mode, acc_type, input_type)] = kernel_idx + + # Create struct template + struct_template = f""" +template typename Epilogue> +struct sm90_{input_type}_config_{kernel_idx} {{ + static_assert(std::is_same()); + using TileShape = Shape<_{tile_dims[0]}, _{tile_dims[1]}, _{tile_dims[2]}>; + using ClusterShape = Shape<_{cluster_dims[0]}, _{cluster_dims[1]}, _{cluster_dims[2]}>; + using KernelSchedule = typename cutlass::gemm::{kernel_schedule}; + using EpilogueSchedule = typename cutlass::epilogue::{epilogue_schedule}; + using TileSchedule = typename cutlass::gemm::{tile_schedule}; + using AccType = {acc_type}; + static constexpr cutlass::gemm::GemmUniversalMode Mode = cutlass::gemm::GemmUniversalMode::{mode}; + using Cutlass3xGemm = + cutlass_3x_gemm; +}}; +""" + return f'sm90_{input_type}_config_{kernel_idx}', struct_template + + +def print_conditionals(meta: SelectKernelMeta, gemm_to_kernel_map: Dict[int, int]): + # Sort the rows based on the first dimension of the gemm shape, then second, then third. + meta.gemm_problems = sorted(meta.gemm_problems, key=lambda x: (x[0], x[1], x[2])) + prev_m = -1 + configs_str = "" + conds_str = "" + kernels_dict = {} + for pi in range(meta.n_problems): + config_name, config_body = generate_struct_from_kernel_name(meta.ops[gemm_to_kernel_map[pi]], kernels_dict) + configs_str += config_body + if prev_m != meta.gemm_problems[pi][0]: + if prev_m == -1: + conds_str += f"if (m == {meta.gemm_problems[pi][0]}) {{\n" + elif meta.gemm_problems[pi][0] == meta.gemm_problems[-1][0]: + conds_str += f"}} else {{ // m{meta.gemm_problems[pi][0]} kernels\n" + else: + conds_str += f"}} else if (m <= {meta.gemm_problems[pi][0]}) {{\n" + prev_m = meta.gemm_problems[pi][0] + conds_str += f" if (n == {meta.gemm_problems[pi][1]} && k == {meta.gemm_problems[pi][2]})\n" + conds_str += f" return cutlass_gemm_caller::Cutlass3xGemm >(\n" + conds_str += f" out, a, b, std::forward(args)...);\n" + # conds_str += f" }}\n" + conds_str += f"}}\n" + + print(configs_str) + print(conds_str) + + +def select_kernels(gemm_ops: np.array, gemm_problems: List[str], + ops: List[str], min_gemm_efficiency: float) -> List[str]: + """ + Given a list of gemm problem shapes, gemm_problems, a list of autogen + kernel operations ops, normalized benchmarking information and a + minimum operation efficiency to consider, this function, finds that + smallest set of kernels such that kernels in the satisfies the + min_gemm_efficiency for all the gemm shapes. + """ + solution_cache = {} + meta = SelectKernelMeta(gemm_ops, gemm_problems, ops, min_gemm_efficiency) + kernels = select_kernels_dp(0, meta.n_problems - 1, meta, solution_cache) + + gemm_to_kernel_map = map_gemm_to_kernel(list(kernels), meta) + + print(f"#kernels found {len(kernels)}") + for pi in range(meta.n_problems): + print(f"Problem {meta.gemm_problems[pi]} - " + f"Kernel {meta.ops[gemm_to_kernel_map[pi]]} " + f"eff. ({gemm_ops[pi][gemm_to_kernel_map[pi]]}) ") + + print_conditionals(meta, gemm_to_kernel_map) + + kernel_names = [ops[ki] for ki in kernels] + return kernel_names + + +def from_measurements(pkl_files: List[str], min_gemm_efficiency: float): + data: List[Data] = [] + + for pkl_file in pkl_files: + with open(pkl_file, 'rb') as f: + pkl_data = pkl.load(f) + data.extend(list(map(lambda x: measurement_to_data(x), pkl_data))) + + ops = list(map(lambda x: x.description, data)) + ops = sorted(list(set(ops))) + # have only autogen kernels + ops = list(filter(lambda x: 'autogen' in x, ops)) + + gemm_problems = list(map(lambda x: (x.m, x.n, x.k), data)) + gemm_problems = sorted(list(set(gemm_problems))) + + print(f"#gemm_problems {len(gemm_problems)}") + print(f"#gemm_ops {len(ops)}") + + gemm_ops: np.array = make_heatmap_data(gemm_problems, ops, data) + select_kernels(gemm_ops, gemm_problems, ops, min_gemm_efficiency) + + +def main(pkl_files: List[str], min_gemm_efficiency: float): + from_measurements(pkl_files, min_gemm_efficiency) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description=("Select minimal set of kernels in some model_bench " + "pkl file such that the set of kernels satisfy" + "the min-gemm-efficiency for all the gemm shapes in" + "the model_bench")) + parser.add_argument("--input-pkl", + "-i", + nargs="+", + required=True, + type=str) + parser.add_argument( + "--min-gemm-efficiency", + type=float, + default=0.95, + help="Gemms that are less than this for a particular gemm shape is" + "disregarded") + args = parser.parse_args() + + main(args.input_pkl, args.min_gemm_efficiency) diff --git a/csrc/quantization/cutlass_w8a8/generator/tools/test_kernel.py b/csrc/quantization/cutlass_w8a8/generator/tools/test_kernel.py new file mode 100644 index 0000000000000..538df4659b08a --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/generator/tools/test_kernel.py @@ -0,0 +1,119 @@ +import dataclasses +from functools import cache +from itertools import product +from typing import Callable, List, Type + +import torch +import tqdm +from test_utils import autogen_scaled_mm_fp8_gemm_test + + +@dataclasses.dataclass +class TestArgs: + m: int + n: int + k: int + out_dtype: Type[torch.dtype] = dataclasses.field(default=torch.bfloat16) + device: str = dataclasses.field(default='cuda') + + @staticmethod + @cache + def default_test_args() -> List["TestArgs"]: + Ms = [1, 16, 32, 64, 128, 256, 512, 222, 100, 33] + Ns = [2048, 4096, 8192, 16384, 24576, 256, 1024] + Ks = [128, 496, 1024] + out_dtypes = [torch.bfloat16] + + args = [] + for m, n, k, out_dtype in product(Ms, Ns, Ks, out_dtypes): + args.append(TestArgs(m, n, k, out_dtype)) + return args + + +@cache +def get_autogen_functions(): + import importlib + from importlib.util import find_spec + + # import vllm nm_cutlass modules so torch._C can find it + m_idx = 0 + m_name = f'vllm._nm_cutlass_{m_idx}_C' + while find_spec(m_name): + print(f"attempting import {m_name}") + importlib.import_module(m_name) + m_idx += 1 + m_name = f'vllm._nm_cutlass_{m_idx}_C' + + dispatch_names = torch._C._dispatch_get_all_op_names() + autogen_dispatch_names = [x for x in dispatch_names if 'autogen' in x] + assert all([x.startswith('_nm_cutlass') for x in autogen_dispatch_names]) + autogen_dispatch_modules_names = [(getattr(torch.ops, + x.split('::')[0]), + x.split('::')[1]) + for x in autogen_dispatch_names] + name_fn = [(name, getattr(m, name)) + for m, name in autogen_dispatch_modules_names] + print(f"#autogen functions found {len(name_fn)}") + return name_fn + + +@cache +def test_kernel_function(name: str, + fn: Callable, + verbose: bool = False) -> bool: + test_args: List[TestArgs] = TestArgs.default_test_args() + for x in test_args: + success = autogen_scaled_mm_fp8_gemm_test( + fn, + m=x.m, + n=x.n, + k=x.k, + per_token_act_quant=False, + per_out_channel_weight_quant=False, + out_dtype=x.out_dtype, + device=x.device) + if not success: + # Early exit + if verbose: + print(f"Test Fail : {name} failed for MNK : {x.m} {x.n} {x.k}") + return False + return True + + +@cache +def test_kernel(kernel_name: str) -> bool: + name_fn = get_autogen_functions() + name_fn = list(filter(lambda x: x[0] == kernel_name, name_fn)) + assert len(name_fn) == 1 + fn = name_fn[0][1] + return test_kernel_function(kernel_name, fn) + + +def main(args): + name_fn = get_autogen_functions() + print(f"#{len(name_fn)} autogen functions found.") + if args.pattern: + name_fn = list(filter(lambda x: args.pattern in x[0], name_fn)) + print(f"${len(name_fn)} autogen functions match the pattern.") + + good_functions = [] + # Test each kernel one after another for correctness + for name, fn in tqdm.tqdm(name_fn): + test_kernel_function(name, fn, verbose=True) + good_functions.append((name, fn)) + + print(f"#{len(good_functions)} good functions found.") + print(f"good functions \n{good_functions}") + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser(description=''' + Test autogen cutlass kernels + ''') + parser.add_argument( + '--pattern', + default=None, + help='Checks for this pattern in the autogen kernel name') + args = parser.parse_args() + main(args) diff --git a/csrc/quantization/cutlass_w8a8/generator/tools/test_utils.py b/csrc/quantization/cutlass_w8a8/generator/tools/test_utils.py new file mode 100644 index 0000000000000..32327ee8547ad --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/generator/tools/test_utils.py @@ -0,0 +1,114 @@ +from typing import Callable, Optional, Type, Tuple + +import torch + +import vllm._custom_ops as ops + + +def to_fp8(tensor: torch.Tensor): + finfo = torch.finfo(torch.float8_e4m3fn) + return torch.round(tensor.clamp( + min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) + + +def to_int8(tensor: torch.Tensor): + return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) + + +def rand_int8(shape: tuple, device: str = "cuda"): + return to_int8(torch.rand(shape, device=device) * 255 - 128) + + +def prune_to_2_4(tensor): + # Reshape tensor to [N, 4] where N is number of groups of 4 + original_shape = tensor.shape + reshaped = tensor.reshape(-1, 4) + + # Get indices of top 2 absolute values in each group of 4 + _, indices = torch.topk(torch.abs(reshaped), k=2, dim=1) + + # Create binary mask + mask = torch.zeros_like(reshaped) + mask.scatter_(dim=1, index=indices, src=torch.ones_like(indices, dtype=mask.dtype)) + + # Apply mask and reshape back + pruned = reshaped * mask + + # Turn all -0.0 to 0.0 + pruned[pruned == -0.0] = 0.0 + + return pruned.reshape(original_shape) + + +def make_rand_tensors(dtype: torch.dtype, m: int, n: int, + k: int) -> Tuple[torch.Tensor, torch.Tensor]: + a = torch.randn((m, k), device='cuda') * 5 + b = torch.randn((n, k), device='cuda').t() * 5 + + # # Initialize a to all ones + # a = torch.ones((m, k), device='cuda') + # # Initialize b to all ones + # b = torch.ones((n, k), device='cuda') + + b = prune_to_2_4(b.t()).t() + + if dtype == torch.int8: + a, b = to_int8(a), to_int8(b) + elif dtype == torch.float8_e4m3fn: + a, b = to_fp8(a), to_fp8(b) + else: + raise ValueError("unsupported dtype") + + b_compressed, e = ops.cutlass_compress_entry(b.t()) + + # Compressed B, Metadata, Original A, B + return b_compressed, e, a, b + + +def baseline_scaled_mm(a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: Type[torch.dtype], + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + output = (scale_a * (scale_b * (torch.mm( + a.to(dtype=torch.float32), b.to(dtype=torch.float32))))).to(out_dtype) + if bias is not None: + output = output + bias + + return output + + +def autogen_scaled_mm_fp8_gemm_test( + fn: Callable, + m: int, + n: int, + k: int, + per_token_act_quant: bool, + per_out_channel_weight_quant: bool, + out_dtype: Type[torch.dtype] = torch.bfloat16, + device: str = "cuda"): + # Test for a cutlass kernel with per-token activation quantization + # and per-output channel weight quantization. + a = torch.randn((m, k), device=device) + b = torch.randn((n, k), device=device).t() + + b = prune_to_2_4(b.t()).t() + + a, b = to_fp8(a), to_fp8(b) + + b_compressed, e = ops.cutlass_compress_entry(b.t()) + + m_a_scales = m if per_token_act_quant else 1 + n_b_scales = n if per_out_channel_weight_quant else 1 + + scale_a = (torch.randn((m_a_scales, 1), device=device, + dtype=torch.float32)) + scale_b = (torch.randn((1, n_b_scales), device=device, + dtype=torch.float32)) + + out = torch.empty((m, n), dtype=out_dtype, device=a.device) + fn(out, b_compressed, e, a.t(), scale_a, scale_b) + # TODO (varun) : cache baseline scaled_mm results so we dont recompute. + baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype) + return torch.allclose(out, baseline, rtol=1e-2, atol=5e-2) diff --git a/csrc/quantization/cutlass_w8a8/generator/tools/utils.py b/csrc/quantization/cutlass_w8a8/generator/tools/utils.py new file mode 100644 index 0000000000000..1ca432cdfae8f --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/generator/tools/utils.py @@ -0,0 +1,75 @@ +from dataclasses import dataclass +from typing import List + +import numpy as np +from torch.utils.benchmark import Measurement as TMeasurement + + +@dataclass +class Data: + m: int + k: int + n: int + description: str + time: float + tflops: float + + +def parse_mkn(mkn_str: str): + # mkn_str : MKN=(16x1024x512) + mkn_tuple = mkn_str.split("=")[1] + # mkn_tuple : (16x1024x512) + mkn_prod = mkn_tuple[1:-1] + # mkn_prod : 16x1024x512 + mkn_tuple = tuple(mkn_prod.split("x")) + return (int(mkn_tuple[0]), int(mkn_tuple[1]), int(mkn_tuple[2])) + + +def measurement_to_data(measurement: TMeasurement) -> Data: + m, k, n = parse_mkn(measurement.sub_label) + t_ops = 2 * m * k * n / 1024 / 1024 / 1024 / 1024 + tflops = t_ops / measurement.median + return Data(m, k, n, measurement.task_spec.description, measurement.median, + tflops) + + +def make_heatmap_data(gemm_problems: List[str], ops: List[str], + data: List[Data]) -> np.array: + """ + gemm_problems : List of gemm problem shapes + ops : List of operations (kernels) + data : List of Data that contains benchmark information for all + op-gemmshape pairs. + Normalize all the benchmark information w.r.t. to its gemm-shape + and return the normalized benchmark information as a numpy array. + """ + gemm_ops: List[List[float]] = [[0.0] * len(ops) + for _ in range(len(gemm_problems))] + for op_idx, op in enumerate(ops): + op_data = list(filter(lambda x: x.description == op, data)) + for gemm_idx, gemm in enumerate(gemm_problems): + m, n, k = gemm + selected = list( + filter(lambda x: x.m == m and x.n == n and x.k == k, op_data)) + if len(selected) >= 1: + gemm_ops[gemm_idx][op_idx] = float(selected[0].tflops) + + for gemm_idx in range(len(gemm_problems)): + max_tflops = max(gemm_ops[gemm_idx]) + for op_idx in range(len(ops)): + gemm_ops[gemm_idx][op_idx] = round( + gemm_ops[gemm_idx][op_idx] / max_tflops, 2) + + return np.array(gemm_ops) + + +def to_cutlass_dtype_str(dtype_str): + if dtype_str == "int8": + return "int8_t" + if dtype_str == "fp8": + return "cutlass::float_e4m3_t" + if dtype_str == "fp16": + return "cutlass::half_t" + if dtype_str == "bf16": + return "cutlass::bfloat16_t" + raise ValueError("unknown type") diff --git a/csrc/quantization/cutlass_w8a8/generator/utils.py b/csrc/quantization/cutlass_w8a8/generator/utils.py new file mode 100644 index 0000000000000..7011b9f26663a --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/generator/utils.py @@ -0,0 +1,56 @@ +""" +Utils used in generating cutlass kernels. +""" + +import os +from pathlib import Path +from typing import Tuple + +## Utilities #### + + +def to_torch_dtype_str(dtype_str): + if dtype_str == "int8": + return "torch::kInt8" + if dtype_str == "fp8": + return "torch::kFloat8_e4m3fn" + if dtype_str == "fp16": + return "torch::kFloat16" + if dtype_str == "bf16": + return "torch::kBFloat16" + raise ValueError("unknown type") + + +def to_cutlass_dtype_str(dtype_str): + if dtype_str == "int8": + return "int8_t" + if dtype_str == "fp8": + return "cutlass::float_e4m3_t" + if dtype_str == "fp16": + return "cutlass::half_t" + if dtype_str == "bf16": + return "cutlass::bfloat16_t" + raise ValueError("unknown type") + + +def get_script_dir() -> Path: + return Path(os.path.dirname(os.path.realpath(__file__))) + + +def get_as_cutlass_gemm_shape(shape: Tuple[int, int, int]): + return f'cutlass::gemm::GemmShape<{shape[0]}, {shape[1]}, {shape[2]}>' + + +def get_as_cutlass3x_gemm_shape(shape: Tuple[int, int, int]): + return f'Shape<_{shape[0]}, _{shape[1]}, _{shape[2]}>' + + +def file_contents_same(filepath, contents): + if not Path(filepath).exists(): + return + + f_contents = None + with open(filepath, "r") as f: + f_contents = f.read() + + return f_contents == contents diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu index dbb72e8bbd3f5..ee801e16573d4 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu @@ -8,10 +8,6 @@ #include "scaled_mm_c2x_sm89_fp8_dispatch.cuh" #include "scaled_mm_c2x_sm89_int8_dispatch.cuh" -#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp" - -using namespace vllm; - /* This file defines quantized GEMM operations using the CUTLASS 2.x API, for NVIDIA GPUs with SM versions prior to sm90 (Hopper). @@ -26,11 +22,12 @@ void cutlass_scaled_mm_sm75_epilogue(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b.dtype() == torch::kInt8); if (out.dtype() == torch::kBFloat16) { - return cutlass_gemm_sm75_dispatch( + return vllm::cutlass_gemm_sm75_dispatch( out, a, b, std::forward(epilogue_args)...); } else { TORCH_CHECK(out.dtype() == torch::kFloat16); - return cutlass_gemm_sm75_dispatch( + return vllm::cutlass_gemm_sm75_dispatch( out, a, b, std::forward(epilogue_args)...); } } @@ -45,10 +42,10 @@ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a, if (bias) { TORCH_CHECK(bias->dtype() == out.dtype(), "currently bias dtype must match output dtype ", out.dtype()); - return cutlass_scaled_mm_sm75_epilogue( + return cutlass_scaled_mm_sm75_epilogue( out, a, b, a_scales, b_scales, *bias); } else { - return cutlass_scaled_mm_sm75_epilogue( + return cutlass_scaled_mm_sm75_epilogue( out, a, b, a_scales, b_scales); } } @@ -64,10 +61,10 @@ void cutlass_scaled_mm_azp_sm75(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b_scales.dtype() == torch::kFloat32); if (azp) { - return cutlass_scaled_mm_sm75_epilogue( + return cutlass_scaled_mm_sm75_epilogue( out, a, b, a_scales, b_scales, azp_adj, *azp, bias); } else { - return cutlass_scaled_mm_sm75_epilogue( + return cutlass_scaled_mm_sm75_epilogue( out, a, b, a_scales, b_scales, azp_adj, bias); } } @@ -81,11 +78,12 @@ void cutlass_scaled_mm_sm80_epilogue(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b.dtype() == torch::kInt8); if (out.dtype() == torch::kBFloat16) { - return cutlass_gemm_sm80_dispatch( + return vllm::cutlass_gemm_sm80_dispatch( out, a, b, std::forward(epilogue_args)...); } else { TORCH_CHECK(out.dtype() == torch::kFloat16); - return cutlass_gemm_sm80_dispatch( + return vllm::cutlass_gemm_sm80_dispatch( out, a, b, std::forward(epilogue_args)...); } } @@ -100,10 +98,10 @@ void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a, if (bias) { TORCH_CHECK(bias->dtype() == out.dtype(), "currently bias dtype must match output dtype ", out.dtype()); - return cutlass_scaled_mm_sm80_epilogue( + return cutlass_scaled_mm_sm80_epilogue( out, a, b, a_scales, b_scales, *bias); } else { - return cutlass_scaled_mm_sm80_epilogue( + return cutlass_scaled_mm_sm80_epilogue( out, a, b, a_scales, b_scales); } } @@ -119,10 +117,10 @@ void cutlass_scaled_mm_azp_sm80(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b_scales.dtype() == torch::kFloat32); if (azp) { - return cutlass_scaled_mm_sm80_epilogue( + return cutlass_scaled_mm_sm80_epilogue( out, a, b, a_scales, b_scales, azp_adj, *azp, bias); } else { - return cutlass_scaled_mm_sm80_epilogue( + return cutlass_scaled_mm_sm80_epilogue( out, a, b, a_scales, b_scales, azp_adj, bias); } } @@ -136,12 +134,13 @@ void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b.dtype() == torch::kInt8); if (out.dtype() == torch::kBFloat16) { - return cutlass_gemm_sm89_int8_dispatch( + return vllm::cutlass_gemm_sm89_int8_dispatch( out, a, b, std::forward(epilogue_args)...); } else { assert(out.dtype() == torch::kFloat16); - return cutlass_gemm_sm89_int8_dispatch( + return vllm::cutlass_gemm_sm89_int8_dispatch( out, a, b, std::forward(epilogue_args)...); } } else { @@ -149,13 +148,13 @@ void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); if (out.dtype() == torch::kBFloat16) { - return cutlass_gemm_sm89_fp8_dispatch( + return vllm::cutlass_gemm_sm89_fp8_dispatch< + cutlass::float_e4m3_t, cutlass::bfloat16_t, Epilogue>( out, a, b, std::forward(epilogue_args)...); } else { TORCH_CHECK(out.dtype() == torch::kFloat16); - return cutlass_gemm_sm89_fp8_dispatch( + return vllm::cutlass_gemm_sm89_fp8_dispatch( out, a, b, std::forward(epilogue_args)...); } } @@ -171,10 +170,10 @@ void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a, if (bias) { TORCH_CHECK(bias->dtype() == out.dtype(), "currently bias dtype must match output dtype ", out.dtype()); - return cutlass_scaled_mm_sm89_epilogue( + return cutlass_scaled_mm_sm89_epilogue( out, a, b, a_scales, b_scales, *bias); } else { - return cutlass_scaled_mm_sm89_epilogue( + return cutlass_scaled_mm_sm89_epilogue( out, a, b, a_scales, b_scales); } } @@ -190,10 +189,10 @@ void cutlass_scaled_mm_azp_sm89(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b_scales.dtype() == torch::kFloat32); if (azp) { - return cutlass_scaled_mm_sm89_epilogue( + return cutlass_scaled_mm_sm89_epilogue( out, a, b, a_scales, b_scales, azp_adj, *azp, bias); } else { - return cutlass_scaled_mm_sm89_epilogue( + return cutlass_scaled_mm_sm89_epilogue( out, a, b, a_scales, b_scales, azp_adj, bias); } } diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh index d03242f44ab1d..6329ff63623e2 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh @@ -21,6 +21,7 @@ #include "cutlass/epilogue/threadblock/fusion/visitors.hpp" #include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h" +#include "broadcast_load_epilogue_c2x.hpp" #include "common.hpp" // clang-format on @@ -70,6 +71,307 @@ struct enable_sm89_to_sm90 : Kernel { #endif } }; + +/* + * This class provides the common load descriptors for the + * ScaledEpilogue[...] classes + */ +template +struct ScaledEpilogueBase { + protected: + using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; + + template + using ColOrScalarLoad = + cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast< + OutputTileThreadMap, T, Stride, Int<0>, Int<0>>>; + + template + using RowOrScalarLoad = + cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast< + OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>; + + template + using ColLoad = cutlass::epilogue::threadblock::VisitorColBroadcast< + OutputTileThreadMap, T, Stride, Int<0>, Int<0>>>; + + template + using RowLoad = cutlass::epilogue::threadblock::VisitorRowBroadcast< + OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>; + + template + using RowOrZeroLoad = + cutlass::epilogue::threadblock::VisitorRowOrZeroBroadcast< + OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>; + + // This utility function constructs the arguments for the load descriptors + // from a tensor. It can handle both row and column, as well as row/column or + // scalar cases. + template + static auto args_from_tensor(torch::Tensor const& tensor) { + using Arguments = typename Descriptor::Arguments; + auto* data_ptr = static_cast(tensor.data_ptr()); + if constexpr (std::is_same_v> || + std::is_same_v>) { + return Arguments{data_ptr, tensor.numel() != 1}; + } else { + // it would technically work but no use case as data_ptr is never nullptr + static_assert(!std::is_same_v>); + return Arguments{data_ptr}; + } + } + + // This overload handles the case where there might not be a tensor, in which + // case a nullptr is passed and a constant (0) is used. + template + static auto args_from_tensor(c10::optional const& tensor) { + static_assert(std::is_same_v>); + using Arguments = typename Descriptor::Arguments; + auto* data_ptr = tensor ? static_cast(tensor->data_ptr()) : nullptr; + return Arguments{data_ptr}; + } +}; + +/* + This epilogue function defines a quantized GEMM operation similar to + torch._scaled_mm. + + A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or + per-row. B can be quantized per-tensor or per-column. + Any combination of per-tensor and per-row or column is supported. + A and B must have symmetric quantization (zero point == 0). + + So the GEMM operation is D = (a_scales * A) (b_scales * B), where the + scales are applied elementwise with numpy-style broadcasting. + + ScaleA and ScaleB define the epilogue functions that apply the scales for + the A and B operands respectively. These scales may be either per-tensor or + per row or column. +*/ +template +struct ScaledEpilogue + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + + using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::threadblock::Sm80EVT; + + using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::threadblock::Sm80EVT; + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + + typename EVTCompute0::Arguments evt0_args{b_args}; + return ArgumentType{a_args, evt0_args}; + } +}; + +/* + * This epilogue performs the same operation as ScaledEpilogue, but adds a bias. + * This bias can also be used in the per-tensor azp case, where the activation + * zero point (azp) is used to compute an azp correction term, + * which is folded into the bias. + * + * The bias tensor must be per-output channel. + * ScaleA and ScaleB can be per-tensor or per-token/per-channel. + */ +template +struct ScaledEpilogueBias + : protected ScaledEpilogueBase { + protected: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowLoad; + using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::threadblock::Sm80EVT; + + using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT; + using ArgumentType = typename EVTCompute::Arguments; + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + + typename EVTCompute0::Arguments evt0_args{b_args}; + return ArgumentType{a_args, evt0_args, bias_args}; + } +}; + +/* + * This epilogue directly supports per-tensor azp in int32 form. + * As opposed to the per-token epilogue below, this epilogue only has an azp_adj + * term, which should already be multiplied with the scalar azp. + * The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B. + * + * This epilogue also supports bias, which remains per-channel. + */ +template +struct ScaledEpilogueBiasAzp + : protected ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowOrZeroLoad; + + // This is the full AZP term, azp * J @ B, shape (1,n) + using AzpWithAdj = typename SUPER::template RowLoad; + + // Compute float(accum - azp_adj), both operands are int32_t + using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::minus, float, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAzp = + cutlass::epilogue::threadblock::Sm80EVT; + + using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeScaleB = + cutlass::epilogue::threadblock::Sm80EVT; + + using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::threadblock::Sm80EVT; + + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& azp_adj, + c10::optional const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + auto azp_adj_args = + SUPER::template args_from_tensor(azp_adj); + + typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args}; + typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args}; + return ArgumentType{a_args, evt_scale_b_args, bias_args}; + } +}; + +/* + * This epilogue supports per-token azp by computing and applying + * the correction term using a rank-1 update. If the term were materialized, + * it would require O(m*n) space, and this way it only requires O(m+n) space. + * The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero + * point for each row of A. + * The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B. + * + * This epilogue also supports bias, which remains per-channel. + */ +template +struct ScaledEpilogueBiasAzpToken + : protected ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowOrZeroLoad; + + // Per-token azp term, shape (m,1) + using Azp = typename SUPER::template ColLoad; + + // This is the AZP adjustment term, J @ B, shape (1,n) + using AzpAdj = typename SUPER::template RowLoad; + + // Compute azp * azp_adj + using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, int32_t, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAzp = + cutlass::epilogue::threadblock::Sm80EVT; + + // Compute float(accum - azp*azp_adj), all operands are int32_t + using ComputeAcc = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::minus, float, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAcc = + cutlass::epilogue::threadblock::Sm80EVT; + + using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeScaleB = + cutlass::epilogue::threadblock::Sm80EVT; + + using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::threadblock::Sm80EVT; + + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& azp_adj, + torch::Tensor const& azp, + c10::optional const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + auto azp_args = SUPER::template args_from_tensor(azp); + auto azp_adj_args = + SUPER::template args_from_tensor(azp_adj); + + typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args}; + typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args}; + typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args}; + return ArgumentType{a_args, evt_scale_b_args, bias_args}; + } +}; + template typename ArchGuard, typename ElementAB_, typename ElementD_, template typename Epilogue_, typename TileShape, diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu index 33581a63d4c3d..e111ab6074626 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu @@ -1,291 +1,9 @@ -// clang-format will break include orders -// clang-format off -#include - -#if defined CUDA_VERSION && CUDA_VERSION >= 12000 - +#include #include - -#include - -#include -#include -#include - #include "cutlass/cutlass.h" +#include "scaled_mm_c3x.cuh" -#include "cute/tensor.hpp" -#include "cute/atom/mma_atom.hpp" -#include "cutlass/numeric_types.h" - -#include "cutlass/gemm/device/gemm_universal_adapter.h" -#include "cutlass/gemm/kernel/gemm_universal.hpp" -#include "cutlass/epilogue/collective/collective_builder.hpp" -#include "cutlass/gemm/collective/collective_builder.hpp" - -#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" -#include "common.hpp" -// clang-format on - -using namespace cute; -using namespace vllm; - -/* - This file defines quantized GEMM operations using the CUTLASS 3.x API, for - NVIDIA GPUs with sm90a (Hopper) or later. - - Epilogue functions can be defined to post-process the output before it is - written to GPU memory. - Epilogues must contain a public type named EVTCompute of type Sm90EVT, - as well as a static prepare_args function that constructs an - EVTCompute::Arguments struct. -*/ - -namespace { - -// A wrapper for the GEMM kernel that is used to guard against compilation on -// architectures that will never use the kernel. The purpose of this is to -// reduce the size of the compiled binary. -// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef -// into code that will be executed on the device where it is defined. -template -struct enable_sm90_or_later : Kernel { - template - CUTLASS_DEVICE void operator()(Args&&... args) { - #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 - Kernel::operator()(std::forward(args)...); - #endif - } -}; -template typename Epilogue_, - typename TileShape, typename ClusterShape, typename KernelSchedule, - typename EpilogueSchedule> -struct cutlass_3x_gemm { - using ElementAB = ElementAB_; - using ElementD = ElementD_; - using ElementAcc = - typename std::conditional, int32_t, - float>::type; - - using EpilogueDescriptor = - cutlass::epilogue::collective::detail::EpilogueDescriptor< - TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD, - ElementD, EpilogueSchedule>; - - using Epilogue = Epilogue_; - - using StrideD = Stride, Int<0>>; - using ElementC = void; - using StrideC = StrideD; - - using EVTCompute = typename Epilogue::EVTCompute; - - using CollectiveEpilogue = - typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, - ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, - ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4, - EpilogueSchedule, EVTCompute>::CollectiveOp; - - static constexpr size_t CEStorageSize = - sizeof(typename CollectiveEpilogue::SharedStorage); - using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout< - static_cast(CEStorageSize)>; - - // clang-format off - using CollectiveMainloop = - typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, - ElementAB, cutlass::layout::RowMajor, 16, - ElementAB, cutlass::layout::ColumnMajor, 16, - ElementAcc, TileShape, ClusterShape, - Stages, - KernelSchedule>::CollectiveOp; - // clang-format on - - using KernelType = enable_sm90_or_later, CollectiveMainloop, CollectiveEpilogue, - cutlass::gemm::PersistentScheduler>>; - - struct GemmKernel : public KernelType {}; -}; - -template -void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, - EpilogueArgs&&... epilogue_params) { - using ElementAB = typename Gemm::ElementAB; - using ElementD = typename Gemm::ElementD; - - int32_t m = a.size(0); - int32_t n = b.size(1); - int32_t k = a.size(1); - - int64_t lda = a.stride(0); - int64_t ldb = b.stride(1); - int64_t ldc = out.stride(0); - - using StrideA = Stride, int64_t>; - using StrideB = Stride, int64_t>; - using StrideC = typename Gemm::StrideC; - - StrideA a_stride{lda, Int<1>{}, 0}; - StrideB b_stride{ldb, Int<1>{}, 0}; - StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; - - using GemmKernel = typename Gemm::GemmKernel; - typename GemmKernel::ProblemShape prob_shape{m, n, k, 1}; - - auto a_ptr = static_cast(a.data_ptr()); - auto b_ptr = static_cast(b.data_ptr()); - typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr, - b_stride}; - - auto c_ptr = static_cast(out.data_ptr()); - typename GemmKernel::EpilogueArguments epilogue_args{ - Gemm::Epilogue::prepare_args( - std::forward(epilogue_params)...), - c_ptr, c_stride, c_ptr, c_stride}; - - typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm, - prob_shape, mainloop_args, epilogue_args}; - - // Launch the CUTLASS GEMM kernel. - using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; - GemmOp gemm_op; - CUTLASS_CHECK(gemm_op.can_implement(args)); - - size_t workspace_size = gemm_op.get_workspace_size(args); - auto const workspace_options = - torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); - auto workspace = torch::empty(workspace_size, workspace_options); - - auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); - - cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); - CUTLASS_CHECK(status); -} - -template typename Epilogue> -struct sm90_fp8_config_default { - // M in (128, inf) - static_assert(std::is_same()); - using KernelSchedule = - cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_128, _128, _128>; - using ClusterShape = Shape<_2, _1, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -template typename Epilogue> -struct sm90_fp8_config_M128 { - // M in (64, 128] - static_assert(std::is_same()); - using KernelSchedule = - cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_64, _128, _128>; - using ClusterShape = Shape<_2, _1, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -template typename Epilogue> -struct sm90_fp8_config_M64 { - // M in [1, 64] - static_assert(std::is_same()); - using KernelSchedule = - cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_64, _64, _128>; - using ClusterShape = Shape<_1, _8, _1>; - - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -template typename Epilogue> -struct sm90_int8_config_default { - // For M > 128 and any N - static_assert(std::is_same()); - using KernelSchedule = - typename cutlass::gemm::KernelTmaWarpSpecializedPingpong; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_128, _128, _128>; - using ClusterShape = Shape<_2, _1, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -template typename Epilogue> -struct sm90_int8_config_M128 { - // For M in (64, 128] and any N - static_assert(std::is_same()); - using KernelSchedule = - typename cutlass::gemm::KernelTmaWarpSpecializedPingpong; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_64, _128, _128>; - using ClusterShape = Shape<_2, _1, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -template typename Epilogue> -struct sm90_int8_config_M64 { - // For M in (32, 64] and any N - static_assert(std::is_same()); - using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_64, _64, _256>; - using ClusterShape = Shape<_1, _1, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -template typename Epilogue> -struct sm90_int8_config_M32_NBig { - // For M in [1, 32] and N >= 8192 - static_assert(std::is_same()); - using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_64, _128, _256>; - using ClusterShape = Shape<_1, _4, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -template typename Epilogue> -struct sm90_int8_config_M32_NSmall { - // For M in [1, 32] and N < 8192 - static_assert(std::is_same()); - using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_64, _64, _256>; - using ClusterShape = Shape<_1, _8, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -} // namespace +#include "scaled_mm_c3x_configs.cuh" template typename Epilogue, @@ -297,6 +15,552 @@ void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); + + uint32_t const m = out.size(0); + uint32_t const n = out.size(1); + uint32_t const k = b.size(0); + + if (m == 1) { + if (n == 2560 && k == 8192) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 4096) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 14336) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 5120 && k == 8192) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 6144 && k == 4096) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 2048) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 4096) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 7168) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 14336) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 14336 && k == 8192) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 4096) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 8192) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); +} else if (m <= 16) { + if (n == 2560 && k == 8192) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 4096) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 14336) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 5120 && k == 8192) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 6144 && k == 4096) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 2048) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 4096) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 7168) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 14336) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 14336 && k == 8192) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 4096) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 8192) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); +} else if (m <= 32) { + if (n == 2560 && k == 8192) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 4096) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 14336) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 5120 && k == 8192) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 6144 && k == 4096) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 2048) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 4096) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 7168) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 14336) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 14336 && k == 8192) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 4096) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 8192) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); +} else if (m <= 64) { + if (n == 2560 && k == 8192) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 4096) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 14336) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 5120 && k == 8192) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 6144 && k == 4096) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 2048) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 4096) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 7168) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 14336) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 14336 && k == 8192) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 4096) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 8192) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); +} else if (m <= 128) { + if (n == 2560 && k == 8192) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 4096) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 14336) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 5120 && k == 8192) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 6144 && k == 4096) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 2048) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 4096) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 7168) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 14336) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 14336 && k == 8192) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 4096) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 8192) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); +} else if (m <= 256) { + if (n == 2560 && k == 8192) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 4096) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 14336) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 5120 && k == 8192) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 6144 && k == 4096) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 2048) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 4096) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 7168) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 14336) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 14336 && k == 8192) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 4096) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 8192) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); +} else if (m <= 512) { + if (n == 2560 && k == 8192) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 4096) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 14336) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 5120 && k == 8192) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 6144 && k == 4096) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 2048) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 4096) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 7168) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 14336) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 14336 && k == 8192) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 4096) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 8192) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); +} else if (m <= 1024) { + if (n == 2560 && k == 8192) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 4096) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 14336) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 5120 && k == 8192) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 6144 && k == 4096) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 2048) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 4096) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 7168) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 14336) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 14336 && k == 8192) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 4096) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 8192) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); +} else if (m <= 2048) { + if (n == 2560 && k == 8192) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 4096) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 14336) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 5120 && k == 8192) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 6144 && k == 4096) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 2048) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 4096) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 7168) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 14336) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 14336 && k == 8192) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 4096) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 8192) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); +} else if (m <= 4096) { + if (n == 2560 && k == 8192) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 4096) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 14336) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 5120 && k == 8192) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 6144 && k == 4096) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 2048) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 4096) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 7168) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 14336) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 14336 && k == 8192) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 4096) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 8192) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); +} else { // m8192 kernels + if (n == 2560 && k == 8192) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 4096) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 4096 && k == 14336) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 5120 && k == 8192) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 6144 && k == 4096) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 2048) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 4096) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 7168) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 8192 && k == 14336) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 14336 && k == 8192) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 4096) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); + if (n == 28672 && k == 8192) + return cutlass_gemm_caller::Cutlass3xGemm >( + out, a, b, std::forward(args)...); +} + using Cutlass3xGemmDefault = typename sm90_fp8_config_default::Cutlass3xGemm; @@ -305,7 +569,6 @@ void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a, using Cutlass3xGemmM128 = typename sm90_fp8_config_M128::Cutlass3xGemm; - uint32_t const m = a.size(0); uint32_t const mp2 = std::max(static_cast(64), next_pow_2(m)); // next power of 2 @@ -423,11 +686,11 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, if (bias) { TORCH_CHECK(bias->dtype() == c.dtype(), "currently bias dtype must match output dtype ", c.dtype()); - return cutlass_scaled_mm_sm90_epilogue( + return cutlass_scaled_mm_sm90_epilogue( c, a, b, a_scales, b_scales, *bias); } else { - return cutlass_scaled_mm_sm90_epilogue( - c, a, b, a_scales, b_scales); + return cutlass_scaled_mm_sm90_epilogue(c, a, b, a_scales, + b_scales); } } @@ -442,12 +705,83 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b_scales.dtype() == torch::kFloat32); if (azp) { - return cutlass_scaled_mm_sm90_epilogue( + return cutlass_scaled_mm_sm90_epilogue( out, a, b, a_scales, b_scales, azp_adj, *azp, bias); } else { - return cutlass_scaled_mm_sm90_epilogue( + return cutlass_scaled_mm_sm90_epilogue( out, a, b, a_scales, b_scales, azp_adj, bias); } } -#endif +// hyper-parameter sweep kernels + +void cutlass_scaled_mm_sm90_dispatch(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + c10::optional const& bias) { + assert(!bias); + + TORCH_CHECK(a_scales.dtype() == torch::kFloat32); + TORCH_CHECK(b_scales.dtype() == torch::kFloat32); + + using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_64, _128, _256>; + using ClusterShape = Shape<_1, _4, _1>; + using AccType = float; + + if (out.dtype() == torch::kBFloat16) { + using Cutlass3xGemm = + cutlass_3x_gemm; + + return cutlass_gemm_caller(out, a, b, a_scales, b_scales); + + } else { + TORCH_CHECK(out.dtype() == torch::kFloat16); + + using Cutlass3xGemm = + cutlass_3x_gemm; + + return cutlass_gemm_caller(out, a, b, a_scales, b_scales); + } +} + +void cutlass_simple_gemm_sm90_dispatch(torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b) { + using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_64, _128, _256>; + using ClusterShape = Shape<_1, _4, _1>; + using AccType = float; + + if (out.dtype() == torch::kBFloat16) { + using Cutlass3xGemm = + cutlass_3x_simple_gemm; + + return cutlass_simple_gemm_caller(out, a, b); + + } else { + TORCH_CHECK(out.dtype() == torch::kFloat16); + + using Cutlass3xGemm = + cutlass_3x_simple_gemm; + + return cutlass_simple_gemm_caller(out, a, b); + } +} diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cuh b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cuh new file mode 100644 index 0000000000000..b44b7cbf65080 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cuh @@ -0,0 +1,779 @@ +#pragma once + +// clang-format will break include orders +// clang-format off +#include + +#if defined CUDA_VERSION && CUDA_VERSION >= 12000 + +#include + +#include + +#include +#include +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "broadcast_load_epilogue_c3x.hpp" +#include "common.hpp" +// clang-format on + +using namespace cute; + +/* + This file defines quantized GEMM operations using the CUTLASS 3.x API, for + NVIDIA GPUs with sm90a (Hopper) or later. + + Epilogue functions can be defined to post-process the output before it is + written to GPU memory. + Epilogues must contain a public type named EVTCompute of type Sm90EVT, + as well as a static prepare_args function that constructs an + EVTCompute::Arguments struct. +*/ + +// A wrapper for the GEMM kernel that is used to guard against compilation on +// architectures that will never use the kernel. The purpose of this is to +// reduce the size of the compiled binary. +// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef +// into code that will be executed on the device where it is defined. +template +struct enable_sm90_or_later : Kernel { + template + CUTLASS_DEVICE void operator()(Args&&... args) { + #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 + Kernel::operator()(std::forward(args)...); + #endif + } +}; + +/* + * This class provides the common load descriptors for the + * ScaledEpilogue[...] classes + */ +template +struct ScaledEpilogueBase { + protected: + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + + template + using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast< + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, + Stride, Int<0>, Int<0>>>; + + template + using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast< + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, + Stride, Int<1>, Int<0>>>; + + // Don't want to support nullptr by default + template + using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast< + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T, + Stride, Int<0>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; + + // Don't want to support nullptr by default + template + using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast< + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T, + Stride, Int<1>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; + + // This utility function constructs the arguments for the load descriptors + // from a tensor. It can handle both row and column, as well as row/column or + // scalar cases. + template + static auto args_from_tensor(torch::Tensor const& tensor) { + using Arguments = typename Descriptor::Arguments; + auto* data_ptr = static_cast(tensor.data_ptr()); + if constexpr (std::is_same_v> || + std::is_same_v>) { + return Arguments{data_ptr, tensor.numel() != 1}; + } else { + static_assert(!std::is_same_v> && + !std::is_same_v>); + return Arguments{data_ptr}; + } + } + + // This overload handles the case where there might not be a tensor, in which + // case a nullptr is passed and a constant (0) is used. + template + static auto args_from_tensor(c10::optional const& tensor) { + using Arguments = typename Descriptor::Arguments; + auto* data_ptr = tensor ? static_cast(tensor->data_ptr()) : nullptr; + static_assert(std::is_same_v> || + std::is_same_v>); + return Arguments{data_ptr}; + } +}; + +/* + This epilogue function defines a quantized GEMM operation similar to + torch.scaled_mm_. + + A and B may be both either int8 or fp8_e4m3. A can be + quantized per-tensor or per-row. B can be quantized per-tensor or per-column. + Any combination of per-tensor and per-row or column is supported. + A and B must have symmetric quantization (zero point == 0). + + So the GEMM operation is D = (a_scales * A) (b_scales * B), where the + scales are applied elementwise with numpy-style broadcasting. + + ScaleA and ScaleB define the epilogue functions that apply the scales for + the A and B operands respectively. These scales may be either per-tensor or + per row or column. +*/ +template +struct ScaledEpilogue + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + + typename EVTCompute0::Arguments evt0_args{b_args}; + return ArgumentType{a_args, evt0_args}; + } +}; + +/* + * This epilogue performs the same operation as ScaledEpilogue, but adds a bias. + * This bias can also be used in the per-tensor azp case, where the activation + * zero point (azp) is used to compute an azp correction term, + * which is folded into the bias. + * + * The bias tensor must be per-output channel. + * ScaleA and ScaleB can be per-tensor or per-token/per-channel. + */ +template +struct ScaledEpilogueBias + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowLoad; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + + using ArgumentType = typename EVTCompute::Arguments; + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + + typename EVTCompute0::Arguments evt0_args{b_args}; + return ArgumentType{a_args, evt0_args, bias_args}; + } +}; + +/* + * This epilogue directly supports per-tensor azp in int32 form. + * As opposed to the per-token epilogue below, this epilogue only has an azp_adj + * term, which should already be multiplied with the scalar azp. + * The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B. + * + * This epilogue also supports bias, which remains per-channel. + */ +template +struct ScaledEpilogueBiasAzp + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowLoad; + + // This is the full AZP term, azp * J @ B, shape (1,n) + using AzpWithAdj = typename SUPER::template RowLoad; + + // Compute float(accum - azp_adj), both operands are int32_t + using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute< + cutlass::minus, float, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAzp = + cutlass::epilogue::fusion::Sm90EVT; + + using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeScaleB = + cutlass::epilogue::fusion::Sm90EVT; + + using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& azp_adj, + c10::optional const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + auto azp_adj_args = + SUPER::template args_from_tensor(azp_adj); + + typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args}; + typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args}; + return ArgumentType{a_args, evt_scale_b_args, bias_args}; + } +}; + +/* + * This epilogue supports per-token azp by computing and applying + * the correction term using a rank-1 update. If the term were materialized, + * it would require O(m*n) space, and this way it only requires O(m+n) space. + * The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero + * point for each row of A. + * The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B. + * + * This epilogue also supports bias, which remains per-channel. + */ +template +struct ScaledEpilogueBiasAzpToken + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowLoad; + + // Per-token azp term, shape (m,1) + using Azp = typename SUPER::template ColLoad; + + // This is the AZP adjustment term, J @ B, shape (1,n) + using AzpAdj = typename SUPER::template RowLoad; + + // Compute azp * azp_adj + using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, int32_t, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAzp = + cutlass::epilogue::fusion::Sm90EVT; + + // Compute float(accum - azp*azp_adj), all operands are int32_t + using ComputeAcc = cutlass::epilogue::fusion::Sm90Compute< + cutlass::minus, float, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAcc = + cutlass::epilogue::fusion::Sm90EVT; + + using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeScaleB = + cutlass::epilogue::fusion::Sm90EVT; + + using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& azp_adj, + torch::Tensor const& azp, + c10::optional const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + auto azp_args = SUPER::template args_from_tensor(azp); + auto azp_adj_args = + SUPER::template args_from_tensor(azp_adj); + + typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args}; + typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args}; + typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args}; + return ArgumentType{a_args, evt_scale_b_args, bias_args}; + } +}; + +using GemmUniversalMode = cutlass::gemm::GemmUniversalMode; + +template typename Epilogue_, + typename TileShape, typename ClusterShape, typename KernelSchedule, + typename EpilogueSchedule, typename AccType, + typename TileSchedule = cutlass::gemm::PersistentScheduler, + GemmUniversalMode Mode_ = GemmUniversalMode::kGemm> +struct cutlass_3x_gemm { + static const GemmUniversalMode Mode = Mode_; + using ElementAB = ElementAB_; + using ElementD = ElementD_; + + using ElementAcc = AccType; + + using EpilogueDescriptor = + cutlass::epilogue::collective::detail::EpilogueDescriptor< + TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD, + ElementD, EpilogueSchedule>; + + using Epilogue = Epilogue_; + + using StrideD = Stride, Int<0>>; + using ElementC = void; + using StrideC = StrideD; + + using EVTCompute = typename Epilogue::EVTCompute; + + static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentCD = 128 / cutlass::sizeof_bits::value; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, + ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, float, ElementC, StrideC, AlignmentCD, ElementD, + StrideD, AlignmentCD, EpilogueSchedule, EVTCompute>::CollectiveOp; + + static constexpr size_t CEStorageSize = + sizeof(typename CollectiveEpilogue::SharedStorage); + using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(CEStorageSize)>; + + // clang-format off + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementAB, cutlass::layout::RowMajor, AlignmentA, + ElementAB, cutlass::layout::ColumnMajor, AlignmentB, + ElementAcc, TileShape, ClusterShape, + Stages, + KernelSchedule>::CollectiveOp; + // clang-format on + + using KernelType = enable_sm90_or_later, CollectiveMainloop, CollectiveEpilogue, + TileSchedule>>; + + struct GemmKernel : public KernelType {}; +}; + +template +inline void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + EpilogueArgs&&... epilogue_params) { + using ElementAB = typename Gemm::ElementAB; + using ElementD = typename Gemm::ElementD; + + int32_t m = a.size(0); + int32_t n = b.size(1); + int32_t k = b.size(0); + + int64_t lda = a.stride(0); + int64_t ldb = b.stride(1); + int64_t ldc = out.stride(0); + + using StrideA = Stride, int64_t>; + using StrideB = Stride, int64_t>; + using StrideC = typename Gemm::StrideC; + + StrideA a_stride{lda, Int<1>{}, 0}; + StrideB b_stride{ldb, Int<1>{}, 0}; + StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; + + using GemmKernel = typename Gemm::GemmKernel; + typename GemmKernel::ProblemShape prob_shape{m, n, k, 1}; + + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr, + b_stride}; + + auto c_ptr = static_cast(out.data_ptr()); + typename GemmKernel::EpilogueArguments epilogue_args{ + Gemm::Epilogue::prepare_args( + std::forward(epilogue_params)...), + c_ptr, c_stride, c_ptr, c_stride}; + + typename GemmKernel::Arguments args{Gemm::Mode, prob_shape, mainloop_args, + epilogue_args}; + + // Launch the CUTLASS GEMM kernel. + using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; + GemmOp gemm_op; + CUTLASS_CHECK(gemm_op.can_implement(args)); + + size_t workspace_size = gemm_op.get_workspace_size(args); + auto const workspace_options = + torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); + auto workspace = torch::empty(workspace_size, workspace_options); + + auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); + + cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); + CUTLASS_CHECK(status); +} + +using ReductionMode = cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::ReductionMode; +using DecompositionMode = cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; +using RasterOrderOptions = cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90Params::RasterOrderOptions; + +template +inline void cutlass_gemm_caller_streamk(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + ReductionMode reduction_mode, + DecompositionMode decomposition_mode, + EpilogueArgs&&... epilogue_params) { + + static_assert(std::is_same::value, "Must be streamk scheduler"); + + using ElementAB = typename Gemm::ElementAB; + using ElementD = typename Gemm::ElementD; + + int32_t m = a.size(0); + int32_t n = b.size(1); + int32_t k = a.size(1); + + int64_t lda = a.stride(0); + int64_t ldb = b.stride(1); + int64_t ldc = out.stride(0); + + using StrideA = Stride, int64_t>; + using StrideB = Stride, int64_t>; + using StrideC = typename Gemm::StrideC; + + StrideA a_stride{lda, Int<1>{}, 0}; + StrideB b_stride{ldb, Int<1>{}, 0}; + StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; + + using GemmKernel = typename Gemm::GemmKernel; + typename GemmKernel::ProblemShape prob_shape{m, n, k, 1}; + + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr, + b_stride}; + + auto c_ptr = static_cast(out.data_ptr()); + typename GemmKernel::EpilogueArguments epilogue_args{ + Gemm::Epilogue::prepare_args( + std::forward(epilogue_params)...), + c_ptr, c_stride, c_ptr, c_stride}; + + typename GemmKernel::TileSchedulerArguments tile_scheduler_args( + 1, + 1, + RasterOrderOptions::Heuristic, + decomposition_mode + ); + tile_scheduler_args.reduction_mode = reduction_mode; + + // Copied from examples... + // The KernelHardwareInfo struct holds the number of SMs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + typename GemmKernel::Arguments args{Gemm::Mode, prob_shape, mainloop_args, + epilogue_args, hw_info, tile_scheduler_args}; + + // Launch the CUTLASS GEMM kernel. + using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; + GemmOp gemm_op; + CUTLASS_CHECK(gemm_op.can_implement(args)); + + size_t workspace_size = gemm_op.get_workspace_size(args); + auto const workspace_options = + torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); + auto workspace = torch::empty(workspace_size, workspace_options); + + auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); + + cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); + CUTLASS_CHECK(status); +} + +template typename Epilogue> +struct sm90_fp8_config_default { + // M in (128, inf) + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_2, _1, _1>; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template typename Epilogue> +struct sm90_fp8_config_M128 { + // M in (64, 128] + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_64, _128, _128>; + using ClusterShape = Shape<_2, _1, _1>; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template typename Epilogue> +struct sm90_fp8_config_M64 { + // M in [1, 64] + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_64, _64, _128>; + using ClusterShape = Shape<_1, _8, _1>; + + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template typename Epilogue> +struct sm90_int8_config_default { + // For M > 128 and any N + static_assert(std::is_same()); + using KernelSchedule = + typename cutlass::gemm::KernelTmaWarpSpecializedPingpong; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_2, _1, _1>; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template typename Epilogue> +struct sm90_int8_config_M128 { + // For M in (64, 128] and any N + static_assert(std::is_same()); + using KernelSchedule = + typename cutlass::gemm::KernelTmaWarpSpecializedPingpong; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_64, _128, _128>; + using ClusterShape = Shape<_2, _1, _1>; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template typename Epilogue> +struct sm90_int8_config_M64 { + // For M in (32, 64] and any N + static_assert(std::is_same()); + using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_64, _64, _256>; + using ClusterShape = Shape<_1, _1, _1>; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template typename Epilogue> +struct sm90_int8_config_M32_NBig { + // For M in [1, 32] and N >= 8192 + static_assert(std::is_same()); + using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_64, _128, _256>; + using ClusterShape = Shape<_1, _4, _1>; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template typename Epilogue> +struct sm90_int8_config_M32_NSmall { + // For M in [1, 32] and N < 8192 + static_assert(std::is_same()); + using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_64, _64, _256>; + using ClusterShape = Shape<_1, _8, _1>; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template +struct cutlass_3x_simple_gemm { + static const GemmUniversalMode Mode = Mode_; + using ElementAB = ElementAB_; + using ElementD = ElementD_; + + using ElementAcc = + typename std::conditional, AccType, + AccType>::type; + + using StrideD = Stride, Int<0>>; + using ElementC = void; + using StrideC = StrideD; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, + ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, + ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp; + + static constexpr size_t CEStorageSize = + sizeof(typename CollectiveEpilogue::SharedStorage); + using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(CEStorageSize)>; + + // clang-format off + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementAB, cutlass::layout::RowMajor, 16, + ElementAB, cutlass::layout::ColumnMajor, 16, + ElementAcc, TileShape, ClusterShape, + Stages, + KernelSchedule>::CollectiveOp; + // clang-format on + + using KernelType = enable_sm90_or_later, CollectiveMainloop, CollectiveEpilogue, + TileSchedule>>; + + struct GemmKernel : public KernelType {}; +}; + +template +inline void cutlass_simple_gemm_caller(torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b) { + using ElementAB = typename Gemm::ElementAB; + using ElementD = typename Gemm::ElementD; + + int32_t m = a.size(0); + int32_t n = b.size(1); + int32_t k = a.size(1); + + int64_t lda = a.stride(0); + int64_t ldb = b.stride(1); + int64_t ldc = out.stride(0); + + using StrideA = Stride, int64_t>; + using StrideB = Stride, int64_t>; + using StrideC = typename Gemm::StrideC; + + StrideA a_stride{lda, Int<1>{}, 0}; + StrideB b_stride{ldb, Int<1>{}, 0}; + StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; + + using GemmKernel = typename Gemm::GemmKernel; + typename GemmKernel::ProblemShape prob_shape{m, n, k, 1}; + + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr, + b_stride}; + + auto c_ptr = static_cast(out.data_ptr()); + + typename GemmKernel::EpilogueArguments epilogue_args{ + {}, c_ptr, c_stride, c_ptr, c_stride}; + + typename GemmKernel::Arguments args{Gemm::Mode, prob_shape, mainloop_args, + epilogue_args}; + + // Launch the CUTLASS GEMM kernel. + using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; + GemmOp gemm_op; + CUTLASS_CHECK(gemm_op.can_implement(args)); + + size_t workspace_size = gemm_op.get_workspace_size(args); + auto const workspace_options = + torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); + auto workspace = torch::empty(workspace_size, workspace_options); + + auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); + + cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); + CUTLASS_CHECK(status); +} + +#endif diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_configs.cuh b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_configs.cuh new file mode 100644 index 0000000000000..c77adbe888cda --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_configs.cuh @@ -0,0 +1,223 @@ +template typename Epilogue> +struct sm90_fp8_config_0 { + static_assert(std::is_same()); + using TileShape = Shape<_64, _64, _256>; + using ClusterShape = Shape<_1, _8, _1>; + using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecializedCooperative; + using TileSchedule = typename cutlass::gemm::PersistentScheduler; + using AccType = float; + static constexpr cutlass::gemm::GemmUniversalMode Mode = cutlass::gemm::GemmUniversalMode::kGemm; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template typename Epilogue> +struct sm90_fp8_config_1 { + static_assert(std::is_same()); + using TileShape = Shape<_64, _64, _128>; + using ClusterShape = Shape<_1, _8, _1>; + using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecializedCooperative; + using TileSchedule = typename cutlass::gemm::PersistentScheduler; + using AccType = float; + static constexpr cutlass::gemm::GemmUniversalMode Mode = cutlass::gemm::GemmUniversalMode::kGemm; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template typename Epilogue> +struct sm90_fp8_config_2 { + static_assert(std::is_same()); + using TileShape = Shape<_64, _128, _128>; + using ClusterShape = Shape<_1, _8, _1>; + using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecializedCooperative; + using TileSchedule = typename cutlass::gemm::PersistentScheduler; + using AccType = float; + static constexpr cutlass::gemm::GemmUniversalMode Mode = cutlass::gemm::GemmUniversalMode::kGemm; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template typename Epilogue> +struct sm90_fp8_config_3 { + static_assert(std::is_same()); + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_1, _8, _1>; + using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileSchedule = typename cutlass::gemm::PersistentScheduler; + using AccType = float; + static constexpr cutlass::gemm::GemmUniversalMode Mode = cutlass::gemm::GemmUniversalMode::kGemm; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template typename Epilogue> +struct sm90_fp8_config_4 { + static_assert(std::is_same()); + using TileShape = Shape<_64, _64, _256>; + using ClusterShape = Shape<_1, _4, _1>; + using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecializedCooperative; + using TileSchedule = typename cutlass::gemm::PersistentScheduler; + using AccType = float; + static constexpr cutlass::gemm::GemmUniversalMode Mode = cutlass::gemm::GemmUniversalMode::kGemm; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template typename Epilogue> +struct sm90_fp8_config_5 { + static_assert(std::is_same()); + using TileShape = Shape<_64, _64, _256>; + using ClusterShape = Shape<_1, _1, _1>; + using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileSchedule = typename cutlass::gemm::PersistentScheduler; + using AccType = float; + static constexpr cutlass::gemm::GemmUniversalMode Mode = cutlass::gemm::GemmUniversalMode::kGemm; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template typename Epilogue> +struct sm90_fp8_config_6 { + static_assert(std::is_same()); + using TileShape = Shape<_64, _128, _128>; + using ClusterShape = Shape<_1, _1, _1>; + using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecializedCooperative; + using TileSchedule = typename cutlass::gemm::PersistentScheduler; + using AccType = float; + static constexpr cutlass::gemm::GemmUniversalMode Mode = cutlass::gemm::GemmUniversalMode::kGemm; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template typename Epilogue> +struct sm90_fp8_config_7 { + static_assert(std::is_same()); + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_1, _1, _1>; + using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileSchedule = typename cutlass::gemm::PersistentScheduler; + using AccType = float; + static constexpr cutlass::gemm::GemmUniversalMode Mode = cutlass::gemm::GemmUniversalMode::kGemm; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template typename Epilogue> +struct sm90_fp8_config_8 { + static_assert(std::is_same()); + using TileShape = Shape<_128, _64, _128>; + using ClusterShape = Shape<_1, _1, _1>; + using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileSchedule = typename cutlass::gemm::PersistentScheduler; + using AccType = float; + static constexpr cutlass::gemm::GemmUniversalMode Mode = cutlass::gemm::GemmUniversalMode::kGemm; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template typename Epilogue> +struct sm90_fp8_config_9 { + static_assert(std::is_same()); + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_1, _1, _1>; + using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileSchedule = typename cutlass::gemm::PersistentScheduler; + using AccType = float; + static constexpr cutlass::gemm::GemmUniversalMode Mode = cutlass::gemm::GemmUniversalMode::kGemm; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template typename Epilogue> +struct sm90_fp8_config_10 { + static_assert(std::is_same()); + using TileShape = Shape<_128, _256, _128>; + using ClusterShape = Shape<_1, _1, _1>; + using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecializedCooperative; + using TileSchedule = typename cutlass::gemm::PersistentScheduler; + using AccType = float; + static constexpr cutlass::gemm::GemmUniversalMode Mode = cutlass::gemm::GemmUniversalMode::kGemm; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template typename Epilogue> +struct sm90_fp8_config_11 { + static_assert(std::is_same()); + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_1, _1, _1>; + using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecializedCooperative; + using TileSchedule = typename cutlass::gemm::StreamKScheduler; + using AccType = float; + static constexpr cutlass::gemm::GemmUniversalMode Mode = cutlass::gemm::GemmUniversalMode::kGemm; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template typename Epilogue> +struct sm90_fp8_config_12 { + static_assert(std::is_same()); + using TileShape = Shape<_256, _128, _128>; + using ClusterShape = Shape<_1, _2, _1>; + using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecializedCooperative; + using TileSchedule = typename cutlass::gemm::PersistentScheduler; + using AccType = float; + static constexpr cutlass::gemm::GemmUniversalMode Mode = cutlass::gemm::GemmUniversalMode::kGemm; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template typename Epilogue> +struct sm90_fp8_config_13 { + static_assert(std::is_same()); + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_1, _4, _1>; + using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileSchedule = typename cutlass::gemm::PersistentScheduler; + using AccType = float; + static constexpr cutlass::gemm::GemmUniversalMode Mode = cutlass::gemm::GemmUniversalMode::kGemm; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; \ No newline at end of file diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index 97a969cf5e3e0..1657f7d0b16e8 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -137,11 +137,9 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a, return; } - if (version_num >= 75) { - // Turing - cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias); - return; - } + // Turing + TORCH_CHECK(version_num >= 75); + cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias); #endif TORCH_CHECK_NOT_IMPLEMENTED(