diff --git a/examples/offline_inference.py b/examples/offline_inference.py index 9b758fa2479f6..9f98a80bc7f80 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -1,4 +1,5 @@ from vllm import LLM, SamplingParams +import torch # Sample prompts. prompts = [ @@ -8,10 +9,10 @@ "The future of AI is", ] # Create a sampling params object. -sampling_params = SamplingParams(temperature=0.8, top_p=0.95) +sampling_params = SamplingParams(temperature=0.0, top_p=0.95) # Create an LLM. -llm = LLM(model="facebook/opt-125m") +llm = LLM(model="meta-llama/Meta-Llama-3-8b", tensor_parallel_size=2, enforce_eager=True, dtype=torch.float16) # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) diff --git a/flux_env.sh b/flux_env.sh new file mode 100644 index 0000000000000..8979ce0858d0c --- /dev/null +++ b/flux_env.sh @@ -0,0 +1,17 @@ +#Point to the directory containing the flux .so files: +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/nm-vllm/flux_experiment/lib + +export NVSHMEM_BOOTSTRAP_MPI_PLUGIN=nvshmem_bootstrap_torch.so + +# Env variables for symmetric heap allocation. +# These are needed for supporting CUDA_VISIBLE DEVICES +# This is big enough for llama3 8b, but should be set correctly +export NVSHMEM_SYMMETRIC_SIZE=$((8*1024**3)) +export NVSHMEM_DISABLE_CUDA_VMM=1 # moving from cpp to shell + +# Not sure if these are needed +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export BYTED_TORCH_BYTECCL=O0 +export NCCL_IB_TIMEOUT=${NCCL_IB_TIMEOUT:=23} +export NCCL_IB_GID_INDEX=${NCCL_IB_GID_INDEX:=3} +export NVSHMEM_IB_GID_INDEX=3 diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 1c864bcd5d708..22041806cc12f 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -29,6 +29,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union from unittest.mock import patch +import flux import torch import torch.distributed from torch.distributed import Backend, ProcessGroup @@ -200,6 +201,10 @@ def __init__( self.use_custom_allreduce = use_custom_allreduce self.use_tpu_communicator = use_tpu_communicator + # Initialize pynvshmem + if torch.distributed.get_world_size(self.device_group) > 1: + flux.init_flux_shm(self.device_group) + # lazy import to avoid documentation build error from vllm.distributed.device_communicators.custom_all_reduce import ( CustomAllreduce) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 568892778abe2..4ad93890b66ef 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -1,6 +1,7 @@ from abc import abstractmethod from typing import Dict, List, Optional, Tuple +import flux import torch import torch.nn.functional as F from torch.nn.parameter import Parameter, UninitializedParameter @@ -10,6 +11,7 @@ split_tensor_along_last_dim, tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) +from vllm.distributed.parallel_state import get_tp_group from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) @@ -135,6 +137,104 @@ def apply(self, return F.linear(x, layer.weight, bias) +class GemmRS(LinearMethodBase): + #Fused Gemm-ReduceScatter without quantization. + + def __init__(self, separate_bias_add: bool = False): + self.separate_bias_add = separate_bias_add + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + weight = Parameter(torch.empty(sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype), + requires_grad=False) + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + layer.register_parameter("weight", weight) + set_weight_attrs(weight, extra_weight_attrs) + + self.gemm_rs_op = flux.GemmRS( + get_tp_group().device_group, + 1, # One node + 8192, # Max M. TODO: Pass in correctly. + output_size, # N + # TODO: Pass in input dtype correctly. + # TODO: It would be nicer to modify flux to dispatch based on dtype + # at run time, but I don't know what the downside would be. + # Similar comment for max m. + torch.float16, + # Note: transpose_weight=False means that B is transposed + transpose_weight=False, + # Note: bfloat16 requires fuse_reduction=False. + fuse_reduction=False, + ) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + assert bias is None + + output = self.gemm_rs_op.forward(x, layer.weight) + output = output.squeeze(0) + + return output + + +class AGCook(LinearMethodBase): + #Fused AllGather-Gemm without quantization. + + def __init__(self, separate_bias_add: bool = False): + self.separate_bias_add = separate_bias_add + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + weight = Parameter(torch.empty(sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype), + requires_grad=False) + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + layer.register_parameter("weight", weight) + set_weight_attrs(weight, extra_weight_attrs) + + self.ag_gemm_op = flux.AGKernel( + get_tp_group().device_group, + 1, # One node + 8192, # Max M. TODO: Pass in correctly. + weight.shape[0], # N + weight.shape[1], # K + # TODO: Pass in input dtype correctly. + # TODO: It would be nicer to modify flux to dispatch based on dtype + # at run time, but I don't know what the downside would be. + # Similar comment for max m. + torch.float16, + torch.float16, + # Note: transpose_weight=False means that B is transposed + transpose_weight=False, + # Note: if local_copy=True, I hit the following runtime error: + # /flux/src/all_gather/ths_op/all_gather_gemm_kernel.cc:648 + # Check failed: 33554432((input.numel() * input.element_size())) + # == 139836453421056((this->chunk_size)) + local_copy=False, + ) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + assert bias is None + + output = self.ag_gemm_op.forward(x, layer.weight) + + return output + + class LinearBase(torch.nn.Module): """Base linear layer. @@ -155,6 +255,8 @@ def __init__( params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + fuse_gemm_rs: bool = False, + fuse_ag_gemm: bool = False, ): super().__init__() @@ -165,9 +267,15 @@ def __init__( if params_dtype is None: params_dtype = torch.get_default_dtype() self.params_dtype = params_dtype - if quant_config is None: - self.quant_method: Optional[ - QuantizeMethodBase] = UnquantizedLinearMethod() + + if fuse_gemm_rs: + assert (quant_config is None) + self.quant_method: Optional[QuantizeMethodBase] = GemmRS() + elif fuse_ag_gemm: + assert (quant_config is None) + self.quant_method = AGCook() + elif quant_config is None: + self.quant_method = UnquantizedLinearMethod() else: self.quant_method = quant_config.get_quant_method(self, prefix=prefix) @@ -280,9 +388,15 @@ def __init__(self, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, output_sizes: Optional[List[int]] = None, - prefix: str = ""): - super().__init__(input_size, output_size, skip_bias_add, params_dtype, - quant_config, prefix) + prefix: str = "", + fuse_ag_gemm: bool = False): + super().__init__(input_size, + output_size, + skip_bias_add, + params_dtype, + quant_config, + prefix, + fuse_ag_gemm=fuse_ag_gemm) self.gather_output = gather_output @@ -413,7 +527,8 @@ def __init__(self, skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + prefix: str = "", + fuse_ag_gemm: bool = False): self.output_sizes = output_sizes tp_size = get_tensor_model_parallel_world_size() assert all(output_size % tp_size == 0 for output_size in output_sizes) @@ -424,7 +539,8 @@ def __init__(self, skip_bias_add=skip_bias_add, params_dtype=params_dtype, quant_config=quant_config, - prefix=prefix) + prefix=prefix, + fuse_ag_gemm=fuse_ag_gemm) def weight_loader(self, param: Parameter, @@ -654,7 +770,8 @@ def __init__(self, skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + prefix: str = "", + fuse_ag_gemm: bool = False): self.hidden_size = hidden_size self.head_size = head_size self.total_num_heads = total_num_heads @@ -687,7 +804,8 @@ def __init__(self, skip_bias_add=skip_bias_add, params_dtype=params_dtype, quant_config=quant_config, - prefix=prefix) + prefix=prefix, + fuse_ag_gemm=fuse_ag_gemm) def _get_shard_offset_mapping(self, loaded_shard_id: str): shard_offset_mapping = { @@ -967,12 +1085,20 @@ def __init__(self, params_dtype: Optional[torch.dtype] = None, reduce_results: bool = True, quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): - super().__init__(input_size, output_size, skip_bias_add, params_dtype, - quant_config, prefix) + prefix: str = "", + fuse_gemm_rs: bool = False): + super().__init__(input_size, + output_size, + skip_bias_add, + params_dtype, + quant_config, + prefix, + fuse_gemm_rs=fuse_gemm_rs) self.input_is_parallel = input_is_parallel self.reduce_results = reduce_results + if fuse_gemm_rs: + self.reduce_results = False # Divide the weight matrix along the last dimension. self.tp_rank = get_tensor_model_parallel_rank() diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 5ff31e3833ec9..4b56fc72d82ea 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -30,7 +30,8 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -65,6 +66,7 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, bias: bool = False, prefix: str = "", + last_layer: bool = False, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( @@ -72,12 +74,16 @@ def __init__( output_sizes=[intermediate_size] * 2, bias=bias, quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj") + prefix=f"{prefix}.gate_up_proj", + fuse_ag_gemm=True) + self.down_proj = RowParallelLinear(input_size=intermediate_size, output_size=hidden_size, bias=bias, quant_config=quant_config, - prefix=f"{prefix}.down_proj") + prefix=f"{prefix}.down_proj", + fuse_gemm_rs=(not last_layer)) + if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -98,6 +104,7 @@ def __init__( hidden_size: int, num_heads: int, num_kv_heads: int, + first_layer: bool, rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, @@ -139,6 +146,7 @@ def __init__( bias=bias, quant_config=quant_config, prefix=f"{prefix}.qkv_proj", + fuse_ag_gemm=(not first_layer), ) self.o_proj = RowParallelLinear( @@ -147,6 +155,7 @@ def __init__( bias=bias, quant_config=quant_config, prefix=f"{prefix}.o_proj", + fuse_gemm_rs=True, ) is_neox_style = True @@ -188,6 +197,11 @@ class LlamaDecoderLayer(nn.Module): def __init__( self, config: LlamaConfig, + # Hack: pass in whether this is the first/last layer + # so we know if we can rewrite AllReduce -> ReduceScatter + AllGather, + # and then propagate the AllGather to the next layer. + first_layer: bool, + last_layer: bool, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -212,6 +226,7 @@ def __init__( num_heads=config.num_attention_heads, num_kv_heads=getattr(config, "num_key_value_heads", config.num_attention_heads), + first_layer=first_layer, rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, @@ -227,12 +242,16 @@ def __init__( quant_config=quant_config, bias=getattr(config, "mlp_bias", False), prefix=f"{prefix}.mlp", + last_layer=last_layer, ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.first_layer = first_layer + self.last_layer = last_layer + def forward( self, positions: torch.Tensor, @@ -246,8 +265,18 @@ def forward( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: + assert (hidden_states.shape == residual.shape) hidden_states, residual = self.input_layernorm( hidden_states, residual) + + # Partition residual + if self.first_layer: + n_slices = get_tensor_model_parallel_world_size() + residual_slices = torch.chunk(residual, n_slices, dim=0) + my_residual = residual_slices[get_tensor_model_parallel_rank()] + else: + my_residual = residual + hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, @@ -256,9 +285,17 @@ def forward( ) # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + assert (hidden_states.shape == my_residual.shape) + hidden_states, my_residual = self.post_attention_layernorm( + hidden_states, my_residual) hidden_states = self.mlp(hidden_states) + + if self.last_layer: + residual = tensor_model_parallel_all_gather(my_residual, 0) + else: + residual = my_residual + + assert (hidden_states.shape == residual.shape) return hidden_states, residual @@ -291,10 +328,13 @@ def __init__( self.embed_tokens = PPMissingLayer() self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: LlamaDecoderLayer(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix), + lambda prefix, first_layer, last_layer: LlamaDecoderLayer( + config=config, + first_layer=first_layer, + last_layer=last_layer, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), prefix=f"{prefix}.layers") if get_pp_group().is_last_rank: self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 8b80dda96db49..f7cd071a7e6fa 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -239,14 +239,33 @@ def make_layers( """Make a list of layers with the given layer function, taking pipeline parallelism into account. """ + import inspect + from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.utils import get_pp_indices + start_layer, end_layer = get_pp_indices(num_hidden_layers, get_pp_group().rank_in_group, get_pp_group().world_size) + + # Determine if layer_fn accepts first/last args by inspecting its signature + sig = inspect.signature(layer_fn) + has_firstlast_args = ('first_layer' + in sig.parameters) and ('last_layer' + in sig.parameters) + + def make_one_layer(idx, start_layer, end_layer): + if has_firstlast_args: + return maybe_offload_to_cpu( + layer_fn(prefix=f"{prefix}.{idx}", + first_layer=(idx == start_layer), + last_layer=(idx == end_layer - 1))) + else: + return maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}")) + modules = torch.nn.ModuleList( [PPMissingLayer() for _ in range(start_layer)] + [ - maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}")) + make_one_layer(idx, start_layer, end_layer) for idx in range(start_layer, end_layer) ] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)]) return start_layer, end_layer, modules