Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Prototype integration of bytedance/flux kernels #5917

Closed
wants to merge 11 commits into from
5 changes: 3 additions & 2 deletions examples/offline_inference.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from vllm import LLM, SamplingParams
import torch

# Sample prompts.
prompts = [
Expand All @@ -8,10 +9,10 @@
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
sampling_params = SamplingParams(temperature=0.0, top_p=0.95)

# Create an LLM.
llm = LLM(model="facebook/opt-125m")
llm = LLM(model="meta-llama/Meta-Llama-3-8b", tensor_parallel_size=2, enforce_eager=True, dtype=torch.float16)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
Expand Down
17 changes: 17 additions & 0 deletions flux_env.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#Point to the directory containing the flux .so files:
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/nm-vllm/flux_experiment/lib

export NVSHMEM_BOOTSTRAP_MPI_PLUGIN=nvshmem_bootstrap_torch.so

# Env variables for symmetric heap allocation.
# These are needed for supporting CUDA_VISIBLE DEVICES
# This is big enough for llama3 8b, but should be set correctly
export NVSHMEM_SYMMETRIC_SIZE=$((8*1024**3))
export NVSHMEM_DISABLE_CUDA_VMM=1 # moving from cpp to shell

# Not sure if these are needed
export CUDA_DEVICE_MAX_CONNECTIONS=1
export BYTED_TORCH_BYTECCL=O0
export NCCL_IB_TIMEOUT=${NCCL_IB_TIMEOUT:=23}
export NCCL_IB_GID_INDEX=${NCCL_IB_GID_INDEX:=3}
export NVSHMEM_IB_GID_INDEX=3
5 changes: 5 additions & 0 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from unittest.mock import patch

import flux
import torch
import torch.distributed
from torch.distributed import Backend, ProcessGroup
Expand Down Expand Up @@ -200,6 +201,10 @@ def __init__(
self.use_custom_allreduce = use_custom_allreduce
self.use_tpu_communicator = use_tpu_communicator

# Initialize pynvshmem
if torch.distributed.get_world_size(self.device_group) > 1:
flux.init_flux_shm(self.device_group)

# lazy import to avoid documentation build error
from vllm.distributed.device_communicators.custom_all_reduce import (
CustomAllreduce)
Expand Down
152 changes: 139 additions & 13 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import abstractmethod
from typing import Dict, List, Optional, Tuple

import flux
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter, UninitializedParameter
Expand All @@ -10,6 +11,7 @@
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce)
from vllm.distributed.parallel_state import get_tp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
Expand Down Expand Up @@ -135,6 +137,104 @@ def apply(self,
return F.linear(x, layer.weight, bias)


class GemmRS(LinearMethodBase):
#Fused Gemm-ReduceScatter without quantization.

def __init__(self, separate_bias_add: bool = False):
self.separate_bias_add = separate_bias_add

def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype),
requires_grad=False)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs)

self.gemm_rs_op = flux.GemmRS(
get_tp_group().device_group,
1, # One node
8192, # Max M. TODO: Pass in correctly.
output_size, # N
# TODO: Pass in input dtype correctly.
# TODO: It would be nicer to modify flux to dispatch based on dtype
# at run time, but I don't know what the downside would be.
# Similar comment for max m.
torch.float16,
# Note: transpose_weight=False means that B is transposed
transpose_weight=False,
# Note: bfloat16 requires fuse_reduction=False.
fuse_reduction=False,
)

def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
assert bias is None

output = self.gemm_rs_op.forward(x, layer.weight)
output = output.squeeze(0)

return output


class AGCook(LinearMethodBase):
#Fused AllGather-Gemm without quantization.

def __init__(self, separate_bias_add: bool = False):
self.separate_bias_add = separate_bias_add

def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype),
requires_grad=False)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs)

self.ag_gemm_op = flux.AGKernel(
get_tp_group().device_group,
1, # One node
8192, # Max M. TODO: Pass in correctly.
weight.shape[0], # N
weight.shape[1], # K
# TODO: Pass in input dtype correctly.
# TODO: It would be nicer to modify flux to dispatch based on dtype
# at run time, but I don't know what the downside would be.
# Similar comment for max m.
torch.float16,
torch.float16,
# Note: transpose_weight=False means that B is transposed
transpose_weight=False,
# Note: if local_copy=True, I hit the following runtime error:
# /flux/src/all_gather/ths_op/all_gather_gemm_kernel.cc:648
# Check failed: 33554432((input.numel() * input.element_size()))
# == 139836453421056((this->chunk_size))
local_copy=False,
)

def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
assert bias is None

output = self.ag_gemm_op.forward(x, layer.weight)

return output


class LinearBase(torch.nn.Module):
"""Base linear layer.

Expand All @@ -155,6 +255,8 @@ def __init__(
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
fuse_gemm_rs: bool = False,
fuse_ag_gemm: bool = False,
):
super().__init__()

Expand All @@ -165,9 +267,15 @@ def __init__(
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
if quant_config is None:
self.quant_method: Optional[
QuantizeMethodBase] = UnquantizedLinearMethod()

if fuse_gemm_rs:
assert (quant_config is None)
self.quant_method: Optional[QuantizeMethodBase] = GemmRS()
elif fuse_ag_gemm:
assert (quant_config is None)
self.quant_method = AGCook()
elif quant_config is None:
self.quant_method = UnquantizedLinearMethod()
else:
self.quant_method = quant_config.get_quant_method(self,
prefix=prefix)
Expand Down Expand Up @@ -280,9 +388,15 @@ def __init__(self,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
output_sizes: Optional[List[int]] = None,
prefix: str = ""):
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config, prefix)
prefix: str = "",
fuse_ag_gemm: bool = False):
super().__init__(input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix,
fuse_ag_gemm=fuse_ag_gemm)

self.gather_output = gather_output

Expand Down Expand Up @@ -413,7 +527,8 @@ def __init__(self,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
prefix: str = "",
fuse_ag_gemm: bool = False):
self.output_sizes = output_sizes
tp_size = get_tensor_model_parallel_world_size()
assert all(output_size % tp_size == 0 for output_size in output_sizes)
Expand All @@ -424,7 +539,8 @@ def __init__(self,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix)
prefix=prefix,
fuse_ag_gemm=fuse_ag_gemm)

def weight_loader(self,
param: Parameter,
Expand Down Expand Up @@ -654,7 +770,8 @@ def __init__(self,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
prefix: str = "",
fuse_ag_gemm: bool = False):
self.hidden_size = hidden_size
self.head_size = head_size
self.total_num_heads = total_num_heads
Expand Down Expand Up @@ -687,7 +804,8 @@ def __init__(self,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix)
prefix=prefix,
fuse_ag_gemm=fuse_ag_gemm)

def _get_shard_offset_mapping(self, loaded_shard_id: str):
shard_offset_mapping = {
Expand Down Expand Up @@ -967,12 +1085,20 @@ def __init__(self,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config, prefix)
prefix: str = "",
fuse_gemm_rs: bool = False):
super().__init__(input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix,
fuse_gemm_rs=fuse_gemm_rs)

self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results
if fuse_gemm_rs:
self.reduce_results = False

# Divide the weight matrix along the last dimension.
self.tp_rank = get_tensor_model_parallel_rank()
Expand Down
Loading
Loading