Skip to content

Commit

Permalink
added flash_attention_recompute arg t provide an option to enable dis…
Browse files Browse the repository at this point in the history
…able
  • Loading branch information
Local Lab User committed Aug 21, 2024
1 parent 3e7ff03 commit c26393d
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 11 deletions.
14 changes: 10 additions & 4 deletions examples/image-to-text/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ python3 run_pipeline.py \
--image_path "https://llava-vl.github.io/static/images/view.jpg" \
--use_hpu_graphs \
--bf16 \
--use_flash_attention
--use_flash_attention \
--flash_attention_recompute
```


Expand All @@ -156,7 +157,8 @@ python3 run_pipeline.py \
--image_path "https://llava-vl.github.io/static/images/view.jpg" \
--use_hpu_graphs \
--bf16 \
--use_flash_attention
--use_flash_attention \
--flash_attention_recompute
```


Expand All @@ -168,7 +170,9 @@ QUANT_CONFIG=./quantization_config/maxabs_measure.json python run_pipeline.py \
--model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \
--image_path "https://llava-vl.github.io/static/images/view.jpg" \
--use_hpu_graphs \
--bf16 --use_flash_attention
--bf16 \
--use_flash_attention \
--flash_attention_recompute
```

Here is an example of quantizing the model based on previous measurements for Llava-v1.6-mistral-7b:
Expand All @@ -177,5 +181,7 @@ QUANT_CONFIG=./quantization_config/maxabs_quant.json python run_pipeline.py \
--model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf \
--image_path "https://llava-vl.github.io/static/images/view.jpg" \
--use_hpu_graphs \
--bf16 --use_flash_attention
--bf16 \
--use_flash_attention
--flash_attention_recompute
```
6 changes: 6 additions & 0 deletions examples/image-to-text/run_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ def main():
action="store_true",
help="Whether to enable Habana Flash Attention, provided that the model supports it.",
)
parser.add_argument(
"--flash_attention_recompute",
action="store_true",
help="Whether to enable Habana Flash Attention in recompute mode on first token generation. This gives an opportunity of splitting graph internally which helps reduce memory consumption.",
)

args = parser.parse_args()

Expand Down Expand Up @@ -156,6 +161,7 @@ def main():
"max_new_tokens": args.max_new_tokens,
"ignore_eos": args.ignore_eos,
"use_flash_attention": args.use_flash_attention,
"flash_attention_recompute": args.flash_attention_recompute,
}
if args.use_hpu_graphs:
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
Expand Down
13 changes: 10 additions & 3 deletions optimum/habana/transformers/models/clip/modeling_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def forward(
causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
Copied from CLIPAttention.forward: https://github.com/huggingface/transformers/blob/ab0f050b42d903f34d6eb97f3f8c0c07f0517ad2/src/transformers/models/clip/modeling_clip.py
Expand All @@ -100,8 +101,7 @@ def forward(
if FusedSDPA and use_flash_attention:
import habana_frameworks.torch.hpu as ht

use_recompute = not self.training
with ht.sdp_kernel(enable_recompute=use_recompute):
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
attn_output = self.fused_scaled_dot_product_attention(
query_states, key_states, value_states, attention_mask, self.dropout, False, 1, "fast"
)
Expand Down Expand Up @@ -178,6 +178,7 @@ def forward(
causal_attention_mask: torch.Tensor,
output_attentions: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
) -> Tuple[torch.FloatTensor]:
"""
Copied from CLIPEncoderLayer.forward: https://github.com/huggingface/transformers/blob/ab0f050b42d903f34d6eb97f3f8c0c07f0517ad2/src/transformers/models/clip/modeling_clip.py
Expand All @@ -193,6 +194,7 @@ def forward(
causal_attention_mask=causal_attention_mask,
output_attentions=output_attentions,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
)
hidden_states = residual + hidden_states

Expand All @@ -219,6 +221,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
) -> Union[Tuple, BaseModelOutput]:
"""
Copied from CLIPEncoder.forward: https://github.com/huggingface/transformers/blob/ab0f050b42d903f34d6eb97f3f8c0c07f0517ad2/src/transformers/models/clip/modeling_clip.py
Expand All @@ -245,7 +248,6 @@ def forward(
attention_mask,
causal_attention_mask,
output_attentions,
use_flash_attention=use_flash_attention,
)
else:
layer_outputs = encoder_layer(
Expand All @@ -254,6 +256,7 @@ def forward(
causal_attention_mask,
output_attentions=output_attentions,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
)

hidden_states = layer_outputs[0]
Expand All @@ -279,6 +282,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
) -> Union[Tuple, BaseModelOutputWithPooling]:
"""
Copied from CLIPVisionTransformer.forward: https://github.com/huggingface/transformers/blob/ab0f050b42d903f34d6eb97f3f8c0c07f0517ad2/src/transformers/models/clip/modeling_clip.py
Expand All @@ -303,6 +307,7 @@ def forward(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
)

last_hidden_state = encoder_outputs[0]
Expand All @@ -328,6 +333,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
) -> Union[Tuple, BaseModelOutputWithPooling]:
"""
Copied from CLIPVisionModel.forward: https://github.com/huggingface/transformers/blob/ab0f050b42d903f34d6eb97f3f8c0c07f0517ad2/src/transformers/models/clip/modeling_clip.py
Expand All @@ -342,4 +348,5 @@ def forward(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
)
10 changes: 8 additions & 2 deletions optimum/habana/transformers/models/llava/modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def forward(
image_offset: Optional[int] = None,
tokens_pos: Optional[torch.LongTensor] = None,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
"""
Inherits from LlavaForConditionalGeneration: https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llava/modeling_llava.py
Expand Down Expand Up @@ -154,7 +155,10 @@ def forward(
# 2. Merge text and images
if pixel_values is not None and input_ids.shape[1] != 1:
image_outputs = self.vision_tower(
pixel_values, output_hidden_states=True, use_flash_attention=use_flash_attention
pixel_values,
output_hidden_states=True,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
)
# this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
Expand Down Expand Up @@ -184,7 +188,7 @@ def forward(
return_dict=return_dict,
token_idx=token_idx + image_offset,
use_flash_attention=use_flash_attention,
flash_attention_recompute=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
)

if input_ids.shape[1] != 1 and pixel_values is not None:
Expand Down Expand Up @@ -296,6 +300,7 @@ def prepare_inputs_for_generation(
else:
model_inputs = {"input_ids": input_ids}
use_flash_attention = kwargs.get("use_flash_attention", False)
flash_attention_recompute = kwargs.get("flash_attention_recompute", False)
model_inputs.update(
{
"position_ids": position_ids,
Expand All @@ -307,6 +312,7 @@ def prepare_inputs_for_generation(
"image_offset": image_offset,
"tokens_pos": tokens_pos,
"use_flash_attention": use_flash_attention,
"flash_attention_recompute": flash_attention_recompute,
}
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def forward(
return_dict: Optional[bool] = None,
token_idx: Optional[torch.Tensor] = None,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
) -> Union[Tuple, LlavaNextCausalLMOutputWithPast]:
"""
Inherits from LlavaForConditionalGeneration: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava_next/modeling_llava_next.py#L433
Expand Down Expand Up @@ -83,7 +84,7 @@ def forward(
return_dict=return_dict,
token_idx=token_idx + self.image_offset,
use_flash_attention=use_flash_attention,
flash_attention_recompute=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
)

if inputs_embeds.shape[1] != 1 and pixel_values is not None:
Expand Down Expand Up @@ -248,6 +249,7 @@ def prepare_inputs_for_generation(
)
else:
use_flash_attention = kwargs.get("use_flash_attention", False)
flash_attention_recompute = kwargs.get("flash_attention_recompute", False)
position_ids = kwargs.get("position_ids", None)
labels = kwargs.get("labels", None)
if past_key_values is None and pixel_values is not None and input_ids.shape[1] != 1:
Expand All @@ -268,7 +270,10 @@ def prepare_inputs_for_generation(
batch_size, num_patches, num_channels, height, width = pixel_values.shape
reshaped_pixel_values = pixel_values.view(batch_size * num_patches, num_channels, height, width)
image_features = self.vision_tower(
reshaped_pixel_values, output_hidden_states=True, use_flash_attention=use_flash_attention
reshaped_pixel_values,
output_hidden_states=True,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
)

selected_image_feature = image_features.hidden_states[vision_feature_layer]
Expand Down Expand Up @@ -390,6 +395,7 @@ def prepare_inputs_for_generation(
"image_sizes": image_sizes,
"labels": labels,
"use_flash_attention": use_flash_attention,
"flash_attention_recompute": flash_attention_recompute,
}
)

Expand Down

0 comments on commit c26393d

Please sign in to comment.