Skip to content

Commit

Permalink
[ Kernel ] FP8 Dynamic-Per-Token Quant Kernel (vllm-project#6511)
Browse files Browse the repository at this point in the history
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
  • Loading branch information
2 people authored and fialhocoelho committed Jul 19, 2024
1 parent d9124a4 commit 6834854
Show file tree
Hide file tree
Showing 7 changed files with 271 additions and 40 deletions.
10 changes: 7 additions & 3 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,16 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,

void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);

void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& scale);
void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor const& scale);

void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor& scale);

void dynamic_per_token_scaled_fp8_quant(torch::Tensor& out,
torch::Tensor const& input,
torch::Tensor& scale);

void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
int64_t block_size, torch::Tensor sorted_token_ids,
torch::Tensor experts_ids,
Expand Down
144 changes: 124 additions & 20 deletions csrc/quantization/fp8/common.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include "cuda_compat.h"
#include "dispatch_utils.h"

#include "../../reduction_utils.cuh"

namespace vllm {

__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
Expand Down Expand Up @@ -88,25 +90,48 @@ typedef struct __align__(4) {
float8x4_t;

template <typename scalar_t>
__global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
const scalar_t* __restrict__ input,
const float* __restrict__ scale,
int64_t num_elems) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
__device__ float thread_max_vec(scalar_t const* __restrict__ input,
int64_t const num_elems, int const tid,
int const step) {
// Vectorized input/output to better utilize memory bandwidth.
vec4_t<scalar_t> const* vectorized_in =
reinterpret_cast<vec4_t<scalar_t> const*>(input);

// Invert the scale so that we can use multiplications to avoid expensive
// division.
const float inverted_scale = 1.0f / (*scale);
int const num_vec_elems = num_elems >> 2;
float absmax_val = 0.0f;

#pragma unroll 4
for (int i = tid; i < num_vec_elems; i += step) {
vec4_t<scalar_t> in_vec = vectorized_in[i];
absmax_val = max(absmax_val, fabs(in_vec.x));
absmax_val = max(absmax_val, fabs(in_vec.y));
absmax_val = max(absmax_val, fabs(in_vec.z));
absmax_val = max(absmax_val, fabs(in_vec.w));
}

// Handle the remaining elements if num_elems is not divisible by 4
for (int i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
absmax_val = max(absmax_val, fabs(input[i]));
}

return absmax_val;
}

template <typename scalar_t>
__device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out,
scalar_t const* __restrict__ input,
float const inverted_scale,
int64_t const num_elems,
int const tid, int const step) {
// Vectorized input/output to better utilize memory bandwidth.
const vec4_t<scalar_t>* vectorized_in =
reinterpret_cast<const vec4_t<scalar_t>*>(input);
vec4_t<scalar_t> const* vectorized_in =
reinterpret_cast<vec4_t<scalar_t> const*>(input);
float8x4_t* vectorized_out = reinterpret_cast<float8x4_t*>(out);

int num_vec_elems = num_elems >> 2;
int const num_vec_elems = num_elems >> 2;

#pragma unroll 4
for (int i = tid; i < num_vec_elems; i += blockDim.x * gridDim.x) {
for (int i = tid; i < num_vec_elems; i += step) {
vec4_t<scalar_t> in_vec = vectorized_in[i];
float8x4_t out_vec;

Expand All @@ -118,17 +143,74 @@ __global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
}

// Handle the remaining elements if num_elems is not divisible by 4
for (int i = num_vec_elems * 4 + tid; i < num_elems;
i += blockDim.x * gridDim.x) {
for (int i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
out[i] = scaled_fp8_conversion(input[i], inverted_scale);
}
}

template <typename scalar_t>
__global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
const scalar_t* __restrict__ input,
const float* __restrict__ scale,
int64_t num_elems) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;

// Invert the scale so that we can use multiplications to avoid expensive
// division.
const float inverted_scale = 1.0f / (*scale);

scaled_fp8_conversion_vec(out, input, inverted_scale, num_elems, tid,
blockDim.x * gridDim.x);
}

