Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ Kernel ] FP8 Dynamic-Per-Token Quant Kernel #6511

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,16 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,

void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);

void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& scale);
void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor const& scale);

void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor& scale);

void dynamic_per_token_scaled_fp8_quant(torch::Tensor& out,
torch::Tensor const& input,
torch::Tensor& scale);

void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
int64_t block_size, torch::Tensor sorted_token_ids,
torch::Tensor experts_ids,
Expand Down
144 changes: 124 additions & 20 deletions csrc/quantization/fp8/common.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include "cuda_compat.h"
#include "dispatch_utils.h"

#include "../../reduction_utils.cuh"

namespace vllm {

__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
Expand Down Expand Up @@ -88,25 +90,48 @@ typedef struct __align__(4) {
float8x4_t;

template <typename scalar_t>
__global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
const scalar_t* __restrict__ input,
const float* __restrict__ scale,
int64_t num_elems) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
__device__ float thread_max_vec(scalar_t const* __restrict__ input,
int64_t const num_elems, int const tid,
int const step) {
// Vectorized input/output to better utilize memory bandwidth.
vec4_t<scalar_t> const* vectorized_in =
reinterpret_cast<vec4_t<scalar_t> const*>(input);

// Invert the scale so that we can use multiplications to avoid expensive
// division.
const float inverted_scale = 1.0f / (*scale);
int const num_vec_elems = num_elems >> 2;
float absmax_val = 0.0f;

#pragma unroll 4
for (int i = tid; i < num_vec_elems; i += step) {
vec4_t<scalar_t> in_vec = vectorized_in[i];
absmax_val = max(absmax_val, fabs(in_vec.x));
absmax_val = max(absmax_val, fabs(in_vec.y));
absmax_val = max(absmax_val, fabs(in_vec.z));
absmax_val = max(absmax_val, fabs(in_vec.w));
}

// Handle the remaining elements if num_elems is not divisible by 4
for (int i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
absmax_val = max(absmax_val, fabs(input[i]));
}

return absmax_val;
}

template <typename scalar_t>
__device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out,
scalar_t const* __restrict__ input,
float const inverted_scale,
int64_t const num_elems,
int const tid, int const step) {
// Vectorized input/output to better utilize memory bandwidth.
const vec4_t<scalar_t>* vectorized_in =
reinterpret_cast<const vec4_t<scalar_t>*>(input);
vec4_t<scalar_t> const* vectorized_in =
reinterpret_cast<vec4_t<scalar_t> const*>(input);
float8x4_t* vectorized_out = reinterpret_cast<float8x4_t*>(out);

int num_vec_elems = num_elems >> 2;
int const num_vec_elems = num_elems >> 2;

#pragma unroll 4
for (int i = tid; i < num_vec_elems; i += blockDim.x * gridDim.x) {
for (int i = tid; i < num_vec_elems; i += step) {
vec4_t<scalar_t> in_vec = vectorized_in[i];
float8x4_t out_vec;

Expand All @@ -118,17 +143,74 @@ __global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
}

// Handle the remaining elements if num_elems is not divisible by 4
for (int i = num_vec_elems * 4 + tid; i < num_elems;
i += blockDim.x * gridDim.x) {
for (int i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
out[i] = scaled_fp8_conversion(input[i], inverted_scale);
}
}

template <typename scalar_t>
__global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
const scalar_t* __restrict__ input,
const float* __restrict__ scale,
int64_t num_elems) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;

// Invert the scale so that we can use multiplications to avoid expensive
// division.
const float inverted_scale = 1.0f / (*scale);

scaled_fp8_conversion_vec(out, input, inverted_scale, num_elems, tid,
blockDim.x * gridDim.x);
}

