Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch.compile] Add torch inductor pass for fusing silu_and_mul with subsequent scaled_fp8_quant operations #10867

Open
wants to merge 35 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
2e0031a
init
SageMoore Dec 2, 2024
8a957c7
remove backend format changes
SageMoore Dec 3, 2024
2913716
format
SageMoore Dec 3, 2024
11c6fae
move activation_quant_kernels to the quantization dir
SageMoore Dec 3, 2024
2dfecb5
added replacement unit test
SageMoore Dec 4, 2024
702fa46
added kernel unit test
SageMoore Dec 5, 2024
583ff4c
misc cleanup
SageMoore Dec 6, 2024
e5680f7
move activation quant fusion to its own pass
SageMoore Dec 6, 2024
4b775c4
update test
SageMoore Dec 6, 2024
d5ff865
format
SageMoore Dec 6, 2024
c970dec
format
SageMoore Dec 6, 2024
596c445
format
SageMoore Dec 6, 2024
7ab3e18
format
SageMoore Dec 6, 2024
d347431
format
SageMoore Dec 6, 2024
553d99c
format
SageMoore Dec 6, 2024
774559d
format
SageMoore Dec 6, 2024
e2fda7f
format
SageMoore Dec 6, 2024
6915fa2
minor comment fix
SageMoore Dec 9, 2024
6d4b8d0
minor updates
SageMoore Dec 9, 2024
6b631b0
fix fix-functionalization
SageMoore Dec 12, 2024
5b78d80
add opcheck test for fused op
SageMoore Dec 13, 2024
391eea5
fix fix_functionalization tests
SageMoore Dec 13, 2024
546b411
Merge branch 'main' of https://github.com/neuralmagic/vllm into sage/…
SageMoore Dec 13, 2024
0d79c17
fix fix_functionalization again
SageMoore Dec 13, 2024
1041529
Merge branch 'main' of https://github.com/neuralmagic/vllm into sage/…
SageMoore Dec 13, 2024
3198f64
format
SageMoore Dec 13, 2024
58111a9
fixup includes
SageMoore Dec 14, 2024
9a18085
refactor math.hpp
SageMoore Dec 16, 2024
5ae5fe0
Merge branch 'main' of https://github.com/neuralmagic/vllm into sage/…
SageMoore Dec 17, 2024
e051b24
fix amd build
SageMoore Dec 18, 2024
bfdac35
Merge branch 'main' of https://github.com/neuralmagic/vllm into sage/…
SageMoore Dec 19, 2024
8514b0e
review comments and format
SageMoore Dec 19, 2024
ec1290a
fix amd build
SageMoore Dec 19, 2024
008b725
review comments and format
SageMoore Dec 20, 2024
4a0ac7e
minor test fix
SageMoore Dec 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ set(VLLM_EXT_SRC
"csrc/quantization/fp8/common.cu"
"csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/quantization/activation_kernels.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/prepare_inputs/advance_step.cu"
"csrc/torch_bindings.cpp")
Expand Down
25 changes: 24 additions & 1 deletion csrc/core/math.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,30 @@
#pragma once

#include <climits>
#include <iostream>

inline uint32_t next_pow_2(uint32_t const num) {
if (num <= 1) return num;
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
}
}

template <typename A, typename B>
static inline constexpr auto div_ceil(A a, B b) {
return (a + b - 1) / b;
}

// Round a down to the next multiple of b. The caller is responsible for making
// sure that b is non-zero
template <typename T>
inline constexpr T round_to_previous_multiple_of(T a, T b)
{
return a % b == 0 ? a : (a / b) * b;
}

// Round a up to the next multiple of b. The caller is responsible for making
// sure that b is non-zero
template <typename T>
inline constexpr T round_to_next_multiple_of(T a, T b)
{
return a % b == 0 ? a : ((a / b) + 1) * b;
}
3 changes: 3 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,

void silu_and_mul(torch::Tensor& out, torch::Tensor& input);

void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& scale);

void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);

void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);
Expand Down
120 changes: 120 additions & 0 deletions csrc/quantization/activation_kernels.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>

#include <cmath>
#include "fp8/common.cuh"
#include "../core/math.hpp"
#include "../cuda_compat.h"
#include "../dispatch_utils.h"

