Skip to content

Commit

Permalink
[Model] Pipeline parallel support for Mixtral (#6516)
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac authored Jul 18, 2024
1 parent b5241e4 commit b5af8c2
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 19 deletions.
17 changes: 11 additions & 6 deletions tests/distributed/test_pipeline_parallel.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
from transformers import AutoTokenizer

from ..utils import RemoteOpenAIServer

Expand All @@ -12,6 +13,8 @@
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B"),
])
def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME):
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

pp_args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
Expand All @@ -34,7 +37,7 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME):
"--dtype",
"bfloat16",
"--tensor-parallel-size",
str(max(TP_SIZE, 2)), # use at least TP_SIZE=2 to hold the model
str(max(TP_SIZE, 2)), # We only use 2 GPUs in the CI.
"--distributed-executor-backend",
"mp",
]
Expand All @@ -45,8 +48,10 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME):
pp_args.append("--enforce-eager")
tp_args.append("--enforce-eager")

prompt = "Hello, my name is"
token_ids = tokenizer(prompt)["input_ids"]
results = []
for args in [pp_args, tp_args]:
for args in (pp_args, tp_args):
with RemoteOpenAIServer(MODEL_NAME, args) as server:
client = server.get_client()

Expand All @@ -62,7 +67,7 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME):

# test with text prompt
completion = client.completions.create(model=MODEL_NAME,
prompt="Hello, my name is",
prompt=prompt,
max_tokens=5,
temperature=0.0)

Expand All @@ -76,7 +81,7 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME):
# test using token IDs
completion = client.completions.create(
model=MODEL_NAME,
prompt=[0, 0, 0, 0, 0],
prompt=token_ids,
max_tokens=5,
temperature=0.0,
)
Expand All @@ -91,7 +96,7 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME):
# test simple list
batch = client.completions.create(
model=MODEL_NAME,
prompt=["Hello, my name is", "Hello, my name is"],
prompt=[prompt, prompt],
max_tokens=5,
temperature=0.0,
)
Expand All @@ -105,7 +110,7 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME):
# test streaming
batch = client.completions.create(
model=MODEL_NAME,
prompt=["Hello, my name is", "Hello, my name is"],
prompt=[prompt, prompt],
max_tokens=5,
temperature=0.0,
stream=True,
Expand Down
1 change: 1 addition & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"MistralForCausalLM",
"Phi3ForCausalLM",
"GPT2LMHeadModel",
"MixtralForCausalLM",
]


Expand Down
61 changes: 48 additions & 13 deletions vllm/model_executor/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
Expand All @@ -48,6 +48,7 @@
from vllm.sequence import IntermediateTensors, SamplerOutput

from .interfaces import SupportsLoRA
from .utils import is_pp_missing_parameter, make_layers


class MixtralMoE(nn.Module):
Expand Down Expand Up @@ -255,12 +256,11 @@ def __init__(
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
self.layers = nn.ModuleList([
MixtralDecoderLayer(config,
cache_config,
quant_config=quant_config)
for _ in range(config.num_hidden_layers)
])

self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, lambda: MixtralDecoderLayer(
config, cache_config, quant_config=quant_config))

self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

def forward(
Expand All @@ -269,14 +269,25 @@ def forward(
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
for i in range(len(self.layers)):
if get_pp_group().is_first_rank:
hidden_states = self.embed_tokens(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states,
kv_caches[i], attn_metadata,
residual)
kv_caches[i - self.start_layer],
attn_metadata, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states

Expand Down Expand Up @@ -347,7 +358,7 @@ def forward(
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
attn_metadata, intermediate_tensors)
return hidden_states

def compute_logits(self, hidden_states: torch.Tensor,
Expand All @@ -356,6 +367,20 @@ def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata)
return logits

def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
"residual":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
})

def sample(
self,
logits: Optional[torch.Tensor],
Expand Down Expand Up @@ -392,6 +417,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue

param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
Expand All @@ -402,6 +431,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
Expand All @@ -414,6 +446,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
Expand Down

0 comments on commit b5af8c2

Please sign in to comment.