template <typename scalar_t>
__global__ void dynamic_per_token_scaled_fp8_quant_kernel(
c10::Float8_e4m3fn* __restrict__ out, float* __restrict__ scale,
scalar_t const* __restrict__ input, const int hidden_size) {
int const tid = threadIdx.x;
int const token_idx = blockIdx.x;

scalar_t const* __restrict__ token_input = &input[token_idx * hidden_size];
c10::Float8_e4m3fn* __restrict__ token_output = &out[token_idx * hidden_size];

// For vectorization, token_input and token_output pointers need to be
// aligned at 8-byte and 4-byte addresses respectively.
bool const can_vectorize = hidden_size % 4 == 0;

float absmax_val = 0.0f;
if (can_vectorize) {
absmax_val = thread_max_vec(token_input, hidden_size, tid, blockDim.x);
} else {
for (int i = tid; i < hidden_size; i += blockDim.x) {
float const x = static_cast<float>(token_input[i]);
absmax_val = max(absmax_val, fabs(x));
}
}

float const block_absmax_val_maybe = blockReduceMax(absmax_val);
__shared__ float block_absmax_val;
if (tid == 0) {
block_absmax_val = block_absmax_val_maybe;
scale[token_idx] = block_absmax_val / FP8_E4M3_MAX;
}
__syncthreads();

float const inverted_scale = FP8_E4M3_MAX / block_absmax_val;
if (can_vectorize) {
scaled_fp8_conversion_vec(token_output, token_input, inverted_scale,
hidden_size, tid, blockDim.x);
} else {
for (int i = tid; i < hidden_size; i += blockDim.x) {
token_output[i] = scaled_fp8_conversion(token_input[i], inverted_scale);
}
}
}

} // namespace vllm

void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
torch::Tensor& input, // [..., d]
torch::Tensor& scale) // [1]
void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
torch::Tensor const& input, // [..., d]
torch::Tensor const& scale) // [1]
{
int64_t num_tokens = input.numel() / input.size(-1);
int64_t num_elems = input.numel();
Expand All @@ -144,9 +226,9 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
});
}

void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
torch::Tensor& input, // [..., d]
torch::Tensor& scale) // [1]
void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
torch::Tensor const& input, // [..., d]
torch::Tensor& scale) // [1]
{
int64_t num_tokens = input.numel() / input.size(-1);
int64_t num_elems = input.numel();
Expand All @@ -163,3 +245,25 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
scale.data_ptr<float>(), num_elems);
});
}

void dynamic_per_token_scaled_fp8_quant(torch::Tensor& out, // [..., d]
torch::Tensor const& input, // [..., d]
torch::Tensor& scales) {
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());

int const hidden_size = input.size(-1);
int const num_tokens = input.numel() / hidden_size;
dim3 const grid(num_tokens);
dim3 const block(std::min(hidden_size, 1024));

const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "dynamic_per_token_scaled_fp8_quant_kernel", [&] {
vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t>
<<<grid, block, 0, stream>>>(
out.data_ptr<c10::Float8_e4m3fn>(), scales.data_ptr<float>(),
input.data_ptr<scalar_t>(), hidden_size);
});
}
10 changes: 9 additions & 1 deletion csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,12 +179,20 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"static_scaled_fp8_quant(Tensor! out, Tensor input, Tensor scale) -> ()");
ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant);

// Compute FP8 quantized tensor and scaling factor.
// Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.
ops.def(
"dynamic_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! scale) -> "
"()");
ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);

// Compute dynamic-per-token FP8 quantized tensor and scaling factor.
ops.def(
"dynamic_per_token_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! "
"scale) -> "
"()");
ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
&dynamic_per_token_scaled_fp8_quant);

