Skip to content

Commit

Permalink
cleanups
Browse files Browse the repository at this point in the history
Signed-off-by: Bill Nell <[email protected]>
  • Loading branch information
bnellnm committed Nov 9, 2024
1 parent 6d20979 commit 1683f80
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 67 deletions.
4 changes: 2 additions & 2 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def fix_functionalization(graph: fx.Graph):

def wrap_inductor(graph,
example_inputs,
additional_inductor_config=None,
additional_inductor_config: Optional[Dict] = None,
do_logging=False,
runtime_shape: Optional[int] = None,
use_inductor: bool = True):
Expand All @@ -233,7 +233,7 @@ def wrap_inductor(graph,
from torch._inductor import config

torch._inductor.config._micro_pipeline_tp = True
# Set to False to avoid infinite recursion logging
# Set to False to avoid infinite recursion logging?
torch._inductor.config.implicit_fallbacks = True

current_config = config.shallow_copy_dict()
Expand Down
119 changes: 56 additions & 63 deletions vllm/compilation/collective_fusion.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import operator
from typing import List, Optional, Tuple
from typing import Callable, List, Optional, Tuple

import torch
import torch.fx as fx
from torch._inductor.pattern_matcher import (Match, PatternMatcherPass,
fwd_only, register_replacement)

import vllm._custom_ops as ops
import vllm.envs as envs
from vllm.compilation.config import CompilationConfig
from vllm.compilation.inductor_pass import InductorPass
Expand All @@ -33,31 +32,31 @@
FLUX_TILE_SIZE: int = 128


# TODO: is this right?
# TODO: is this ok?
def get_world_name() -> str:
return torch.distributed.group.WORLD.group_name


# Note: this heuristic is unique to flux
def should_slice(shape) -> bool:
def should_slice(shape: torch.Size) -> bool:
n_slices = get_tensor_model_parallel_world_size()
return (shape[0] % (FLUX_TILE_SIZE * n_slices) == 0
and shape[0] >= FLUX_TILE_SIZE * n_slices)


def residual_slice_shape(residual, rank) -> int:
def residual_slice_shape(residual: torch.Tensor, rank: int) -> int:
n_slices = get_tensor_model_parallel_world_size()
chunk, rem = divmod(residual.shape[0], n_slices)
return chunk if rank < n_slices - 1 or rem == 0 else rem


def residual_slice_shape_fake(residual, rank) -> int:
def residual_slice_shape_fake(residual: torch.Tensor, rank: int) -> int:
n_slices = get_tensor_model_parallel_world_size()
slices = torch.chunk(residual, n_slices, dim=0)
return slices[rank].shape[0]


def get_match_gemm_rs_ag_gemm(tp_group_name: str, custom_ar: bool):
def get_match_gemm_rs_ag_gemm(tp_group_name: str, custom_ar: bool) -> Callable:

def match_gemm_rs_ag_gemm(
residual: torch.Tensor,
Expand All @@ -69,8 +68,8 @@ def match_gemm_rs_ag_gemm(
gemm_1_w_perm = torch.ops.aten.permute.default(gemm_1_weights, [1, 0])
mm_1 = torch.ops.aten.mm.default(gemm_1_activations, gemm_1_w_perm)

# It would be nice to do this instead of having two separate patterns
#all_reduce = tensor_model_parallel_all_reduce(mm_1)
# It would be nice to do this instead of having two separate patterns.
# all_reduce = tensor_model_parallel_all_reduce(mm_1)
if custom_ar:
all_reduce = torch.ops.vllm.outplace_all_reduce.default(
mm_1, tp_group_name)
Expand Down Expand Up @@ -98,10 +97,10 @@ def match_gemm_rs_ag_gemm(
return match_gemm_rs_ag_gemm


def get_gemm_rs_ag_gemm(use_flux: bool, gemm_1_type,
gemm_1_weights: torch.Size, gemm_1_max_m: int,
gemm_2_type, gemm_2_weights: torch.Size,
gemm_2_max_m: int, tp_group_name: str):
def get_gemm_rs_ag_gemm(use_flux: bool, max_m: int, gemm_1_type: torch.dtype,
gemm_1_weights: torch.Size, gemm_2_type: torch.dtype,
gemm_2_weights: torch.Size,
tp_group_name: str) -> Callable:

group = get_group_from_group_name(tp_group_name)
device_group = group.device_group
Expand All @@ -111,7 +110,7 @@ def get_gemm_rs_ag_gemm(use_flux: bool, gemm_1_type,
gemm_rs_op = flux.GemmRS(
device_group,
1, # One node
gemm_1_max_m, # M
max_m, # max M
gemm_1_weights[0], # N
# TODO: It would be nicer to modify flux to dispatch based on dtype
# at run time, but I don't know what the downside would be.
Expand All @@ -126,7 +125,7 @@ def get_gemm_rs_ag_gemm(use_flux: bool, gemm_1_type,
ag_gemm_op = flux.AGKernel(
device_group,
1, # One node
gemm_2_max_m, # M
max_m, # max M
gemm_2_weights[0], # N
gemm_2_weights[1], # K
# TODO: It would be nicer to modify flux to dispatch based on dtype
Expand All @@ -149,10 +148,9 @@ def get_gemm_rs_ag_gemm(use_flux: bool, gemm_1_type,
gemm_1_str = str(gemm_1_type).removeprefix("torch.")
gemm_2_str = str(gemm_2_type).removeprefix("torch.")
group_str = tp_group_name.replace(":", "_")
name = (
f"gemm_rs_ag_gemm_{gemm_1_str}_{gemm_1_weights[0]}_{gemm_1_max_m}_"
f"{gemm_2_str}_{gemm_2_weights[0]}_{gemm_2_weights[1]}_{gemm_2_max_m}_"
f"{group_str}")
name = (f"gemm_rs_ag_gemm_{max_m}_{gemm_1_str}_{gemm_1_weights[0]}_"
f"{gemm_2_str}_{gemm_2_weights[0]}_{gemm_2_weights[1]}_"
f"{group_str}")
else:
world_group_name = get_world_name()

Expand Down Expand Up @@ -187,27 +185,31 @@ def gemm_rs_ag_gemm(
gemm_1_weights.transpose(1, 0))
reduced_output = tensor_model_parallel_all_reduce(output)

ops.fused_add_rms_norm(input=reduced_output,
residual=my_residual,
weight=rms_norm_weight,
epsilon=1e-05)
torch.ops._C.fused_add_rms_norm.default(input=reduced_output,
residual=my_residual,
weight=rms_norm_weight,
epsilon=1e-05)

mm_2 = torch.ops.aten.mm.default(reduced_output,
gemm_2_weights.transpose(1, 0))
return mm_2, my_residual, my_residual.clone()
else:
output = gemm_rs(gemm_1_activations, gemm_1_weights)

ops.fused_add_rms_norm(input=output,
residual=my_residual,
weight=rms_norm_weight,
epsilon=1e-05)
torch.ops._C.fused_add_rms_norm.default(input=output,
residual=my_residual,
weight=rms_norm_weight,
epsilon=1e-05)

residual_1 = residual if first_layer else old_my_residual
slice_scatter = torch.ops.aten.slice_scatter.default(
residual_1, my_residual, 0, 0, slice_shape)
split_2 = torch.ops.aten.split.Tensor(slice_scatter, slice_shape)
new_residual = split_2[0]
#if False:
#slice_scatter = torch.ops.aten.slice_scatter.default(
# residual_1, my_residual, 0, 0, slice_shape)
#split_2 = torch.ops.aten.split.Tensor(slice_scatter, slice_shape)
#new_residual = split_2[0]
#else:
slice_scatter = my_residual
new_residual = residual_1

mm_2 = ag_gemm(output, gemm_2_weights)

Expand Down Expand Up @@ -248,7 +250,7 @@ def gemm_rs_ag_gemm_fake(
return getattr(torch.ops.vllm, name).default


def get_match_final(tp_group_name: str, use_custom_ar: bool):
def get_match_final(tp_group_name: str, use_custom_ar: bool) -> Callable:

def match_final(
my_residual: torch.Tensor,
Expand All @@ -260,7 +262,7 @@ def match_final(
mm_1 = torch.ops.aten.mm.default(gemm_1_activations, gemm_1_w_perm)

# TODO: it would be nice to be able to use the official api directly.
#all_reduce = tensor_model_parallel_all_reduce(mm_1)
# all_reduce = tensor_model_parallel_all_reduce(mm_1)
if use_custom_ar:
all_reduce = torch.ops.vllm.outplace_all_reduce.default(
mm_1, tp_group_name)
Expand Down Expand Up @@ -288,8 +290,8 @@ def match_final(
def gemm_ag_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor,
gemm_1_activations: torch.Tensor,
rms_norm_weights: torch.Tensor) -> torch.Tensor:
permute_254 = torch.ops.aten.permute.default(gemm_1_weights, [1, 0])
mm_1 = torch.ops.aten.mm.default(gemm_1_activations, permute_254)
mm_1 = torch.ops.aten.mm.default(gemm_1_activations,
gemm_1_weights.transpose(1, 0))

reduced = tensor_model_parallel_all_reduce(mm_1)

Expand All @@ -298,10 +300,10 @@ def gemm_ag_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor,
else:
wait_tensor = my_residual

ops.fused_add_rms_norm(input=reduced,
residual=wait_tensor,
weight=rms_norm_weights,
epsilon=1e-05)
torch.ops._C.fused_add_rms_norm.default(input=reduced,
residual=wait_tensor,
weight=rms_norm_weights,
epsilon=1e-05)

return reduced

Expand All @@ -325,7 +327,7 @@ class CollectiveFusionPass(InductorPass):
_instance: 'Optional[CollectiveFusionPass]' = None

@classmethod
def instance(cls, config: CompilationConfig):
def instance(cls, config: CompilationConfig) -> "CollectiveFusionPass":
"""
Get the singleton instance of the CollectiveFusionPass.
If the instance exists, the config is updated but
Expand Down Expand Up @@ -358,8 +360,8 @@ def __init__(self, config):
final_inputs = [x, w, resid, resid_w]

# register multiple patterns for all group names.
max_gpus = 8 # TODO: get this officially
group_names = [f"tp:{rank}" for rank in range(max_gpus)]
world_size = get_tensor_model_parallel_world_size()
group_names = [f"tp:{rank}" for rank in range(world_size)]

for group_name in group_names:
for m in [
Expand Down Expand Up @@ -389,24 +391,15 @@ def record_match(self, match: Match) -> bool:
# Return False to prevent automatic replacement.
return False

def find_max_m(self, matches) -> Tuple[int, int]:
gemm_1_max_m = 0
gemm_2_max_m = 0
def find_max_m(self, matches: List[Match]) -> int:
max_m = 0
for m in matches:
#gemm_1 = m.kwargs["gemm_1_weights"].meta["val"]
#gemm_2 = m.kwargs["gemm_2_weights"].meta["val"]
#gemm_1_max_m = max(gemm_1_max_m, gemm_1.shape[1])
#gemm_2_max_m = max(gemm_2_max_m, gemm_2.shape[1])
gemm_1 = m.kwargs["residual"].meta["val"]
gemm_2 = m.kwargs["residual"].meta["val"]
gemm_1_max_m = max(gemm_1_max_m, gemm_1.shape[1])
gemm_2_max_m = max(gemm_2_max_m, gemm_2.shape[1])

assert gemm_1_max_m > 0
assert gemm_2_max_m > 0
return gemm_1_max_m, gemm_2_max_m

def process_matches(self, graph: fx.Graph):
residual = m.kwargs["residual"].meta["val"]
max_m = max(max_m, residual.shape[1])
assert max_m > 0
return max_m

def process_matches(self, graph: fx.Graph) -> None:
nodes = list(graph.nodes)

def find_min_index(match: Match) -> int:
Expand All @@ -418,8 +411,8 @@ def find_min_index(match: Match) -> int:
res_replacements: List[fx.Node] = []
my_res_replacements: List[fx.Node] = []

gemm_1_max_m, gemm_2_max_m = self.find_max_m(matches)
logger.info("max m = %d, %d", gemm_1_max_m, gemm_2_max_m)
max_m = self.find_max_m(matches)
logger.info("max m = %d", max_m)

for match in matches:
last_node = last_node_in_match(match)
Expand Down Expand Up @@ -451,8 +444,8 @@ def find_min_index(match: Match) -> int:
tp_group_name = ar_node.args[1]

fused_gemm_func = get_gemm_rs_ag_gemm(
use_flux, gemm_1.dtype, gemm_1.shape, gemm_1_max_m,
gemm_2.dtype, gemm_2.shape, gemm_2_max_m, tp_group_name)
use_flux, max_m, gemm_1.dtype, gemm_1.shape, gemm_2.dtype,
gemm_2.shape, tp_group_name)

fused_node = graph.call_function(fused_gemm_func,
kwargs=kwargs)
Expand Down
3 changes: 1 addition & 2 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1135,9 +1135,8 @@ def load_model(self) -> None:

if envs.VLLM_TORCH_COMPILE_LEVEL == CompilationLevel.DYNAMO_AS_IS \
and supports_dynamo():
from vllm.compilation.backends import wrap_inductor
from vllm.plugins import get_torch_compile_backend
backend = get_torch_compile_backend() or wrap_inductor #"eager"
backend = get_torch_compile_backend() or "eager"
self.model = torch.compile(
self.model,
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
Expand Down

0 comments on commit 1683f80

Please sign in to comment.