Skip to content

Commit

Permalink
[torch.compile] use interpreter with stable api from pytorch (#9889)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao authored Nov 1, 2024
1 parent 4581d2c commit aff1fd8
Showing 1 changed file with 89 additions and 76 deletions.
165 changes: 89 additions & 76 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,65 @@ def split_graph(graph: fx.GraphModule,
return split_gm, outputs


# we share the global graph pool among all the backends
global_graph_pool = None


class PiecewiseCompileInterpreter(torch.fx.Interpreter):
"""Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
It runs the given graph with fake inputs, and compile some
submodules specified by `compile_submod_names` with the given
compilation configs.
"""

def __init__(self, module: torch.fx.GraphModule,
compile_submod_names: List[str],
compilation_configs: CompilationConfig, graph_pool):
super().__init__(module)
from torch._guards import detect_fake_mode
self.fake_mode = detect_fake_mode()
self.compile_submod_names = compile_submod_names
self.compilation_configs = compilation_configs
self.graph_pool = graph_pool
self.have_seen_first_graph = False

def run(self, *args):
fake_args = [
self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
for t in args
]
return super().run(*fake_args)

def call_module(self, target: torch.fx.node.Target,
args: Tuple[torch.fx.node.Argument,
...], kwargs: Dict[str, Any]) -> Any:
assert isinstance(target, str)
output = super().call_module(target, args, kwargs)

if target in self.compile_submod_names:
submod = self.fetch_attr(target)
sym_shape_indices = [
i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
]
compiled_graph_for_general_shape = wrap_inductor(
submod,
args,
self.compilation_configs.inductor_compile_config,
runtime_shape=None,
do_logging=not self.have_seen_first_graph,
use_inductor=self.compilation_configs.use_inductor)

self.module.__dict__[target] = PiecewiseBackend(
submod, self.compilation_configs, self.graph_pool,
not self.have_seen_first_graph, sym_shape_indices,
compiled_graph_for_general_shape)

self.have_seen_first_graph = True
compilation_counter.num_piecewise_capturable_graphs_seen += 1

return output


class VllmBackend:
"""The compilation backend for `torch.compile` with VLLM.
It is used for compilation level of `CompilationLevel.PIECEWISE`,
Expand All @@ -263,8 +322,14 @@ class VllmBackend:
returned_callable: Callable

def __init__(self, ):
# every instance of VllmBackend has its own graph pool
self.graph_pool = torch.cuda.graph_pool_handle()
global global_graph_pool
if global_graph_pool is None:
global_graph_pool = torch.cuda.graph_pool_handle()

# TODO: in the future, if we want to use multiple
# streams, it might not be safe to share a global pool.
# only investigate this when we use multiple streams
self.graph_pool = global_graph_pool

# `torch.compile` is JIT compiled, so we don't need to
# do anything here
Expand All @@ -286,55 +351,26 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
self.split_gm, self.piecewise_graphs = split_graph(
graph, self.compilation_configs.non_cudagraph_ops)

returned_callable: Callable # type: ignore
from torch._dynamo.utils import lazy_format_graph_code
logger.debug("%s",
lazy_format_graph_code("stiching module", self.split_gm))

if len(self.piecewise_graphs) == 0:
compilation_counter.num_piecewise_graphs_seen += 1
compilation_counter.num_piecewise_capturable_graphs_seen += 1
returned_callable = PiecewiseBackend(graph,
self.compilation_configs,
self.graph_pool,
is_first_graph=True)
else:
from torch._dynamo.utils import lazy_format_graph_code
logger.debug(
"%s", lazy_format_graph_code("stiching module", self.split_gm))

is_first_graph = True

for item in self.piecewise_graphs:
compilation_counter.num_piecewise_graphs_seen += 1
compilation_counter.num_piecewise_capturable_graphs_seen += not item.is_splitting_graph # noqa
if not item.is_splitting_graph:
# cannot setattr to a module, so we need to set
# the attribute in the __dict__
self.split_gm.__dict__[
item.submod_name] = PiecewiseBackend(
item.graph, self.compilation_configs,
self.graph_pool, is_first_graph)
is_first_graph = False
returned_callable = self.split_gm

self.returned_callable = returned_callable
# trigger the first compilation
# code borrowed from https://github.com/pytorch/pytorch/blob/4e3e08b71171fa34172b2362ff668553fac75f27/torch/_dynamo/backends/distributed.py#L206 # noqa
# to turn the inputs into fake tensors
import torch._guards
from torch._guards import detect_fake_mode
fake_mode = detect_fake_mode(example_inputs)
fake_args = []
for arg in example_inputs:
if isinstance(arg, torch.Tensor) and not isinstance(
arg, torch._subclasses.FakeTensor):
fake_args.append(
torch._dynamo.utils.to_fake_tensor(arg, fake_mode))
else:
fake_args.append(arg)
self.returned_callable(*fake_args)
compilation_counter.num_piecewise_graphs_seen += len(
self.piecewise_graphs)
submod_names_to_compile = [
item.submod_name for item in self.piecewise_graphs
if not item.is_splitting_graph
]

# propagate the split graph to the piecewise backend,
# compile submodules with symbolic shapes
PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile,
self.compilation_configs,
self.graph_pool).run(*example_inputs)

self._called = True

return self.returned_callable
return self.split_gm


@dataclasses.dataclass
Expand All @@ -352,11 +388,10 @@ class ConcreteSizeEntry:

class PiecewiseBackend:

def __init__(self,
graph: fx.GraphModule,
compilation_configs: CompilationConfig,
graph_pool: Any,
is_first_graph: bool = False):
def __init__(self, graph: fx.GraphModule,
compilation_configs: CompilationConfig, graph_pool: Any,
is_first_graph: bool, sym_shape_indices: List[int],
compiled_graph_for_general_shape: Callable):
"""
The backend for piecewise compilation.
It mainly handles the compilation and cudagraph capturing.
Expand All @@ -381,12 +416,11 @@ def __init__(self,
self.compilation_configs.capture_sizes
) if self.compilation_configs.use_cudagraph else set()

self.compile_finished = False
self.first_run_finished = False

self.compiled_graph_for_general_shape: Callable = None # type: ignore
self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa

self.sym_shape_indices: List[int] = []
self.sym_shape_indices = sym_shape_indices

# the entries for different shapes that we need to either
# compile or capture cudagraph
Expand All @@ -399,27 +433,6 @@ def __init__(self,
)

def __call__(self, *args) -> Any:

if not self.compile_finished:
self.compile_finished = True

# this is the first compilation, we will compile a graph with
# dynamic shape, as the caller will mark first dimension as dynamic

self.sym_shape_indices = [
i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
]

self.compiled_graph_for_general_shape = wrap_inductor(
self.graph,
args,
self.compilation_configs.inductor_compile_config,
runtime_shape=None,
do_logging=self.is_first_graph,
use_inductor=self.compilation_configs.use_inductor)

return self.graph(*args)

if not self.first_run_finished:
self.first_run_finished = True
return self.compiled_graph_for_general_shape(*args)
Expand Down

0 comments on commit aff1fd8

Please sign in to comment.