Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ Kernel ] FP8 Dynamic-Per-Token Quant Kernel #6511

Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,10 @@ void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& scale);

void dynamic_per_token_scaled_fp8_quant(torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& scale);
varun-sundar-rabindranath marked this conversation as resolved.
Show resolved Hide resolved

void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
int64_t block_size, torch::Tensor sorted_token_ids,
torch::Tensor experts_ids,
Expand Down
100 changes: 84 additions & 16 deletions csrc/quantization/fp8/common.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include "cuda_compat.h"
#include "dispatch_utils.h"

#include "../../reduction_utils.cuh"

namespace vllm {

__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
Expand Down Expand Up @@ -88,25 +90,20 @@ typedef struct __align__(4) {
float8x4_t;

template <typename scalar_t>
__global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
const scalar_t* __restrict__ input,
const float* __restrict__ scale,
int64_t num_elems) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;

// Invert the scale so that we can use multiplications to avoid expensive
// division.
const float inverted_scale = 1.0f / (*scale);

__device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out,
scalar_t const* __restrict__ input,
float const inverted_scale,
int64_t const num_elems,
int const tid, int const step) {
// Vectorized input/output to better utilize memory bandwidth.
const vec4_t<scalar_t>* vectorized_in =
reinterpret_cast<const vec4_t<scalar_t>*>(input);
vec4_t<scalar_t> const* vectorized_in =
reinterpret_cast<vec4_t<scalar_t> const*>(input);
float8x4_t* vectorized_out = reinterpret_cast<float8x4_t*>(out);

int num_vec_elems = num_elems >> 2;
int const num_vec_elems = num_elems >> 2;

#pragma unroll 4
for (int i = tid; i < num_vec_elems; i += blockDim.x * gridDim.x) {
for (int i = tid; i < num_vec_elems; i += step) {
vec4_t<scalar_t> in_vec = vectorized_in[i];
float8x4_t out_vec;

Expand All @@ -118,12 +115,62 @@ __global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
}

// Handle the remaining elements if num_elems is not divisible by 4
for (int i = num_vec_elems * 4 + tid; i < num_elems;
i += blockDim.x * gridDim.x) {
for (int i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
out[i] = scaled_fp8_conversion(input[i], inverted_scale);
}
}

template <typename scalar_t>
__global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
const scalar_t* __restrict__ input,
const float* __restrict__ scale,
int64_t num_elems) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;

// Invert the scale so that we can use multiplications to avoid expensive
// division.
const float inverted_scale = 1.0f / (*scale);

scaled_fp8_conversion_vec(out, input, inverted_scale, num_elems, tid,
blockDim.x * gridDim.x);
}

template <typename scalar_t>
__global__ void dynamic_per_token_scaled_fp8_quant_kernel(
c10::Float8_e4m3fn* __restrict__ out, float* __restrict__ scale,
scalar_t const* __restrict__ input, const int hidden_size,
bool const vectorize_conversions) {
int const tid = threadIdx.x;
int const token_idx = blockIdx.x;
float absmax_val = 0.0f;

for (int i = tid; i < hidden_size; i += blockDim.x) {
float const x = static_cast<float>(input[token_idx * hidden_size + i]);
absmax_val = max(absmax_val, fabs(x));
}
varun-sundar-rabindranath marked this conversation as resolved.
Show resolved Hide resolved

float const block_absmax_val_maybe = blockReduceMax(absmax_val);
__shared__ float block_absmax_val;
if (tid == 0) {
block_absmax_val = block_absmax_val_maybe;
scale[token_idx] = block_absmax_val / FP8_E4M3_MAX;
}
__syncthreads();

float const inverted_scale = FP8_E4M3_MAX / block_absmax_val;
if (vectorize_conversions) {
scalar_t const* token_input = &input[token_idx * hidden_size];
c10::Float8_e4m3fn* token_output = &out[token_idx * hidden_size];
scaled_fp8_conversion_vec(token_output, token_input, inverted_scale,
hidden_size, tid, blockDim.x);
} else {
for (int i = tid; i < hidden_size; i += blockDim.x) {
out[token_idx * hidden_size + i] = scaled_fp8_conversion(
input[token_idx * hidden_size + i], inverted_scale);
}
}
}

} // namespace vllm

