Skip to content

Commit

Permalink
Update modeling_mixtral.py
Browse files Browse the repository at this point in the history
  • Loading branch information
tthakkal authored Aug 16, 2024
1 parent a40667e commit 6d06161
Showing 1 changed file with 2 additions and 14 deletions.
16 changes: 2 additions & 14 deletions optimum/habana/transformers/models/mixtral/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,6 @@ def forward(
reuse_cache: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
cache_idx: int = None,
lazy_mode: Optional[bool] = True,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Expand All @@ -481,10 +480,7 @@ def forward(
- add new args reuse_cache
- add new args flash_attention_recompute
- add new args cache_idx
- add new args lazy_mode
"""
if lazy_mode:
htcore.mark_step()
residual = hidden_states

hidden_states = self.input_layernorm(hidden_states)
Expand All @@ -504,16 +500,12 @@ def forward(
cache_idx=cache_idx,
)
hidden_states = residual + hidden_states
if lazy_mode:
htcore.mark_step()

# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states, router_logits = self.block_sparse_moe(hidden_states)
hidden_states = residual + hidden_states
if lazy_mode:
htcore.mark_step()

outputs = (hidden_states,)

Expand Down Expand Up @@ -554,7 +546,6 @@ def forward(
reuse_cache: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
cache_idx: int = None,
lazy_mode: Optional[bool] = True,
) -> Union[Tuple, MoeModelOutputWithPast]:
"""
Copied from MixtralModel.forward: https://github.com/huggingface/transformers/blob/v4.37.0/src/transformers/models/mixtral/modeling_mixtral.py#L1069
Expand Down Expand Up @@ -608,7 +599,6 @@ def forward(

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
Expand All @@ -618,6 +608,8 @@ def forward(
if position_ids is None:
position_ids = cache_position.unsqueeze(0)



if self.config._attn_implementation == "flash_attention_2":
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
Expand Down Expand Up @@ -678,7 +670,6 @@ def forward(
reuse_cache=reuse_cache,
flash_attention_recompute=flash_attention_recompute,
cache_idx=cache_idx,
lazy_mode=lazy_mode,
)

hidden_states = layer_outputs[0]
Expand Down Expand Up @@ -753,7 +744,6 @@ def forward(
reuse_cache: Optional[bool] = None,
flash_attention_recompute: Optional[bool] = False,
cache_idx: int = None,
lazy_mode: Optional[bool] = True,
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_router_logits = (
Expand Down Expand Up @@ -782,7 +772,6 @@ def forward(
reuse_cache=reuse_cache,
flash_attention_recompute=flash_attention_recompute,
cache_idx=cache_idx,
lazy_mode=lazy_mode,
)

hidden_states = outputs[0]
Expand Down Expand Up @@ -887,7 +876,6 @@ def prepare_inputs_for_generation(
"reuse_cache": reuse_cache,
"flash_attention_recompute": kwargs.get("flash_attention_recompute"),
"cache_idx": kwargs.get("cache_idx"),
"lazy_mode": kwargs.get("lazy_mode"),
}
)
return model_inputs

0 comments on commit 6d06161

Please sign in to comment.