diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 44f47fac1c1b3..b563c96343f92 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -224,8 +224,12 @@ steps: mirror_hardwares: [amd] source_file_dependencies: - vllm/model_executor/layers + - vllm/model_executor/guided_decoding - tests/test_logits_processor - command: pytest -v -s test_logits_processor.py + - tests/model_executor/test_guided_processors + commands: + - pytest -v -s test_logits_processor.py + - pytest -v -s model_executor/test_guided_processors.py - label: Speculative decoding tests # 30min source_file_dependencies: diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index c1051d10a4860..e40ceaaa8b037 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -39,67 +39,68 @@ jobs: const script = require('.github/workflows/scripts/create_release.js') await script(github, context, core) - wheel: - name: Build Wheel - runs-on: ${{ matrix.os }} - needs: release - - strategy: - fail-fast: false - matrix: - os: ['ubuntu-20.04'] - python-version: ['3.9', '3.10', '3.11', '3.12'] - pytorch-version: ['2.4.0'] # Must be the most recent version that meets requirements-cuda.txt. - cuda-version: ['11.8', '12.1'] - - steps: - - name: Checkout - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - - name: Setup ccache - uses: hendrikmuhs/ccache-action@ed74d11c0b343532753ecead8a951bb09bb34bc9 # v1.2.14 - with: - create-symlink: true - key: ${{ github.job }}-${{ matrix.python-version }}-${{ matrix.cuda-version }} - - - name: Set up Linux Env - if: ${{ runner.os == 'Linux' }} - run: | - bash -x .github/workflows/scripts/env.sh - - - name: Set up Python - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 - with: - python-version: ${{ matrix.python-version }} - - - name: Install CUDA ${{ matrix.cuda-version }} - run: | - bash -x .github/workflows/scripts/cuda-install.sh ${{ matrix.cuda-version }} ${{ matrix.os }} - - - name: Install PyTorch ${{ matrix.pytorch-version }} with CUDA ${{ matrix.cuda-version }} - run: | - bash -x .github/workflows/scripts/pytorch-install.sh ${{ matrix.python-version }} ${{ matrix.pytorch-version }} ${{ matrix.cuda-version }} - - - name: Build wheel - shell: bash - env: - CMAKE_BUILD_TYPE: Release # do not compile with debug symbol to reduce wheel size - run: | - bash -x .github/workflows/scripts/build.sh ${{ matrix.python-version }} ${{ matrix.cuda-version }} - wheel_name=$(find dist -name "*whl" -print0 | xargs -0 -n 1 basename) - asset_name=${wheel_name//"linux"/"manylinux1"} - echo "wheel_name=${wheel_name}" >> "$GITHUB_ENV" - echo "asset_name=${asset_name}" >> "$GITHUB_ENV" - - - name: Upload Release Asset - uses: actions/upload-release-asset@e8f9f06c4b078e705bd2ea027f0926603fc9b4d5 # v1.0.2 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - upload_url: ${{ needs.release.outputs.upload_url }} - asset_path: ./dist/${{ env.wheel_name }} - asset_name: ${{ env.asset_name }} - asset_content_type: application/* + # NOTE(simon): No longer build wheel using Github Actions. See buildkite's release workflow. + # wheel: + # name: Build Wheel + # runs-on: ${{ matrix.os }} + # needs: release + + # strategy: + # fail-fast: false + # matrix: + # os: ['ubuntu-20.04'] + # python-version: ['3.9', '3.10', '3.11', '3.12'] + # pytorch-version: ['2.4.0'] # Must be the most recent version that meets requirements-cuda.txt. + # cuda-version: ['11.8', '12.1'] + + # steps: + # - name: Checkout + # uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + # - name: Setup ccache + # uses: hendrikmuhs/ccache-action@ed74d11c0b343532753ecead8a951bb09bb34bc9 # v1.2.14 + # with: + # create-symlink: true + # key: ${{ github.job }}-${{ matrix.python-version }}-${{ matrix.cuda-version }} + + # - name: Set up Linux Env + # if: ${{ runner.os == 'Linux' }} + # run: | + # bash -x .github/workflows/scripts/env.sh + + # - name: Set up Python + # uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + # with: + # python-version: ${{ matrix.python-version }} + + # - name: Install CUDA ${{ matrix.cuda-version }} + # run: | + # bash -x .github/workflows/scripts/cuda-install.sh ${{ matrix.cuda-version }} ${{ matrix.os }} + + # - name: Install PyTorch ${{ matrix.pytorch-version }} with CUDA ${{ matrix.cuda-version }} + # run: | + # bash -x .github/workflows/scripts/pytorch-install.sh ${{ matrix.python-version }} ${{ matrix.pytorch-version }} ${{ matrix.cuda-version }} + + # - name: Build wheel + # shell: bash + # env: + # CMAKE_BUILD_TYPE: Release # do not compile with debug symbol to reduce wheel size + # run: | + # bash -x .github/workflows/scripts/build.sh ${{ matrix.python-version }} ${{ matrix.cuda-version }} + # wheel_name=$(find dist -name "*whl" -print0 | xargs -0 -n 1 basename) + # asset_name=${wheel_name//"linux"/"manylinux1"} + # echo "wheel_name=${wheel_name}" >> "$GITHUB_ENV" + # echo "asset_name=${asset_name}" >> "$GITHUB_ENV" + + # - name: Upload Release Asset + # uses: actions/upload-release-asset@e8f9f06c4b078e705bd2ea027f0926603fc9b4d5 # v1.0.2 + # env: + # GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + # with: + # upload_url: ${{ needs.release.outputs.upload_url }} + # asset_path: ./dist/${{ env.wheel_name }} + # asset_name: ${{ env.asset_name }} + # asset_content_type: application/* # (Danielkinz): This last step will publish the .whl to pypi. Warning: untested # - name: Publish package diff --git a/CMakeLists.txt b/CMakeLists.txt index bf19b3d227171..83c8033434f3b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -206,7 +206,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") # Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case. - set(CUTLASS_REVISION "v3.5.1" CACHE STRING "CUTLASS revision to use") + set(CUTLASS_REVISION "v3.6.0" CACHE STRING "CUTLASS revision to use") # Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR}) @@ -223,13 +223,13 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") FetchContent_Declare( cutlass GIT_REPOSITORY https://github.com/nvidia/cutlass.git - GIT_TAG v3.5.1 + GIT_TAG 8aa95dbb888be6d81c6fbf7169718c5244b53227 GIT_PROGRESS TRUE # Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history. # Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags. # So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE - GIT_SHALLOW TRUE + GIT_SHALLOW FALSE ) endif() FetchContent_MakeAvailable(cutlass) @@ -241,7 +241,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/awq/gemm_kernels.cu" "csrc/custom_all_reduce.cu" "csrc/permute_cols.cu" - "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu") + "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu" + "csrc/sparse/cutlass/sparse_scaled_mm_entry.cu" + "csrc/sparse/cutlass/sparse_compressor_entry.cu" + "csrc/cutlass_extensions/common.cpp") set_gencode_flags_for_srcs( SRCS "${VLLM_EXT_SRC}" @@ -270,7 +273,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") " in CUDA target architectures") endif() - # # The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require # CUDA 12.0 or later (and only work on Hopper, 9.0/9.0a for now). cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0;9.0a" "${CUDA_ARCHS}") @@ -323,6 +325,31 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() endif() + # + # 2:4 Sparse Kernels + + # The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor + # require CUDA 12.2 or later (and only work on Hopper, 9.0/9.0a for now). + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS) + set(SRCS "csrc/sparse/cutlass/sparse_compressor_c3x.cu" + "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${SCALED_MM_3X_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_SPARSE_SCALED_MM_C3X=1") + message(STATUS "Building sparse_scaled_mm_c3x for archs: ${SCALED_MM_3X_ARCHS}") + else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS) + message(STATUS "Not building sparse_scaled_mm_c3x kernels as CUDA Compiler version is " + "not >= 12.2, we recommend upgrading to CUDA 12.2 or later " + "if you intend on running FP8 sparse quantized models on Hopper.") + else() + message(STATUS "Not building sparse_scaled_mm_c3x as no compatible archs found " + "in CUDA target architectures") + endif() + endif() + # # Machete kernels @@ -404,7 +431,7 @@ define_gpu_extension_target( SOURCES ${VLLM_EXT_SRC} COMPILE_FLAGS ${VLLM_GPU_FLAGS} ARCHITECTURES ${VLLM_GPU_ARCHES} - INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} USE_SABI 3 WITH_SOABI) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 1e5967bd9bf8b..c1b10b3cf8f58 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -4,7 +4,8 @@ import json import random import time -from typing import List, Optional +from functools import cache +from typing import Dict, List, Optional, Tuple import torch import uvloop @@ -17,8 +18,11 @@ from vllm.entrypoints.openai.api_server import ( build_async_engine_client_from_engine_args) from vllm.inputs import TextPrompt +from vllm.lora.request import LoRARequest +from vllm.lora.utils import get_adapter_absolute_path from vllm.multimodal import MultiModalDataDict from vllm.sampling_params import BeamSearchParams +from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer from vllm.utils import FlexibleArgumentParser, merge_async_iterators @@ -28,15 +32,17 @@ class SampleRequest: Attributes: prompt: The input text prompt for the model. - multi_modal_data: Optional dictionary containing multi-modal data (e.g. - images). prompt_len: The length of the prompt in tokens. expected_output_len: The expected length of the output in tokens. + multi_modal_data: Optional dictionary containing multi-modal data (e.g. + images). + lora_request: Optional LoRARequest specifying the LoRA to use. """ prompt: str prompt_len: int expected_output_len: int multi_modal_data: Optional[MultiModalDataDict] = None + lora_request: Optional[LoRARequest] = None def _get_prompt_for_image_model(question: str, *, model: str) -> str: @@ -60,8 +66,30 @@ def _get_prompt_for_image_model(question: str, *, model: str) -> str: raise ValueError(f"Unsupported model {model}") +@cache +def lora_path_on_disk(lora_path: str) -> str: + return get_adapter_absolute_path(lora_path) + + +lora_tokenizer_cache: Dict[int, AnyTokenizer] = {} + + +def get_random_lora_request( + args: argparse.Namespace +) -> Tuple[LoRARequest, Optional[AnyTokenizer]]: + global lora_tokenizer_cache + lora_id = random.randint(1, args.max_loras) + lora_request = LoRARequest(lora_name=str(lora_id), + lora_int_id=lora_id, + lora_path=lora_path_on_disk(args.lora_path)) + if lora_id not in lora_tokenizer_cache: + lora_tokenizer_cache[lora_id] = get_lora_tokenizer(lora_request) + return lora_request, lora_tokenizer_cache[lora_id] + + def sample_requests(tokenizer: PreTrainedTokenizerBase, args: argparse.Namespace) -> List[SampleRequest]: + dataset_path: str = args.dataset num_requests: int = args.num_prompts fixed_output_len: Optional[int] = args.output_len @@ -79,7 +107,9 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase, # Filter out sequences that are too long or too short filtered_dataset: List[SampleRequest] = [] - for data in dataset: + for data in tqdm(dataset, + total=len(filtered_dataset), + desc="sampling requests"): if len(filtered_dataset) == num_requests: break @@ -102,9 +132,16 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase, continue prompt = _get_prompt_for_image_model(question=prompt, model=model) + request_tokenizer = tokenizer + lora_request: Optional[LoRARequest] = None + if args.enable_lora: + lora_request, lora_tokenizer = get_random_lora_request(args) + if lora_tokenizer: + request_tokenizer = lora_tokenizer + # Tokenize the prompts and completions. - prompt_token_ids = tokenizer(prompt).input_ids - completion_token_ids = tokenizer(completion).input_ids + prompt_token_ids = request_tokenizer(prompt).input_ids + completion_token_ids = request_tokenizer(completion).input_ids prompt_len = len(prompt_token_ids) output_len = len(completion_token_ids ) if fixed_output_len is None else fixed_output_len @@ -118,7 +155,8 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase, SampleRequest(prompt=prompt, prompt_len=prompt_len, expected_output_len=output_len, - multi_modal_data=multi_modal_data)) + multi_modal_data=multi_modal_data, + lora_request=lora_request)) return filtered_dataset @@ -146,14 +184,21 @@ def run_vllm( ignore_eos=True, max_tokens=request.expected_output_len, )) + lora_requests: Optional[List[LoRARequest]] = None + if engine_args.enable_lora: + lora_requests = [request.lora_request for request in requests] use_beam_search = False if not use_beam_search: start = time.perf_counter() - llm.generate(prompts, sampling_params, use_tqdm=True) + llm.generate(prompts, + sampling_params, + lora_request=lora_requests, + use_tqdm=True) end = time.perf_counter() else: + assert lora_requests is None, "BeamSearch API does not support LoRA" prompts = [request.prompt for request in requests] # output_len should be the same for all requests. output_len = requests[0][2] @@ -185,6 +230,7 @@ async def run_vllm_async( # Add the requests to the engine. prompts: List[TextPrompt] = [] sampling_params: List[SamplingParams] = [] + lora_requests: List[Optional[LoRARequest]] = [] for request in requests: prompts.append( TextPrompt(prompt=request.prompt, @@ -197,11 +243,16 @@ async def run_vllm_async( ignore_eos=True, max_tokens=request.expected_output_len, )) + lora_requests.append(request.lora_request) generators = [] start = time.perf_counter() - for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)): - generator = llm.generate(prompt, sp, request_id=f"test{i}") + for i, (prompt, sp, + lr) in enumerate(zip(prompts, sampling_params, lora_requests)): + generator = llm.generate(prompt, + sp, + lora_request=lr, + request_id=f"test{i}") generators.append(generator) all_gens = merge_async_iterators(*generators) async for i, res in all_gens: @@ -297,6 +348,14 @@ def main(args: argparse.Namespace): vocab_size = tokenizer.vocab_size requests = [] for _ in range(args.num_prompts): + + request_tokenizer = tokenizer + lora_request: Optional[LoRARequest] = None + if args.enable_lora: + lora_request, lora_tokenizer = get_random_lora_request(args) + if lora_tokenizer: + request_tokenizer = lora_tokenizer + # Synthesize a prompt with the given input length. candidate_ids = [ random.randint(0, vocab_size - 1) @@ -305,8 +364,8 @@ def main(args: argparse.Namespace): # As tokenizer may add additional tokens like BOS, we need to try # different lengths to get the desired input length. for _ in range(5): # Max attempts to correct - candidate_prompt = tokenizer.decode(candidate_ids) - tokenized_len = len(tokenizer.encode(candidate_prompt)) + candidate_prompt = request_tokenizer.decode(candidate_ids) + tokenized_len = len(request_tokenizer.encode(candidate_prompt)) if tokenized_len == args.input_len: break @@ -323,7 +382,8 @@ def main(args: argparse.Namespace): requests.append( SampleRequest(prompt=candidate_prompt, prompt_len=args.input_len, - expected_output_len=args.output_len)) + expected_output_len=args.output_len, + lora_request=lora_request)) else: requests = sample_requests(tokenizer, args) @@ -422,6 +482,14 @@ def main(args: argparse.Namespace): action='store_true', default=False, help="Disable decoupled async engine frontend.") + # LoRA + parser.add_argument( + "--lora-path", + type=str, + default=None, + help="Path to the lora adapters to use. This can be an absolute path, " + "a relative path, or a Hugging Face model identifier.") + parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() if args.tokenizer is None: @@ -431,6 +499,8 @@ def main(args: argparse.Namespace): assert args.output_len is not None else: assert args.input_len is None + if args.enable_lora: + assert args.lora_path is not None if args.backend == "vllm": if args.hf_max_batch_size is not None: @@ -440,6 +510,9 @@ def main(args: argparse.Namespace): raise ValueError("HF max batch size is required for HF backend.") if args.quantization is not None: raise ValueError("Quantization is only for vLLM backend.") + if args.enable_lora is not None: + raise ValueError("LoRA benchmarking is only supported for vLLM" + " backend") elif args.backend == "mii": if args.dtype != "auto": raise ValueError("dtype must be auto for MII backend.") @@ -452,4 +525,7 @@ def main(args: argparse.Namespace): if args.tokenizer != args.model: raise ValueError("Tokenizer must be the same as the model for MII " "backend.") + if args.enable_lora is not None: + raise ValueError("LoRA benchmarking is only supported for vLLM" + " backend") main(args) diff --git a/benchmarks/cutlass_benchmarks/sparse_benchmarks.py b/benchmarks/cutlass_benchmarks/sparse_benchmarks.py new file mode 100644 index 0000000000000..3d1c5e392f9e2 --- /dev/null +++ b/benchmarks/cutlass_benchmarks/sparse_benchmarks.py @@ -0,0 +1,384 @@ +import argparse +import copy +import itertools +import pickle as pkl +import time +from typing import Callable, Iterable, List, Tuple + +import torch +import torch.utils.benchmark as TBenchmark +from torch.utils.benchmark import Measurement as TMeasurement +from utils import make_rand_sparse_tensors +from weight_shapes import WEIGHT_SHAPES + +from vllm import _custom_ops as ops +from vllm.utils import FlexibleArgumentParser + +DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) +DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] +DEFAULT_TP_SIZES = [1] + + +# bench +def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args, + **kwargs) -> TMeasurement: + min_run_time = 1 + + globals = { + "args": args, + "kwargs": kwargs, + "fn": fn, + } + return TBenchmark.Timer( + stmt="fn(*args, **kwargs)", + globals=globals, + label=label, + sub_label=sub_label, + description=description, + ).blocked_autorange(min_run_time=min_run_time) + + +def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str, + sub_label: str) -> Iterable[TMeasurement]: + assert dtype == torch.int8 + b_compressed, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k) + scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) + scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) + bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) + + out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b, + torch.bfloat16) + out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16) + + if not torch.allclose(out, out_ref): + print("Incorrect results") + print(out) + print(out_ref) + else: + print("Correct results") + + timers = [] + # pytorch impl - bfloat16 + timers.append( + bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", + torch.mm, a.to(dtype=torch.bfloat16), + b.to(dtype=torch.bfloat16))) + + # pytorch impl - float16 + timers.append( + bench_fn(label, sub_label, + "pytorch_fp16_fp16_fp16_matmul-no-scales", torch.mm, + a.to(dtype=torch.float16), b.to(dtype=torch.float16))) + + # cutlass impl + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm", + ops.cutlass_scaled_mm, a, b, scale_a, scale_b, + torch.bfloat16)) + + # cutlass with bias + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_bias", + ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16, + bias)) + + # cutlass sparse impl + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm", + ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, + scale_b, torch.bfloat16)) + + # cutlass sparse with bias + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm_bias", + ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, + scale_b, torch.bfloat16, bias)) + + return timers + + +def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str, + sub_label: str) -> Iterable[TMeasurement]: + assert dtype == torch.float8_e4m3fn + b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, + k) + scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) + scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) + bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) + + out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b, + torch.bfloat16) + out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16) + + if not torch.allclose(out, out_ref): + print("Incorrect results") + print(out) + print(out_ref) + else: + print("Correct results") + + timers = [] + + # pytorch impl w. bf16 + timers.append( + bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", + torch.mm, a.to(dtype=torch.bfloat16, device="cuda"), + b.to(dtype=torch.bfloat16, device="cuda"))) + + # pytorch impl: bf16 output, without fp8 fast accum + timers.append( + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_bf16_scaled_mm", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.bfloat16)) + + # pytorch impl: bf16 output, with fp8 fast accum + timers.append( + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.bfloat16, + use_fast_accum=True)) + + # pytorch impl: fp16 output, without fp8 fast accum + timers.append( + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_fp16_scaled_mm", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.float16)) + + # pytorch impl: fp16 output, with fp8 fast accum + timers.append( + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.float16, + use_fast_accum=True)) + + # cutlass impl: bf16 output + timers.append( + bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm", + ops.cutlass_scaled_mm, a, b, scale_a, scale_b, + torch.bfloat16)) + + # cutlass impl: bf16 output + timers.append( + bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_sparse_mm", + ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, + scale_b, torch.bfloat16)) + + # cutlass impl: fp16 output + timers.append( + bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_sparse_mm", + ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, + scale_b, torch.float16)) + + # cutlass impl: bf16 output, with bias + timers.append( + bench_fn(label, sub_label, + "cutlass_fp8_fp8_bf16_scaled_sparse_mm_bias", + ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, + scale_b, torch.bfloat16, bias)) + + # cutlass impl: fp16 output, with bias + timers.append( + bench_fn(label, sub_label, + "cutlass_fp8_fp8_fp16_scaled_sparse_mm_bias", + ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, + scale_b, torch.float16, bias.to(dtype=torch.float16))) + + return timers + + +def bench(dtype: torch.dtype, m: int, k: int, n: int, label: str, + sub_label: str) -> Iterable[TMeasurement]: + if dtype == torch.int8: + return bench_int8(dtype, m, k, n, label, sub_label) + if dtype == torch.float8_e4m3fn: + return bench_fp8(dtype, m, k, n, label, sub_label) + raise ValueError("unsupported type") + + +# runner +def print_timers(timers: Iterable[TMeasurement]): + compare = TBenchmark.Compare(timers) + compare.print() + + +def run(dtype: torch.dtype, + MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]: + results = [] + for m, k, n in MKNs: + timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm", + f"MKN=({m}x{k}x{n})") + print_timers(timers) + results.extend(timers) + + return results + + +# output makers +def make_output(data: Iterable[TMeasurement], + MKNs: Iterable[Tuple[int, int, int]], + base_description: str, + timestamp=None): + print(f"== All Results {base_description} ====") + print_timers(data) + + # pickle all the results + timestamp = int(time.time()) if timestamp is None else timestamp + with open(f"{base_description}-{timestamp}.pkl", "wb") as f: + pkl.dump(data, f) + + +# argparse runners + + +def run_square_bench(args): + dim_sizes = list( + range(args.dim_start, args.dim_end + 1, args.dim_increment)) + MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) + data = run(args.dtype, MKNs) + + make_output(data, MKNs, f"square_bench-{args.dtype}") + + +def run_range_bench(args): + dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment)) + n = len(dim_sizes) + Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes + Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes + Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes + MKNs = list(zip(Ms, Ks, Ns)) + data = run(args.dtype, MKNs) + + make_output(data, MKNs, f"range_bench-{args.dtype}") + + +def run_model_bench(args): + print("Benchmarking models:") + for i, model in enumerate(args.models): + print(f"[{i}] {model}") + + def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]: + KNs = [] + for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]): + KN[tp_split_dim] = KN[tp_split_dim] // tp_size + KNs.append(KN) + return KNs + + model_bench_data = [] + models_tps = list(itertools.product(args.models, args.tp_sizes)) + for model, tp_size in models_tps: + Ms = args.batch_sizes + KNs = model_shapes(model, tp_size) + MKNs = [] + for m in Ms: + for k, n in KNs: + MKNs.append((m, k, n)) + + data = run(args.dtype, MKNs) + model_bench_data.append(data) + + # Print all results + for data, model_tp in zip(model_bench_data, models_tps): + model, tp_size = model_tp + print(f"== Results {args.dtype} {model}-TP{tp_size} ====") + print_timers(data) + + timestamp = int(time.time()) + + all_data = [] + for d in model_bench_data: + all_data.extend(d) + # pickle all data + with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f: + pkl.dump(all_data, f) + + +if __name__ == '__main__': + + def to_torch_dtype(dt): + if dt == "int8": + return torch.int8 + if dt == "fp8": + return torch.float8_e4m3fn + raise ValueError("unsupported dtype") + + parser = FlexibleArgumentParser( + description=""" +Benchmark Cutlass GEMM. + + To run square GEMMs: + python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 square_bench --dim-start 128 --dim-end 512 --dim-increment 64 + + To run constant N and K and sweep M: + python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384 + + To run dimensions from a model: + python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1 + + Output: + - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs. + """, # noqa: E501 + formatter_class=argparse.RawTextHelpFormatter) + + parser.add_argument("--dtype", + type=to_torch_dtype, + required=True, + help="Available options are ['int8', 'fp8']") + subparsers = parser.add_subparsers(dest="cmd") + + square_parser = subparsers.add_parser("square_bench") + square_parser.add_argument("--dim-start", type=int, required=True) + square_parser.add_argument("--dim-end", type=int, required=True) + square_parser.add_argument("--dim-increment", type=int, required=True) + square_parser.set_defaults(func=run_square_bench) + + range_parser = subparsers.add_parser("range_bench") + range_parser.add_argument("--dim-start", type=int, required=True) + range_parser.add_argument("--dim-end", type=int, required=True) + range_parser.add_argument("--dim-increment", type=int, required=True) + range_parser.add_argument("--m-constant", type=int, default=None) + range_parser.add_argument("--n-constant", type=int, default=None) + range_parser.add_argument("--k-constant", type=int, default=None) + range_parser.set_defaults(func=run_range_bench) + + model_parser = subparsers.add_parser("model_bench") + model_parser.add_argument("--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES.keys()) + model_parser.add_argument("--tp-sizes", + nargs="+", + type=int, + default=DEFAULT_TP_SIZES) + model_parser.add_argument("--batch-sizes", + nargs="+", + type=int, + default=DEFAULT_BATCH_SIZES) + model_parser.set_defaults(func=run_model_bench) + + args = parser.parse_args() + args.func(args) diff --git a/benchmarks/cutlass_benchmarks/utils.py b/benchmarks/cutlass_benchmarks/utils.py new file mode 100644 index 0000000000000..ef06fcd6604dd --- /dev/null +++ b/benchmarks/cutlass_benchmarks/utils.py @@ -0,0 +1,96 @@ +# Cutlass bench utils +from typing import Iterable, Tuple + +import torch + +import vllm._custom_ops as ops + + +def to_fp8(tensor: torch.Tensor) -> torch.Tensor: + finfo = torch.finfo(torch.float8_e4m3fn) + return torch.round(tensor.clamp( + min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) + + +def to_int8(tensor: torch.Tensor) -> torch.Tensor: + return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) + + +def to_bf16(tensor: torch.Tensor) -> torch.Tensor: + return tensor.to(dtype=torch.bfloat16) + + +def to_fp16(tensor: torch.Tensor) -> torch.Tensor: + return tensor.to(dtype=torch.float16) + + +def make_rand_tensors(dtype: torch.dtype, m: int, n: int, + k: int) -> Tuple[torch.Tensor, torch.Tensor]: + a = torch.randn((m, k), device='cuda') * 5 + b = torch.randn((n, k), device='cuda').t() * 5 + + if dtype == torch.int8: + return to_int8(a), to_int8(b) + if dtype == torch.float8_e4m3fn: + return to_fp8(a), to_fp8(b) + + raise ValueError("unsupported dtype") + + +def prune_to_2_4(tensor): + # Reshape tensor to [N, 4] where N is number of groups of 4 + original_shape = tensor.shape + reshaped = tensor.reshape(-1, 4) + + # Get indices of top 2 absolute values in each group of 4 + _, indices = torch.topk(torch.abs(reshaped), k=2, dim=1) + + # Create binary mask + mask = torch.zeros_like(reshaped) + mask.scatter_(dim=1, + index=indices, + src=torch.ones_like(indices, dtype=mask.dtype)) + + # Apply mask and reshape back + pruned = reshaped * mask + + # Turn all -0.0 to 0.0 + pruned[pruned == -0.0] = 0.0 + + return pruned.reshape(original_shape) + + +def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int, + k: int) -> Tuple[torch.Tensor, torch.Tensor]: + a = torch.randn((m, k), device='cuda') * 5 + b = torch.randn((n, k), device='cuda').t() * 5 + + b = prune_to_2_4(b.t()).t() + + if dtype == torch.int8: + a, b = to_int8(a), to_int8(b) + elif dtype == torch.float8_e4m3fn: + a, b = to_fp8(a), to_fp8(b) + elif dtype == torch.float16: + a, b = to_fp16(a), to_fp16(b) + elif dtype == torch.bfloat16: + a, b = to_bf16(a), to_bf16(b) + else: + raise ValueError("unsupported dtype") + + b_compressed, e = ops.cutlass_sparse_compress(b.t()) + + # Compressed B, Metadata, Original A, B + return b_compressed, e, a, b + + +def make_n_rand_sparse_tensors(num_tensors: int, dtype: torch.dtype, + m: int, n: int, k: int) -> \ + Tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]: + ABs = [] + for _ in range(num_tensors): + b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k) + if b_comp is not None: + ABs.append(make_rand_sparse_tensors(dtype, m, n, k)) + BComps, Es, As, Bs = zip(*ABs) + return list(BComps), list(Es), list(As), list(Bs) diff --git a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py index 63cf5d50cac75..d0353bc8cb42a 100644 --- a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py +++ b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py @@ -8,6 +8,7 @@ import torch import torch.utils.benchmark as TBenchmark from torch.utils.benchmark import Measurement as TMeasurement +from utils import make_rand_tensors from weight_shapes import WEIGHT_SHAPES from vllm import _custom_ops as ops @@ -17,31 +18,6 @@ DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] DEFAULT_TP_SIZES = [1] -# helpers - - -def to_fp8(tensor: torch.Tensor) -> torch.Tensor: - finfo = torch.finfo(torch.float8_e4m3fn) - return torch.round(tensor.clamp( - min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) - - -def to_int8(tensor: torch.Tensor) -> torch.Tensor: - return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) - - -def make_rand_tensors(dtype: torch.dtype, m: int, n: int, - k: int) -> Tuple[torch.Tensor, torch.Tensor]: - a = torch.randn((m, k), device='cuda') * 5 - b = torch.randn((n, k), device='cuda').t() * 5 - - if dtype == torch.int8: - return to_int8(a), to_int8(b) - if dtype == torch.float8_e4m3fn: - return to_fp8(a), to_fp8(b) - - raise ValueError("unsupported dtype") - # bench def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args, @@ -386,4 +362,4 @@ def to_torch_dtype(dt): model_parser.set_defaults(func=run_model_bench) args = parser.parse_args() - args.func(args) + args.func(args) \ No newline at end of file diff --git a/benchmarks/cutlass_benchmarks/weight_shapes.py b/benchmarks/cutlass_benchmarks/weight_shapes.py index 25ec9d6028627..d58fb0bf86374 100644 --- a/benchmarks/cutlass_benchmarks/weight_shapes.py +++ b/benchmarks/cutlass_benchmarks/weight_shapes.py @@ -40,4 +40,4 @@ ([8192, 57344], 1), ([28672, 8192], 0), ], -} +} \ No newline at end of file diff --git a/csrc/core/math.hpp b/csrc/core/math.hpp new file mode 100644 index 0000000000000..ba9f40a230c8e --- /dev/null +++ b/csrc/core/math.hpp @@ -0,0 +1,7 @@ +#include +#include + +inline uint32_t next_pow_2(uint32_t const num) { + if (num <= 1) return num; + return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); +} \ No newline at end of file diff --git a/csrc/cutlass_extensions/common.cpp b/csrc/cutlass_extensions/common.cpp new file mode 100644 index 0000000000000..3d2093ab94297 --- /dev/null +++ b/csrc/cutlass_extensions/common.cpp @@ -0,0 +1,11 @@ +#include "cutlass_extensions/common.hpp" + +int32_t get_sm_version_num() { + int32_t major_capability, minor_capability; + cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor, + 0); + cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor, + 0); + int32_t version_num = major_capability * 10 + minor_capability; + return version_num; +} \ No newline at end of file diff --git a/csrc/cutlass_extensions/common.hpp b/csrc/cutlass_extensions/common.hpp new file mode 100644 index 0000000000000..85e359aa57113 --- /dev/null +++ b/csrc/cutlass_extensions/common.hpp @@ -0,0 +1,35 @@ +#pragma once + +#include "cutlass/cutlass.h" +#include +#include "cuda_runtime.h" +#include + +/** + * Helper function for checking CUTLASS errors + */ +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + TORCH_CHECK(error == cutlass::Status::kSuccess, \ + cutlassGetStatusString(error)); \ + } + +/** + * Panic wrapper for unwinding CUDA runtime errors + */ +#define CUDA_CHECK(status) \ + { \ + cudaError_t error = status; \ + TORCH_CHECK(error == cudaSuccess, cudaGetErrorString(error)); \ + } + +inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) { + int max_shared_mem_per_block_opt_in = 0; + cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in, + cudaDevAttrMaxSharedMemoryPerBlockOptin, + device); + return max_shared_mem_per_block_opt_in; +} + +int32_t get_sm_version_num(); diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp index c69e87999ae71..26f7423fd7455 100644 --- a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp +++ b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp @@ -1,3 +1,5 @@ +#pragma once + #include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp" /* diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp index 95764ecddc79f..c723adf126422 100644 --- a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp +++ b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp @@ -1,3 +1,5 @@ +#pragma once + #include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp" /* @@ -36,13 +38,13 @@ struct ScaledEpilogueBase { // Don't want to support nullptr by default template using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T, Stride, Int<0>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; // Don't want to support nullptr by default template using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T, Stride, Int<1>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; // This utility function constructs the arguments for the load descriptors diff --git a/csrc/ops.h b/csrc/ops.h index 816b471d062d2..347c502845d8f 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -162,6 +162,17 @@ void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& azp_adj, c10::optional const& azp, c10::optional const& bias); + +bool cutlass_sparse_scaled_mm_supported(int64_t cuda_device_capability); + +void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, torch::Tensor const& e, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + c10::optional const& bias); + +bool cutlass_sparse_compress_entry(torch::Tensor& a_compressed, + torch::Tensor& e, torch::Tensor const& a); #endif void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, diff --git a/csrc/quantization/cutlass_w8a8/common.hpp b/csrc/quantization/cutlass_w8a8/common.hpp deleted file mode 100644 index bf04bb400790f..0000000000000 --- a/csrc/quantization/cutlass_w8a8/common.hpp +++ /dev/null @@ -1,27 +0,0 @@ -#pragma once - -#include "cutlass/cutlass.h" -#include - -/** - * Helper function for checking CUTLASS errors - */ -#define CUTLASS_CHECK(status) \ - { \ - TORCH_CHECK(status == cutlass::Status::kSuccess, \ - cutlassGetStatusString(status)) \ - } - -inline uint32_t next_pow_2(uint32_t const num) { - if (num <= 1) return num; - return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); -} - -inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) { - int max_shared_mem_per_block_opt_in = 0; - cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in, - cudaDevAttrMaxSharedMemoryPerBlockOptin, - device); - return max_shared_mem_per_block_opt_in; -} - diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh index d03242f44ab1d..f2fae4b66d651 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh @@ -21,15 +21,16 @@ #include "cutlass/epilogue/threadblock/fusion/visitors.hpp" #include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h" -#include "common.hpp" +#include "core/math.hpp" +#include "cutlass_extensions/common.hpp" // clang-format on using namespace cute; /* - Epilogue functions can be defined to post-process the output before it is - written to GPU memory. - Epilogues must contain a public type named EVTCompute of type Sm80EVT, + Epilogues defined in, + csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp + must contain a public type named EVTCompute of type Sm80EVT, as well as a static prepare_args function that constructs an EVTCompute::Arguments struct. */ diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu index 33581a63d4c3d..123f4359c0d1a 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu @@ -1,384 +1,18 @@ -// clang-format will break include orders -// clang-format off #include #if defined CUDA_VERSION && CUDA_VERSION >= 12000 -#include + #include "scaled_mm_c3x_sm90_fp8_dispatch.cuh" + #include "scaled_mm_c3x_sm90_int8_dispatch.cuh" -#include - -#include -#include -#include - -#include "cutlass/cutlass.h" - -#include "cute/tensor.hpp" -#include "cute/atom/mma_atom.hpp" -#include "cutlass/numeric_types.h" - -#include "cutlass/gemm/device/gemm_universal_adapter.h" -#include "cutlass/gemm/kernel/gemm_universal.hpp" -#include "cutlass/epilogue/collective/collective_builder.hpp" -#include "cutlass/gemm/collective/collective_builder.hpp" - -#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" -#include "common.hpp" -// clang-format on - -using namespace cute; + #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" using namespace vllm; /* This file defines quantized GEMM operations using the CUTLASS 3.x API, for NVIDIA GPUs with sm90a (Hopper) or later. - - Epilogue functions can be defined to post-process the output before it is - written to GPU memory. - Epilogues must contain a public type named EVTCompute of type Sm90EVT, - as well as a static prepare_args function that constructs an - EVTCompute::Arguments struct. */ -namespace { - -// A wrapper for the GEMM kernel that is used to guard against compilation on -// architectures that will never use the kernel. The purpose of this is to -// reduce the size of the compiled binary. -// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef -// into code that will be executed on the device where it is defined. -template -struct enable_sm90_or_later : Kernel { - template - CUTLASS_DEVICE void operator()(Args&&... args) { - #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900 - Kernel::operator()(std::forward(args)...); - #endif - } -}; -template typename Epilogue_, - typename TileShape, typename ClusterShape, typename KernelSchedule, - typename EpilogueSchedule> -struct cutlass_3x_gemm { - using ElementAB = ElementAB_; - using ElementD = ElementD_; - using ElementAcc = - typename std::conditional, int32_t, - float>::type; - - using EpilogueDescriptor = - cutlass::epilogue::collective::detail::EpilogueDescriptor< - TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD, - ElementD, EpilogueSchedule>; - - using Epilogue = Epilogue_; - - using StrideD = Stride, Int<0>>; - using ElementC = void; - using StrideC = StrideD; - - using EVTCompute = typename Epilogue::EVTCompute; - - using CollectiveEpilogue = - typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, - ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, - ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4, - EpilogueSchedule, EVTCompute>::CollectiveOp; - - static constexpr size_t CEStorageSize = - sizeof(typename CollectiveEpilogue::SharedStorage); - using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout< - static_cast(CEStorageSize)>; - - // clang-format off - using CollectiveMainloop = - typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, - ElementAB, cutlass::layout::RowMajor, 16, - ElementAB, cutlass::layout::ColumnMajor, 16, - ElementAcc, TileShape, ClusterShape, - Stages, - KernelSchedule>::CollectiveOp; - // clang-format on - - using KernelType = enable_sm90_or_later, CollectiveMainloop, CollectiveEpilogue, - cutlass::gemm::PersistentScheduler>>; - - struct GemmKernel : public KernelType {}; -}; - -template -void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, - EpilogueArgs&&... epilogue_params) { - using ElementAB = typename Gemm::ElementAB; - using ElementD = typename Gemm::ElementD; - - int32_t m = a.size(0); - int32_t n = b.size(1); - int32_t k = a.size(1); - - int64_t lda = a.stride(0); - int64_t ldb = b.stride(1); - int64_t ldc = out.stride(0); - - using StrideA = Stride, int64_t>; - using StrideB = Stride, int64_t>; - using StrideC = typename Gemm::StrideC; - - StrideA a_stride{lda, Int<1>{}, 0}; - StrideB b_stride{ldb, Int<1>{}, 0}; - StrideC c_stride{ldc, Int<1>{}, Int<0>{}}; - - using GemmKernel = typename Gemm::GemmKernel; - typename GemmKernel::ProblemShape prob_shape{m, n, k, 1}; - - auto a_ptr = static_cast(a.data_ptr()); - auto b_ptr = static_cast(b.data_ptr()); - typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr, - b_stride}; - - auto c_ptr = static_cast(out.data_ptr()); - typename GemmKernel::EpilogueArguments epilogue_args{ - Gemm::Epilogue::prepare_args( - std::forward(epilogue_params)...), - c_ptr, c_stride, c_ptr, c_stride}; - - typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm, - prob_shape, mainloop_args, epilogue_args}; - - // Launch the CUTLASS GEMM kernel. - using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; - GemmOp gemm_op; - CUTLASS_CHECK(gemm_op.can_implement(args)); - - size_t workspace_size = gemm_op.get_workspace_size(args); - auto const workspace_options = - torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); - auto workspace = torch::empty(workspace_size, workspace_options); - - auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); - - cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); - CUTLASS_CHECK(status); -} - -template typename Epilogue> -struct sm90_fp8_config_default { - // M in (128, inf) - static_assert(std::is_same()); - using KernelSchedule = - cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_128, _128, _128>; - using ClusterShape = Shape<_2, _1, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -template typename Epilogue> -struct sm90_fp8_config_M128 { - // M in (64, 128] - static_assert(std::is_same()); - using KernelSchedule = - cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_64, _128, _128>; - using ClusterShape = Shape<_2, _1, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -template typename Epilogue> -struct sm90_fp8_config_M64 { - // M in [1, 64] - static_assert(std::is_same()); - using KernelSchedule = - cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_64, _64, _128>; - using ClusterShape = Shape<_1, _8, _1>; - - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -template typename Epilogue> -struct sm90_int8_config_default { - // For M > 128 and any N - static_assert(std::is_same()); - using KernelSchedule = - typename cutlass::gemm::KernelTmaWarpSpecializedPingpong; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_128, _128, _128>; - using ClusterShape = Shape<_2, _1, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -template typename Epilogue> -struct sm90_int8_config_M128 { - // For M in (64, 128] and any N - static_assert(std::is_same()); - using KernelSchedule = - typename cutlass::gemm::KernelTmaWarpSpecializedPingpong; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_64, _128, _128>; - using ClusterShape = Shape<_2, _1, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -template typename Epilogue> -struct sm90_int8_config_M64 { - // For M in (32, 64] and any N - static_assert(std::is_same()); - using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_64, _64, _256>; - using ClusterShape = Shape<_1, _1, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -template typename Epilogue> -struct sm90_int8_config_M32_NBig { - // For M in [1, 32] and N >= 8192 - static_assert(std::is_same()); - using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_64, _128, _256>; - using ClusterShape = Shape<_1, _4, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -template typename Epilogue> -struct sm90_int8_config_M32_NSmall { - // For M in [1, 32] and N < 8192 - static_assert(std::is_same()); - using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - using TileShape = Shape<_64, _64, _256>; - using ClusterShape = Shape<_1, _8, _1>; - using Cutlass3xGemm = - cutlass_3x_gemm; -}; - -} // namespace - -template typename Epilogue, - typename... EpilogueArgs> -void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, - EpilogueArgs&&... args) { - static_assert(std::is_same()); - TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); - TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); - - using Cutlass3xGemmDefault = - typename sm90_fp8_config_default::Cutlass3xGemm; - using Cutlass3xGemmM64 = - typename sm90_fp8_config_M64::Cutlass3xGemm; - using Cutlass3xGemmM128 = - typename sm90_fp8_config_M128::Cutlass3xGemm; - - uint32_t const m = a.size(0); - uint32_t const mp2 = - std::max(static_cast(64), next_pow_2(m)); // next power of 2 - - if (mp2 <= 64) { - // m in [1, 64] - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } else if (mp2 <= 128) { - // m in (64, 128] - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } else { - // m in (128, inf) - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } -} - -template typename Epilogue, - typename... EpilogueArgs> -void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, - EpilogueArgs&&... args) { - static_assert(std::is_same()); - TORCH_CHECK(a.dtype() == torch::kInt8); - TORCH_CHECK(b.dtype() == torch::kInt8); - - using Cutlass3xGemmDefault = - typename sm90_int8_config_default::Cutlass3xGemm; - using Cutlass3xGemmM128 = - typename sm90_int8_config_M128::Cutlass3xGemm; - using Cutlass3xGemmM64 = - typename sm90_int8_config_M64::Cutlass3xGemm; - using Cutlass3xGemmM32NBig = - typename sm90_int8_config_M32_NBig::Cutlass3xGemm; - using Cutlass3xGemmM32NSmall = - typename sm90_int8_config_M32_NSmall::Cutlass3xGemm; - - uint32_t const n = out.size(1); - bool const is_small_n = n < 8192; - - uint32_t const m = a.size(0); - uint32_t const mp2 = - std::max(static_cast(32), next_pow_2(m)); // next power of 2 - - if (mp2 <= 32) { - // m in [1, 32] - if (is_small_n) { - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } else { - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } - } else if (mp2 <= 64) { - // m in (32, 64] - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } else if (mp2 <= 128) { - // m in (64, 128] - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } else { - // m in (128, inf) - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); - } -} - template