-
-
Notifications
You must be signed in to change notification settings - Fork 4.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
2 changed files
with
121 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from copy import deepcopy | ||
from typing import Callable | ||
|
||
import torch | ||
|
||
|
||
class TestBackend(): | ||
""" | ||
This class provides a simple Inductor backend that can be used for testing. | ||
It takes a list of custom passes and runs them after Inductor's passes. | ||
It also saves the graph before and after the custom passes for inspection. | ||
""" | ||
|
||
def __init__(self, *args: Callable[[torch.fx.Graph], None]): | ||
self.custom_passes = args | ||
from torch._inductor import config | ||
self.current_config = config.shallow_copy_dict() | ||
self.current_config['post_grad_custom_post_pass'] = self.post_pass | ||
|
||
def __call__(self, graph: torch.fx.GraphModule, example_inputs): | ||
from torch._inductor.compile_fx import compile_fx | ||
return compile_fx(graph, | ||
example_inputs, | ||
config_patches=self.current_config) | ||
|
||
def post_pass(self, graph: torch.fx.Graph): | ||
self.graph_pre_pass = deepcopy(graph) | ||
for pass_ in self.custom_passes: | ||
pass_(graph) | ||
|
||
self.graph_post_pass = deepcopy(graph) | ||
# assign by reference, will reflect the final state of the graph | ||
self.final_graph = graph |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import pytest | ||
import torch | ||
from compressed_tensors.quantization import FP8_DTYPE | ||
|
||
from vllm._custom_ops import cutlass_scaled_mm, scaled_fp8_quant | ||
from vllm.compilation.fusion import (FusionPass, find_auto_fn, | ||
find_auto_fn_maybe) | ||
from vllm.model_executor.layers.layernorm import RMSNorm | ||
|
||
from .backend import TestBackend | ||
|
||
|
||
class TestModel(torch.nn.Module): | ||
|
||
def __init__(self, hidden_size: int, eps: float, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)] | ||
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(4)] | ||
self.w = [ | ||
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() | ||
for _ in range(2) | ||
] | ||
|
||
def forward(self, x): | ||
resid = torch.relu(x) | ||
y = self.norm[0](x) | ||
yq, s0 = scaled_fp8_quant(y, self.scale[0]) | ||
x2 = cutlass_scaled_mm(yq, | ||
self.w[0], | ||
s0, | ||
self.scale[1], | ||
out_dtype=x.dtype) | ||
# make sure resid is used for replacement to work | ||
y2, resid = self.norm[1](x2, resid) | ||
yq2, s2 = scaled_fp8_quant(y2, self.scale[2]) | ||
x3 = cutlass_scaled_mm(yq2, | ||
self.w[1], | ||
s2, | ||
self.scale[3], | ||
out_dtype=x.dtype) | ||
y3, resid = self.norm[2](x3, resid) # use resid here | ||
return y3 | ||
|
||
|
||
# Init does pattern registration, which can only happen once | ||
fusion_pass = FusionPass() | ||
|
||
|
||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) | ||
@pytest.mark.parametrize("hidden_size", [64, 3392, 4096]) | ||
@pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049]) | ||
@pytest.mark.parametrize("eps", [1e-5, 1e-6]) | ||
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps): | ||
torch.set_default_device("cuda") | ||
torch.set_default_dtype(torch.float16) | ||
|
||
if eps != 1e-5: | ||
pytest.skip("Only test eps=1e-5 for now") | ||
|
||
backend = TestBackend(fusion_pass) | ||
model = TestModel(hidden_size, eps) | ||
|
||
x = torch.rand(num_tokens, hidden_size) | ||
result = model(x) | ||
|
||
model2 = torch.compile(model, backend=backend) | ||
result2 = model2(x) | ||
|
||
# Check that it gives the same answer | ||
torch.testing.assert_close(result, result2, atol=1e-3, rtol=1e-3) | ||
|
||
# Check substitution worked | ||
pre_nodes = backend.graph_pre_pass.nodes | ||
post_nodes = backend.graph_post_pass.nodes | ||
|
||
rms_quant = torch.ops._C.rms_norm_static_fp8_quant.default | ||
add_rms_quant = torch.ops._C.fused_add_rms_norm_static_fp8_quant.default | ||
fp8_quant = torch.ops._C.static_scaled_fp8_quant.default | ||
|
||
# In pre-nodes, fp8 quant should be present and fused kernels should not | ||
assert find_auto_fn_maybe(pre_nodes, rms_quant) is None | ||
assert find_auto_fn_maybe(pre_nodes, add_rms_quant) is None | ||
find_auto_fn(pre_nodes, fp8_quant) | ||
|
||
# In post-nodes, fused kernels should be present and fp8 quant should not | ||
find_auto_fn(post_nodes, rms_quant) | ||
find_auto_fn(post_nodes, add_rms_quant) | ||
assert find_auto_fn_maybe(post_nodes, fp8_quant) is None |