Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] PP support for Mamba-like models #10992

Merged
merged 31 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
90d0ff3
Add PP support to Jamba
mzusman Dec 8, 2024
d30f741
GPU Runner adaptations
mzusman Dec 8, 2024
47e961b
Add tests and adaptations
mzusman Dec 8, 2024
e7e3ea4
Add PP to mamba
mzusman Dec 8, 2024
ebbf7e3
Introduce LayerBlockType
mzusman Dec 8, 2024
5514aa6
Fix jamba
mzusman Dec 8, 2024
f9fcff4
Fix MS for Mamba
mzusman Dec 8, 2024
2b5b525
Format
mzusman Dec 8, 2024
01c7d93
Format
mzusman Dec 8, 2024
50f2a8c
Format
mzusman Dec 8, 2024
314c9af
Remove comment
mzusman Dec 8, 2024
3347f3f
Add `is_pp_missing_parameter` to mamba weight loading
mzusman Dec 8, 2024
f21ff9a
Revert Mamba MS fix, will fix it in future PR
mzusman Dec 8, 2024
ff6e13e
Addressing review's comments
mzusman Dec 9, 2024
06ab64d
Add mamba/jamba PP support to supported_models page
mzusman Dec 9, 2024
46e3c25
Use only needed number layers per PP
mzusman Dec 9, 2024
f354468
Fix `get_num_layers_by_block_type` to always return num of layers
mzusman Dec 9, 2024
529ade6
Format
mzusman Dec 9, 2024
e7f6338
Add comment for workaround
mzusman Dec 9, 2024
6d863ae
Fix comment
mzusman Dec 9, 2024
d9ef20d
Add IsHybrid interface to return 0 layers on non-hybrid models asking
mzusman Dec 9, 2024
029d710
Typo
mzusman Dec 9, 2024
fbf4ccd
Merge remote-tracking branch 'github/main' into mamba_jamba_pp
mzusman Dec 9, 2024
3bc3823
Format
mzusman Dec 9, 2024
0f70131
Fix mamba logic
mzusman Dec 9, 2024
704f29a
Merge remote-tracking branch 'github/main' into mamba_jamba_pp
mzusman Dec 10, 2024
fcce742
Small fix to logic
mzusman Dec 10, 2024
e3d8343
FIx logic
mzusman Dec 10, 2024
6d6b3d4
Fixing according to comment
mzusman Dec 10, 2024
c034ffe
Format
mzusman Dec 10, 2024
75ff1ef
Fix comment
mzusman Dec 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ Text Generation
- FalconMamba
- :code:`tiiuae/falcon-mamba-7b`, :code:`tiiuae/falcon-mamba-7b-instruct`, etc.
- ✅︎
-
- ✅︎
* - :code:`GemmaForCausalLM`
- Gemma
- :code:`google/gemma-2b`, :code:`google/gemma-7b`, etc.
Expand Down Expand Up @@ -193,7 +193,7 @@ Text Generation
- Jamba
- :code:`ai21labs/AI21-Jamba-1.5-Large`, :code:`ai21labs/AI21-Jamba-1.5-Mini`, :code:`ai21labs/Jamba-v0.1`, etc.
- ✅︎
-
- ✅︎
* - :code:`LlamaForCausalLM`
- Llama 3.1, Llama 3, Llama 2, LLaMA, Yi
- :code:`meta-llama/Meta-Llama-3.1-405B-Instruct`, :code:`meta-llama/Meta-Llama-3.1-70B`, :code:`meta-llama/Meta-Llama-3-70B-Instruct`, :code:`meta-llama/Llama-2-70b-hf`, :code:`01-ai/Yi-34B`, etc.
Expand All @@ -203,7 +203,7 @@ Text Generation
- Mamba
- :code:`state-spaces/mamba-130m-hf`, :code:`state-spaces/mamba-790m-hf`, :code:`state-spaces/mamba-2.8b-hf`, etc.
-
-
- ✅︎
* - :code:`MiniCPMForCausalLM`
- MiniCPM
- :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, :code:`openbmb/MiniCPM-S-1B-sft`, etc.
Expand Down
8 changes: 5 additions & 3 deletions tests/distributed/test_pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,13 +156,13 @@ def iter_params(self, model_name: str):
# "internlm/internlm-chat-7b": PPTestSettings.fast(),
"internlm/internlm2-chat-7b": PPTestSettings.fast(trust_remote_code=True),
"inceptionai/jais-13b-chat": PPTestSettings.fast(),
# TODO: Implement PP
# "ai21labs/AI21-Jamba-1.5-Mini": PPTestSettings.fast(),
"ai21labs/Jamba-tiny-dev": PPTestSettings.fast(),
"meta-llama/Meta-Llama-3-8B": PPTestSettings.detailed(),
"openbmb/MiniCPM-2B-sft-bf16": PPTestSettings.fast(trust_remote_code=True),
"openbmb/MiniCPM3-4B": PPTestSettings.fast(trust_remote_code=True),
# Uses Llama
# "mistralai/Mistral-7B-Instruct-v0.1": PPTestSettings.fast(),
"state-spaces/mamba-130m-hf": PPTestSettings.fast(),
"mistralai/Mixtral-8x7B-Instruct-v0.1": PPTestSettings.fast(tp_base=4),
"mosaicml/mpt-7b": PPTestSettings.fast(),
"nvidia/Minitron-8B-Base": PPTestSettings.fast(),
Expand All @@ -180,7 +180,7 @@ def iter_params(self, model_name: str):
"Qwen/Qwen1.5-MoE-A2.7B-Chat": PPTestSettings.fast(),
"stabilityai/stablelm-3b-4e1t": PPTestSettings.fast(),
"bigcode/starcoder2-3b": PPTestSettings.fast(),
"upstage/solar-pro-preview-instruct": PPTestSettings.fast(tp_base=2),
"upstage/solar-pro-preview-instruct": PPTestSettings.fast(tp_base=2)
mzusman marked this conversation as resolved.
Show resolved Hide resolved
# FIXME: Cannot load tokenizer in latest transformers version.
# Need to use tokenizer from `meta-llama/Llama-2-7b-chat-hf`
# "xverse/XVERSE-7B-Chat": PPTestSettings.fast(trust_remote_code=True),
Expand Down Expand Up @@ -234,6 +234,8 @@ def iter_params(self, model_name: str):
"OpenGVLab/InternVL2-1B",
"microsoft/Phi-3-vision-128k-instruct",
"fixie-ai/ultravox-v0_3",
# [LANGUAGE GENERATION - HYBRID ARCH]
"ai21labs/Jamba-tiny-dev",
]


