Skip to content

Commit

Permalink
use FusedMoE layer
Browse files Browse the repository at this point in the history
Signed-off-by: xffxff <[email protected]>
  • Loading branch information
xffxff committed Nov 22, 2024
1 parent a34edd6 commit 06103f9
Showing 1 changed file with 152 additions and 103 deletions.
255 changes: 152 additions & 103 deletions vllm/model_executor/models/aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@
from vllm.transformers_utils.configs.aria import AriaVisionConfig, AriaMoELMConfig
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
get_compressed_tensors_cache_scale)


class AriaVisionTransformer(Idefics2VisionTransformer):
Expand Down Expand Up @@ -162,7 +170,9 @@ def __init__(self, kv_dim, embed_dim, num_heads, drop_out_rate=0):
super().__init__()
self.num_heads = num_heads
self.q_proj = ColumnParallelLinear(embed_dim, embed_dim, bias=False)
self.kv_proj = MergedColumnParallelLinear(kv_dim, [embed_dim, embed_dim], bias=False)
self.kv_proj = MergedColumnParallelLinear(kv_dim,
[embed_dim, embed_dim],
bias=False)

self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
self.linear = RowParallelLinear(embed_dim, embed_dim)
Expand Down Expand Up @@ -290,8 +300,9 @@ def forward(self, x, attn_mask=None):
out = self.ffn(self.ln_ffn(attention_out))

return out

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:

def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".kv_proj", ".k_proj", 0),
Expand All @@ -318,100 +329,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
return loaded_params



class Experts(nn.Module):

def __init__(self, config: AriaMoELMConfig):
super().__init__()
self.config = config

self.router_weight = nn.Parameter(
torch.empty(
(self.config.moe_num_experts, self.config.hidden_size)))

self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
if self.tp_size > config.moe_num_experts:
raise ValueError(
f"Tensor model parallel size {self.tp_size} is greater than the number of experts {config.moe_num_experts}"
)

self.w1 = nn.Parameter(
torch.empty((
config.moe_num_experts,
config.moe_intermediate_size * 2 // self.tp_size,
config.hidden_size,
)))
self.w2 = nn.Parameter(
torch.empty((
config.moe_num_experts,
config.hidden_size,
config.moe_intermediate_size // self.tp_size,
)))
set_weight_attrs(self.router_weight,
{"weight_loader": self._weight_loader_for_router})
set_weight_attrs(self.w1,
{"weight_loader": self._weight_loader_for_w1})
set_weight_attrs(self.w2,
{"weight_loader": self._weight_loader_for_w2})

def _weight_loader_for_router(self, param: nn.Parameter,
loaded_weight: torch.Tensor):
param.data.copy_(loaded_weight)

def _weight_loader_for_w1(self, param: nn.Parameter,
loaded_weight: torch.Tensor):
# the shape of loaded_weight is (num_experts, hidden_size, 2 * moe_intermediate_size)
if self.tp_size > 1:
up, gate = loaded_weight.chunk(2, dim=-1)
up_current_rank = up.chunk(self.tp_size, dim=-1)[self.tp_rank]
gate_current_rank = gate.chunk(self.tp_size, dim=-1)[self.tp_rank]
up_and_gate = torch.cat([up_current_rank, gate_current_rank],
dim=-1).transpose(1, 2)
param.data.copy_(up_and_gate)
else:
param.data.copy_(loaded_weight.transpose(1, 2))

def _weight_loader_for_w2(self, param: nn.Parameter,
loaded_weight: torch.Tensor):
# the shape of loaded_weight is (num_experts, moe_intermediate_size, hidden_size)
if self.tp_size > 1:
down_current_rank = loaded_weight.chunk(self.tp_size,
dim=1)[self.tp_rank]
param.data.copy_(down_current_rank.transpose(1, 2))
else:
param.data.copy_(loaded_weight.transpose(1, 2))

def forward(self, hidden_states):
router_output = torch.nn.functional.linear(hidden_states,
self.router_weight)

def custom_routing_function(hidden_states, router_output, topk,
renormalize):
top_logits, top_indices = torch.topk(router_output,
k=self.config.moe_topk,
dim=1)
scores = torch.softmax(top_logits, dim=-1, dtype=torch.float32)
return scores, top_indices.to(torch.int32)

hidden_states_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_states.size(-1))
final_hidden_states = fused_moe(
hidden_states,
self.w1,
self.w2,
router_output,
self.config.moe_topk,
False,
inplace=True,
custom_routing_function=custom_routing_function,
)
final_hidden_states = final_hidden_states.view(hidden_states_shape)
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
return final_hidden_states


class MoELayer(nn.Module):
"""
Mixture of Experts (MoE) Layer for the AriaMoE model.
Expand All @@ -430,7 +347,17 @@ def __init__(
super().__init__()
self.config = config

self.experts = Experts(config)
self.router_weight = nn.Parameter(
torch.empty(
(self.config.moe_num_experts, self.config.hidden_size)))

self.experts = FusedMoE(
num_experts=config.moe_num_experts,
top_k=config.moe_topk,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
quant_config=quant_config,
)
self.shared_experts = LlamaMLP(
config.hidden_size,
config.moe_intermediate_size * config.moe_num_shared_experts,
Expand All @@ -449,8 +376,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
torch.Tensor: Output tensor after passing through the MoE layer.
"""

router_output = torch.nn.functional.linear(hidden_states,
self.router_weight)

shared_expert_output = self.shared_experts(hidden_states)
sparse_expert_output = self.experts(hidden_states)
sparse_expert_output = self.experts(hidden_states, router_output)

return sparse_expert_output + shared_expert_output

Expand Down Expand Up @@ -535,6 +465,125 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
prefix=f"{prefix}.layers",
)

# Adapted from FusedMoE.make_expert_params_mapping with the modification
# of changing the prefix of the weight names
def _make_expert_params_mapping(
self, ckpt_gate_proj_name: str, ckpt_down_proj_name: str,
ckpt_up_proj_name: str,
num_experts: int) -> List[Tuple[str, str, int, str]]:

return [
# (param_name, weight_name, expert_id, shard_id)
("experts.w13_" if weight_name
in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_",
f"experts.experts.{expert_id}.{weight_name}.", expert_id, shard_id
) for expert_id in range(num_experts)
for shard_id, weight_name in [
("w1", ckpt_gate_proj_name),
("w2", ckpt_down_proj_name),
("w3", ckpt_up_proj_name),
]
]

def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]

# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = self._make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.moe_num_experts)

params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if scale_name := get_compressed_tensors_cache_scale(name):
# Loading kv cache scales for compressed-tensors quantization
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = loaded_weight[0]
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
# We have mlp.experts.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts.experts[0].gate_gate_up_proj, which breaks load.
if (("mlp.experts.experts." in name)
and name not in params_dict):
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue

if is_pp_missing_parameter(name, self):
continue

param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)

if is_pp_missing_parameter(name, self):
continue

param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue

if is_pp_missing_parameter(name, self):
continue

param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params


def build_mm_projector(config):
"""
Expand Down Expand Up @@ -617,7 +666,8 @@ def input_processor(ctx, llm_inputs):
max_image_size = multi_modal_data.pop("max_image_size", 980)
_split_image = multi_modal_data.pop("split_image", False)

assert isinstance(max_image_size, (int, float)), "max_image_size should be float or int"
assert isinstance(max_image_size,
(int, float)), "max_image_size should be float or int"
images = (multi_modal_data["image"] if isinstance(
multi_modal_data["image"], list) else [multi_modal_data["image"]])

Expand Down Expand Up @@ -682,6 +732,7 @@ def __init__(
vllm_config=vllm_config.with_hf_config(config.text_config),
prefix=maybe_prefix(prefix, "language_model.model"),
)
print("self.language_model", self.language_model)
self.pad_token_id = (self.config.pad_token_id
if self.config.pad_token_id is not None else -1)
self.unpadded_vocab_size = config.text_config.vocab_size
Expand Down Expand Up @@ -760,9 +811,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
"language_model.lm_head": "lm_head",
},
orig_to_new_suffix={
"experts.fc1.weight": "experts.w1",
"experts.fc2.weight": "experts.w2",
"router.weight": "experts.router_weight",
"router.weight": "router_weight",
},
)

Expand Down

0 comments on commit 06103f9

Please sign in to comment.