Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ Kernel ] FP8 Dynamic Per Token Quant - Add scale_ub #6593

Merged
merged 2 commits into from
Jul 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,9 @@ void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor& scale);

void dynamic_per_token_scaled_fp8_quant(torch::Tensor& out,
torch::Tensor const& input,
torch::Tensor& scale);
void dynamic_per_token_scaled_fp8_quant(
torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
c10::optional<torch::Tensor> const& scale_ub);

void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
int64_t block_size, torch::Tensor sorted_token_ids,
Expand Down
73 changes: 48 additions & 25 deletions csrc/quantization/fp8/common.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,16 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {

#define FP8_E4M3_MAX std::numeric_limits<c10::Float8_e4m3fn>::max()

template <typename scalar_t>
template <bool is_scale_inverted>
__device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(
const scalar_t val, const float inverted_scale) {
float x = static_cast<float>(val) * inverted_scale;
float const val, float const scale) {
float x = 0.0f;
if constexpr (is_scale_inverted) {
x = val * scale;
} else {
x = val / scale;
}

float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
return static_cast<c10::Float8_e4m3fn>(r);
}
Expand Down Expand Up @@ -117,10 +123,10 @@ __device__ float thread_max_vec(scalar_t const* __restrict__ input,
return absmax_val;
}

template <typename scalar_t>
template <typename scalar_t, bool is_scale_inverted>
__device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out,
scalar_t const* __restrict__ input,
float const inverted_scale,
float const scale,
int64_t const num_elems,
int const tid, int const step) {
// Vectorized input/output to better utilize memory bandwidth.
Expand All @@ -135,16 +141,21 @@ __device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out,
vec4_t<scalar_t> in_vec = vectorized_in[i];
float8x4_t out_vec;

out_vec.x = scaled_fp8_conversion(in_vec.x, inverted_scale);
out_vec.y = scaled_fp8_conversion(in_vec.y, inverted_scale);
out_vec.z = scaled_fp8_conversion(in_vec.z, inverted_scale);
out_vec.w = scaled_fp8_conversion(in_vec.w, inverted_scale);
out_vec.x = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(in_vec.x), scale);
out_vec.y = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(in_vec.y), scale);
out_vec.z = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(in_vec.z), scale);
out_vec.w = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(in_vec.w), scale);
vectorized_out[i] = out_vec;
}

// Handle the remaining elements if num_elems is not divisible by 4
for (int i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
out[i] = scaled_fp8_conversion(input[i], inverted_scale);
out[i] = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(input[i]), scale);
}
}

Expand All @@ -158,15 +169,17 @@ __global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
// Invert the scale so that we can use multiplications to avoid expensive
// division.
const float inverted_scale = 1.0f / (*scale);

scaled_fp8_conversion_vec(out, input, inverted_scale, num_elems, tid,
blockDim.x * gridDim.x);
scaled_fp8_conversion_vec<scalar_t, true>(
out, input, inverted_scale, num_elems, tid, blockDim.x * gridDim.x);
}

template <typename scalar_t>
__global__ void dynamic_per_token_scaled_fp8_quant_kernel(
c10::Float8_e4m3fn* __restrict__ out, float* __restrict__ scale,
scalar_t const* __restrict__ input, const int hidden_size) {
scalar_t const* __restrict__ input, float const* __restrict__ scale_ub,
const int hidden_size) {
float const min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);

int const tid = threadIdx.x;
int const token_idx = blockIdx.x;

Expand All @@ -188,20 +201,27 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
}

float const block_absmax_val_maybe = blockReduceMax(absmax_val);
__shared__ float block_absmax_val;
__shared__ float token_scale;
if (tid == 0) {
block_absmax_val = block_absmax_val_maybe;
scale[token_idx] = block_absmax_val / FP8_E4M3_MAX;
if (scale_ub) {
token_scale = min(block_absmax_val_maybe, *scale_ub);
} else {
token_scale = block_absmax_val_maybe;
}
// token scale computation
token_scale = max(token_scale / FP8_E4M3_MAX, min_scaling_factor);
scale[token_idx] = token_scale;
}
__syncthreads();

float const inverted_scale = FP8_E4M3_MAX / block_absmax_val;
// Note that we don't use inverted scales so we can match FBGemm impl.
if (can_vectorize) {
scaled_fp8_conversion_vec(token_output, token_input, inverted_scale,
hidden_size, tid, blockDim.x);
scaled_fp8_conversion_vec<scalar_t, false>(
token_output, token_input, token_scale, hidden_size, tid, blockDim.x);
} else {
for (int i = tid; i < hidden_size; i += blockDim.x) {
token_output[i] = scaled_fp8_conversion(token_input[i], inverted_scale);
token_output[i] = scaled_fp8_conversion<false>(
static_cast<float>(token_input[i]), token_scale);
}
}
}
Expand Down Expand Up @@ -246,9 +266,10 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
});
}

