diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 6f38cd313f115..6e83c887f89b6 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -148,8 +148,9 @@ steps: - python3 cpu_offload.py - python3 offline_inference_with_prefix.py - python3 llm_engine_example.py - - python3 llava_example.py + - python3 offline_inference_vision_language.py - python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors + - python3 offline_inference_encoder_decoder.py - label: Models Test # 1hr10min source_file_dependencies: @@ -289,6 +290,7 @@ steps: commands: - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py - TARGET_TEST_SUITE=L4 pytest -v -s distributed/test_basic_distributed_correctness.py + - pytest -v -s distributed/test_basic_distributed_correctness_enc_dec.py - pytest -v -s distributed/test_chunked_prefill_distributed.py - pytest -v -s distributed/test_multimodal_broadcast.py - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py diff --git a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py index 64011b2db2395..63cf5d50cac75 100644 --- a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py +++ b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py @@ -32,7 +32,6 @@ def to_int8(tensor: torch.Tensor) -> torch.Tensor: def make_rand_tensors(dtype: torch.dtype, m: int, n: int, k: int) -> Tuple[torch.Tensor, torch.Tensor]: - a = torch.randn((m, k), device='cuda') * 5 b = torch.randn((n, k), device='cuda').t() * 5 @@ -44,59 +43,18 @@ def make_rand_tensors(dtype: torch.dtype, m: int, n: int, raise ValueError("unsupported dtype") -# impl - - -def pytorch_mm_impl(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: torch.dtype) -> torch.Tensor: - return torch.mm(a, b) - - -def pytorch_fp8_impl(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: torch.dtype) -> torch.Tensor: - return torch._scaled_mm(a, - b, - scale_a=scale_a, - scale_b=scale_b, - out_dtype=out_dtype) - - -def pytorch_fp8_impl_fast_accum(a: torch.Tensor, b: torch.Tensor, - scale_a: torch.Tensor, scale_b: torch.Tensor, - out_dtype: torch.dtype) -> torch.Tensor: - return torch._scaled_mm(a, - b, - scale_a=scale_a, - scale_b=scale_b, - out_dtype=out_dtype, - use_fast_accum=True) - - -def cutlass_impl(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: torch.dtype) -> torch.Tensor: - return ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype=out_dtype) - - # bench -def bench_fn(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, - scale_b: torch.Tensor, out_dtype: torch.dtype, label: str, - sub_label: str, fn: Callable, description: str) -> TMeasurement: - +def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args, + **kwargs) -> TMeasurement: min_run_time = 1 globals = { - "a": a, - "b": b, - "scale_a": scale_a, - "scale_b": scale_b, - "out_dtype": out_dtype, + "args": args, + "kwargs": kwargs, "fn": fn, } return TBenchmark.Timer( - stmt="fn(a, b, scale_a, scale_b, out_dtype)", + stmt="fn(*args, **kwargs)", globals=globals, label=label, sub_label=sub_label, @@ -110,26 +68,58 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str, a, b = make_rand_tensors(torch.int8, m, n, k) scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) + bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) + azp = torch.zeros((m, ), device="cuda", dtype=torch.int32) + azp_adj = torch.zeros((n, ), device="cuda", dtype=torch.int32) timers = [] # pytorch impl - bfloat16 timers.append( - bench_fn(a.to(dtype=torch.bfloat16, device="cuda"), - b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b, - torch.bfloat16, label, sub_label, pytorch_mm_impl, - "pytorch_bf16_bf16_bf16_matmul-no-scales")) + bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", + torch.mm, a.to(dtype=torch.bfloat16), + b.to(dtype=torch.bfloat16))) # pytorch impl - float16 timers.append( - bench_fn(a.to(dtype=torch.float16, device="cuda"), - b.to(dtype=torch.float16, device="cuda"), scale_a, scale_b, - torch.float16, label, sub_label, pytorch_mm_impl, - "pytorch_fp16_fp16_fp16_matmul-no-scales")) + bench_fn(label, sub_label, + "pytorch_fp16_fp16_fp16_matmul-no-scales", torch.mm, + a.to(dtype=torch.float16), b.to(dtype=torch.float16))) # cutlass impl timers.append( - bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label, - cutlass_impl, "cutlass_i8_i8_bf16_scaled_mm")) + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm", + ops.cutlass_scaled_mm, a, b, scale_a, scale_b, + torch.bfloat16)) + + # cutlass with bias + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_bias", + ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16, + bias)) + + # cutlass with azp per-tensor + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp", + ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b, + torch.bfloat16, azp_adj)) + + # cutlass with azp per-tensor + bias + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp_bias", + ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b, + torch.bfloat16, azp_adj, None, bias)) + + # cutlass with azp per-token + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp_pt", + ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b, + torch.bfloat16, azp_adj, azp)) + + # cutlass with azp per-token + bias + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias", + ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b, + torch.bfloat16, azp_adj, azp, bias)) return timers @@ -140,46 +130,88 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str, a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k) scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) + bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) timers = [] # pytorch impl w. bf16 timers.append( - bench_fn(a.to(dtype=torch.bfloat16, device="cuda"), - b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b, - torch.bfloat16, label, sub_label, pytorch_mm_impl, - "pytorch_bf16_bf16_bf16_matmul-no-scales")) + bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", + torch.mm, a.to(dtype=torch.bfloat16, device="cuda"), + b.to(dtype=torch.bfloat16, device="cuda"))) # pytorch impl: bf16 output, without fp8 fast accum timers.append( - bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label, - pytorch_fp8_impl, "pytorch_fp8_fp8_bf16_scaled_mm")) + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_bf16_scaled_mm", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.bfloat16)) # pytorch impl: bf16 output, with fp8 fast accum timers.append( - bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label, - pytorch_fp8_impl_fast_accum, - "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum")) + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.bfloat16, + use_fast_accum=True)) # pytorch impl: fp16 output, without fp8 fast accum timers.append( - bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label, - pytorch_fp8_impl, "pytorch_fp8_fp8_fp16_scaled_mm")) + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_fp16_scaled_mm", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.float16)) # pytorch impl: fp16 output, with fp8 fast accum timers.append( - bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label, - pytorch_fp8_impl_fast_accum, - "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum")) + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.float16, + use_fast_accum=True)) # cutlass impl: bf16 output timers.append( - bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label, - cutlass_impl, "cutlass_fp8_fp8_bf16_scaled_mm")) + bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm", + ops.cutlass_scaled_mm, a, b, scale_a, scale_b, + torch.bfloat16)) # cutlass impl: fp16 output timers.append( - bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label, - cutlass_impl, "cutlass_fp8_fp8_fp16_scaled_mm")) + bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_mm", + ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.float16)) + + # cutlass impl: bf16 output, with bias + timers.append( + bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm_bias", + ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16, + bias)) + + # cutlass impl: fp16 output, with bias + timers.append( + bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_mm_bias", + ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.float16, + bias.to(dtype=torch.float16))) + return timers @@ -200,7 +232,6 @@ def print_timers(timers: Iterable[TMeasurement]): def run(dtype: torch.dtype, MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]: - results = [] for m, k, n in MKNs: timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm", @@ -216,7 +247,6 @@ def make_output(data: Iterable[TMeasurement], MKNs: Iterable[Tuple[int, int, int]], base_description: str, timestamp=None): - print(f"== All Results {base_description} ====") print_timers(data) @@ -251,7 +281,6 @@ def run_range_bench(args): def run_model_bench(args): - print("Benchmarking models:") for i, model in enumerate(args.models): print(f"[{i}] {model}") diff --git a/csrc/ops.h b/csrc/ops.h index e9e5f79a4a6f6..023455f8a1530 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -128,6 +128,14 @@ void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b_scales, c10::optional const& bias); +void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& azp_adj, + c10::optional const& azp, + c10::optional const& bias); + torch::Tensor marlin_qqq_gemm(torch::Tensor const& a, torch::Tensor const& b_q_weight, torch::Tensor const& s_tok, diff --git a/csrc/quantization/cutlass_w8a8/Epilogues.md b/csrc/quantization/cutlass_w8a8/Epilogues.md new file mode 100644 index 0000000000000..aae04157b10de --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/Epilogues.md @@ -0,0 +1,147 @@ +# CUTLASS Epilogues + +## Introduction +This document describes the various CUTLASS epilogues implemented for fusing de-quantization operations onto GEMMs. + +Currently, we only support symmetric quantization for weights, +and symmetric and asymmetric quantization for activations. +Both can be quantized per-tensor or per-channel (weights) / per-token (activations). + +There are 4 epilogues: +1. ScaledEpilogue: symmetric quantization for activations, no bias. +1. ScaledEpilogueBias: symmetric quantization for activations, supports bias. +1. ScaledEpilogueAzp: asymmetric per-tensor quantization for activations, supports bias. +1. ScaledEpilogueAzpPerToken: asymmetric per-token quantization for activations, supports bias. + +We do not have epilogues for asymmetric quantization of activations without bias in order to reduce final binary size. +Instead, if no bias is passed, the epilogue will use 0 as the bias. +That induces a redundant addition operation (and runtime check), but the performance impact is minor. + +## Underlying Linear Algebra + +More details available in the [Activation Quantization RFC](https://github.com/vllm-project/vllm/issues/3975). + +If $` \widehat X `$ is the quantized $` X `$, our matrices become the following + +```math +A = s_a (\widehat A - J_a z_a) +``` +```math +B = s_b \widehat B +``` +```math +D = A B + C +``` +```math +D = s_a s_b \widehat D + C +``` + +Here, D is the output of the GEMM, and C is the bias. +A is the activations and supports asymmetric quantization, +and B is the weights and only supports symmetric quantization. +$ s_a $ and $s_b$ are the scales for activations and weights, respectively. +$ z_a $ is the zero-point for activations, and $ J_a $ is the matrix of all ones with dimensions of A. +Additional epilogues would be required to support asymmetric quantization for weights. + +Expanding further, we can calculate $` \widehat D `$ as follows: + +```math +A B = s_a ( \widehat A - J_a z_a ) s_b \widehat B +``` +```math +A B = s_a s_b \left( \widehat A \widehat B - J_a z_a \widehat B \right) +``` +```math +\widehat D = \widehat A \widehat B - z_a J_a \widehat B +``` + +Note that $` \widehat A \widehat B `$ is the raw output of the GEMM, +and $` J_a \widehat B `$ is known ahead of time. +Each row of it is equal to $` \mathbf 1 \widehat B `$, which is a row-vector of column sums of $` \widehat B `$. + +## Epilogues + +### ScaledEpilogue +This epilogue computes the symmetric quantization for activations without bias, meaning $` C = 0 `$ and $` z_a = 0 `$. +The output of the GEMM is: + +```math +\widehat D = \widehat A \widehat B +``` +```math +D = s_a s_b \widehat D +``` +```math +D = s_a s_b \widehat A \widehat B +``` + +Epilogue parameters: +- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector). +- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector). + +### ScaledEpilogueBias +This epilogue computes the symmetric quantization for activations with bias, meaning $` z_a = 0 `$. +The output of the GEMM is: + +```math +\widehat D = \widehat A \widehat B +``` +```math +D = s_a s_b \widehat D + C +``` +```math +D = s_a s_b \widehat A \widehat B + C +``` + + +Epilogue parameters: +- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector). +- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector). +- `bias` is the bias, is always per-channel (row-vector). + +### ScaledEpilogueAzp +This epilogue computes the asymmetric per-tensor quantization for activations with bias. +The output of the GEMM is: + +```math +\widehat D = \widehat A \widehat B - z_a J_a \widehat B +``` +```math +D = s_a s_b \widehat D + C +``` +```math +D = s_a s_b \left( \widehat A \widehat B - z_a J_a \widehat B \right) + C +``` + +Because $` z_a `$ is a scalar, the zero-point term $` z_a J_a \widehat B `$ has every row equal to $` z_a \mathbf 1 B `$. +That is precomputed and stored in `azp_with_adj` as a row-vector. + +Epilogue parameters: +- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector). + - Generally this will be per-tensor as the zero-points are per-tensor. +- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector). +- `azp_with_adj` is the precomputed zero-point term ($` z_a J_a \widehat B `$), is per-channel (row-vector). +- `bias` is the bias, is always per-channel (row-vector). + +To use these kernels efficiently, users must precompute the `azp_with_adj` term offline and pass it to the kernel. + +### ScaledEpilogueAzpPerToken +This epilogue computes the asymmetric per-token quantization for activations with bias. + +The output of the GEMM is the same as above, but the $` z_a `$ is a column-vector. +That means the zero-point term $` z_a J_a \widehat B `$ becomes an outer product of $` z_a `$ and $` \mathbf 1 \widehat B `$. + +Epilogue parameters: +- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector). + - Generally this will be per-token as the zero-points are per-token. +- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector). +- `azp_adj` is the precomputed zero-point adjustment term ($` \mathbf 1 \widehat B `$), is per-channel (row-vector). +- `azp` is the zero-point (`z_a`), is per-token (column-vector). +- `bias` is the bias, is always per-channel (row-vector). + +To use these kernels efficiently, users must precompute the `azp_adj` term offline and pass it to the kernel. + +The epilogue performs the following computation (where `Dq` is the raw quantized output of the GEMM): +``` +out = scale_a * scale_b * (Dq - azp_adj * azp) + bias +``` diff --git a/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp b/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp index c4c6b18654eed..d407d66ab2aa6 100644 --- a/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp +++ b/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp @@ -207,6 +207,156 @@ struct VisitorRowOrScalarBroadcast { }; +///////////////////////////////////////////////////////////////////////////////////////////////// + +// This is a modified RowBroadcast that will broadcast 0 if ptr_row is null +template< + class ThreadMap, + class Element, + class StrideMNL +> +struct VisitorRowOrZeroBroadcast { + + // This struct has been modified to remove null_default (because it's always 0) + struct Arguments { + Element const* ptr_row = nullptr; + StrideMNL dRow = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + struct SharedStorage {}; + + // Global load type + static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits::value; + using VecType = uint_bit_t; + static int constexpr VecLength = sizeof(VecType) / sizeof(Element); + + CUTLASS_HOST_DEVICE + VisitorRowOrZeroBroadcast() { } + + CUTLASS_HOST_DEVICE + VisitorRowOrZeroBroadcast(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) { } + + Params const* params_ptr; + + template + struct Callbacks : EmptyCallbacks { + CUTLASS_DEVICE + Callbacks( + GTensor&& tC_gRow, + RTensor&& tC_rRow, + CTensor&& tC_cRow, + ProblemShape problem_shape, + Params const* params_ptr + ): + tC_gRow(cute::forward(tC_gRow)), + tC_rRow(cute::forward(tC_rRow)), + tC_cRow(cute::forward(tC_cRow)), + n(get<1>(problem_shape)), + params_ptr(params_ptr) { } + + GTensor tC_gRow; + RTensor tC_rRow; + CTensor tC_cRow; + Params const* params_ptr; + int n; + + // This function is modified from VisitorRowBroadcast + CUTLASS_DEVICE void + begin_epilogue() { + clear(tC_rRow); + auto src_v = filter(tC_gRow); + auto coord_v = filter(tC_cRow); + auto dst_v = filter(tC_rRow); + + if (params_ptr->ptr_row != nullptr) { + // In this case we are loading from a row vector and broadcasting + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src_v); ++i) { + bool guard = get<1>(coord_v(i)) < n; + cutlass::arch::global_load( + dst_v(i), (void const*)&src_v(i), guard); + } + } else { + // In this case we are broadcasting 0 + VecType filled_vec; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < VecLength; i++) { + reinterpret_cast(&filled_vec)[i] = Element{0}; + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src_v); ++i) { + if (get<1>(coord_v(i)) < n) { + dst_v(i) = filled_vec; + } + } + } + } + + template + CUTLASS_DEVICE auto // returns an Array + visit(int iter_idx, int row_idx, int column_idx, int frg_idx, + Array const& frg_acc) { + Tensor rRow_frg = recast>(coalesce(tC_rRow)); + return rRow_frg(column_idx); + } + }; + + template + CUTLASS_DEVICE auto + get_callbacks( + gemm::GemmCoord threadblock_tile_offset, + int thread_idx, + ProblemShape problem_shape + ) { + Tensor mRow = make_tensor( + make_gmem_ptr(params_ptr->ptr_row), + problem_shape, + params_ptr->dRow); + + // VECTOR, FRAGMENT_COLUMN + Tensor tC_gRow = recast( + ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset) + )(_,_,_0{},_0{},_0{},_0{}); + Tensor tC_rRow = make_tensor_like(tC_gRow); + + // Generate the pred tensor + Tensor cRow = make_identity_tensor(mRow.shape()); + Tensor tC_cRow = outer_partition( + ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}), + Shape>{}, + (_0{}) + ); + + return Callbacks< + decltype(tC_gRow), decltype(tC_rRow), + decltype(tC_cRow), ProblemShape>( + cute::move(tC_gRow), + cute::move(tC_rRow), + cute::move(tC_cRow), + problem_shape, + params_ptr + ); + } + +}; + + ///////////////////////////////////////////////////////////////////////////////////////////////// // Column vector broadcast @@ -217,7 +367,7 @@ template< > struct VisitorColOrScalarBroadcast { - // This struct has been modified to have a bool indicating that ptr_col is a + // This struct has been modified to have a bool indicating that ptr_col is a // scalar that must be broadcast. struct Arguments { Element const* ptr_col = nullptr; diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu index 8d0dfee7bf23a..ee801e16573d4 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu @@ -50,6 +50,25 @@ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a, } } +void cutlass_scaled_mm_azp_sm75(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& azp_adj, + c10::optional const& azp, + c10::optional const& bias) { + TORCH_CHECK(a_scales.dtype() == torch::kFloat32); + TORCH_CHECK(b_scales.dtype() == torch::kFloat32); + + if (azp) { + return cutlass_scaled_mm_sm75_epilogue( + out, a, b, a_scales, b_scales, azp_adj, *azp, bias); + } else { + return cutlass_scaled_mm_sm75_epilogue( + out, a, b, a_scales, b_scales, azp_adj, bias); + } +} + template