Skip to content

Commit

Permalink
[Kernel] Replaced blockReduce[...] functions with cub::BlockReduce (
Browse files Browse the repository at this point in the history
vllm-project#7233)

Co-authored-by: Michael Goin <[email protected]>
  • Loading branch information
ProExpertProg and mgoin authored Aug 22, 2024
1 parent 3e051e6 commit 4d13eae
Show file tree
Hide file tree
Showing 8 changed files with 237 additions and 116 deletions.
4 changes: 2 additions & 2 deletions .buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.409
value: 0.419
- name: "exact_match,flexible-extract"
value: 0.406
value: 0.416
limit: 1000
num_fewshot: 5
89 changes: 89 additions & 0 deletions benchmarks/kernels/benchmark_layernorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import random
import time

import torch

from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser


@torch.inference_mode()
def main(num_tokens: int,
hidden_size: int,
add_residual: bool,
dtype: torch.dtype,
seed: int = 0,
do_profile: bool = False,
num_warmup_iters: int = 5,
num_iters: int = 100) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device("cuda")

layer = RMSNorm(hidden_size).to(dtype=dtype)
layer.weight.data.normal_(mean=1.0, std=0.1)
scale = 1 / (2 * hidden_size)
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
x *= scale
residual = torch.randn_like(x) * scale if add_residual else None

def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
torch.cuda.synchronize()
if profile:
torch.cuda.cudart().cudaProfilerStart()
start_time = time.perf_counter()

for _ in range(num_iters):
layer(x, residual)
torch.cuda.synchronize()

end_time = time.perf_counter()
if profile:
torch.cuda.cudart().cudaProfilerStart()
return (end_time - start_time) / num_iters

# Warmup.
print("Warming up...")
run_benchmark = run_cuda_benchmark
run_benchmark(num_iters=num_warmup_iters, profile=False)

# Benchmark.
if do_profile:
latency = run_benchmark(num_iters=1, profile=True)
else:
latency = run_benchmark(num_iters=num_iters, profile=False)
print(f"Kernel running time: {latency * 1000000:.3f} us")


if __name__ == '__main__':
parser = FlexibleArgumentParser(
description="Benchmark the layernorm kernel.")
parser.add_argument("--num-tokens", type=int, default=4096)
parser.add_argument("--hidden-size", type=int, default=8192)
parser.add_argument("--add-residual", action="store_true")
parser.add_argument("--dtype",
type=str,
choices=["half", "bfloat16", "float"],
default="half")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--profile", action="store_true")
parser.add_argument("--num-warmup-iters", type=int, default=5)
parser.add_argument("--num-iters",
type=int,
default=100,
help="Number of benchmark iterations. "
"If --profile is set, this number is ignored")

args = parser.parse_args()
print(args)

main(num_tokens=args.num_tokens,
hidden_size=args.hidden_size,
add_residual=args.add_residual,
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
seed=args.seed,
do_profile=args.profile,
num_warmup_iters=args.num_warmup_iters,
num_iters=args.num_iters)
103 changes: 103 additions & 0 deletions benchmarks/kernels/benchmark_quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import random
import time

import torch

from vllm import _custom_ops as ops
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser


@torch.inference_mode()
def main(num_tokens: int,
hidden_size: int,
static_scale: bool,
quant_dtype: torch.dtype,
dtype: torch.dtype,
seed: int = 0,
do_profile: bool = False,
num_warmup_iters: int = 5,
num_iters: int = 100) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device("cuda")

x = torch.randn(num_tokens, hidden_size, dtype=dtype)
scale = torch.randn(1, 1, dtype=torch.float32) if static_scale else None

def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
torch.cuda.synchronize()
if profile:
torch.cuda.cudart().cudaProfilerStart()
start_time = time.perf_counter()

for _ in range(num_iters):
if quant_dtype == torch.int8:
ops.scaled_int8_quant(x, scale)
else:
ops.scaled_fp8_quant(x, scale)
torch.cuda.synchronize()

end_time = time.perf_counter()
if profile:
torch.cuda.cudart().cudaProfilerStart()
return (end_time - start_time) / num_iters

# Warmup.
print("Warming up...")
run_benchmark = run_cuda_benchmark
run_benchmark(num_iters=num_warmup_iters, profile=False)

# Benchmark.
if do_profile:
latency = run_benchmark(num_iters=1, profile=True)
else:
latency = run_benchmark(num_iters=num_iters, profile=False)
print(f"Kernel running time: {latency * 1000000:.3f} us")


if __name__ == '__main__':

def to_torch_dtype(dt):
if dt == "int8":
return torch.int8
if dt == "fp8":
return torch.float8_e4m3fn
raise ValueError(f"Unsupported dtype: {dt}")

parser = FlexibleArgumentParser(
description="Benchmark the quantization (fp8 or int8) kernel.")
parser.add_argument("--num-tokens", type=int, default=4096)
parser.add_argument("--hidden-size", type=int, default=8192)
parser.add_argument("--static-scale", action="store_true")
parser.add_argument("--quant-dtype",
type=str,
choices=["fp8", "int8"],
default="int8")
parser.add_argument("--dtype",
type=str,
choices=["half", "bfloat16", "float"],
default="half")

parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--profile", action="store_true")
parser.add_argument("--num-warmup-iters", type=int, default=5)
parser.add_argument("--num-iters",
type=int,
default=100,
help="Number of benchmark iterations. "
"If --profile is set, this number is ignored")

args = parser.parse_args()
print(args)

main(num_tokens=args.num_tokens,
hidden_size=args.hidden_size,
static_scale=args.static_scale,
quant_dtype=to_torch_dtype(args.quant_dtype),
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
seed=args.seed,
do_profile=args.profile,
num_warmup_iters=args.num_warmup_iters,
num_iters=args.num_iters)
33 changes: 19 additions & 14 deletions csrc/layernorm_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@
#include <c10/cuda/CUDAGuard.h>

#include "dispatch_utils.h"
#include "reduction_utils.cuh"
#ifndef USE_ROCM
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cub/util_type.cuh>
#include <cub/cub.cuh>
#else
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
#include <hipcub/util_type.hpp>
#include <hipcub/hipcub.hpp>

using __nv_bfloat16 = __hip_bfloat16;
using __nv_bfloat162 = __hip_bfloat162;
Expand All @@ -31,7 +34,11 @@ __global__ void rms_norm_kernel(
const float x = (float)input[blockIdx.x * hidden_size + idx];
variance += x * x;
}
variance = blockReduceSum<float>(variance);

using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);

if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
Expand Down Expand Up @@ -228,12 +235,11 @@ fused_add_rms_norm_kernel(
variance += temp.sum_squares();
residual_v[id] = temp;
}
/* Keep the following if-else block in sync with the
calculation of max_block_size in fused_add_rms_norm */
if (num_tokens < 256) {
variance = blockReduceSum<float, 1024>(variance);
} else
variance = blockReduceSum<float, 256>(variance);

using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);

if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
Expand Down Expand Up @@ -268,12 +274,11 @@ fused_add_rms_norm_kernel(
variance += x * x;
residual[blockIdx.x * hidden_size + idx] = z;
}
/* Keep the following if-else block in sync with the
calculation of max_block_size in fused_add_rms_norm */
if (num_tokens < 256) {
variance = blockReduceSum<float, 1024>(variance);
} else
variance = blockReduceSum<float, 256>(variance);

using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);

if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
Expand Down
14 changes: 12 additions & 2 deletions csrc/quantization/compressed_tensors/int8_quant_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@
#include <cmath>

#include "../../dispatch_utils.h"
#include "../../reduction_utils.cuh"

#ifndef USE_ROCM
#include <cub/util_type.cuh>
#include <cub/cub.cuh>
#else
#include <hipcub/util_type.hpp>
#include <hipcub/hipcub.hpp>
#endif

static inline __device__ int8_t float_to_int8_rn(float x) {
#ifdef USE_ROCM
Expand Down Expand Up @@ -55,7 +62,10 @@ __global__ void dynamic_scaled_int8_quant_kernel(
absmax_val = val > absmax_val ? val : absmax_val;
}

float const block_absmax_val_maybe = blockReduceMax(absmax_val);
using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStorage;
float const block_absmax_val_maybe =
BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x);
__shared__ float block_absmax_val;
if (tid == 0) {
block_absmax_val = block_absmax_val_maybe;
Expand Down
13 changes: 11 additions & 2 deletions csrc/quantization/fp8/common.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@
#include "cuda_compat.h"
#include "dispatch_utils.h"

#include "../../reduction_utils.cuh"
#ifndef USE_ROCM
#include <cub/util_type.cuh>
#include <cub/cub.cuh>
#else
#include <hipcub/util_type.hpp>
#include <hipcub/hipcub.hpp>
#endif

#ifndef USE_ROCM
using FP8_TYPE = c10::Float8_e4m3fn;
Expand Down Expand Up @@ -215,7 +221,10 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
}
}

float const block_absmax_val_maybe = blockReduceMax(absmax_val);
using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStorage;
float const block_absmax_val_maybe =
BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x);
__shared__ float token_scale;
if (tid == 0) {
if (scale_ub) {
Expand Down
Loading

0 comments on commit 4d13eae

Please sign in to comment.