void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
Expand Down Expand Up @@ -163,3 +210,24 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
scale.data_ptr<float>(), num_elems);
});
}

void dynamic_per_token_scaled_fp8_quant(torch::Tensor& out, // [..., d]
torch::Tensor& input, // [..., d]
torch::Tensor& scales) {
int const hidden_size = input.size(-1);
int const num_tokens = input.numel() / hidden_size;
dim3 const grid(num_tokens);
dim3 const block(std::min(hidden_size, 1024));
bool const vectorize_conversions =
(hidden_size % 4 == 0) && input.is_contiguous() && out.is_contiguous();

const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "dynamic_per_token_scaled_fp8_quant_kernel", [&] {
vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t>
<<<grid, block, 0, stream>>>(
out.data_ptr<c10::Float8_e4m3fn>(), scales.data_ptr<float>(),
input.data_ptr<scalar_t>(), hidden_size, vectorize_conversions);
});
}
10 changes: 9 additions & 1 deletion csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,12 +175,20 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"static_scaled_fp8_quant(Tensor! out, Tensor input, Tensor scale) -> ()");
ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant);

// Compute FP8 quantized tensor and scaling factor.
// Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.
ops.def(
"dynamic_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! scale) -> "
"()");
ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);

// Compute dynamic-per-token FP8 quantized tensor and scaling factor.
ops.def(
"dynamic_per_token_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! "
"scale) -> "
"()");
ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
&dynamic_per_token_scaled_fp8_quant);

// Aligning the number of tokens to be processed by each expert such
// that it is divisible by the block size.
ops.def(
Expand Down
56 changes: 56 additions & 0 deletions tests/kernels/quant_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from typing import Tuple, Union

import torch


def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
return torch.as_tensor(x, dtype=torch.float32, device='cuda')

def ref_dynamic_per_token_quant(x: torch.tensor,
quant_dtype: torch.dtype) \
-> Tuple[torch.tensor, torch.tensor]:

assert quant_dtype in [torch.int8, torch.float8_e4m3fn]
qtype_traits = torch.iinfo(quant_dtype) if quant_dtype == torch.int8 \
else torch.finfo(quant_dtype)
qtype_max = as_float32_tensor(qtype_traits.max)

# For fp8, inorder to match the cuda kernel output, we have to do exactly
varun-sundar-rabindranath marked this conversation as resolved.
Show resolved Hide resolved
# the same operations as in the corresponding fp8 kernel to prevent rounding
# errors.

# Compute scales
x_token_max, _ = x.abs().max(dim=-1)
x_token_max = as_float32_tensor(x_token_max)
scales = (x_token_max / qtype_max)[:, None]

# Quant
iscales = (qtype_max / x_token_max)[:, None]
torch_out = as_float32_tensor(x) * iscales
torch_out = torch_out.round() if quant_dtype == torch.int8 else torch_out
torch_out = torch_out.clamp(qtype_traits.min,
qtype_traits.max).to(quant_dtype)

return torch_out, scales


# The int8 version is very similar. Incorporate the int8 version, like in
# ref_dynamic_per_token_quant, when we have a dynamic_per_tensor int8 quant
# kernel
def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
-> Tuple[torch.tensor, torch.tensor]:

fp8_traits = torch.finfo(torch.float8_e4m3fn)
fp8_max = as_float32_tensor(fp8_traits.max)
one = as_float32_tensor(1.0)

# For fp8, inorder to match the cuda kernel output, we have to do exactly
varun-sundar-rabindranath marked this conversation as resolved.
Show resolved Hide resolved
# the same operations as in the corresponding fp8 kernel to prevent rounding
# errors.

x_max = as_float32_tensor(x.abs().max())
ref_scale = x_max / fp8_max
ref_iscale = one / ref_scale
ref_out = (as_float32_tensor(x) * ref_iscale).clamp(
fp8_traits.min, fp8_traits.max).to(dtype=torch.float8_e4m3fn)
return ref_out, ref_scale
54 changes: 54 additions & 0 deletions tests/kernels/test_fp8_quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import pytest
import torch

import vllm._custom_ops as ops
from tests.kernels.quant_utils import (ref_dynamic_per_tensor_fp8_quant,
ref_dynamic_per_token_quant)

DTYPES = [torch.half, torch.bfloat16, torch.float]
HIDDEN_SIZES = [1, 2, 3, 4, 16, 67, 768, 2048, 5120, 5137, 8192,
8193] # Arbitrary values for testing
HIDDEN_SIZES += list(range(1024, 1033)) # vectorized conversion edge cases
NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing
SEEDS = [0]


@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode()
def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int,
dtype: torch.dtype, seed: int) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)

x = torch.rand(num_tokens, hidden_size, dtype=dtype,
device="cuda") + 1e-6 # avoid nans

ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.float8_e4m3fn)
ops_out, ops_scales = ops.dynamic_per_token_scaled_fp8_quant(x)

assert torch.allclose(ref_scales, ops_scales)
assert torch.allclose(ref_out.to(dtype=torch.float32),
ops_out.to(dtype=torch.float32))


@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode()
def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int,
dtype: torch.dtype, seed: int) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)

x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda")

ref_out, ref_scale = ref_dynamic_per_tensor_fp8_quant(x)
ops_out, ops_scale = ops.scaled_fp8_quant(x)

assert torch.allclose(ref_scale, ops_scale)
assert torch.allclose(ref_out.to(dtype=torch.float32),
ops_out.to(dtype=torch.float32))
26 changes: 10 additions & 16 deletions tests/kernels/test_int8_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly refactors in this file!

# ruff: noqa: F401
import vllm._C
from tests.kernels.quant_utils import ref_dynamic_per_token_quant
from vllm._custom_ops import scaled_int8_quant

DTYPES = [torch.half, torch.bfloat16, torch.float]
HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 5137, 8192,
Expand All @@ -21,23 +23,16 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
dtype: torch.dtype, seed: int) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
int8_traits = torch.iinfo(torch.int8)

x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000

x_token_max, _ = x.max(dim=1)
x_token_max = x_token_max.to(dtype=torch.float32)
scales = (x_token_max / float(127.0))[:, None].to(device="cuda",
dtype=torch.float32)
torch_out = (x / scales).round().clamp(int8_traits.min,
int8_traits.max).to(torch.int8)

ops_out = torch.empty_like(x, dtype=torch.int8, device="cuda")
scales_out = torch.empty_like(scales, dtype=torch.float32, device="cuda")
torch.ops._C.dynamic_scaled_int8_quant(ops_out, x, scales_out)
# reference
ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.int8)
# kernel
ops_out, ops_scales = scaled_int8_quant(x)

assert torch.allclose(scales_out, scales)
assert torch.allclose(torch_out, ops_out,
assert torch.allclose(ops_scales, ref_scales)
assert torch.allclose(ops_out, ref_out,
atol=1) # big atol to account for rounding errors


Expand All @@ -55,12 +50,11 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
int8_traits = torch.iinfo(torch.int8)

x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000
scale = torch.tensor([scale], dtype=torch.float32, device="cuda")

out1 = (x / scale).round().clamp(int8_traits.min,
int8_traits.max).to(torch.int8)
out2 = torch.empty_like(x, dtype=torch.int8)
scale_argument = torch.tensor([scale], dtype=torch.float32, device="cuda")
out2, _ = scaled_int8_quant(x, scale)

torch.ops._C.static_scaled_int8_quant(out2, x, scale_argument)
assert torch.allclose(out1, out2,
atol=1) # big atol to account for rounding errors
11 changes: 11 additions & 0 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,17 @@ def scaled_fp8_quant(
return output, scale


def dynamic_per_token_scaled_fp8_quant(
input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To fit the existing interface, where all quantization schemes go through scaled_fp8_quant, this should probably also be merged up with the scaled_fp8_quant. How do we want to handle this ?
I believe the dynamic-per-token vs dynamic-per-tensor model quantization scheme will come from the model itself and we can plumb it down into the scaled_fp8_quant function. I'd appreciate any references into will how to plumb this down. Thanks.
@robertgshaw2-neuralmagic @comaniac @tlrmchlsmth ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this looks good as is. I can handle the plumbing from here in a follow up with some end-to-end model tests

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely prefer to have a unified API.


output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
scales = torch.empty((input.numel() // input.shape[-1], 1),
device=input.device,
dtype=torch.float32)
torch.ops._C.dynamic_per_token_scaled_fp8_quant(output, input, scales)
return output, scales


# int8
def scaled_int8_quant(
input: torch.Tensor,
Expand Down
Loading