// Aligning the number of tokens to be processed by each expert such
// that it is divisible by the block size.
ops.def(
Expand Down
56 changes: 56 additions & 0 deletions tests/kernels/quant_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from typing import Tuple, Union

import torch


def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
return torch.as_tensor(x, dtype=torch.float32, device='cuda')

def ref_dynamic_per_token_quant(x: torch.tensor,
quant_dtype: torch.dtype) \
-> Tuple[torch.tensor, torch.tensor]:

assert quant_dtype in [torch.int8, torch.float8_e4m3fn]
qtype_traits = torch.iinfo(quant_dtype) if quant_dtype == torch.int8 \
else torch.finfo(quant_dtype)
qtype_max = as_float32_tensor(qtype_traits.max)

# For fp8, in order to match the cuda kernel output, we have to do exactly
# the same operations as in the corresponding fp8 kernel to prevent
# rounding errors.

# Compute scales
x_token_max, _ = x.abs().max(dim=-1)
x_token_max = as_float32_tensor(x_token_max)
scales = (x_token_max / qtype_max)[:, None]

# Quant
iscales = (qtype_max / x_token_max)[:, None]
torch_out = as_float32_tensor(x) * iscales
torch_out = torch_out.round() if quant_dtype == torch.int8 else torch_out
torch_out = torch_out.clamp(qtype_traits.min,
qtype_traits.max).to(quant_dtype)

return torch_out, scales


# The int8 version is very similar. Incorporate the int8 version, like in
# ref_dynamic_per_token_quant, when we have a dynamic_per_tensor int8 quant
# kernel
def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
-> Tuple[torch.tensor, torch.tensor]:

fp8_traits = torch.finfo(torch.float8_e4m3fn)
fp8_max = as_float32_tensor(fp8_traits.max)
one = as_float32_tensor(1.0)

# For fp8, in order to match the cuda kernel output, we have to do exactly
# the same operations as in the corresponding fp8 kernel to prevent
# rounding errors.

x_max = as_float32_tensor(x.abs().max())
ref_scale = x_max / fp8_max
ref_iscale = one / ref_scale
ref_out = (as_float32_tensor(x) * ref_iscale).clamp(
fp8_traits.min, fp8_traits.max).to(dtype=torch.float8_e4m3fn)
return ref_out, ref_scale
54 changes: 54 additions & 0 deletions tests/kernels/test_fp8_quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import pytest
import torch

import vllm._custom_ops as ops
from tests.kernels.quant_utils import (ref_dynamic_per_tensor_fp8_quant,
ref_dynamic_per_token_quant)

DTYPES = [torch.half, torch.bfloat16, torch.float]
HIDDEN_SIZES = [1, 2, 3, 4, 16, 67, 768, 2048, 5120, 5137, 8192,
8193] # Arbitrary values for testing
HIDDEN_SIZES += list(range(1024, 1033)) # vectorized conversion edge cases
NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing
SEEDS = [0]


@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode()
def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int,
dtype: torch.dtype, seed: int) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)

x = torch.rand(num_tokens, hidden_size, dtype=dtype,
device="cuda") + 1e-6 # avoid nans

ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.float8_e4m3fn)
ops_out, ops_scales = ops.dynamic_per_token_scaled_fp8_quant(x)

assert torch.allclose(ref_scales, ops_scales)
assert torch.allclose(ref_out.to(dtype=torch.float32),
ops_out.to(dtype=torch.float32))


@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode()
def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int,
dtype: torch.dtype, seed: int) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)

x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda")

ref_out, ref_scale = ref_dynamic_per_tensor_fp8_quant(x)
ops_out, ops_scale = ops.scaled_fp8_quant(x)

assert torch.allclose(ref_scale, ops_scale)
assert torch.allclose(ref_out.to(dtype=torch.float32),
ops_out.to(dtype=torch.float32))
26 changes: 10 additions & 16 deletions tests/kernels/test_int8_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

# ruff: noqa: F401
import vllm._C
from tests.kernels.quant_utils import ref_dynamic_per_token_quant
from vllm._custom_ops import scaled_int8_quant

DTYPES = [torch.half, torch.bfloat16, torch.float]
HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 5137, 8192,
Expand All @@ -21,23 +23,16 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
dtype: torch.dtype, seed: int) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
int8_traits = torch.iinfo(torch.int8)

x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000

x_token_max, _ = x.max(dim=1)
x_token_max = x_token_max.to(dtype=torch.float32)
scales = (x_token_max / float(127.0))[:, None].to(device="cuda",
dtype=torch.float32)
torch_out = (x / scales).round().clamp(int8_traits.min,
int8_traits.max).to(torch.int8)

ops_out = torch.empty_like(x, dtype=torch.int8, device="cuda")
scales_out = torch.empty_like(scales, dtype=torch.float32, device="cuda")
torch.ops._C.dynamic_scaled_int8_quant(ops_out, x, scales_out)
# reference
ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.int8)
# kernel
ops_out, ops_scales = scaled_int8_quant(x)

assert torch.allclose(scales_out, scales)
assert torch.allclose(torch_out, ops_out,
assert torch.allclose(ops_scales, ref_scales)
assert torch.allclose(ops_out, ref_out,
atol=1) # big atol to account for rounding errors


Expand All @@ -55,12 +50,11 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
int8_traits = torch.iinfo(torch.int8)

x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000
scale = torch.tensor([scale], dtype=torch.float32, device="cuda")

out1 = (x / scale).round().clamp(int8_traits.min,
int8_traits.max).to(torch.int8)
out2 = torch.empty_like(x, dtype=torch.int8)
scale_argument = torch.tensor([scale], dtype=torch.float32, device="cuda")
out2, _ = scaled_int8_quant(x, scale)

torch.ops._C.static_scaled_int8_quant(out2, x, scale_argument)
assert torch.allclose(out1, out2,
atol=1) # big atol to account for rounding errors
Loading

0 comments on commit 6834854

Please sign in to comment.