diff --git a/tests/compile/backend.py b/tests/compile/backend.py new file mode 100644 index 0000000000000..c06c15bb17965 --- /dev/null +++ b/tests/compile/backend.py @@ -0,0 +1,33 @@ +from copy import deepcopy +from typing import Callable + +import torch + + +class TestBackend(): + """ + This class provides a simple Inductor backend that can be used for testing. + It takes a list of custom passes and runs them after Inductor's passes. + It also saves the graph before and after the custom passes for inspection. + """ + + def __init__(self, *args: Callable[[torch.fx.Graph], None]): + self.custom_passes = args + from torch._inductor import config + self.current_config = config.shallow_copy_dict() + self.current_config['post_grad_custom_post_pass'] = self.post_pass + + def __call__(self, graph: torch.fx.GraphModule, example_inputs): + from torch._inductor.compile_fx import compile_fx + return compile_fx(graph, + example_inputs, + config_patches=self.current_config) + + def post_pass(self, graph: torch.fx.Graph): + self.graph_pre_pass = deepcopy(graph) + for pass_ in self.custom_passes: + pass_(graph) + + self.graph_post_pass = deepcopy(graph) + # assign by reference, will reflect the final state of the graph + self.final_graph = graph diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py new file mode 100644 index 0000000000000..81cd66795d310 --- /dev/null +++ b/tests/compile/test_fusion.py @@ -0,0 +1,88 @@ +import pytest +import torch +from compressed_tensors.quantization import FP8_DTYPE + +from vllm._custom_ops import cutlass_scaled_mm, scaled_fp8_quant +from vllm.compilation.fusion import (FusionPass, find_auto_fn, + find_auto_fn_maybe) +from vllm.model_executor.layers.layernorm import RMSNorm + +from .backend import TestBackend + + +class TestModel(torch.nn.Module): + + def __init__(self, hidden_size: int, eps: float, *args, **kwargs): + super().__init__(*args, **kwargs) + self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)] + self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(4)] + self.w = [ + torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() + for _ in range(2) + ] + + def forward(self, x): + resid = torch.relu(x) + y = self.norm[0](x) + yq, s0 = scaled_fp8_quant(y, self.scale[0]) + x2 = cutlass_scaled_mm(yq, + self.w[0], + s0, + self.scale[1], + out_dtype=x.dtype) + # make sure resid is used for replacement to work + y2, resid = self.norm[1](x2, resid) + yq2, s2 = scaled_fp8_quant(y2, self.scale[2]) + x3 = cutlass_scaled_mm(yq2, + self.w[1], + s2, + self.scale[3], + out_dtype=x.dtype) + y3, resid = self.norm[2](x3, resid) # use resid here + return y3 + + +# Init does pattern registration, which can only happen once +fusion_pass = FusionPass() + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("hidden_size", [64, 3392, 4096]) +@pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049]) +@pytest.mark.parametrize("eps", [1e-5, 1e-6]) +def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps): + torch.set_default_device("cuda") + torch.set_default_dtype(torch.float16) + + if eps != 1e-5: + pytest.skip("Only test eps=1e-5 for now") + + backend = TestBackend(fusion_pass) + model = TestModel(hidden_size, eps) + + x = torch.rand(num_tokens, hidden_size) + result = model(x) + + model2 = torch.compile(model, backend=backend) + result2 = model2(x) + + # Check that it gives the same answer + torch.testing.assert_close(result, result2, atol=1e-3, rtol=1e-3) + + # Check substitution worked + pre_nodes = backend.graph_pre_pass.nodes + post_nodes = backend.graph_post_pass.nodes + + rms_quant = torch.ops._C.rms_norm_static_fp8_quant.default + add_rms_quant = torch.ops._C.fused_add_rms_norm_static_fp8_quant.default + fp8_quant = torch.ops._C.static_scaled_fp8_quant.default + + # In pre-nodes, fp8 quant should be present and fused kernels should not + assert find_auto_fn_maybe(pre_nodes, rms_quant) is None + assert find_auto_fn_maybe(pre_nodes, add_rms_quant) is None + find_auto_fn(pre_nodes, fp8_quant) + + # In post-nodes, fused kernels should be present and fp8 quant should not + find_auto_fn(post_nodes, rms_quant) + find_auto_fn(post_nodes, add_rms_quant) + assert find_auto_fn_maybe(post_nodes, fp8_quant) is None