Skip to content

Commit

Permalink
add flux support
Browse files Browse the repository at this point in the history
Signed-off-by: Bill Nell <[email protected]>
  • Loading branch information
bnellnm committed Nov 1, 2024
1 parent 79ac0d1 commit 31be8c7
Show file tree
Hide file tree
Showing 5 changed files with 214 additions and 18 deletions.
173 changes: 161 additions & 12 deletions vllm/compilation/collective_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,26 @@
from torch._inductor.pattern_matcher import (Match, PatternMatcherPass,
fwd_only, register_replacement)

import vllm.envs as envs

from vllm.compilation.inductor_pass import InductorPass
from vllm.compilation.utils import (find_auto_fn, find_fn, find_getitem,
last_node_in_match)
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.distributed import tensor_model_parallel_all_reduce, tensor_model_parallel_all_gather
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
get_tp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.logger import init_logger

use_flux = False
if envs.VLLM_USE_FLUX:
try:
import flux
use_flux = True
print("USE FLUX")
except ImportError:
use_flux = False


logger = init_logger(__name__)

# TODO: factor out somehow
Expand Down Expand Up @@ -67,7 +79,7 @@ def match_gemm_rs_ag_gemm(
return mm_2, new_residual


@torch.library.custom_op("vllm::gemm_rs_ag_gemm", mutates_args=())
@torch.library.custom_op("vllm::gemm_rs_ag_gemm_old", mutates_args=())
def gemm_rs_ag_gemm(
residual: torch.Tensor, old_my_residual: torch.Tensor,
gemm_1_weights: torch.Tensor, gemm_1_activations: torch.Tensor,
Expand Down Expand Up @@ -131,7 +143,7 @@ def gemm_rs_ag_gemm(
return mm_2[0], new_residual.clone(), slice_scatter


@torch.library.register_fake("vllm::gemm_rs_ag_gemm")
#@torch.library.register_fake("vllm::gemm_rs_ag_gemm")
def gemm_rs_ag_gemm_fake(
residual: torch.Tensor,
my_residual: torch.Tensor,
Expand Down Expand Up @@ -159,6 +171,124 @@ def gemm_rs_ag_gemm_fake(
return (mm_res, my_residual, residual)


def get_gemm_rs_ag_gemm(use_flux: bool, gemm_1_weights: torch.Size, gemm_2_weights: torch.Size):

if use_flux:
gemm_rs_op = flux.GemmRS(
get_tp_group().device_group,
1, # One node
8192, # Max M. TODO: Pass in correctly.
gemm_1_weights[0], # N
# TODO: Pass in input dtype correctly.
# TODO: It would be nicer to modify flux to dispatch based on dtype
# at run time, but I don't know what the downside would be.
# Similar comment for max m.
torch.float16,
# Note: transpose_weight=False means that B is transposed
transpose_weight=False,
# Note: bfloat16 requires fuse_reduction=False.
fuse_reduction=False,
)

ag_gemm_op = flux.AGKernel(
get_tp_group().device_group,
1, # One node
8192, # Max M. TODO: Pass in correctly.
gemm_2_weights[0], # N
gemm_2_weights[1], # K
# TODO: Pass in input dtype correctly.
# TODO: It would be nicer to modify flux to dispatch based on dtype
# at run time, but I don't know what the downside would be.
# Similar comment for max m.
torch.float16,
torch.float16,
# Note: transpose_weight=False means that B is transposed
transpose_weight=False,
# Note: if local_copy=True, I hit the following runtime error:
# /flux/src/all_gather/ths_op/all_gather_gemm_kernel.cc:648
# Check failed: 33554432((input.numel() * input.element_size()))
# == 139836453421056((this->chunk_size))
local_copy=False,
)

gemm_rs = lambda act, wt: gemm_rs_op.forward(act, wt).squeeze(0)
ag_gemm = lambda act, wt: ag_gemm_op.forward(act, wt)

name = f"gemm_rs_ag_gemm_{gemm_1_weights[0]}_{gemm_2_weights[0]}_{gemm_2_weights[1]}"
else:
group_name = get_world_name()

gemm_rs = lambda act, wt: torch.ops.symm_mem.fused_matmul_reduce_scatter.default(
act, wt.transpose(1,0), 'avg', 0, group_name)

ag_gemm = lambda act, wt: torch.ops.symm_mem.fused_all_gather_matmul.default(
act, [wt.transpose(1,0)], 0, group_name)[1]

name = "gemm_rs_ag_gemm"

def gemm_rs_ag_gemm(
residual: torch.Tensor, old_my_residual: torch.Tensor,
gemm_1_weights: torch.Tensor, gemm_1_activations: torch.Tensor,
rms_norm_weight: torch.Tensor, gemm_2_weights: torch.Tensor,
first_layer: bool) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

if should_slice(residual.shape) and first_layer:
res_slices = slice_residual(residual)
slice_size = res_slices[get_tensor_model_parallel_rank()].shape[0]
residual_chunk = torch.ops.aten.split.Tensor(residual, slice_size)
my_residual = residual_chunk[0]
else:
my_residual = residual
slice_size = residual.shape[0]

if not should_slice(residual.shape):
output = torch.matmul(gemm_1_activations, gemm_1_weights.transpose(1,0))
reduced_output = tensor_model_parallel_all_reduce(output)

norm_res = torch.ops.higher_order.auto_functionalized(
torch.ops._C.fused_add_rms_norm.default,
input=reduced_output,
residual=my_residual,
weight=rms_norm_weight,
epsilon=1e-05)
normalized = norm_res[1]
new_residual = norm_res[2]

mm_2 = torch.matmul(normalized, gemm_2_weights.transpose(1,0))
return mm_2, new_residual, new_residual.clone()
else:
group_name = get_world_name()
output = gemm_rs(gemm_1_activations, gemm_1_weights)

norm_res = torch.ops.higher_order.auto_functionalized(
torch.ops._C.fused_add_rms_norm.default,
input=output,
residual=my_residual,
weight=rms_norm_weight,
epsilon=1e-05)
normalized = norm_res[1]
new_residual = norm_res[2]

residual_1 = residual if first_layer else old_my_residual
slice_scatter = torch.ops.aten.slice_scatter.default(
residual_1, new_residual, 0, 0, slice_size)
split_2 = torch.ops.aten.split.Tensor(slice_scatter, slice_size)
new_residual = split_2[0]

mm_2 = ag_gemm(normalized, gemm_2_weights)

# TODO: can we avoid clone here?
return mm_2[0], new_residual.clone(), slice_scatter

if not hasattr(torch.ops.vllm, name):
logger.info("registering torch.ops.vllm.%s", name)
torch.library.custom_op(f"vllm::{name}", gemm_rs_ag_gemm, mutates_args=())
torch.library.register_fake(f"vllm::{name}", gemm_rs_ag_gemm_fake)
assert getattr(torch.ops.vllm, name)

return getattr(torch.ops.vllm, name).default


def match_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor,
gemm_1_activations: torch.Tensor,
rms_norm_weights: torch.Tensor) -> torch.Tensor:
Expand All @@ -181,7 +311,8 @@ def match_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor,

return normalized


# Register this as a custom op since all reduce cannot be torch.compiled.
#@torch.library.custom_op("vllm::gemm_ag_final", mutates_args=())
def replace_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor,
gemm_1_activations: torch.Tensor,
rms_norm_weights: torch.Tensor) -> torch.Tensor:
Expand All @@ -195,12 +326,15 @@ def replace_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor,

# is this the right thing to call it on?
if should_slice(gemm_1_activations.shape):
group_name = get_world_name()
world_size = get_tensor_model_parallel_world_size()
all_gather = torch.ops._c10d_functional.all_gather_into_tensor.default(
my_residual, world_size, group_name)
wait_tensor = torch.ops._c10d_functional.wait_tensor.default(
all_gather)
if True:
group_name = get_world_name()
world_size = get_tensor_model_parallel_world_size()
all_gather = torch.ops._c10d_functional.all_gather_into_tensor.default(
my_residual, world_size, group_name)
wait_tensor = torch.ops._c10d_functional.wait_tensor.default(
all_gather)
else:
wait_tensor = tensor_model_parallel_all_gather(my_residual)
else:
wait_tensor = my_residual

Expand All @@ -214,6 +348,15 @@ def replace_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor,
return norm_res[1]


#@torch.library.register_fake("vllm::gemm_ag_final")
def replace_final_fake(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor,
gemm_1_activations: torch.Tensor,
rms_norm_weights: torch.Tensor) -> torch.Tensor:
return torch.empty([gemm_1_activations.shape[0],
my_residual.shape[1]],
dtype=my_residual.dtype, device=my_residual.device)


class CollectiveFusionPass(InductorPass):

def __init__(self):
Expand Down Expand Up @@ -273,8 +416,14 @@ def find_min_index(match: Match) -> int:
res_replacements) > 0 else match.kwargs["residual"]
kwargs["old_my_residual"] = my_res_replacements[-1] if len(
my_res_replacements) > 0 else match.kwargs["residual"]

# TODO: use get
gemm_1_w = kwargs["gemm_1_weights"].meta["val"].shape
gemm_2_w = kwargs["gemm_2_weights"].meta["val"].shape

fused_node = graph.call_function(
torch.ops.vllm.gemm_rs_ag_gemm.default, kwargs=kwargs)
get_gemm_rs_ag_gemm(use_flux, gemm_1_w, gemm_2_w),
kwargs=kwargs)

graph.inserting_after(fused_node)
result_node_new = graph.call_function(operator.getitem,
Expand Down
20 changes: 20 additions & 0 deletions vllm/distributed/device_communicators/pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,26 @@ def all_reduce(self,
ncclRedOpTypeEnum.from_torch(op), self.comm,
cudaStream_t(stream.cuda_stream))

def all_gather(self,
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
stream=None):
if self.disabled:
return
# nccl communicator created on a specific device
# will only work on tensors on the same device
# otherwise it will cause "illegal memory access"
assert input_tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {input_tensor.device}")
if stream is None:
stream = self.stream
self.nccl.ncclAllGather(
buffer_type(input_tensor.data_ptr()),
buffer_type(output_tensor.data_ptr()), input_tensor.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype), self.comm,
cudaStream_t(stream.cuda_stream))

def send(self, tensor: torch.Tensor, dst: int, stream=None):
if self.disabled:
return
Expand Down
20 changes: 20 additions & 0 deletions vllm/distributed/device_communicators/pynccl_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,17 @@ class NCCLLibrary:
ncclRedOp_t, ncclComm_t, cudaStream_t
]),

