Skip to content

Commit

Permalink
add flux support
Browse files Browse the repository at this point in the history
  • Loading branch information
bnellnm committed Nov 1, 2024
1 parent 5b2c415 commit b3200f8
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 92 deletions.
126 changes: 35 additions & 91 deletions vllm/compilation/collective_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,27 @@
fwd_only, register_replacement)

import vllm.envs as envs

from vllm.compilation.inductor_pass import InductorPass
from vllm.compilation.utils import (find_auto_fn, find_fn, find_getitem,
last_node_in_match)
from vllm.distributed import tensor_model_parallel_all_reduce, tensor_model_parallel_all_gather
from vllm.distributed import (tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce)
from vllm.distributed.parallel_state import (
get_tp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
get_tp_group)
from vllm.logger import init_logger

logger = init_logger(__name__)

use_flux = False
if envs.VLLM_USE_FLUX:
try:
import flux
use_flux = True
print("USE FLUX")
logger.info("USING FLUX")
except ImportError:
use_flux = False


logger = init_logger(__name__)

# TODO: factor out somehow
TP_GROUP_NAME = "tp:0"

Expand Down Expand Up @@ -79,71 +79,6 @@ def match_gemm_rs_ag_gemm(
return mm_2, new_residual


@torch.library.custom_op("vllm::gemm_rs_ag_gemm_old", mutates_args=())
def gemm_rs_ag_gemm(
residual: torch.Tensor, old_my_residual: torch.Tensor,
gemm_1_weights: torch.Tensor, gemm_1_activations: torch.Tensor,
rms_norm_weight: torch.Tensor, gemm_2_weights: torch.Tensor,
first_layer: bool) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

if should_slice(residual.shape) and first_layer:
res_slices = slice_residual(residual)
slice_size = res_slices[get_tensor_model_parallel_rank()].shape[0]
residual_chunk = torch.ops.aten.split.Tensor(residual, slice_size)
my_residual = residual_chunk[0]
else:
my_residual = residual
slice_size = residual.shape[0]

gemm_1_w_perm = torch.ops.aten.permute.default(gemm_1_weights, [1, 0])

if not should_slice(residual.shape):
output = torch.matmul(gemm_1_activations, gemm_1_w_perm)
reduced_output = tensor_model_parallel_all_reduce(output)

norm_res = torch.ops.higher_order.auto_functionalized(
torch.ops._C.fused_add_rms_norm.default,
input=reduced_output,
residual=my_residual,
weight=rms_norm_weight,
epsilon=1e-05)
normalized = norm_res[1]
new_residual = norm_res[2]

gemm_2_w_perm = torch.ops.aten.permute.default(gemm_2_weights, [1, 0])
mm_2 = torch.matmul(normalized, gemm_2_w_perm)
return mm_2, new_residual, new_residual.clone()
else:
group_name = get_world_name()
output = torch.ops.symm_mem.fused_matmul_reduce_scatter.default(
gemm_1_activations, gemm_1_w_perm, 'avg', 0, group_name)

norm_res = torch.ops.higher_order.auto_functionalized(
torch.ops._C.fused_add_rms_norm.default,
input=output,
residual=my_residual,
weight=rms_norm_weight,
epsilon=1e-05)
normalized = norm_res[1]
new_residual = norm_res[2]

residual_1 = residual if first_layer else old_my_residual
slice_scatter = torch.ops.aten.slice_scatter.default(
residual_1, new_residual, 0, 0, slice_size)
split_2 = torch.ops.aten.split.Tensor(slice_scatter, slice_size)
new_residual = split_2[0]

gemm_2_w_perm = torch.ops.aten.permute.default(gemm_2_weights, [1, 0])
fused_all_gather_matmul = (
torch.ops.symm_mem.fused_all_gather_matmul.default(
normalized, [gemm_2_w_perm], 0, group_name))
mm_2 = fused_all_gather_matmul[1]

# TODO: can we avoid clone here?
return mm_2[0], new_residual.clone(), slice_scatter


#@torch.library.register_fake("vllm::gemm_rs_ag_gemm")
def gemm_rs_ag_gemm_fake(
residual: torch.Tensor,
my_residual: torch.Tensor,
Expand Down Expand Up @@ -171,7 +106,8 @@ def gemm_rs_ag_gemm_fake(
return (mm_res, my_residual, residual)


def get_gemm_rs_ag_gemm(use_flux: bool, gemm_1_weights: torch.Size, gemm_2_weights: torch.Size):
def get_gemm_rs_ag_gemm(use_flux: bool, gemm_1_weights: torch.Size,
gemm_2_weights: torch.Size):

if use_flux:
gemm_rs_op = flux.GemmRS(
Expand Down Expand Up @@ -214,23 +150,27 @@ def get_gemm_rs_ag_gemm(use_flux: bool, gemm_1_weights: torch.Size, gemm_2_weigh
gemm_rs = lambda act, wt: gemm_rs_op.forward(act, wt).squeeze(0)
ag_gemm = lambda act, wt: ag_gemm_op.forward(act, wt)

name = f"gemm_rs_ag_gemm_{gemm_1_weights[0]}_{gemm_2_weights[0]}_{gemm_2_weights[1]}"
name = (f"gemm_rs_ag_gemm_{gemm_1_weights[0]}_"
f"{gemm_2_weights[0]}_{gemm_2_weights[1]}")
else:
group_name = get_world_name()

gemm_rs = lambda act, wt: torch.ops.symm_mem.fused_matmul_reduce_scatter.default(
act, wt.transpose(1,0), 'avg', 0, group_name)
gemm_rs = lambda act, wt: \
torch.ops.symm_mem.fused_matmul_reduce_scatter.default(
act, wt.transpose(1, 0), 'avg', 0, group_name)

ag_gemm = lambda act, wt: torch.ops.symm_mem.fused_all_gather_matmul.default(
act, [wt.transpose(1,0)], 0, group_name)[1]
ag_gemm = lambda act, wt: \
torch.ops.symm_mem.fused_all_gather_matmul.default(
act, [wt.transpose(1, 0)], 0, group_name)[1]

name = "gemm_rs_ag_gemm"

def gemm_rs_ag_gemm(
residual: torch.Tensor, old_my_residual: torch.Tensor,
gemm_1_weights: torch.Tensor, gemm_1_activations: torch.Tensor,
rms_norm_weight: torch.Tensor, gemm_2_weights: torch.Tensor,
first_layer: bool) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
first_layer: bool
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

if should_slice(residual.shape) and first_layer:
res_slices = slice_residual(residual)
Expand All @@ -242,7 +182,8 @@ def gemm_rs_ag_gemm(
slice_size = residual.shape[0]

if not should_slice(residual.shape):
output = torch.matmul(gemm_1_activations, gemm_1_weights.transpose(1,0))
output = torch.matmul(gemm_1_activations,
gemm_1_weights.transpose(1, 0))
reduced_output = tensor_model_parallel_all_reduce(output)

norm_res = torch.ops.higher_order.auto_functionalized(
Expand All @@ -254,10 +195,9 @@ def gemm_rs_ag_gemm(
normalized = norm_res[1]
new_residual = norm_res[2]

mm_2 = torch.matmul(normalized, gemm_2_weights.transpose(1,0))
mm_2 = torch.matmul(normalized, gemm_2_weights.transpose(1, 0))
return mm_2, new_residual, new_residual.clone()
else:
group_name = get_world_name()
output = gemm_rs(gemm_1_activations, gemm_1_weights)

norm_res = torch.ops.higher_order.auto_functionalized(
Expand All @@ -282,7 +222,9 @@ def gemm_rs_ag_gemm(

if not hasattr(torch.ops.vllm, name):
logger.info("registering torch.ops.vllm.%s", name)
torch.library.custom_op(f"vllm::{name}", gemm_rs_ag_gemm, mutates_args=())
torch.library.custom_op(f"vllm::{name}",
gemm_rs_ag_gemm,
mutates_args=())
torch.library.register_fake(f"vllm::{name}", gemm_rs_ag_gemm_fake)
assert getattr(torch.ops.vllm, name)

Expand Down Expand Up @@ -311,6 +253,7 @@ def match_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor,

return normalized


# Register this as a custom op since all reduce cannot be torch.compiled.
#@torch.library.custom_op("vllm::gemm_ag_final", mutates_args=())
def replace_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor,
Expand All @@ -329,8 +272,9 @@ def replace_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor,
if True:
group_name = get_world_name()
world_size = get_tensor_model_parallel_world_size()
all_gather = torch.ops._c10d_functional.all_gather_into_tensor.default(
my_residual, world_size, group_name)
all_gather = (
torch.ops._c10d_functional.all_gather_into_tensor.default(
my_residual, world_size, group_name))
wait_tensor = torch.ops._c10d_functional.wait_tensor.default(
all_gather)
else:
Expand All @@ -352,9 +296,9 @@ def replace_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor,
def replace_final_fake(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor,
gemm_1_activations: torch.Tensor,
rms_norm_weights: torch.Tensor) -> torch.Tensor:
return torch.empty([gemm_1_activations.shape[0],
my_residual.shape[1]],
dtype=my_residual.dtype, device=my_residual.device)
return torch.empty([gemm_1_activations.shape[0], my_residual.shape[1]],
dtype=my_residual.dtype,
device=my_residual.device)


class CollectiveFusionPass(InductorPass):
Expand Down Expand Up @@ -421,9 +365,9 @@ def find_min_index(match: Match) -> int:
gemm_1_w = kwargs["gemm_1_weights"].meta["val"].shape
gemm_2_w = kwargs["gemm_2_weights"].meta["val"].shape

fused_node = graph.call_function(
get_gemm_rs_ag_gemm(use_flux, gemm_1_w, gemm_2_w),
kwargs=kwargs)
fused_node = graph.call_function(get_gemm_rs_ag_gemm(
use_flux, gemm_1_w, gemm_2_w),
kwargs=kwargs)

graph.inserting_after(fused_node)
result_node_new = graph.call_function(operator.getitem,
Expand Down
1 change: 1 addition & 0 deletions vllm/distributed/device_communicators/pynccl_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ def ncclAllGather(self, sendbuff: buffer_type, recvbuff: buffer_type,
self.NCCL_CHECK(self._funcs["ncclAllGather"](sendbuff, recvbuff, count,
datatype, comm, stream))


__all__ = [
"NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId",
"ncclComm_t", "cudaStream_t", "buffer_type"
Expand Down
2 changes: 1 addition & 1 deletion vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ def get_default_config_root():
"VLLM_USE_V1":
lambda: bool(int(os.getenv("VLLM_USE_V1", "0"))),

# If set, try to use the flux fused collective comminucation gemm kernels
# If set, try to use the flux fused collective communication gemm kernels
"VLLM_USE_FLUX":
lambda: bool(int(os.getenv("VLLM_USE_FLUX", "0"))),
}
Expand Down

0 comments on commit b3200f8

Please sign in to comment.