From d1e34487f8c87d455b6f5cc808d1877c4612b08d Mon Sep 17 00:00:00 2001 From: Divakar Verma Date: Mon, 16 Sep 2024 19:20:05 +0000 Subject: [PATCH] yapf re-format --- vllm/model_executor/models/dbrx.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index bb9f8c46ed88d..397a46a486f72 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -53,6 +53,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class DbrxExperts(FusedMoE): + def __init__( self, config: DbrxConfig, @@ -106,6 +107,7 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, ).transpose(1, 2) param_data[:] = loaded_weight[:, :, shard] + class DbrxMoE(nn.Module): """A tensor-parallel MoE implementation for DBRX. @@ -128,10 +130,9 @@ def __init__( self.router = DbrxRouter(config, self.params_dtype) - self.experts = DbrxExperts( - config=config, - quant_config=quant_config, - params_dtype=self.params_dtype) + self.experts = DbrxExperts(config=config, + quant_config=quant_config, + params_dtype=self.params_dtype) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: orig_shape = hidden_states.shape @@ -389,11 +390,11 @@ def sample( return next_tokens def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - + expert_params_mapping = [( "w13_weight" if weight_name in ["w1", "v1"] else "w2_weight", f"mlp.{weight_name}", - ) for weight_name in ["w1", "v1", "w2"]] + ) for weight_name in ["w1", "v1", "w2"]] params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in weights: for param_name, weight_name in expert_params_mapping: