Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[model][utils] add extract_layer_index utility function #10599

Merged
merged 4 commits into from
Nov 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 18 additions & 23 deletions vllm/model_executor/models/arctic.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from vllm.transformers_utils.configs.arctic import ArcticConfig

from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
from .utils import (extract_layer_index, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)

Expand All @@ -44,15 +44,14 @@ class ArcticMLP(nn.Module):

def __init__(self,
config: ArcticConfig,
layer_id: int,
expert_id: int = -1,
is_residual_mlp: bool = False,
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True):
reduce_results: bool = True,
prefix: str = ""):
super().__init__()
self.hidden_size = config.hidden_size
self.expert_id = expert_id
self.layer_id = layer_id

self.ffn_dim = config.intermediate_size if not is_residual_mlp \
else self.hidden_size
Expand Down Expand Up @@ -85,13 +84,14 @@ class ArcticMoE(nn.Module):

def __init__(self,
config: ArcticConfig,
layer_id: int,
tp_size: Optional[int] = None,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True):
reduce_results: bool = True,
prefix: str = ""):
super().__init__()

layer_id = extract_layer_index(prefix)
self.tp_size = tp_size or get_tensor_model_parallel_world_size()
self.hidden_size = config.hidden_size
self.num_experts = config.num_local_experts
Expand All @@ -109,15 +109,16 @@ def __init__(self,

if not self.is_moe_layer:
self.mlp = ArcticMLP(config,
layer_id=layer_id,
quant_config=quant_config,
reduce_results=reduce_results)
reduce_results=reduce_results,
prefix=f"{prefix}.mlp")
else:
self.gate = ReplicatedLinear(self.hidden_size,
self.num_experts,
bias=False,
params_dtype=self.params_dtype,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.gate")
if self.is_quant:
self.ws = DeepSpeedFPParameter(
torch.Size((self.num_experts, 2 * self.intermediate_size,
Expand Down Expand Up @@ -220,14 +221,12 @@ class ArcticAttention(nn.Module):
def __init__(
self,
config: ArcticConfig,
layer_idx: Optional[int] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size

tp_size = get_tensor_model_parallel_world_size()
Expand Down Expand Up @@ -298,26 +297,25 @@ class ArcticDecoderLayer(nn.Module):
def __init__(
self,
config: ArcticConfig,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size
layer_idx = extract_layer_index(prefix)
is_moe_layer = (layer_idx + 1) % config.moe_layer_frequency == 0
self.use_residual = config.use_residual and is_moe_layer
self.self_attn = ArcticAttention(config,
layer_idx,
cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn")
self.block_sparse_moe = ArcticMoE(
config,
layer_id=layer_idx,
quant_config=quant_config,
reduce_results=(not self.use_residual))
reduce_results=(not self.use_residual),
prefix=f"{prefix}.block_sparse_moe",
)

self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
Expand All @@ -328,9 +326,9 @@ def __init__(
self.residual_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.residual_mlp = ArcticMLP(config,
layer_id=layer_idx,
is_residual_mlp=True,
reduce_results=False)
reduce_results=False,
prefix=f"{prefix}.residual_mlp")

def forward(
self,
Expand Down Expand Up @@ -384,11 +382,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
org_num_embeddings=self.vocab_size)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: ArcticDecoderLayer(config,
int(prefix.split(".")[-1]),
cache_config,
quant_config,
prefix=prefix),
lambda prefix: ArcticDecoderLayer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers")
self._attn_implementation = config._attn_implementation
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Expand Down
19 changes: 11 additions & 8 deletions vllm/model_executor/models/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from vllm.sequence import IntermediateTensors

from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
from .utils import (extract_layer_index, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)

Expand All @@ -63,6 +63,7 @@ def __init__(
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
Expand Down Expand Up @@ -92,6 +93,7 @@ def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
Expand Down Expand Up @@ -260,12 +262,12 @@ class DeepseekDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
layer_idx = extract_layer_index(prefix)
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
Expand All @@ -285,13 +287,16 @@ def __init__(
if (config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
and layer_idx % config.moe_layer_freq == 0):
self.mlp = DeepseekMoE(config=config, quant_config=quant_config)
self.mlp = DeepseekMoE(config=config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
else:
self.mlp = DeepseekMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
Expand Down Expand Up @@ -347,11 +352,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: DeepseekDecoderLayer(config,
int(prefix.split(".")[-1]),
cache_config,
quant_config=quant_config,
prefix=prefix),
lambda prefix: DeepseekDecoderLayer(
config, cache_config, quant_config=quant_config, prefix=prefix
),
prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (
Expand Down
15 changes: 5 additions & 10 deletions vllm/model_executor/models/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@
from vllm.sequence import IntermediateTensors, PoolerOutput

from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
from .utils import (AutoWeightsLoader, extract_layer_index,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)

Expand Down Expand Up @@ -85,7 +86,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
class Gemma2Attention(nn.Module):

def __init__(self,
layer_idx: int,
config: Gemma2Config,
hidden_size: int,
num_heads: int,
Expand All @@ -98,7 +98,6 @@ def __init__(self,
attn_logits_soft_cap: Optional[float] = None,
prefix: str = "") -> None:
super().__init__()
self.layer_idx = layer_idx
self.config = config
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
Expand Down Expand Up @@ -145,6 +144,7 @@ def __init__(self,

# reference:
# https://github.com/huggingface/transformers/blob/54be2d7ae87e873482b984cc956e165ca4dc0ba3/src/transformers/models/gemma2/modeling_gemma2.py#L312 # noqa
layer_idx = extract_layer_index(prefix)
use_sliding_window = (layer_idx % 2 == 0 and
config.interleaved_sliding_window is not None)
sliding_window = config.interleaved_sliding_window if \
Expand Down Expand Up @@ -178,7 +178,6 @@ class Gemma2DecoderLayer(nn.Module):

def __init__(
self,
layer_idx: int,
config: Gemma2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
Expand All @@ -187,7 +186,6 @@ def __init__(
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = Gemma2Attention(
layer_idx=layer_idx,
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
Expand Down Expand Up @@ -262,11 +260,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Gemma2DecoderLayer(int(prefix.split(".")[-1]),
config,
cache_config,
quant_config,
prefix=prefix),
lambda prefix: Gemma2DecoderLayer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers")
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

Expand Down
8 changes: 2 additions & 6 deletions vllm/model_executor/models/olmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ class OlmoeDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
Expand Down Expand Up @@ -264,11 +263,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: OlmoeDecoderLayer(config,
int(prefix.split(".")[-1]),
cache_config,
quant_config,
prefix=prefix),
lambda prefix: OlmoeDecoderLayer(
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=1e-5)

Expand Down
6 changes: 2 additions & 4 deletions vllm/model_executor/models/qwen2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
from vllm.utils import print_warning_once

from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
from .utils import (extract_layer_index, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)

Expand Down Expand Up @@ -244,7 +244,6 @@ class Qwen2MoeDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
layer_idx: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
Expand All @@ -269,6 +268,7 @@ def __init__(

# Note: Qwen/Qwen2-57B-A14B-Instruct does not have
# `mlp_only_layers` in the config.
layer_idx = extract_layer_index(prefix)
mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
config.mlp_only_layers)
if (layer_idx not in mlp_only_layers) and (
Expand Down Expand Up @@ -337,8 +337,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Qwen2MoeDecoderLayer(config=config,
layer_idx=int(
prefix.split(".")[-1]),
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix),
Expand Down
21 changes: 21 additions & 0 deletions vllm/model_executor/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,3 +629,24 @@ def maybe_prefix(prefix: str, name: str) -> str:
The string "prefix.name" if prefix was non-empty, otherwise just "name".
"""
return name if not prefix else f"{prefix}.{name}"


def extract_layer_index(layer_name: str) -> int:
"""
Extract the layer index from the module name.
Examples:
- "encoder.layers.0" -> 0
- "encoder.layers.1.self_attn" -> 1
- "2.self_attn" -> 2
- "model.encoder.layers.0.sub.1" -> ValueError
"""
subnames = layer_name.split(".")
int_vals: List[int] = []
for subname in subnames:
try:
int_vals.append(int(subname))
except ValueError:
continue
assert len(int_vals) == 1, (f"layer name {layer_name} should"
" only contain one integer")
return int_vals[0]