void dynamic_per_token_scaled_fp8_quant(torch::Tensor& out, // [..., d]
torch::Tensor const& input, // [..., d]
torch::Tensor& scales) {
void dynamic_per_token_scaled_fp8_quant(
torch::Tensor& out, // [..., d]
torch::Tensor const& input, // [..., d]
torch::Tensor& scales, std::optional<at::Tensor> const& scale_ub) {
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());

Expand All @@ -264,6 +285,8 @@ void dynamic_per_token_scaled_fp8_quant(torch::Tensor& out, // [..., d]
vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t>
<<<grid, block, 0, stream>>>(
out.data_ptr<c10::Float8_e4m3fn>(), scales.data_ptr<float>(),
input.data_ptr<scalar_t>(), hidden_size);
input.data_ptr<scalar_t>(),
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
hidden_size);
});
}
2 changes: 1 addition & 1 deletion csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Compute dynamic-per-token FP8 quantized tensor and scaling factor.
ops.def(
"dynamic_per_token_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! "
"scale) -> "
"scale, Tensor? scale_ub) -> "
"()");
ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
&dynamic_per_token_scaled_fp8_quant);
Expand Down
30 changes: 23 additions & 7 deletions tests/kernels/quant_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Tuple, Union
from typing import Optional, Tuple, Union

import torch

Expand All @@ -7,13 +7,19 @@ def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
return torch.as_tensor(x, dtype=torch.float32, device='cuda')

def ref_dynamic_per_token_quant(x: torch.tensor,
quant_dtype: torch.dtype) \
quant_dtype: torch.dtype,
scale_ub: Optional[torch.tensor] = None) \
-> Tuple[torch.tensor, torch.tensor]:

assert quant_dtype in [torch.int8, torch.float8_e4m3fn]
if scale_ub is not None:
assert quant_dtype == torch.float8_e4m3fn

qtype_traits = torch.iinfo(quant_dtype) if quant_dtype == torch.int8 \
else torch.finfo(quant_dtype)
qtype_max = as_float32_tensor(qtype_traits.max)
s_1 = as_float32_tensor(1.0)
s_512 = as_float32_tensor(512.0)

# For fp8, in order to match the cuda kernel output, we have to do exactly
# the same operations as in the corresponding fp8 kernel to prevent
Expand All @@ -22,14 +28,24 @@ def ref_dynamic_per_token_quant(x: torch.tensor,
# Compute scales
x_token_max, _ = x.abs().max(dim=-1)
x_token_max = as_float32_tensor(x_token_max)
if scale_ub is not None:
x_token_max = x_token_max.clamp(max=scale_ub)
scales = (x_token_max / qtype_max)[:, None]

# Quant
iscales = (qtype_max / x_token_max)[:, None]
torch_out = as_float32_tensor(x) * iscales
torch_out = torch_out.round() if quant_dtype == torch.int8 else torch_out
torch_out = torch_out.clamp(qtype_traits.min,
qtype_traits.max).to(quant_dtype)
if quant_dtype == torch.int8:
iscales = as_float32_tensor(s_1 / scales)
torch_out = as_float32_tensor(x) * iscales
torch_out = torch_out.round()
torch_out = torch_out.clamp(qtype_traits.min,
qtype_traits.max).to(quant_dtype)
else:
assert quant_dtype == torch.float8_e4m3fn
min_scaling_factor = s_1 / (qtype_max * s_512)
scales = scales.clamp(min=min_scaling_factor)
torch_out = as_float32_tensor(x) / scales
torch_out = torch_out.clamp(qtype_traits.min,
qtype_traits.max).to(quant_dtype)

return torch_out, scales

Expand Down
11 changes: 9 additions & 2 deletions tests/kernels/test_fp8_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,31 @@
8193] # Arbitrary values for testing
HIDDEN_SIZES += list(range(1024, 1033)) # vectorized conversion edge cases
NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing
SCALE_UBS = [True, False]
SEEDS = [0]


@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("scale_ub", SCALE_UBS)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode()
def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int,
dtype: torch.dtype, seed: int) -> None:
dtype: torch.dtype, scale_ub: bool,
seed: int) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)

x = torch.rand(num_tokens, hidden_size, dtype=dtype,
device="cuda") + 1e-6 # avoid nans

ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.float8_e4m3fn)
scale_ub = torch.mean(x).to(dtype=torch.float32, device='cuda') \
if scale_ub else None
ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.float8_e4m3fn,
scale_ub)
ops_out, ops_scales = ops.scaled_fp8_quant(x,
scale_ub=scale_ub,
use_per_token_if_dynamic=True)

assert torch.allclose(ref_scales, ops_scales)
Expand Down
3 changes: 2 additions & 1 deletion vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ def scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
batch_dim_padding: Optional[int] = None,
scale_ub: Optional[torch.Tensor] = None,
use_per_token_if_dynamic: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Expand Down Expand Up @@ -336,7 +337,7 @@ def scaled_fp8_quant(
device=input.device,
dtype=torch.float32)
torch.ops._C.dynamic_per_token_scaled_fp8_quant(
output, input, scale)
output, input, scale, scale_ub)
else:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
Expand Down
Loading