namespace vllm {

template <typename T>
__device__ __forceinline__ T silu_kernel(const T& x) {
// x * sigmoid(x)
return (T)(((float)x) / (1.0f + expf((float)-x)));
}

__device__ __forceinline__ FP8_TYPE
scaled_fp8_conversion(float const val, float const inverted_scale) {
SageMoore marked this conversation as resolved.
Show resolved Hide resolved
float x = val * inverted_scale;
float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
return static_cast<FP8_TYPE>(r);
}

// Activation and gating kernel template.
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
__global__ void act_and_mul_quant_kernel(
SageMoore marked this conversation as resolved.
Show resolved Hide resolved
FP8_TYPE* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d]
const float* scale, const int d) {
const int32_t blocks_per_token = gridDim.y;

const int32_t elems_per_128bit_load = (128 / 8) / sizeof(scalar_t);

// We don't expect the hidden dimension to exceed 32 bits so int32 should
// be safe here.
const int32_t tgt_elems_per_block = div_ceil(d, blocks_per_token);
const int32_t elems_per_block =
round_to_next_multiple_of(tgt_elems_per_block, elems_per_128bit_load);
const int32_t block_start = blockIdx.y * elems_per_block;
SageMoore marked this conversation as resolved.
Show resolved Hide resolved
int32_t block_end = block_start + elems_per_block;
block_end = block_end > d ? d : block_end;

// token_idx is 64 bit to prevent 32 bit overflow when the number of tokens
// is very large
const int64_t token_idx = blockIdx.x;
const scalar_t* __restrict__ x_ptr = input + token_idx * 2 * d;
const scalar_t* __restrict__ y_ptr = input + token_idx * 2 * d + d;
FP8_TYPE* __restrict__ out_ptr = out + token_idx * d;
SageMoore marked this conversation as resolved.
Show resolved Hide resolved
SageMoore marked this conversation as resolved.
Show resolved Hide resolved

// 128-bit vectorized code
const int32_t vec_loop_end =
round_to_previous_multiple_of(elems_per_128bit_load, block_end);
const int32_t vec_end_idx = vec_loop_end / elems_per_128bit_load;
const int32_t vec_start_idx = block_start / elems_per_128bit_load;

const int4* __restrict__ x_128bit_ptr = reinterpret_cast<const int4*>(x_ptr);
const int4* __restrict__ y_128bit_ptr = reinterpret_cast<const int4*>(y_ptr);
int2* __restrict__ out_128bit_ptr = reinterpret_cast<int2*>(out_ptr);

float inverted_scale = 1 / *scale;
#pragma unroll
for (int32_t vec_idx = vec_start_idx + threadIdx.x; vec_idx < vec_end_idx;
vec_idx += blockDim.x) {
const int4 x_128bit = VLLM_LDG(&x_128bit_ptr[vec_idx]);
const int4 y_128bit = VLLM_LDG(&y_128bit_ptr[vec_idx]);
using scalar_128bit_vec_t = std::array<scalar_t, elems_per_128bit_load>;
using scalar_64bit_vec_t = std::array<FP8_TYPE, elems_per_128bit_load>;

scalar_64bit_vec_t out_vec;
const auto x_vec = reinterpret_cast<scalar_128bit_vec_t const&>(x_128bit);
const auto y_vec = reinterpret_cast<scalar_128bit_vec_t const&>(y_128bit);

#pragma unroll
for (int i = 0; i < elems_per_128bit_load; i++) {
out_vec[i] =
scaled_fp8_conversion(ACT_FN(x_vec[i]) * y_vec[i], inverted_scale);
}

out_128bit_ptr[vec_idx] = reinterpret_cast<const int2&>(out_vec);
}

// Scalar cleanup code
if (block_end > vec_loop_end) {
for (int64_t idx = vec_loop_end + threadIdx.x; idx < block_end;
idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&x_ptr[idx]);
const scalar_t y = VLLM_LDG(&y_ptr[idx]);
out_ptr[idx] = scaled_fp8_conversion(ACT_FN(x) * y, inverted_scale);
}
}
}
} // namespace vllm

// Launch activation, gating, and quantize kernel.
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason this needs a macro?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just copied what the existing act_and_mul kernel does. This allows us to just drop in kernels for the other activation functions. I'm in favor of keeping it.

