-
-
Notifications
You must be signed in to change notification settings - Fork 5.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Kernel][Triton] Add Triton implementation for scaled_mm_triton to support fp8 and int8 SmoothQuant, symmetric case #9857
Changes from 2 commits
4624680
242e6d1
fa38282
7cb6c6e
a875da8
0fb4836
7c8865a
5abafe4
d5e390d
f003676
4f1e62e
fa6b1cc
5a04f41
69cd7a3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,96 @@ | ||||||
"""Tests for the scaled_mm_triton kernel | ||||||
|
||||||
Run `pytest tests/kernels/test_scaled_mm_triton.py`. | ||||||
""" | ||||||
import importlib | ||||||
from typing import Optional, Type | ||||||
|
||||||
import pytest | ||||||
import torch | ||||||
|
||||||
from vllm.utils import seed_everything | ||||||
|
||||||
device = "cuda" | ||||||
|
||||||
|
||||||
def scaled_mm_torch(a: torch.Tensor, | ||||||
b: torch.Tensor, | ||||||
scale_a: torch.Tensor, | ||||||
scale_b: torch.Tensor, | ||||||
out_dtype: Type[torch.dtype], | ||||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor: | ||||||
out = torch.mm(a.to(torch.float32), b.to(torch.float32)) | ||||||
out = scale_a * out | ||||||
out = scale_b.T * out | ||||||
out = out.to(out_dtype) | ||||||
if bias is not None: | ||||||
out = out + bias | ||||||
|
||||||
return out | ||||||
|
||||||
|
||||||
@pytest.mark.parametrize("M", [1, 16, 32, 64, 128, 256, 512, 222, 33, 1]) | ||||||
@pytest.mark.parametrize("N", [2048, 8192, 16384, 256, 1024]) | ||||||
@pytest.mark.parametrize("K", [128, 496, 1024]) | ||||||
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16]) | ||||||
@pytest.mark.parametrize("in_dtype", [torch.int8]) | ||||||
@pytest.mark.parametrize("use_scalar_scale_a", [True, False]) | ||||||
@pytest.mark.parametrize("use_scalar_scale_b", [True, False]) | ||||||
@pytest.mark.parametrize("use_bias", [True, False]) | ||||||
def test_scaled_mm(M, N, K, in_dtype, out_dtype, use_scalar_scale_a, | ||||||
use_scalar_scale_b, use_bias): | ||||||
is_floating_point_type = lambda t: torch.tensor([1, 1], dtype=t | ||||||
).is_floating_point() | ||||||
|
||||||
seed_everything(0) | ||||||
|
||||||
# NOTE: There are cases, where if the matrix is large enough, an output | ||||||
# like 65504.4 can be produced, and can easily turn into inf when | ||||||
# multiplied when using float16/bfloat16. This means one function, e.g., | ||||||
# testing function, and another function, e.g. golden function, can | ||||||
# produce a non-inf value while the other produces an inf value, and | ||||||
# will cause assert_close/allclose to fail, even though if overflow | ||||||
# wouldn't have occurred, the values would have been "close." | ||||||
# | ||||||
# So, the values here are kept small enough to avoid this situation. | ||||||
if is_floating_point_type(in_dtype): | ||||||
a = (0.25 * torch.rand( | ||||||
(M, K), dtype=torch.float32, device=device)).to(in_dtype) | ||||||
b = (0.25 * torch.rand( | ||||||
(K, N), dtype=torch.float32, device=device)).to(in_dtype) | ||||||
else: | ||||||
a = torch.randint(-32, 32, (M, K), dtype=in_dtype, device=device) | ||||||
b = torch.randint(-32, 32, (K, N), dtype=in_dtype, device=device) | ||||||
|
||||||
if use_scalar_scale_a: | ||||||
scale_a = torch.rand((1, 1), device=device) | ||||||
else: | ||||||
scale_a = 0.25 * torch.rand((M, 1), device=device) | ||||||
|
||||||
if use_scalar_scale_b: | ||||||
scale_b = torch.rand((1, 1), device=device) | ||||||
else: | ||||||
scale_b = 0.25 * torch.rand((1, 1), device=device) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should this be?
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed, thanks for catching that! I retested and everything still works! |
||||||
|
||||||
bias = None | ||||||
if use_bias: | ||||||
bias = torch.rand((N, ), device=device, dtype=out_dtype) | ||||||
|
||||||
scaled_mm_triton_module = importlib.import_module( | ||||||
"vllm.model_executor.layers.quantization.compressed_tensors." | ||||||
"scaled_mm_triton") | ||||||
scaled_mm_triton = scaled_mm_triton_module.scaled_mm_triton | ||||||
|
||||||
c_check = scaled_mm_triton(a, b, scale_a, scale_b, out_dtype, bias) | ||||||
|
||||||
a_cpu = a.cpu() | ||||||
b_cpu = b.cpu() | ||||||
scale_a_cpu = scale_a.cpu() | ||||||
scale_b_cpu = scale_b.cpu() | ||||||
bias_cpu = None if bias is None else bias.cpu() | ||||||
|
||||||
c_actual = scaled_mm_torch(a_cpu, b_cpu, scale_a_cpu, scale_b_cpu, | ||||||
out_dtype, bias_cpu) | ||||||
|
||||||
c_check_cpu = c_check.cpu() | ||||||
torch.testing.assert_close(c_check_cpu, c_actual, rtol=1e-1, atol=1e-1) |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,207 @@ | ||||||
from typing import Optional, Type | ||||||
|
||||||
import torch | ||||||
import triton | ||||||
import triton.language as tl | ||||||
|
||||||
|
||||||
# This function handles some cases that can cause certain failure, e.g. | ||||||
# a tensor that has shape = (72, 48) but stride = (5120, 1). It can happen, | ||||||
# for example by saving a tensor using torch.save() and then adjusting its | ||||||
# size afterwards and then trying to use it. Unfortunately, | ||||||
# torch.is_contiguous() doesn't help since a transposed tensor doesn't return | ||||||
# True, even though it can be stored contiguously in memory. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a fundamental reason we can't handle this case? This case can arise when working with slices of a tensor, and I went out of my way to support it for the cutlass_scaled_mm kernels. Supporting this is definitely a Nice To Have rather than a requirement but would like to know what the problem is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought it might have been occurring in vllm, but I haven't seen it happen. I did it to myself when I was debugging though. I did a torch.save() when I found a tensor that didn't have the correct output, but the tensor was huge, so I changed the size of the tensor to (72, 48) while the old size was something like (131072, 5120). However, the strides remained (5120, 1) and my kernel got the incorrect result, while torch.__int_mm() would get the correct result for torch._int_mm(a,b). I was wondering why that was, so I looked at the pytorch code to figure it out and found out what they were doing. I put it there, just in case, but I actually haven't seen this happen in vLLM though. I thought about removing the prepare_matrix_for_triton() function, but this scaled_mm_triton() function is general enough that it could be used elsewhere (it handles any input dtype and output dtype I've tried so far and can mix and match per-tensor, row-wise, and column-wise scaling), so I left it in. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will try taking this out and test, and put an assert where it was used instead, per suggestion on the other comment. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just using assert now. |
||||||
# | ||||||
# There is a way to handle this case, which I learned about from here: | ||||||
# | ||||||
# https://github.com/pytorch/pytorch/blob/ | ||||||
# a874ec85e83cfe75e7238296022d53d7e20860df/aten/src/ATen/native/ | ||||||
# cuda/Blas.cpp#L58 | ||||||
# | ||||||
# This doesn't happen very often fortunately, because the only solution is | ||||||
# inefficient. | ||||||
def prepare_matrix_for_triton(x: torch.Tensor): | ||||||
strides = x.stride() | ||||||
sizes = x.shape | ||||||
is_not_transpose = strides[0] == 1 and (strides[1] >= max(1, sizes[0])) | ||||||
is_transpose = strides[1] == 1 and (strides[0] >= max(1, sizes[1])) | ||||||
if not is_not_transpose and not is_transpose: | ||||||
return torch.clone(x, memory_format=torch.contiguous_format) | ||||||
return x | ||||||
|
||||||
|
||||||
@triton.jit | ||||||
def scaled_mm_kernel(a_ptr, b_ptr, scale_a_ptr, scale_b_ptr, c_ptr, bias_ptr, | ||||||
M, N, K, stride_am, stride_ak, stride_bk, stride_bn, | ||||||
stride_cm, stride_cn, ACCUMULATOR_DTYPE: tl.constexpr, | ||||||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, | ||||||
BLOCK_SIZE_K: tl.constexpr, | ||||||
BLOCK_SIZE_SCALE_A: tl.constexpr, | ||||||
BLOCK_SIZE_SCALE_B: tl.constexpr): | ||||||
pid = tl.program_id(axis=0) | ||||||
|
||||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) | ||||||
|
||||||
pid_m = pid // num_pid_n | ||||||
pid_n = pid % num_pid_n | ||||||
|
||||||
accumulator_dtype = ACCUMULATOR_DTYPE | ||||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), | ||||||
dtype=accumulator_dtype) | ||||||
|
||||||
# NOTE: Some tensor inputs are so large, they will cause int32 overflow | ||||||
# so it is necessary to use tl.int64 for all the offsets, else SEGV will | ||||||
# eventually occur. | ||||||
|
||||||
# Offsets and masks. | ||||||
offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) | ||||||
masks_am = offsets_am < M | ||||||
|
||||||
offsets_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64) | ||||||
masks_bn = offsets_bn < N | ||||||
|
||||||
offsets_k = tl.arange(0, BLOCK_SIZE_K).to(tl.int64) | ||||||
offsets_a = (stride_am * offsets_am[:, None] + | ||||||
stride_ak * offsets_k[None, :]) | ||||||
offsets_b = (stride_bk * offsets_k[:, None] + | ||||||
stride_bn * offsets_bn[None, :]) | ||||||
|
||||||
# NOTE: BLOCK_SIZE_SCALE_A could be 1 or BLOCK_SIZE_M, so need to create | ||||||
# appropriate offsets and masks for each case. Same goes for | ||||||
# BLOCK_SIZE_SCALE_B. | ||||||
offsets_scale_am = (tl.arange(0, BLOCK_SIZE_SCALE_A) + | ||||||
(BLOCK_SIZE_SCALE_A > 1) * pid_m * BLOCK_SIZE_M) | ||||||
masks_scale_am = offsets_scale_am < M | ||||||
|
||||||
offsets_scale_bn = (tl.arange(0, BLOCK_SIZE_SCALE_B) + | ||||||
(BLOCK_SIZE_SCALE_B > 1) * pid_n * BLOCK_SIZE_N) | ||||||
masks_scale_bn = offsets_scale_bn < N | ||||||
|
||||||
offsets_scale_a = (offsets_scale_am[:, None].to(tl.int64) + | ||||||
tl.arange(0, 1)[None, :].to(tl.int64)) | ||||||
offsets_scale_b = (offsets_scale_bn[:, None].to(tl.int64) + | ||||||
tl.arange(0, 1)[None, :].to(tl.int64)) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you explain what's going on here? It seems like this could be simplified and left as 1D with broadcasting happening during the load. Also I would think you could leave these as int32 since they are vectors There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I might be able to simplify this, going to give it a shot. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Simplified. |
||||||
|
||||||
a_ptrs = a_ptr + offsets_a | ||||||
b_ptrs = b_ptr + offsets_b | ||||||
|
||||||
scale_a_ptrs = scale_a_ptr + offsets_scale_a | ||||||
scale_b_ptrs = scale_b_ptr + offsets_scale_b | ||||||
|
||||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): | ||||||
masks_k = offsets_k < K | ||||||
masks_a = masks_am[:, None] & masks_k[None, :] | ||||||
a = tl.load(a_ptrs, mask=masks_a) | ||||||
|
||||||
masks_b = masks_k[:, None] & masks_bn[None, :] | ||||||
b = tl.load(b_ptrs, mask=masks_b) | ||||||
|
||||||
# Accumulate results. | ||||||
accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype) | ||||||
|
||||||
offsets_k += BLOCK_SIZE_K | ||||||
a_ptrs += BLOCK_SIZE_K * stride_ak | ||||||
b_ptrs += BLOCK_SIZE_K * stride_bk | ||||||
|
||||||
# Apply scale at end. | ||||||
masks_scale_a = masks_scale_am[:, None] & (tl.arange(0, 1) < 1)[:, None] | ||||||
scale_a = tl.load(scale_a_ptrs, masks_scale_a) | ||||||
# Need to broadcast to the appropriate size, if scale_a is already | ||||||
# (BLOCK_SIZE_M, 1) then it will broadcast to its own shape. Same goes | ||||||
# for scale_b below. | ||||||
scale_a = scale_a.broadcast_to((BLOCK_SIZE_M, 1)) | ||||||
accumulator = scale_a * accumulator.to(tl.float32) | ||||||
|
||||||
masks_scale_b = masks_scale_bn[:, None] & (tl.arange(0, 1) < 1)[None, :] | ||||||
scale_b = tl.load(scale_b_ptrs, masks_scale_b) | ||||||
scale_b = scale_b.broadcast_to((BLOCK_SIZE_N, 1)) | ||||||
accumulator = scale_b.T * accumulator.to(tl.float32) | ||||||
|
||||||
# Convert to output format. | ||||||
c = accumulator.to(c_ptr.type.element_ty) | ||||||
|
||||||
# Add bias, it's already in output format, so add it after conversion. | ||||||
if bias_ptr: | ||||||
offsets_bias = offsets_bn | ||||||
bias_ptrs = bias_ptr + offsets_bias | ||||||
bias_mask = offsets_bias < N | ||||||
bias = tl.load(bias_ptrs, bias_mask) | ||||||
c += bias | ||||||
|
||||||
# Save output | ||||||
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) | ||||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64) | ||||||
offs_cm = offs_cm.to(tl.int64) | ||||||
offs_cn = offs_cn.to(tl.int64) | ||||||
c_ptrs = (c_ptr + stride_cm * offs_cm[:, None] + | ||||||
stride_cn * offs_cn[None, :]) | ||||||
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) | ||||||
|
||||||
tl.store(c_ptrs, c, mask=c_mask) | ||||||
|
||||||
|
||||||
# input - [M, K] | ||||||
# weight - [K, N] | ||||||
def scaled_mm_triton(input: torch.Tensor, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: maybe call it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||||||
weight: torch.Tensor, | ||||||
scale_a: torch.Tensor, | ||||||
scale_b: torch.Tensor, | ||||||
out_dtype: Type[torch.dtype], | ||||||
bias: Optional[torch.Tensor] = None, | ||||||
block_size_m: int = 32, | ||||||
block_size_n: int = 32, | ||||||
block_size_k: int = 32) -> torch.Tensor: | ||||||
M, K = input.shape | ||||||
N = weight.shape[1] | ||||||
|
||||||
assert N > 0 and K > 0 and M > 0 | ||||||
assert weight.shape[0] == K | ||||||
assert input.dtype == weight.dtype | ||||||
assert scale_a.dtype == scale_b.dtype and scale_a.is_floating_point() | ||||||
assert scale_a.shape == torch.Size([1, 1]) or scale_a.shape == torch.Size( | ||||||
[M, 1]) | ||||||
assert scale_b.shape == torch.Size([1, 1]) or scale_b.shape == torch.Size( | ||||||
[N, 1]) | ||||||
assert torch.empty((1, 1), dtype=out_dtype).is_floating_point() | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||||||
assert bias is None or bias.is_floating_point() | ||||||
|
||||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv( | ||||||
N, META['BLOCK_SIZE_N']), ) | ||||||
|
||||||
result = torch.empty((M, N), dtype=out_dtype, device=input.device) | ||||||
|
||||||
has_scalar = lambda x: x.shape[0] == 1 and x.shape[1] == 1 | ||||||
|
||||||
block_size_sa = 1 if has_scalar(scale_a) else block_size_m | ||||||
block_size_sb = 1 if has_scalar(scale_b) else block_size_n | ||||||
|
||||||
input = prepare_matrix_for_triton(input) | ||||||
weight = prepare_matrix_for_triton(weight) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the weights need to be preprocessed, it's be better to do this during I suggest replacing this with an assert or adding a warning so it's obvious that there's a problem rather than silently having a performance regression. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just using assert now. |
||||||
|
||||||
accumulator_dtype = tl.float32 if input.is_floating_point() else tl.int32 | ||||||
|
||||||
# A = input, B = weight, C = result | ||||||
# A = M x K, B = K x N, C = M x N | ||||||
scaled_mm_kernel[grid](input, | ||||||
weight, | ||||||
scale_a, | ||||||
scale_b, | ||||||
result, | ||||||
bias, | ||||||
M, | ||||||
N, | ||||||
K, | ||||||
input.stride(0), | ||||||
input.stride(1), | ||||||
weight.stride(0), | ||||||
weight.stride(1), | ||||||
result.stride(0), | ||||||
result.stride(1), | ||||||
accumulator_dtype, | ||||||
BLOCK_SIZE_M=block_size_m, | ||||||
BLOCK_SIZE_N=block_size_n, | ||||||
BLOCK_SIZE_K=block_size_k, | ||||||
BLOCK_SIZE_SCALE_A=block_size_sa, | ||||||
BLOCK_SIZE_SCALE_B=block_size_sb) | ||||||
|
||||||
return result.to(out_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does
scaled_mm_triton
supportfp8_e4m3_fnuz
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, scaled_mm_triton supports pretty much everything, as long as the hardware and Triton supports it. It's just that the tests will take super long to run if I run it on all types. The fp8 types aren't supported uniformly, e.g. e5m2 seems to only have some limited support on AMD hardware right now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case, could you expand the test coverage to cover fp8? vLLM doesn't use e5m2 for linear layers currently, but will use
fp8_e4m3
extensively. Ideally we would detect if we're on a CUDA vs RoCM system and test usingfp8_e4m3_fn
orfp8_e4m3_fnuz
accordinglyThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Testing covering fp8 now.