Skip to content

Commit

Permalink
[torch.compile] Inductor code caching fix (vllm-project#10273)
Browse files Browse the repository at this point in the history
Signed-off-by: luka <[email protected]>
Signed-off-by: Luka Govedic <[email protected]>
  • Loading branch information
ProExpertProg authored Nov 21, 2024
1 parent 9d82717 commit 8b0fe06
Show file tree
Hide file tree
Showing 14 changed files with 604 additions and 288 deletions.
16 changes: 10 additions & 6 deletions tests/compile/backend.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from copy import deepcopy
from typing import Callable
from typing import Callable, Union

import torch
from torch import fx

from vllm.compilation.inductor_pass import InductorPass


class TestBackend:
Expand All @@ -11,19 +13,21 @@ class TestBackend:
It also saves the graph before and after the custom passes for inspection.
"""

def __init__(self, *args: Callable[[torch.fx.Graph], None]):
self.custom_passes = args
def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph],
None]]):
self.custom_passes = list(passes)
from torch._inductor import config
self.current_config = config.shallow_copy_dict()
self.current_config['force_disable_caches'] = True
self.current_config['post_grad_custom_post_pass'] = self.post_pass

def __call__(self, graph: torch.fx.GraphModule, example_inputs):
def __call__(self, graph: fx.GraphModule, example_inputs):
from torch._inductor.compile_fx import compile_fx
return compile_fx(graph,
example_inputs,
config_patches=self.current_config)

def post_pass(self, graph: torch.fx.Graph):
def post_pass(self, graph: fx.Graph):
self.graph_pre_pass = deepcopy(graph)
for pass_ in self.custom_passes:
pass_(graph)
Expand Down
95 changes: 95 additions & 0 deletions tests/compile/test_functionalization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import pytest
import torch

import vllm.envs as envs
from vllm import LLM, SamplingParams
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.fusion import (FusionPass, find_auto_fn,
find_auto_fn_maybe)
from vllm.compilation.reshapes import RedundantReshapesPass
from vllm.compilation.vllm_inductor_pass import is_func
from vllm.config import CompilationConfig

from .backend import TestBackend

OPS_IN_MODEL = [
torch.ops._C.rotary_embedding.default,
torch.ops._C.fused_add_rms_norm.default,
torch.ops._C.silu_and_mul.default,
]

RMS_OP = torch.ops._C.rms_norm.default

RMS_QUANT_OPS = {
"static_fp8": [
torch.ops._C.rms_norm_static_fp8_quant.default,
torch.ops._C.fused_add_rms_norm_static_fp8_quant.default
],
}

prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]


@pytest.mark.parametrize("model",
["nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"])
@pytest.mark.parametrize("do_fusion", [True, False])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda",
reason="Only test on CUDA")
def test_fix_functionalization(model: str, do_fusion: bool):
torch.set_default_device("cuda")

config = CompilationConfig.PassConfig(enable_fusion=do_fusion,
enable_reshape=True)
reshape_pass = RedundantReshapesPass(config)
fusion_pass = FusionPass.instance(config)

passes = [reshape_pass, fusion_pass] if do_fusion else [reshape_pass]
func_pass = FixFunctionalizationPass(config)
backend_func = TestBackend(*passes, func_pass)
backend_no_func = TestBackend(*passes)

# instantiate a full engine and manually compile the model 2x
# (with and without FixFunctionalizationPass)
llm = LLM(model=model, enforce_eager=True)
model_runner = llm.llm_engine.model_executor.driver_worker.model_runner
orig_model = model_runner.model
# TODO mark inputs dynamic? (currently torch.compile is triggered 4x)
# Can only do that by using the decorator but then we'd have to instantiate
# 2 LLM instances.

sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
model_runner.model = torch.compile(orig_model,
fullgraph=True,
backend=backend_func)
gen_func = llm.generate(prompts, sampling_params)

model_runner.model = torch.compile(orig_model,
fullgraph=True,
backend=backend_no_func)
gen_no_func = llm.generate(prompts, sampling_params)

for output_func, output_no_func in zip(gen_func, gen_no_func):
assert output_func.outputs[0].text == output_no_func.outputs[0].text

# OPS_IN_MODEL always appear. RMS_OP is fused away if we run fusion,
# and replaced by fused quantized ops in RMS_QUANT_OPS.
ops = OPS_IN_MODEL + (RMS_QUANT_OPS["static_fp8"]
if do_fusion else [RMS_OP])

for op in ops:
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes,
op) is None # noqa: E501

# make sure the ops were all de-functionalized
found = dict()
for node in backend_func.graph_post_pass.nodes:
for op in ops:
if is_func(node, op):
found[op] = True
assert all(found[op] for op in ops)
11 changes: 5 additions & 6 deletions tests/compile/test_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,6 @@ def forward(self, x):
return y3


# Init does pattern registration, which can only happen once
config = CompilationConfig(enable_fusion=True)
reshape_pass = RedundantReshapesPass(config)
fusion_pass = FusionPass.instance(config)


@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [64, 3392, 4096])
@pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049])
Expand All @@ -58,6 +52,11 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps):
pytest.skip("Only test eps=1e-5 for now")

# Reshape pass is needed for the fusion pass to work
config = CompilationConfig.PassConfig(enable_fusion=True,
enable_reshape=True)
reshape_pass = RedundantReshapesPass(config)
fusion_pass = FusionPass.instance(config)

backend = TestBackend(reshape_pass, fusion_pass)
model = TestModel(hidden_size, eps)

Expand Down
35 changes: 35 additions & 0 deletions tests/compile/test_pass_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import pickle

import pytest
import torch
from torch._inductor.codecache import BypassFxGraphCache

from vllm.compilation.config import CompilationConfig
from vllm.compilation.inductor_pass import (CallableInductorPass,
as_inductor_pass)
from vllm.compilation.pass_manager import PostGradPassManager


def simple_callable(graph: torch.fx.Graph):
pass


@as_inductor_pass(files=(__file__, ))
def callable_decorated(graph: torch.fx.Graph):
pass


@pytest.mark.parametrize(
"works, callable",
[(False, simple_callable), (True, callable_decorated),
(True, CallableInductorPass(simple_callable, "simple_callable"))])
def test_pass_manager(works: bool, callable):
config = CompilationConfig().pass_config
pass_manager = PostGradPassManager([callable])
pass_manager.configure(config) # Adds default passes

if works:
pickle.dumps(pass_manager)
else:
with pytest.raises(BypassFxGraphCache):
pickle.dumps(pass_manager)
Loading

0 comments on commit 8b0fe06

Please sign in to comment.