int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens, num_tokens > 16 ? num_tokens > 32 ? 1 : 2 : 4); \
dim3 block(std::min(d, 512)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "act_and_mul_kernel", [&] { \
vllm::act_and_mul_quant_kernel<scalar_t, KERNEL<scalar_t>> \
<<<grid, block, 0, stream>>>(out.data_ptr<FP8_TYPE>(), \
input.data_ptr<scalar_t>(), \
scale.data_ptr<float>(), d); \
});

void silu_and_mul_quant(torch::Tensor& out, // [..., d]
torch::Tensor& input, // [..., 2 * d]
torch::Tensor& scale) {
TORCH_CHECK(out.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(input.dtype() == torch::kFloat16 ||
input.dtype() == torch::kBFloat16);
TORCH_CHECK(input.size(-1) % 2 == 0);
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
}
5 changes: 4 additions & 1 deletion csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {

// Activation ops
// Activation function used in SwiGLU.
ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
ops.def("silu_and_mul(Tensor! result, Tensor input) -> ()");
ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);

ops.def(
SageMoore marked this conversation as resolved.
Show resolved Hide resolved
"silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()");
ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant);
// Activation function used in GeGLU with `none` approximation.
ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);
Expand Down
25 changes: 18 additions & 7 deletions tests/compile/test_functionalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import vllm.envs as envs
from vllm import LLM, SamplingParams
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey,
kFp8DynamicTokenSym, kFp8StaticTensorSym)
Expand All @@ -15,18 +16,17 @@
OPS_IN_MODEL = [
torch.ops._C.rotary_embedding.default,
torch.ops._C.fused_add_rms_norm.default,
torch.ops._C.silu_and_mul.default,
]

RMS_OP = torch.ops._C.rms_norm.default

RMS_QUANT_OPS = {
SILU_MUL_OP = torch.ops._C.silu_and_mul.default

SILU_MUL_QUANT_OPS = {
SageMoore marked this conversation as resolved.
Show resolved Hide resolved
"static_fp8": [
torch.ops._C.rms_norm_static_fp8_quant.default,
torch.ops._C.fused_add_rms_norm_static_fp8_quant.default
torch.ops._C.silu_and_mul_quant.default,
],
}

prompts = [
"Hello, my name is",
"The president of the United States is",
Expand All @@ -51,8 +51,13 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
enable_reshape=True)
reshape_pass = RedundantReshapesPass(config)
fusion_pass = FusionPass.instance(config)
act_quant_fusion_pass = ActivationQuantFusionPass.instance(config)

passes = [reshape_pass, fusion_pass] if do_fusion else [reshape_pass]
passes = [
reshape_pass,
fusion_pass,
act_quant_fusion_pass,
] if do_fusion else [reshape_pass]
func_pass = FixFunctionalizationPass(config)
backend_func = TestBackend(*passes, func_pass)
backend_no_func = TestBackend(*passes)
Expand All @@ -75,6 +80,7 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
model_runner.model = torch.compile(orig_model,
fullgraph=True,
backend=backend_no_func)

gen_no_func = llm.generate(prompts, sampling_params)

for output_func, output_no_func in zip(gen_func, gen_no_func):
Expand All @@ -84,7 +90,12 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
# and replaced by fused quantized ops in RMS_QUANT_OPS.
rms_ops = [FUSED_OPS[(quant_key, True)], FUSED_OPS[(quant_key, False)]
] if do_fusion else [RMS_OP]
ops = OPS_IN_MODEL + rms_ops
silu_mul_ops = SILU_MUL_QUANT_OPS[
"static_fp8"] if do_fusion and quant_key == kFp8StaticTensorSym else [
SILU_MUL_OP
]

ops = OPS_IN_MODEL + rms_ops + silu_mul_ops

for op in ops:
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
Expand Down
73 changes: 73 additions & 0 deletions tests/compile/test_silu_mul_quant_fusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import pytest
import torch

import vllm.envs as envs
from vllm._custom_ops import scaled_fp8_quant
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
from vllm.compilation.fusion import find_auto_fn, find_auto_fn_maybe
from vllm.compilation.reshapes import RedundantReshapesPass
from vllm.config import CompilationConfig
from vllm.model_executor.layers.activation import SiluAndMul

from .backend import TestBackend


class TestModel(torch.nn.Module):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.silu_and_mul = SiluAndMul()
self.scale = torch.rand(1, dtype=torch.float32)

def forward(self, x):
y = self.silu_and_mul(x)
x2 = scaled_fp8_quant(y, self.scale)
return x2


@pytest.mark.parametrize("num_tokens", [256])
@pytest.mark.parametrize("hidden_size", [64])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda",
reason="Only test on CUDA")
def test_fusion_silu_and_mul_quant(num_tokens, hidden_size):
torch.set_default_device("cuda")
torch.set_default_dtype(torch.float16)

