Skip to content

Commit

Permalink
[model][utils] add extract_layer_index utility function (#10599)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao authored Nov 24, 2024
1 parent eda2b35 commit c055747
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 51 deletions.
41 changes: 18 additions & 23 deletions vllm/model_executor/models/arctic.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from vllm.transformers_utils.configs.arctic import ArcticConfig

from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
from .utils import (extract_layer_index, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)

Expand All @@ -44,15 +44,14 @@ class ArcticMLP(nn.Module):

def __init__(self,
config: ArcticConfig,
layer_id: int,
expert_id: int = -1,
is_residual_mlp: bool = False,
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True):
reduce_results: bool = True,
prefix: str = ""):
super().__init__()
self.hidden_size = config.hidden_size
self.expert_id = expert_id
self.layer_id = layer_id

self.ffn_dim = config.intermediate_size if not is_residual_mlp \
else self.hidden_size
Expand Down Expand Up @@ -85,13 +84,14 @@ class ArcticMoE(nn.Module):

def __init__(self,
config: ArcticConfig,
layer_id: int,
tp_size: Optional[int] = None,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True):
reduce_results: bool = True,
prefix: str = ""):
super().__init__()

layer_id = extract_layer_index(prefix)
self.tp_size = tp_size or get_tensor_model_parallel_world_size()
self.hidden_size = config.hidden_size
self.num_experts = config.num_local_experts
Expand All @@ -109,15 +109,16 @@ def __init__(self,

if not self.is_moe_layer:
self.mlp = ArcticMLP(config,
layer_id=layer_id,
quant_config=quant_config,
reduce_results=reduce_results)
reduce_results=reduce_results,
prefix=f"{prefix}.mlp")
else:
self.gate = ReplicatedLinear(self.hidden_size,
self.num_experts,
bias=False,
params_dtype=self.params_dtype,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.gate")
if self.is_quant:
self.ws = DeepSpeedFPParameter(
torch.Size((self.num_experts, 2 * self.intermediate_size,
Expand Down Expand Up @@ -220,14 +221,12 @@ class ArcticAttention(nn.Module):
def __init__(
self,
config: ArcticConfig,
layer_idx: Optional[int] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size

tp_size = get_tensor_model_parallel_world_size()
Expand Down Expand Up @@ -298,26 +297,25 @@ class ArcticDecoderLayer(nn.Module):
def __init__(
self,
config: ArcticConfig,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size
layer_idx = extract_layer_index(prefix)
is_moe_layer = (layer_idx + 1) % config.moe_layer_frequency == 0
self.use_residual = config.use_residual and is_moe_layer
self.self_attn = ArcticAttention(config,
layer_idx,
cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn")
self.block_sparse_moe = ArcticMoE(
config,
layer_id=layer_idx,
quant_config=quant_config,
reduce_results=(not self.use_residual))
reduce_results=(not self.use_residual),
prefix=f"{prefix}.block_sparse_moe",
)

self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
Expand All @@ -328,9 +326,9 @@ def __init__(
self.residual_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.residual_mlp = ArcticMLP(config,
layer_id=layer_idx,
is_residual_mlp=True,
reduce_results=False)
reduce_results=False,
prefix=f"{prefix}.residual_mlp")

def forward(
self,
Expand Down Expand Up @@ -384,11 +382,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
org_num_embeddings=self.vocab_size)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: ArcticDecoderLayer(config,
int(prefix.split(".")[-1]),
cache_config,
quant_config,
prefix=prefix),
lambda prefix: ArcticDecoderLayer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers")
self._attn_implementation = config._attn_implementation
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Expand Down
19 changes: 11 additions & 8 deletions vllm/model_executor/models/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from vllm.sequence import IntermediateTensors

from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
from .utils import (extract_layer_index, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)

Expand All @@ -63,6 +63,7 @@ def __init__(
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
Expand Down Expand Up @@ -92,6 +93,7 @@ def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
Expand Down Expand Up @@ -260,12 +262,12 @@ class DeepseekDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
layer_idx = extract_layer_index(prefix)
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
Expand All @@ -285,13 +287,16 @@ def __init__(
if (config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
and layer_idx % config.moe_layer_freq == 0):
self.mlp = DeepseekMoE(config=config, quant_config=quant_config)
self.mlp = DeepseekMoE(config=config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
else:
self.mlp = DeepseekMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
Expand Down Expand Up @@ -347,11 +352,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: DeepseekDecoderLayer(config,
int(prefix.split(".")[-1]),
cache_config,
quant_config=quant_config,
prefix=prefix),
lambda prefix: DeepseekDecoderLayer(
config, cache_config, quant_config=quant_config, prefix=prefix
),
prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (
Expand Down
15 changes: 5 additions & 10 deletions vllm/model_executor/models/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@
from vllm.sequence import IntermediateTensors, PoolerOutput

from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
from .utils import (AutoWeightsLoader, extract_layer_index,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)

Expand Down Expand Up @@ -85,7 +86,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
class Gemma2Attention(nn.Module):

def __init__(self,
layer_idx: int,
config: Gemma2Config,
hidden_size: int,
num_heads: int,
Expand All @@ -98,7 +98,6 @@ def __init__(self,
attn_logits_soft_cap: Optional[float] = None,
prefix: str = "") -> None:
super().__init__()
self.layer_idx = layer_idx
self.config = config
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
Expand Down Expand Up @@ -145,6 +144,7 @@ def __init__(self,

# reference:
# https://github.com/huggingface/transformers/blob/54be2d7ae87e873482b984cc956e165ca4dc0ba3/src/transformers/models/gemma2/modeling_gemma2.py#L312 # noqa
layer_idx = extract_layer_index(prefix)
use_sliding_window = (layer_idx % 2 == 0 and
config.interleaved_sliding_window is not None)
sliding_window = config.interleaved_sliding_window if \
Expand Down Expand Up @@ -178,7 +178,6 @@ class Gemma2DecoderLayer(nn.Module):

def __init__(
self,
layer_idx: int,
config: Gemma2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
Expand All @@ -187,7 +186,6 @@ def __init__(
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = Gemma2Attention(
layer_idx=layer_idx,
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
Expand Down Expand Up @@ -262,11 +260,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Gemma2DecoderLayer(int(prefix.split(".")[-1]),
config,
cache_config,
quant_config,
prefix=prefix),
lambda prefix: Gemma2DecoderLayer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers")
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

Expand Down
8 changes: 2 additions & 6 deletions vllm/model_executor/models/olmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ class OlmoeDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
Expand Down Expand Up @@ -264,11 +263,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: OlmoeDecoderLayer(config,
int(prefix.split(".")[-1]),
cache_config,
quant_config,
prefix=prefix),
lambda prefix: OlmoeDecoderLayer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=1e-5)

Expand Down
6 changes: 2 additions & 4 deletions vllm/model_executor/models/qwen2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
from vllm.utils import print_warning_once

from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
from .utils import (extract_layer_index, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)

Expand Down Expand Up @@ -244,7 +244,6 @@ class Qwen2MoeDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
Expand All @@ -269,6 +268,7 @@ def __init__(

# Note: Qwen/Qwen2-57B-A14B-Instruct does not have
# `mlp_only_layers` in the config.
layer_idx = extract_layer_index(prefix)
mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
config.mlp_only_layers)
if (layer_idx not in mlp_only_layers) and (
Expand Down Expand Up @@ -337,8 +337,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Qwen2MoeDecoderLayer(config=config,
layer_idx=int(
prefix.split(".")[-1]),
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix),
Expand Down
21 changes: 21 additions & 0 deletions vllm/model_executor/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,3 +629,24 @@ def maybe_prefix(prefix: str, name: str) -> str:
The string "prefix.name" if prefix was non-empty, otherwise just "name".
"""
return name if not prefix else f"{prefix}.{name}"


def extract_layer_index(layer_name: str) -> int:
"""
Extract the layer index from the module name.
Examples:
- "encoder.layers.0" -> 0
- "encoder.layers.1.self_attn" -> 1
- "2.self_attn" -> 2
- "model.encoder.layers.0.sub.1" -> ValueError
"""
subnames = layer_name.split(".")
int_vals: List[int] = []
for subname in subnames:
try:
int_vals.append(int(subname))
except ValueError:
continue
assert len(int_vals) == 1, (f"layer name {layer_name} should"
" only contain one integer")
return int_vals[0]

0 comments on commit c055747

Please sign in to comment.