Expand Down
30 changes: 18 additions & 12 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
ConfigFormat, get_config, get_hf_image_processor_config,
get_hf_text_config, get_pooling_config,
get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope)
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
print_warning_once, random_uuid,
from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
get_cpu_memory, print_warning_once, random_uuid,
resolve_obj_by_qualname)

if TYPE_CHECKING:
Expand Down Expand Up @@ -680,26 +680,32 @@ def get_num_attention_heads(self,
num_heads = getattr(self.hf_text_config, "num_attention_heads", 0)
return num_heads // parallel_config.tensor_parallel_size

def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
def get_layers_start_end_indices(
self, parallel_config: "ParallelConfig") -> Tuple[int, int]:
from vllm.distributed.utils import get_pp_indices
total_num_hidden_layers = getattr(self.hf_text_config,
"num_hidden_layers", 0)
pp_rank = parallel_config.rank // parallel_config.tensor_parallel_size
pp_size = parallel_config.pipeline_parallel_size
start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size)
return end - start
return start, end

def get_num_attention_layers(self,
parallel_config: "ParallelConfig") -> int:
if self.is_attention_free:
return 0
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
start, end = self.get_layers_start_end_indices(parallel_config)
return end - start

num_layers = self.get_num_layers(parallel_config)
def get_num_layers_by_block_type(
self,
parallel_config: "ParallelConfig",
block_type: LayerBlockType = LayerBlockType.attention,
) -> int:
start, end = self.get_layers_start_end_indices(parallel_config)

# Transformers supports layers_block_type @property
layers = getattr(self.hf_config, "layers_block_type",
["attention"] * num_layers)
return len([t for t in layers if t == "attention"])
layers_block_type_value = getattr(self.hf_config, "layers_block_type",
[block_type.value] * (end - start))
return sum(t == block_type.value
for t in layers_block_type_value[start:end])
mzusman marked this conversation as resolved.
Show resolved Hide resolved

def get_multimodal_config(self) -> "MultiModalConfig":
"""
Expand Down
92 changes: 64 additions & 28 deletions vllm/model_executor/models/jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from vllm.attention.layer import Attention
from vllm.config import _BATCH_SIZES_TO_CAPTURE, CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
Expand All @@ -25,9 +26,12 @@
MambaCacheParams)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType

from .interfaces import HasInnerState, SupportsLoRA
from .utils import maybe_prefix
from .interfaces import HasInnerState, SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)

KVCache = Tuple[torch.Tensor, torch.Tensor]

Expand Down Expand Up @@ -281,16 +285,24 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
org_num_embeddings=config.vocab_size,
)

decoder_layers = []
for i in range(config.num_hidden_layers):
layer_class = ALL_DECODER_LAYER_TYPES[config.layers_block_type[i]]
decoder_layers.append(
layer_class(config,
layer_idx=i,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{i}"))
self.layers = nn.ModuleList(decoder_layers)
def get_layer(prefix: str):
layer_idx = int(prefix.rsplit(".", 1)[1])
layer_class = ALL_DECODER_LAYER_TYPES[
config.layers_block_type[layer_idx]]
return layer_class(
config,
layer_idx,
cache_config,
quant_config=quant_config,
prefix=prefix,
)

self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers")
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))

self.final_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)

Expand All @@ -304,26 +316,34 @@ def forward(
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
mamba_cache_params: MambaCacheParams,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if inputs_embeds is not None:
hidden_states = inputs_embeds
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
for i in range(len(self.layers)):
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]

kv_cache_index = 0
mamba_cache_index = 0
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
kv_cache = None
layer_mamba_cache_params = None
if isinstance(layer, JambaAttentionDecoderLayer):
kv_cache = kv_caches[(i - self.config.attn_layer_offset) //
self.config.attn_layer_period]
kv_cache = kv_caches[kv_cache_index]
kv_cache_index += 1
if isinstance(layer, JambaMambaDecoderLayer):
current_state_layer = i - (1 +
(i - self.config.attn_layer_offset)
// self.config.attn_layer_period)
mzusman marked this conversation as resolved.
Show resolved Hide resolved
current_state_layer = mamba_cache_index
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
current_state_layer)
mamba_cache_index += 1

hidden_states, residual = layer(
positions=positions,
Expand All @@ -332,11 +352,16 @@ def forward(
attn_metadata=attn_metadata,
residual=residual,
mamba_cache_params=layer_mamba_cache_params)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.final_layernorm(hidden_states, residual)
return hidden_states


class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
Expand Down Expand Up @@ -368,6 +393,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

super().__init__()
self.config = config
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.scheduler_config = scheduler_config
self.model = JambaModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
Expand All @@ -390,6 +417,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config.vocab_size)
self.sampler = get_sampler()

self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)

Expand All @@ -406,10 +436,8 @@ def forward(self,
self.scheduler_config.max_num_seqs) if self.scheduler_config
else max(_BATCH_SIZES_TO_CAPTURE) + 2)

layers_type = self.config.layers_block_type
num_mamba_layers = sum(
[layer_type == "mamba" for layer_type in layers_type])

num_mamba_layers = self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config, LayerBlockType.mamba)
self.mamba_cache = MambaCacheManager(
self.lm_head.weight.dtype, num_mamba_layers, max_batch_size,
*self._get_mamba_cache_shape())
Expand All @@ -423,7 +451,7 @@ def forward(self,
state_indices_tensor)
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, mamba_cache_params,
inputs_embeds)
intermediate_tensors, inputs_embeds)
return hidden_states

def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
Expand Down Expand Up @@ -504,8 +532,12 @@ def load_weights(self, weights: Iterable[Tuple[str,
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.

if name.endswith(".bias") and name not in params_dict:
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
Expand All @@ -520,6 +552,8 @@ def load_weights(self, weights: Iterable[Tuple[str,
if weight_name not in name:
continue

if is_pp_missing_parameter(name, self):
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
Expand All @@ -533,6 +567,8 @@ def load_weights(self, weights: Iterable[Tuple[str,
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue

param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
Expand Down
Loading
Loading