# ncclResult_t ncclAllGather(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclComm_t comm,
# cudaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument
# is a pointer
Function("ncclAllGather", ncclResult_t, [
buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t,
ncclComm_t, cudaStream_t
]),

# ncclResult_t ncclSend(
# const void* sendbuff, size_t count, ncclDataType_t datatype,
# int dest, ncclComm_t comm, cudaStream_t stream);
Expand Down Expand Up @@ -271,6 +282,15 @@ def ncclRecv(self, recvbuff: buffer_type, count: int, datatype: int,
def ncclCommDestroy(self, comm: ncclComm_t) -> None:
self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))

def ncclAllGather(self, sendbuff: buffer_type, recvbuff: buffer_type,
count: int, datatype: int, comm: ncclComm_t,
stream: cudaStream_t) -> None:
# `datatype` actually should be `ncclDataType_t`
# which is an aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self.NCCL_CHECK(self._funcs["ncclAllGather"](sendbuff, recvbuff, count,
datatype, comm, stream))

__all__ = [
"NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId",
Expand Down
14 changes: 8 additions & 6 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,6 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from unittest.mock import patch

try:
import flux
has_flux = True
except ImportError:
has_flux = False

import torch
import torch.distributed
from torch.distributed import Backend, ProcessGroup, _symmetric_memory
Expand All @@ -45,6 +39,14 @@
from vllm.platforms import current_platform
from vllm.utils import supports_custom_op

has_flux = False
if envs.VLLM_USE_FLUX:
try:
import flux
has_flux = True
except ImportError:
has_flux = False


@dataclass
class GraphCaptureContext:
Expand Down
5 changes: 5 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
VLLM_CUSTOM_OPS: List[str] = []
VLLM_DISABLED_KERNELS: List[str] = []
VLLM_USE_V1: bool = False
VLLM_USE_FLUX: bool = False


def get_default_cache_root():
Expand Down Expand Up @@ -465,6 +466,10 @@ def get_default_config_root():
# If set, use the V1 code path.
"VLLM_USE_V1":
lambda: bool(int(os.getenv("VLLM_USE_V1", "0"))),

# If set, try to use the flux fused collective comminucation gemm kernels
"VLLM_USE_FLUX":
lambda: bool(int(os.getenv("VLLM_USE_FLUX", "0"))),
}

# end-env-vars-definition
Expand Down

0 comments on commit 31be8c7

Please sign in to comment.