diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml index c457468902c98..0424586598391 100644 --- a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml @@ -4,8 +4,8 @@ tasks: - name: "gsm8k" metrics: - name: "exact_match,strict-match" - value: 0.409 + value: 0.419 - name: "exact_match,flexible-extract" - value: 0.406 + value: 0.416 limit: 1000 num_fewshot: 5 diff --git a/benchmarks/kernels/benchmark_layernorm.py b/benchmarks/kernels/benchmark_layernorm.py new file mode 100644 index 0000000000000..4947fda02e1cc --- /dev/null +++ b/benchmarks/kernels/benchmark_layernorm.py @@ -0,0 +1,89 @@ +import random +import time + +import torch + +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser + + +@torch.inference_mode() +def main(num_tokens: int, + hidden_size: int, + add_residual: bool, + dtype: torch.dtype, + seed: int = 0, + do_profile: bool = False, + num_warmup_iters: int = 5, + num_iters: int = 100) -> None: + random.seed(seed) + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.set_default_device("cuda") + + layer = RMSNorm(hidden_size).to(dtype=dtype) + layer.weight.data.normal_(mean=1.0, std=0.1) + scale = 1 / (2 * hidden_size) + x = torch.randn(num_tokens, hidden_size, dtype=dtype) + x *= scale + residual = torch.randn_like(x) * scale if add_residual else None + + def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: + torch.cuda.synchronize() + if profile: + torch.cuda.cudart().cudaProfilerStart() + start_time = time.perf_counter() + + for _ in range(num_iters): + layer(x, residual) + torch.cuda.synchronize() + + end_time = time.perf_counter() + if profile: + torch.cuda.cudart().cudaProfilerStart() + return (end_time - start_time) / num_iters + + # Warmup. + print("Warming up...") + run_benchmark = run_cuda_benchmark + run_benchmark(num_iters=num_warmup_iters, profile=False) + + # Benchmark. + if do_profile: + latency = run_benchmark(num_iters=1, profile=True) + else: + latency = run_benchmark(num_iters=num_iters, profile=False) + print(f"Kernel running time: {latency * 1000000:.3f} us") + + +if __name__ == '__main__': + parser = FlexibleArgumentParser( + description="Benchmark the layernorm kernel.") + parser.add_argument("--num-tokens", type=int, default=4096) + parser.add_argument("--hidden-size", type=int, default=8192) + parser.add_argument("--add-residual", action="store_true") + parser.add_argument("--dtype", + type=str, + choices=["half", "bfloat16", "float"], + default="half") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--profile", action="store_true") + parser.add_argument("--num-warmup-iters", type=int, default=5) + parser.add_argument("--num-iters", + type=int, + default=100, + help="Number of benchmark iterations. " + "If --profile is set, this number is ignored") + + args = parser.parse_args() + print(args) + + main(num_tokens=args.num_tokens, + hidden_size=args.hidden_size, + add_residual=args.add_residual, + dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype], + seed=args.seed, + do_profile=args.profile, + num_warmup_iters=args.num_warmup_iters, + num_iters=args.num_iters) diff --git a/benchmarks/kernels/benchmark_quant.py b/benchmarks/kernels/benchmark_quant.py new file mode 100644 index 0000000000000..4c1a7b26213a5 --- /dev/null +++ b/benchmarks/kernels/benchmark_quant.py @@ -0,0 +1,103 @@ +import random +import time + +import torch + +from vllm import _custom_ops as ops +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser + + +@torch.inference_mode() +def main(num_tokens: int, + hidden_size: int, + static_scale: bool, + quant_dtype: torch.dtype, + dtype: torch.dtype, + seed: int = 0, + do_profile: bool = False, + num_warmup_iters: int = 5, + num_iters: int = 100) -> None: + random.seed(seed) + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.set_default_device("cuda") + + x = torch.randn(num_tokens, hidden_size, dtype=dtype) + scale = torch.randn(1, 1, dtype=torch.float32) if static_scale else None + + def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: + torch.cuda.synchronize() + if profile: + torch.cuda.cudart().cudaProfilerStart() + start_time = time.perf_counter() + + for _ in range(num_iters): + if quant_dtype == torch.int8: + ops.scaled_int8_quant(x, scale) + else: + ops.scaled_fp8_quant(x, scale) + torch.cuda.synchronize() + + end_time = time.perf_counter() + if profile: + torch.cuda.cudart().cudaProfilerStart() + return (end_time - start_time) / num_iters + + # Warmup. + print("Warming up...") + run_benchmark = run_cuda_benchmark + run_benchmark(num_iters=num_warmup_iters, profile=False) + + # Benchmark. + if do_profile: + latency = run_benchmark(num_iters=1, profile=True) + else: + latency = run_benchmark(num_iters=num_iters, profile=False) + print(f"Kernel running time: {latency * 1000000:.3f} us") + + +if __name__ == '__main__': + + def to_torch_dtype(dt): + if dt == "int8": + return torch.int8 + if dt == "fp8": + return torch.float8_e4m3fn + raise ValueError(f"Unsupported dtype: {dt}") + + parser = FlexibleArgumentParser( + description="Benchmark the quantization (fp8 or int8) kernel.") + parser.add_argument("--num-tokens", type=int, default=4096) + parser.add_argument("--hidden-size", type=int, default=8192) + parser.add_argument("--static-scale", action="store_true") + parser.add_argument("--quant-dtype", + type=str, + choices=["fp8", "int8"], + default="int8") + parser.add_argument("--dtype", + type=str, + choices=["half", "bfloat16", "float"], + default="half") + + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--profile", action="store_true") + parser.add_argument("--num-warmup-iters", type=int, default=5) + parser.add_argument("--num-iters", + type=int, + default=100, + help="Number of benchmark iterations. " + "If --profile is set, this number is ignored") + + args = parser.parse_args() + print(args) + + main(num_tokens=args.num_tokens, + hidden_size=args.hidden_size, + static_scale=args.static_scale, + quant_dtype=to_torch_dtype(args.quant_dtype), + dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype], + seed=args.seed, + do_profile=args.profile, + num_warmup_iters=args.num_warmup_iters, + num_iters=args.num_iters) diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index ca1c04bd880d9..7a7a25d2173d2 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -3,13 +3,16 @@ #include #include "dispatch_utils.h" -#include "reduction_utils.cuh" #ifndef USE_ROCM #include #include + #include + #include #else #include #include + #include + #include using __nv_bfloat16 = __hip_bfloat16; using __nv_bfloat162 = __hip_bfloat162; @@ -31,7 +34,11 @@ __global__ void rms_norm_kernel( const float x = (float)input[blockIdx.x * hidden_size + idx]; variance += x * x; } - variance = blockReduceSum(variance); + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); } @@ -228,12 +235,11 @@ fused_add_rms_norm_kernel( variance += temp.sum_squares(); residual_v[id] = temp; } - /* Keep the following if-else block in sync with the - calculation of max_block_size in fused_add_rms_norm */ - if (num_tokens < 256) { - variance = blockReduceSum(variance); - } else - variance = blockReduceSum(variance); + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); } @@ -268,12 +274,11 @@ fused_add_rms_norm_kernel( variance += x * x; residual[blockIdx.x * hidden_size + idx] = z; } - /* Keep the following if-else block in sync with the - calculation of max_block_size in fused_add_rms_norm */ - if (num_tokens < 256) { - variance = blockReduceSum(variance); - } else - variance = blockReduceSum(variance); + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); } diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index aa9511daa2772..616fc149760e5 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -3,7 +3,14 @@ #include #include "../../dispatch_utils.h" -#include "../../reduction_utils.cuh" + +#ifndef USE_ROCM + #include + #include +#else + #include + #include +#endif static inline __device__ int8_t float_to_int8_rn(float x) { #ifdef USE_ROCM @@ -55,7 +62,10 @@ __global__ void dynamic_scaled_int8_quant_kernel( absmax_val = val > absmax_val ? val : absmax_val; } - float const block_absmax_val_maybe = blockReduceMax(absmax_val); + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStorage; + float const block_absmax_val_maybe = + BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x); __shared__ float block_absmax_val; if (tid == 0) { block_absmax_val = block_absmax_val_maybe; diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/fp8/common.cu index 3f77c76ae7ec4..7e23f92257769 100644 --- a/csrc/quantization/fp8/common.cu +++ b/csrc/quantization/fp8/common.cu @@ -7,7 +7,13 @@ #include "cuda_compat.h" #include "dispatch_utils.h" -#include "../../reduction_utils.cuh" +#ifndef USE_ROCM + #include + #include +#else + #include + #include +#endif #ifndef USE_ROCM using FP8_TYPE = c10::Float8_e4m3fn; @@ -215,7 +221,10 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel( } } - float const block_absmax_val_maybe = blockReduceMax(absmax_val); + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStorage; + float const block_absmax_val_maybe = + BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x); __shared__ float token_scale; if (tid == 0) { if (scale_ub) { diff --git a/csrc/reduction_utils.cuh b/csrc/reduction_utils.cuh deleted file mode 100644 index 08063356012b8..0000000000000 --- a/csrc/reduction_utils.cuh +++ /dev/null @@ -1,95 +0,0 @@ -/* - * Adapted from - * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh - * Copyright (c) 2023, The vLLM team. - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "cuda_compat.h" - -namespace vllm { - -namespace detail { - -template -__inline__ __device__ T _max(T a, T b) { - return max(a, b); -} - -template -__inline__ __device__ T _sum(T a, T b) { - return a + b; -} - -} // namespace detail - -template -using ReduceFnType = T (*)(T, T); - -// Helper function to return the next largest power of 2 -static constexpr int _nextPow2(unsigned int num) { - if (num <= 1) return num; - return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); -} - -template -__inline__ __device__ T warpReduce(T val, ReduceFnType fn) { - static_assert(numLanes > 0 && (numLanes & (numLanes - 1)) == 0, - "numLanes is not a positive power of 2!"); - static_assert(numLanes <= WARP_SIZE); -#pragma unroll - for (int mask = numLanes >> 1; mask > 0; mask >>= 1) - val = fn(val, VLLM_SHFL_XOR_SYNC(val, mask)); - - return val; -} - -template -__inline__ __device__ T blockReduce(T val, ReduceFnType fn) { - static_assert(maxBlockSize <= 1024); - if constexpr (maxBlockSize > WARP_SIZE) { - val = warpReduce(val, fn); - // Calculates max number of lanes that need to participate in the last - // warpReduce - constexpr int maxActiveLanes = (maxBlockSize + WARP_SIZE - 1) / WARP_SIZE; - static __shared__ T shared[maxActiveLanes]; - int lane = threadIdx.x % WARP_SIZE; - int wid = threadIdx.x / WARP_SIZE; - if (lane == 0) shared[wid] = val; - - __syncthreads(); - - val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane] - : (T)(0.0f); - val = warpReduce(val, fn); - } else { - // A single warpReduce is equal to blockReduce - val = warpReduce(val, fn); - } - return val; -} - -template -__inline__ __device__ T blockReduceMax(T val) { - return blockReduce(val, detail::_max); -} - -template -__inline__ __device__ T blockReduceSum(T val) { - return blockReduce(val, detail::_sum); -} - -} // namespace vllm diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py index fcc444842213a..9c6364ecc6792 100644 --- a/tests/basic_correctness/test_chunked_prefill.py +++ b/tests/basic_correctness/test_chunked_prefill.py @@ -83,7 +83,7 @@ def test_models( for m in E4M3_KV_MODELS]) # Due to low-precision numerical divergence, we only test logprob of 4 tokens @pytest.mark.parametrize("max_tokens", [4]) -@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16]) +@pytest.mark.parametrize("chunked_prefill_token_size", [4, 16]) @pytest.mark.parametrize("enforce_eager", [False, True]) # NOTE: Increasing this in this suite will fail CI because we currently cannot # reset distributed env properly. Use a value > 1 just when you test.