diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 123a77e14ad74..d7e640ce96995 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -1,4 +1,5 @@ import pytest +from transformers import AutoTokenizer from ..utils import RemoteOpenAIServer @@ -12,6 +13,8 @@ (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B"), ]) def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME): + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + pp_args = [ # use half precision for speed and memory savings in CI environment "--dtype", @@ -34,7 +37,7 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME): "--dtype", "bfloat16", "--tensor-parallel-size", - str(max(TP_SIZE, 2)), # use at least TP_SIZE=2 to hold the model + str(max(TP_SIZE, 2)), # We only use 2 GPUs in the CI. "--distributed-executor-backend", "mp", ] @@ -45,8 +48,10 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME): pp_args.append("--enforce-eager") tp_args.append("--enforce-eager") + prompt = "Hello, my name is" + token_ids = tokenizer(prompt)["input_ids"] results = [] - for args in [pp_args, tp_args]: + for args in (pp_args, tp_args): with RemoteOpenAIServer(MODEL_NAME, args) as server: client = server.get_client() @@ -62,7 +67,7 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME): # test with text prompt completion = client.completions.create(model=MODEL_NAME, - prompt="Hello, my name is", + prompt=prompt, max_tokens=5, temperature=0.0) @@ -76,7 +81,7 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME): # test using token IDs completion = client.completions.create( model=MODEL_NAME, - prompt=[0, 0, 0, 0, 0], + prompt=token_ids, max_tokens=5, temperature=0.0, ) @@ -91,7 +96,7 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME): # test simple list batch = client.completions.create( model=MODEL_NAME, - prompt=["Hello, my name is", "Hello, my name is"], + prompt=[prompt, prompt], max_tokens=5, temperature=0.0, ) @@ -105,7 +110,7 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME): # test streaming batch = client.completions.create( model=MODEL_NAME, - prompt=["Hello, my name is", "Hello, my name is"], + prompt=[prompt, prompt], max_tokens=5, temperature=0.0, stream=True, diff --git a/vllm/config.py b/vllm/config.py index de7bb3943a45f..a20e830955671 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -34,6 +34,7 @@ "MistralForCausalLM", "Phi3ForCausalLM", "GPT2LMHeadModel", + "MixtralForCausalLM", ] diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index e739df87cf96a..28dbcb30bdf55 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -29,7 +29,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (QKVParallelLinear, @@ -48,6 +48,7 @@ from vllm.sequence import IntermediateTensors, SamplerOutput from .interfaces import SupportsLoRA +from .utils import is_pp_missing_parameter, make_layers class MixtralMoE(nn.Module): @@ -255,12 +256,11 @@ def __init__( config.hidden_size, org_num_embeddings=config.vocab_size, ) - self.layers = nn.ModuleList([ - MixtralDecoderLayer(config, - cache_config, - quant_config=quant_config) - for _ in range(config.num_hidden_layers) - ]) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, lambda: MixtralDecoderLayer( + config, cache_config, quant_config=quant_config)) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( @@ -269,14 +269,25 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) - residual = None - for i in range(len(self.layers)): + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer(positions, hidden_states, - kv_caches[i], attn_metadata, - residual) + kv_caches[i - self.start_layer], + attn_metadata, residual) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -347,7 +358,7 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, @@ -356,6 +367,20 @@ def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata) return logits + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + "residual": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) + def sample( self, logits: Optional[torch.Tensor], @@ -392,6 +417,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -402,6 +431,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if weight_name not in name: continue name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, @@ -414,6 +446,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue # Remapping the name of FP8 kv-scale. name = maybe_remap_kv_scale_name(name, params_dict) if name is None: