Skip to content

Commit

Permalink
fix some todos
Browse files Browse the repository at this point in the history
Signed-off-by: Bill Nell <[email protected]>
  • Loading branch information
bnellnm committed Nov 5, 2024
1 parent 0e56f64 commit 0a1f637
Showing 1 changed file with 51 additions and 52 deletions.
103 changes: 51 additions & 52 deletions vllm/compilation/collective_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
from vllm.distributed import (tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce)
from vllm.distributed.parallel_state import (
get_group_from_group_name, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
get_group_from_group_name, get_tensor_model_parallel_world_size)
from vllm.logger import init_logger
from vllm.utils import direct_register_custom_op

Expand All @@ -31,12 +30,12 @@
use_flux = False


# how to do this properly?
# TODO: is this right?
def get_world_name() -> str:
return torch.distributed.group.WORLD.group_name


# This check is a hack
# TODO: This check is a hack
def should_slice(shape) -> bool:
n_slices = get_tensor_model_parallel_world_size()
return (shape[0] % n_slices == 0 and shape[0] >= 128)
Expand All @@ -59,6 +58,7 @@ def match_gemm_rs_ag_gemm(
gemm_1_w_perm = torch.ops.aten.permute.default(gemm_1_weights, [1, 0])
mm_1 = torch.ops.aten.mm.default(gemm_1_activations, gemm_1_w_perm)

# It would be nice to do this instead of having two separate patterns
#all_reduce = tensor_model_parallel_all_reduce(mm_1)
if custom_ar:
all_reduce = torch.ops.vllm.outplace_all_reduce.default(
Expand Down Expand Up @@ -87,47 +87,20 @@ def match_gemm_rs_ag_gemm(
return match_gemm_rs_ag_gemm


def gemm_rs_ag_gemm_fake(
residual: torch.Tensor,
my_residual: torch.Tensor,
gemm_1_weights: torch.Tensor,
gemm_1_activations: torch.Tensor,
rms_norm_weight: torch.Tensor,
gemm_2_weights: torch.Tensor,
first_layer: bool,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

if first_layer and should_slice(gemm_1_activations.shape):
res_slices = slice_residual(residual)
# is this rank ok?
slice_size = res_slices[get_tensor_model_parallel_rank()].shape[0]
split_1 = torch.ops.aten.split.Tensor(residual, slice_size)
my_residual = split_1[0]
else:
my_residual = residual

# verify the type is always correct
mm_res = torch.empty(
(gemm_1_activations.shape[0], gemm_2_weights.shape[0]),
device=gemm_1_activations.device,
dtype=gemm_1_activations.dtype)

return (mm_res, my_residual, residual)


# TODO: factor out groupnames, etc.
def get_gemm_rs_ag_gemm(use_flux: bool, gemm_1_type,
gemm_1_weights: torch.Size, gemm_2_type,
gemm_2_weights: torch.Size, tp_group_name: str):

group = get_group_from_group_name(tp_group_name)
device_group = group.device_group
rank = group.rank_in_group

if use_flux:
device_group = get_group_from_group_name(tp_group_name).device_group
gemm_rs_op = flux.GemmRS(
device_group,
1, # One node
8192, # Max M. TODO: Pass in correctly.
gemm_1_weights[0], # N
# TODO: Pass in input dtype correctly.
# TODO: It would be nicer to modify flux to dispatch based on dtype
# at run time, but I don't know what the downside would be.
# Similar comment for max m.
Expand All @@ -144,7 +117,6 @@ def get_gemm_rs_ag_gemm(use_flux: bool, gemm_1_type,
8192, # Max M. TODO: Pass in correctly.
gemm_2_weights[0], # N
gemm_2_weights[1], # K
# TODO: Pass in input dtype correctly.
# TODO: It would be nicer to modify flux to dispatch based on dtype
# at run time, but I don't know what the downside would be.
# Similar comment for max m.
Expand Down Expand Up @@ -192,12 +164,11 @@ def gemm_rs_ag_gemm(

if first_layer and should_slice(residual.shape):
res_slices = slice_residual(residual)
# is this rank ok?
slice_size = res_slices[get_tensor_model_parallel_rank()].shape[0]
slice_size = res_slices[rank].shape[0]
residual_chunk = torch.ops.aten.split.Tensor(residual, slice_size)
my_residual = residual_chunk[0]
else:
my_residual = residual #.clone()
my_residual = residual
slice_size = residual.shape[0]

if not should_slice(residual.shape):
Expand Down Expand Up @@ -225,14 +196,37 @@ def gemm_rs_ag_gemm(
slice_scatter = torch.ops.aten.slice_scatter.default(
residual_1, my_residual, 0, 0, slice_size)
split_2 = torch.ops.aten.split.Tensor(slice_scatter, slice_size)

# TODO: can we avoid clone here?
new_residual = split_2[0] #.clone()
new_residual = split_2[0]

mm_2 = ag_gemm(output, gemm_2_weights)

return mm_2[0], new_residual, slice_scatter

def gemm_rs_ag_gemm_fake(
residual: torch.Tensor,
my_residual: torch.Tensor,
gemm_1_weights: torch.Tensor,
gemm_1_activations: torch.Tensor,
rms_norm_weight: torch.Tensor,
gemm_2_weights: torch.Tensor,
first_layer: bool,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if first_layer and should_slice(gemm_1_activations.shape):
res_slices = slice_residual(residual)
slice_size = res_slices[rank].shape[0]
split_1 = torch.ops.aten.split.Tensor(residual, slice_size)
my_residual = split_1[0]
else:
my_residual = residual

# TODO: verify the type is always correct
mm_res = torch.empty(
(gemm_1_activations.shape[0], gemm_2_weights.shape[0]),
device=gemm_1_activations.device,
dtype=gemm_1_activations.dtype)

return (mm_res, my_residual, residual)

if not hasattr(torch.ops.vllm, name):
logger.info("registering torch.ops.vllm.%s", name)
direct_register_custom_op(name,
Expand All @@ -255,6 +249,7 @@ def match_final(
gemm_1_w_perm = torch.ops.aten.permute.default(gemm_1_weights, [1, 0])
mm_1 = torch.ops.aten.mm.default(gemm_1_activations, gemm_1_w_perm)

# TODO: it would be nice to be able to use the official api directly.
#all_reduce = tensor_model_parallel_all_reduce(mm_1)
if use_custom_ar:
all_reduce = torch.ops.vllm.outplace_all_reduce.default(
Expand Down Expand Up @@ -322,6 +317,8 @@ def __init__(self):
self.final_pattern = PatternMatcherPass()
self.matches: List[Match] = []

# Run in fake mode so that we don't call real functions
# when tracing the patterns.
with torch._dynamo.utils.detect_fake_mode():
x = torch.empty([4, 4], device='cuda')
w = torch.empty([4, 4], device='cuda')
Expand Down Expand Up @@ -351,13 +348,9 @@ def __init__(self):
get_match_final(group_name, False),
get_match_final(group_name, True)
]:
register_replacement(
m,
torch.ops.vllm.gemm_ag_final,
#replace_final,
final_inputs,
fwd_only,
[self.final_pattern])
register_replacement(m, torch.ops.vllm.gemm_ag_final,
final_inputs, fwd_only,
[self.final_pattern])

def record_match(self, match: Match) -> bool:
# Hijack the extra_check to record the match and
Expand Down Expand Up @@ -394,6 +387,8 @@ def find_min_index(match: Match) -> int:
gemm_1 = kwargs["gemm_1_weights"].meta["val"]
gemm_2 = kwargs["gemm_2_weights"].meta["val"]

# Extract group_name from matched code. Use to
# generate proper replacement code.
ar_node = find_auto_fn(
match.nodes, torch.ops.vllm.inplace_all_reduce.default)
if ar_node is not None:
Expand All @@ -405,9 +400,13 @@ def find_min_index(match: Match) -> int:
assert ar_node is not None
tp_group_name = ar_node.args[1]

fused_node = graph.call_function(get_gemm_rs_ag_gemm(
use_flux, gemm_1.dtype, gemm_1.shape, gemm_2.dtype,
gemm_2.shape, tp_group_name),
fused_gemm_func = get_gemm_rs_ag_gemm(use_flux, gemm_1.dtype,
gemm_1.shape,
gemm_2.dtype,
gemm_2.shape,
tp_group_name)

fused_node = graph.call_function(fused_gemm_func,
kwargs=kwargs)

graph.inserting_after(fused_node)
Expand Down

0 comments on commit 0a1f637

Please sign in to comment.