# Reshape pass is needed for the fusion pass to work
config = CompilationConfig.PassConfig(enable_fusion=True,
enable_reshape=True)
reshape_pass = RedundantReshapesPass(config)
fusion_pass = ActivationQuantFusionPass.instance(config)

backend = TestBackend(reshape_pass, fusion_pass)
model = TestModel()

# First dimension dynamic
x = torch.rand(num_tokens, hidden_size)
torch._dynamo.mark_dynamic(x, 0)

result = model(x)

model2 = torch.compile(model, backend=backend)
result2 = model2(x)

# Check that it gives the same answer
torch.testing.assert_close(result[0].to(dtype=torch.float16),
result2[0].to(dtype=torch.float16),
atol=1e-3,
rtol=1e-3)

# Check substitution worked
pre_nodes = backend.graph_pre_pass.nodes
post_nodes = backend.graph_post_pass.nodes

silu_and_mul_quant = torch.ops._C.silu_and_mul_quant.default
fp8_quant = torch.ops._C.static_scaled_fp8_quant.default

# In pre-nodes, fp8 quant should be present and fused kernels should not
assert find_auto_fn_maybe(pre_nodes, silu_and_mul_quant) is None
find_auto_fn(pre_nodes, fp8_quant)

# In post-nodes, fused kernels should be present and fp8 quant should not
find_auto_fn(post_nodes, silu_and_mul_quant)
assert find_auto_fn_maybe(post_nodes, fp8_quant) is None
68 changes: 68 additions & 0 deletions tests/kernels/test_fused_quant_activation.py
SageMoore marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import pytest
import torch

import vllm._custom_ops as ops
from tests.kernels.utils import opcheck
from vllm.model_executor.layers.activation import SiluAndMul

DTYPES = [torch.bfloat16, torch.float16]
QUANT_DTYPES = [torch.float8_e4m3fn]
NUM_TOKENS = [1, 17, 86, 1234, 3045] # Arbitrary values for testing
HIDDEN_SIZES = [16, 48, 128, 1562, 4096] # Arbitrary values for testing
SEEDS = [0]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]


def ref_impl(silu_and_mul: SiluAndMul, x: torch.Tensor,
scale: torch.Tensor) -> torch.Tensor:
silu_and_mul_out = silu_and_mul.forward_native(x)
out, scales = ops.scaled_fp8_quant(silu_and_mul_out, scale)
return out


def ops_impl(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
out_shape = (x.shape[0], x.shape[1] // 2)
out = torch.empty(out_shape,
dtype=torch.torch.float8_e4m3fn,
device=x.device)
torch.ops._C.silu_and_mul_quant(out, x, scale)
return out


@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("quant_dtype", QUANT_DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_silu_and_mul(
num_tokens: int,
hidden_size: int,
dtype: torch.dtype,
quant_dtype: torch.dtype,
seed: int,
device: str,
) -> None:
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device(device)

layer = SiluAndMul()

# Make inputs
scale = (torch.randn((1), device=device, dtype=torch.float32))
x = torch.randn(num_tokens, hidden_size, dtype=dtype)

ref_out = ref_impl(layer, x, scale)
ops_out = ops_impl(x, scale)

assert ref_out.dtype == quant_dtype
assert ops_out.dtype == quant_dtype
assert ref_out.shape == ops_out.shape
assert torch.allclose(ref_out.to(dtype=torch.float32),
ops_out.to(dtype=torch.float32))
opcheck(torch.ops._C.silu_and_mul_quant, (ops_out, x, scale))
Loading
Loading