Skip to content

Commit

Permalink
yapf re-format
Browse files Browse the repository at this point in the history
  • Loading branch information
divakar-amd committed Sep 16, 2024
1 parent 85e6084 commit d1e3448
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions vllm/model_executor/models/dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:


class DbrxExperts(FusedMoE):

def __init__(
self,
config: DbrxConfig,
Expand Down Expand Up @@ -106,6 +107,7 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
).transpose(1, 2)
param_data[:] = loaded_weight[:, :, shard]


class DbrxMoE(nn.Module):
"""A tensor-parallel MoE implementation for DBRX.
Expand All @@ -128,10 +130,9 @@ def __init__(

self.router = DbrxRouter(config, self.params_dtype)

self.experts = DbrxExperts(
config=config,
quant_config=quant_config,
params_dtype=self.params_dtype)
self.experts = DbrxExperts(config=config,
quant_config=quant_config,
params_dtype=self.params_dtype)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
orig_shape = hidden_states.shape
Expand Down Expand Up @@ -389,11 +390,11 @@ def sample(
return next_tokens

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

expert_params_mapping = [(
"w13_weight" if weight_name in ["w1", "v1"] else "w2_weight",
f"mlp.{weight_name}",
) for weight_name in ["w1", "v1", "w2"]]
) for weight_name in ["w1", "v1", "w2"]]
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
for param_name, weight_name in expert_params_mapping:
Expand Down

0 comments on commit d1e3448

Please sign in to comment.