template <typename scalar_t>
__global__ void dynamic_per_token_scaled_fp8_quant_kernel(
c10::Float8_e4m3fn* __restrict__ out, float* __restrict__ scale,
scalar_t const* __restrict__ input, const int hidden_size) {
int const tid = threadIdx.x;
int const token_idx = blockIdx.x;

scalar_t const* __restrict__ token_input = &input[token_idx * hidden_size];
c10::Float8_e4m3fn* __restrict__ token_output = &out[token_idx * hidden_size];

// For vectorization, token_input and token_output pointers need to be
// aligned at 8-byte and 4-byte addresses respectively.
bool const can_vectorize = hidden_size % 4 == 0;

float absmax_val = 0.0f;
if (can_vectorize) {
absmax_val = thread_max_vec(token_input, hidden_size, tid, blockDim.x);
} else {
for (int i = tid; i < hidden_size; i += blockDim.x) {
float const x = static_cast<float>(token_input[i]);
absmax_val = max(absmax_val, fabs(x));
}
}

float const block_absmax_val_maybe = blockReduceMax(absmax_val);
__shared__ float block_absmax_val;
if (tid == 0) {
block_absmax_val = block_absmax_val_maybe;
scale[token_idx] = block_absmax_val / FP8_E4M3_MAX;
}
__syncthreads();

float const inverted_scale = FP8_E4M3_MAX / block_absmax_val;
if (can_vectorize) {
scaled_fp8_conversion_vec(token_output, token_input, inverted_scale,
hidden_size, tid, blockDim.x);
} else {
for (int i = tid; i < hidden_size; i += blockDim.x) {
token_output[i] = scaled_fp8_conversion(token_input[i], inverted_scale);
}
}
}

} // namespace vllm

void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
torch::Tensor& input, // [..., d]
torch::Tensor& scale) // [1]
void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
torch::Tensor const& input, // [..., d]
torch::Tensor const& scale) // [1]
{
int64_t num_tokens = input.numel() / input.size(-1);
int64_t num_elems = input.numel();
Expand All @@ -144,9 +226,9 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
});
}

void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
torch::Tensor& input, // [..., d]
torch::Tensor& scale) // [1]
void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
torch::Tensor const& input, // [..., d]
torch::Tensor& scale) // [1]
{
int64_t num_tokens = input.numel() / input.size(-1);
int64_t num_elems = input.numel();
Expand All @@ -163,3 +245,25 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
scale.data_ptr<float>(), num_elems);
});
}

void dynamic_per_token_scaled_fp8_quant(torch::Tensor& out, // [..., d]
torch::Tensor const& input, // [..., d]
torch::Tensor& scales) {
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());

int const hidden_size = input.size(-1);
int const num_tokens = input.numel() / hidden_size;
dim3 const grid(num_tokens);
dim3 const block(std::min(hidden_size, 1024));

const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "dynamic_per_token_scaled_fp8_quant_kernel", [&] {
vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t>
<<<grid, block, 0, stream>>>(
out.data_ptr<c10::Float8_e4m3fn>(), scales.data_ptr<float>(),
input.data_ptr<scalar_t>(), hidden_size);
});
}
10 changes: 9 additions & 1 deletion csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,12 +175,20 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"static_scaled_fp8_quant(Tensor! out, Tensor input, Tensor scale) -> ()");
ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant);

// Compute FP8 quantized tensor and scaling factor.
// Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.
ops.def(
"dynamic_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! scale) -> "
"()");
ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);

// Compute dynamic-per-token FP8 quantized tensor and scaling factor.
ops.def(
"dynamic_per_token_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! "
"scale) -> "
"()");
ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
&dynamic_per_token_scaled_fp8_quant);

// Aligning the number of tokens to be processed by each expert such
// that it is divisible by the block size.
ops.def(
Expand Down
56 changes: 56 additions & 0 deletions tests/kernels/quant_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from typing import Tuple, Union

import torch


def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
return torch.as_tensor(x, dtype=torch.float32, device='cuda')

def ref_dynamic_per_token_quant(x: torch.tensor,
quant_dtype: torch.dtype) \
-> Tuple[torch.tensor, torch.tensor]:

assert quant_dtype in [torch.int8, torch.float8_e4m3fn]
qtype_traits = torch.iinfo(quant_dtype) if quant_dtype == torch.int8 \
else torch.finfo(quant_dtype)
qtype_max = as_float32_tensor(qtype_traits.max)

# For fp8, in order to match the cuda kernel output, we have to do exactly
# the same operations as in the corresponding fp8 kernel to prevent
# rounding errors.

# Compute scales
x_token_max, _ = x.abs().max(dim=-1)
x_token_max = as_float32_tensor(x_token_max)
scales = (x_token_max / qtype_max)[:, None]

# Quant
iscales = (qtype_max / x_token_max)[:, None]
torch_out = as_float32_tensor(x) * iscales
torch_out = torch_out.round() if quant_dtype == torch.int8 else torch_out
torch_out = torch_out.clamp(qtype_traits.min,
qtype_traits.max).to(quant_dtype)

return torch_out, scales


# The int8 version is very similar. Incorporate the int8 version, like in
# ref_dynamic_per_token_quant, when we have a dynamic_per_tensor int8 quant
# kernel
def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
-> Tuple[torch.tensor, torch.tensor]:

fp8_traits = torch.finfo(torch.float8_e4m3fn)
fp8_max = as_float32_tensor(fp8_traits.max)
one = as_float32_tensor(1.0)

# For fp8, in order to match the cuda kernel output, we have to do exactly
# the same operations as in the corresponding fp8 kernel to prevent
# rounding errors.

x_max = as_float32_tensor(x.abs().max())
ref_scale = x_max / fp8_max
ref_iscale = one / ref_scale
ref_out = (as_float32_tensor(x) * ref_iscale).clamp(
fp8_traits.min, fp8_traits.max).to(dtype=torch.float8_e4m3fn)
return ref_out, ref_scale
54 changes: 54 additions & 0 deletions tests/kernels/test_fp8_quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import pytest
import torch

import vllm._custom_ops as ops
from tests.kernels.quant_utils import (ref_dynamic_per_tensor_fp8_quant,
ref_dynamic_per_token_quant)

DTYPES = [torch.half, torch.bfloat16, torch.float]
HIDDEN_SIZES = [1, 2, 3, 4, 16, 67, 768, 2048, 5120, 5137, 8192,
8193] # Arbitrary values for testing
HIDDEN_SIZES += list(range(1024, 1033)) # vectorized conversion edge cases
NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing
SEEDS = [0]


@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode()
def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int,
dtype: torch.dtype, seed: int) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)

x = torch.rand(num_tokens, hidden_size, dtype=dtype,
device="cuda") + 1e-6 # avoid nans

ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.float8_e4m3fn)
ops_out, ops_scales = ops.dynamic_per_token_scaled_fp8_quant(x)

assert torch.allclose(ref_scales, ops_scales)
assert torch.allclose(ref_out.to(dtype=torch.float32),
ops_out.to(dtype=torch.float32))


@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode()
def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int,
dtype: torch.dtype, seed: int) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)

x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda")

ref_out, ref_scale = ref_dynamic_per_tensor_fp8_quant(x)
ops_out, ops_scale = ops.scaled_fp8_quant(x)

assert torch.allclose(ref_scale, ops_scale)
assert torch.allclose(ref_out.to(dtype=torch.float32),
ops_out.to(dtype=torch.float32))
26 changes: 10 additions & 16 deletions tests/kernels/test_int8_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly refactors in this file!

# ruff: noqa: F401
import vllm._C
from tests.kernels.quant_utils import ref_dynamic_per_token_quant
from vllm._custom_ops import scaled_int8_quant

DTYPES = [torch.half, torch.bfloat16, torch.float]
HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 5137, 8192,
Expand All @@ -21,23 +23,16 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
dtype: torch.dtype, seed: int) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
int8_traits = torch.iinfo(torch.int8)

x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000

x_token_max, _ = x.max(dim=1)
x_token_max = x_token_max.to(dtype=torch.float32)
scales = (x_token_max / float(127.0))[:, None].to(device="cuda",
dtype=torch.float32)
torch_out = (x / scales).round().clamp(int8_traits.min,
int8_traits.max).to(torch.int8)

ops_out = torch.empty_like(x, dtype=torch.int8, device="cuda")
scales_out = torch.empty_like(scales, dtype=torch.float32, device="cuda")
torch.ops._C.dynamic_scaled_int8_quant(ops_out, x, scales_out)
# reference
ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.int8)
# kernel
ops_out, ops_scales = scaled_int8_quant(x)

assert torch.allclose(scales_out, scales)
assert torch.allclose(torch_out, ops_out,
assert torch.allclose(ops_scales, ref_scales)
assert torch.allclose(ops_out, ref_out,
atol=1) # big atol to account for rounding errors


Expand All @@ -55,12 +50,11 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
int8_traits = torch.iinfo(torch.int8)

x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000
scale = torch.tensor([scale], dtype=torch.float32, device="cuda")

out1 = (x / scale).round().clamp(int8_traits.min,
int8_traits.max).to(torch.int8)
out2 = torch.empty_like(x, dtype=torch.int8)
scale_argument = torch.tensor([scale], dtype=torch.float32, device="cuda")
out2, _ = scaled_int8_quant(x, scale)

torch.ops._C.static_scaled_int8_quant(out2, x, scale_argument)
assert torch.allclose(out1, out2,
atol=1) # big atol to account for rounding errors
Loading
Loading