Skip to content

Commit

Permalink
Support INT4 Dequant onto GPU for Seq INT TBE look up (#3584)
Browse files Browse the repository at this point in the history
Summary:

Seq INT4 -> INT4 STBE look up is supported in the diff stack: https://www.internalfb.com/diff/D61305978 . 

This diff supports:

1. The dequanitzation of INT4 -> INT4 STBE look up onto Cuda for all float types
2. Extends the dequantization of INT4 > INT4 STBE look up onto CPU for BF16

The main gap is to handle the dequant for the case when scale bias for INT4 quantized tensor is in the front. While for CPU, just need to add the dequantization for BF16 based on dtype.

This will enable us to reduce the network overhead to remote embedding server as well as D2H data transfer from onto GPU host.

Differential Revision: D68187234
  • Loading branch information
faran928 authored and facebook-github-bot committed Jan 17, 2025
1 parent 3e0db25 commit 867b7f7
Show file tree
Hide file tree
Showing 9 changed files with 186 additions and 55 deletions.
3 changes: 2 additions & 1 deletion fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,8 @@ at::Tensor fusednbitrowwise_to_half_cpu(
at::Tensor fusednbitrowwise_to_float_or_half_cpu(
const at::Tensor& input,
const int64_t bit_rate,
const int64_t output_dtype);
const int64_t output_dtype,
const bool scale_bias_last);

at::Tensor quantize_mx_cuda(
const at::Tensor& input,
Expand Down
58 changes: 39 additions & 19 deletions fbgemm_gpu/src/quantize_ops/quantize_fused_nbit_rowwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ __global__ inline void _float_to_fusednbitrowwise_cuda_kernel(
}

// Fused 4/2-bit rowwise -> FP32/FP16 kernel
template <typename output_t>
template <typename output_t, bool scale_bias_last>
__global__ inline void _fusednbitrowwise_to_float_cuda_kernel(
const int bit_rate,
const std::uint8_t* input,
Expand All @@ -83,7 +83,6 @@ __global__ inline void _fusednbitrowwise_to_float_cuda_kernel(
output_t* const output) {
const int num_elem_per_byte = 8 / bit_rate;
const int output_columns = (ncols - 2 * sizeof(__half)) * num_elem_per_byte;

int row = (int)blockIdx.y * blockDim.y + threadIdx.y;
const int col = (int)blockIdx.x * blockDim.x + threadIdx.x;
const int row_incre = blockDim.y * gridDim.y;
Expand All @@ -92,9 +91,14 @@ __global__ inline void _fusednbitrowwise_to_float_cuda_kernel(
const std::uint8_t* input_row = input + row * ncols;
const __half* input_row_scale_bias = reinterpret_cast<const __half*>(
input_row +
(output_columns + num_elem_per_byte - 1) / num_elem_per_byte);
(!scale_bias_last
? 0
: (output_columns + num_elem_per_byte - 1) / num_elem_per_byte));
float scale = __half2float(input_row_scale_bias[0]);
float bias = __half2float(input_row_scale_bias[1]);
if constexpr (!scale_bias_last) {
input_row += 2 * sizeof(__half);
}
output_t* output_row = output + row * output_columns;

std::uint8_t quantized = input_row[col / num_elem_per_byte];
Expand Down Expand Up @@ -215,7 +219,8 @@ DLL_PUBLIC Tensor _single_or_half_precision_to_fusednbitrowwise_gpu(
template <typename output_t>
Tensor _fusednbitrowwise_to_float_gpu_t(
const Tensor& input,
const int64_t bit_rate) {
const int64_t bit_rate,
const bool scale_bias_last) {
TENSOR_ON_CUDA_GPU(input);
TENSOR_NDIM_EQUALS(input, 2);
CUDA_DEVICE_GUARD(input);
Expand Down Expand Up @@ -245,7 +250,9 @@ Tensor _fusednbitrowwise_to_float_gpu_t(
{nrows, output_columns}, // 2 = sizeof(bfloat16)
input.options().dtype(at::kBFloat16));
} else {
TORCH_CHECK(false, "Unsupported output dtype");
TORCH_CHECK(
false,
"Unsupported output dtype within _fusednbitrowwise_to_float_gpu_t");
}
if (nrows == 0 || output_columns == 0) {
Expand All @@ -260,18 +267,25 @@ Tensor _fusednbitrowwise_to_float_gpu_t(
const auto gridDim_y = cuda_calc_block_count(nrows, blockDim.y);
const dim3 gridDim(gridDim_x, gridDim_y);
#define DEQUANT_LAUNCH_NBIT(scale_bias_last) \
_fusednbitrowwise_to_float_cuda_kernel<scalar_t, scale_bias_last> \
<<<gridDim, blockDim, 0, at::cuda::getCurrentCUDAStream()>>>( \
bit_rate, \
input.data_ptr<std::uint8_t>(), \
nrows, \
ncols, \
output.data_ptr<scalar_t>())
FBGEMM_DISPATCH_FLOATING_TYPES(
output.scalar_type(), "fusednbitrowwise_to_float_cuda_kernel", [&] {
_fusednbitrowwise_to_float_cuda_kernel<scalar_t>
<<<gridDim, blockDim, 0, at::cuda::getCurrentCUDAStream()>>>(
bit_rate,
input.data_ptr<uint8_t>(),
nrows,
ncols,
output.data_ptr<scalar_t>());
if (scale_bias_last) {
DEQUANT_LAUNCH_NBIT(true);
} else {
DEQUANT_LAUNCH_NBIT(false);
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
#undef DEQUANT_LAUNCH_NBIT
return output;
}
Expand All @@ -286,7 +300,8 @@ Tensor _fusednbitrowwise_to_float_gpu_t(
DLL_PUBLIC at::Tensor _fusednbitrowwise_to_float_gpu(
const at::Tensor& input,
const int64_t bit_rate) {
return _fusednbitrowwise_to_float_gpu_t<float>(input, bit_rate);
return _fusednbitrowwise_to_float_gpu_t<float>(
input, bit_rate, true /* scale_bias_last */);
}
/// @ingroup quantize-ops-cuda
Expand All @@ -301,7 +316,8 @@ DLL_PUBLIC at::Tensor _fusednbitrowwise_to_float_gpu(
DLL_PUBLIC at::Tensor _fusednbitrowwise_to_half_gpu(
const at::Tensor& input,
const int64_t bit_rate) {
return _fusednbitrowwise_to_float_gpu_t<at::Half>(input, bit_rate);
return _fusednbitrowwise_to_float_gpu_t<at::Half>(
input, bit_rate, true /* scale_bias_last */);
}
/// @ingroup quantize-ops-cuda
Expand All @@ -321,19 +337,23 @@ DLL_PUBLIC at::Tensor _fusednbitrowwise_to_half_gpu(
DLL_PUBLIC at::Tensor _fusednbitrowwise_to_single_or_half_precision_gpu(
const at::Tensor& input,
const int64_t bit_rate,
const int64_t output_dtype) {
const int64_t output_dtype,
const bool scale_bias_last) {
Tensor output;
SparseType output_sparse_dtype = static_cast<SparseType>(output_dtype);
switch (output_sparse_dtype) {
case SparseType::FP32:
output = _fusednbitrowwise_to_float_gpu_t<float>(input, bit_rate);
output = _fusednbitrowwise_to_float_gpu_t<float>(
input, bit_rate, scale_bias_last);
break;
case SparseType::FP16:
output = _fusednbitrowwise_to_float_gpu_t<at::Half>(input, bit_rate);
output = _fusednbitrowwise_to_float_gpu_t<at::Half>(
input, bit_rate, scale_bias_last);
break;
case SparseType::BF16:
output = _fusednbitrowwise_to_float_gpu_t<at::BFloat16>(input, bit_rate);
output = _fusednbitrowwise_to_float_gpu_t<at::BFloat16>(
input, bit_rate, scale_bias_last);
break;
default:
TORCH_CHECK(false);
Expand Down
67 changes: 54 additions & 13 deletions fbgemm_gpu/src/quantize_ops/quantize_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,9 @@ Tensor _fusednbitrowwise_to_float_cpu(
return output;
}

Tensor _fusednbitrowwise_sbfront_to_float_cpu(
// Both float16 and bfloat16 are of same type uint16_t
template <typename output_t>
Tensor _fusednbitrowwise_sbfront_to_float_or_half_cpu(
const Tensor& input,
const int64_t bit_rate) {
TENSOR_ON_CPU(input);
Expand All @@ -165,15 +167,36 @@ Tensor _fusednbitrowwise_sbfront_to_float_cpu(
(ncols - 2 * sizeof(at::Half)) * num_elem_per_byte;

Tensor output;
output = at::empty(
{nrows, output_columns}, // 4 = sizeof(float)
input.options().dtype(at::kFloat));
if (std::is_same<output_t, float>::value) {
output = at::empty(
{nrows, output_columns}, // 4 = sizeof(float)
input.options().dtype(at::kFloat));
} else if (std::is_same<output_t, at::Half>::value) {
output = at::empty(
{nrows, output_columns}, // 2 = sizeof(half)
input.options().dtype(at::kHalf));
} else if (std::is_same<output_t, at::BFloat16>::value) {
output = at::empty(
{nrows, output_columns}, // 2 = sizeof(half)
input.options().dtype(at::kBFloat16));
} else {
TORCH_CHECK(
false,
"Unsupported output dtype for _fusednbitrowwise_sbfront_to_float_or_half_cpu");
}

float* output_data = static_cast<float*>(
using output_ty = std::conditional_t<
std::is_same<output_t, float>::value,
float,
fbgemm::float16>;
output_ty* output_data = static_cast<output_ty*>(
output.data_ptr()); // output.data_ptr<output_t>(); -> Yields
// unresolved data_ptr symbol.

fbgemm::FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef<float>(
constexpr bool is_float16_bf16 = std::is_same<output_t, at::BFloat16>::value;
fbgemm::FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef<
output_ty,
is_float16_bf16>(
bit_rate,
input.data_ptr<uint8_t>(),
nrows,
Expand Down Expand Up @@ -311,7 +334,7 @@ Tensor fusednbitrowwise_to_float_cpu(

/// @ingroup quantize-data-cpu
/// @brief Dequantize int4/int2 rows with scale and bias stored in the front
/// into float32.
/// into float32/float15/BFloat16.
/// @param input Tensor of int4/int2 rows with scale and bias stored in the
/// front.
/// @param bit_rate Bit rate of each element. Should be 4 or 2.
Expand All @@ -323,8 +346,25 @@ Tensor fusednbitrowwise_to_float_cpu(
/// purpose because its kernel is reference implementation and not optimized.
Tensor fusednbitrowwise_sbfront_to_float_cpu(
const Tensor& input,
const int64_t bit_rate) {
return _fusednbitrowwise_sbfront_to_float_cpu(input, bit_rate);
const int64_t bit_rate,
const int64_t output_dtype) {
SparseType output_sparse_dtype = static_cast<SparseType>(output_dtype);
switch (output_sparse_dtype) {
case SparseType::FP32:
return _fusednbitrowwise_sbfront_to_float_or_half_cpu<float>(
input, bit_rate);
break;
case SparseType::FP16:
return _fusednbitrowwise_sbfront_to_float_or_half_cpu<at::Half>(
input, bit_rate);
break;
case SparseType::BF16:
return _fusednbitrowwise_sbfront_to_float_or_half_cpu<at::BFloat16>(
input, bit_rate);
break;
default:
TORCH_CHECK(false);
}
}

/// @ingroup quantize-data-cpu
Expand All @@ -340,7 +380,8 @@ Tensor fusednbitrowwise_to_half_cpu(
Tensor fusednbitrowwise_to_float_or_half_cpu(
const Tensor& input,
const int64_t bit_rate,
const int64_t output_dtype) {
const int64_t output_dtype,
[[maybe_unused]] const bool scale_bias_last) {
Tensor output;

SparseType output_sparse_dtype = static_cast<SparseType>(output_dtype);
Expand Down Expand Up @@ -520,11 +561,11 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
"FusedNBitRowwiseQuantizedSBHalfToFloat(Tensor input, int bit_rate) -> Tensor");
m.def(
"FusedNBitRowwiseQuantizedSBHalfFrontToFloat(Tensor input, int bit_rate) -> Tensor");
"FusedNBitRowwiseQuantizedSBHalfFrontToFloatOrHalf(Tensor input, int bit_rate, int output_dtype) -> Tensor");
m.def(
"FusedNBitRowwiseQuantizedSBHalfToHalf(Tensor input, int bit_rate) -> Tensor");
m.def(
"FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf(Tensor input, int bit_rate, int output_dtype=0) -> Tensor");
"FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf(Tensor input, int bit_rate, int output_dtype=0, bool scale_bias_last=True) -> Tensor");
m.def(
"FloatToHFP8Quantized(Tensor input, int ebits, int exponent_bias, float max_pos) -> Tensor");
m.def(
Expand All @@ -542,7 +583,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {

TORCH_LIBRARY_IMPL(fbgemm, QuantizedCPU, m) {
DISPATCH_TO_QUANTIZED_CPU(
"FusedNBitRowwiseQuantizedSBHalfFrontToFloat",
"FusedNBitRowwiseQuantizedSBHalfFrontToFloatOrHalf",
fbgemm_gpu::fusednbitrowwise_sbfront_to_float_cpu);
}

Expand Down
3 changes: 2 additions & 1 deletion fbgemm_gpu/src/quantize_ops/quantize_ops_meta.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ Tensor FloatToFP8RowwiseQuantized_meta(const Tensor& input, bool forward) {
Tensor fusednbitrowwise_to_float_or_half_meta(
const Tensor& input,
const int64_t bit_rate,
const int64_t output_dtype) {
const int64_t output_dtype,
[[maybe_unused]] const bool scale_bias_last) {
const at::SymIntArrayRef input_sizes = input.sym_sizes();
const at::SymInt nrows = input_sizes[0];
// Here we want the number of bytes in a row
Expand Down
6 changes: 4 additions & 2 deletions fbgemm_gpu/test/tbe/inference/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,8 +351,10 @@ def execute_nbit_forward_( # noqa C901
f = torch.cat(fs, dim=0).view(-1, D)

if fc2.dtype == torch.quint4x2:
fc2_float = torch.ops.fbgemm.FusedNBitRowwiseQuantizedSBHalfFrontToFloat(
fc2.cpu(), bit_rate=4
fc2_float = (
torch.ops.fbgemm.FusedNBitRowwiseQuantizedSBHalfFrontToFloatOrHalf(
fc2.cpu(), bit_rate=4, output_dtype=0
)
)
else:
fc2_float = fc2.float()
Expand Down
11 changes: 10 additions & 1 deletion fbgemm_gpu/test/tbe/inference/failures_dict_fast.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
"fbgemm::FloatToHFP8Quantized": {},
"fbgemm::Fused8BitRowwiseQuantizedToFloat": {},
"fbgemm::Fused8BitRowwiseQuantizedToFloatOrHalf": {},
"fbgemm::FusedNBitRowwiseQuantizedSBHalfFrontToFloat": {},
"fbgemm::FusedNBitRowwiseQuantizedSBHalfFrontToFloatOrHalf": {},
"fbgemm::FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf": {},
"fbgemm::HFP8QuantizedToFloat": {},
"fbgemm::asynchronous_complete_cumsum": {},
"fbgemm::bounds_check_indices": {},
Expand Down Expand Up @@ -44,9 +45,17 @@
"comment": "",
"status": "xsuccess"
},
"NBitFowardTest.test_faketensor__test_nbit_forward_cpu_gpu_dequantize_parity": {
"comment": "this operator outputs torch.quint4x2 tensors which is not compatible with generate_opcheck_tests",
"status": "xfail"
},
"NBitFowardTest.test_faketensor__test_nbit_forward_cpu_seq_int4": {
"comment": "this operator outputs torch.quint4x2 tensors which is not compatible with generate_opcheck_tests",
"status": "xfail"
},
"NBitFowardTest.test_schema__test_nbit_forward_cpu_gpu_dequantize_parity": {
"comment": "this operator outputs torch.quint4x2 tensors which is not compatible with generate_opcheck_tests",
"status": "xfail"
}
},
"fbgemm::int_nbit_split_embedding_uvm_caching_codegen_lookup_function": {
Expand Down
Loading

0 comments on commit 867b7f7

Please sign in to comment.