Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transformers 4.44 support #1996

Merged
merged 8 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 146 additions & 72 deletions optimum/bettertransformer/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from typing import Optional, Tuple

import torch
import torch.nn.functional as F

from ...utils import check_if_transformers_greater


# TODO (CRITICAL): Layer-wise attention scaling is broken for several archs.
Expand All @@ -23,7 +26,7 @@
def raise_on_head_mask(head_mask: Optional[torch.Tensor]):
if head_mask is not None:
raise ValueError(
"layer_head_mask different than None is unsupported for now with BetterTransformer, please"
"layer_head_mask (or head_mask) different than None is unsupported for now with BetterTransformer, please"
"open a PR or an issue at https://github.com/huggingface/optimum."
)

Expand Down Expand Up @@ -534,88 +537,159 @@ def bart_forward(
return attn_output, None, past_key_value


# Adapted from transformers.models.bloom.modeling_bloom.BloomAttention.forward
def bloom_forward(
self,
hidden_states: torch.Tensor,
residual: torch.Tensor,
alibi: torch.Tensor,
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
**kwargs,
):
raise_on_head_mask(head_mask)
if check_if_transformers_greater("4.44"):
from transformers.cache_utils import Cache
from transformers.models.bloom.modeling_bloom import dropout_add

# Adapted from transformers.models.bloom.modeling_bloom.BloomAttention.forward
def bloom_forward(
self,
hidden_states: torch.Tensor,
residual: torch.Tensor,
alibi: torch.Tensor,
attention_mask: torch.Tensor,
layer_past: Optional[Cache] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
):
raise_on_head_mask(head_mask)

if output_attentions is True:
raise ValueError("output_attentions=True can not be supported with BetterTransformer.")

batch_size, q_length, _ = hidden_states.shape
# [batch_size, seq_length, 3 x hidden_size]
fused_qkv = self.query_key_value(hidden_states)
# 3 x [batch_size, num_heads, seq_length, head_dim]
query_layer, key_layer, value_layer = self._reshape(fused_qkv)

if layer_past is not None:
cache_kwargs = {"cache_position": cache_position}
key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx, cache_kwargs)

alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:])

if attention_mask is not None: # no matter the length, we just slice it
kv_length = cache_position[-1] + 1 # cache position is 0-indexed while length should start from 1
causal_mask = attention_mask[:, :, :, :kv_length]
alibi = torch.masked_fill(alibi, causal_mask.bool(), torch.finfo(alibi.dtype).min)

context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
attn_mask=alibi,
dropout_p=self.dropout_prob_attn if self.training else 0.0,
)

if output_attentions is True:
raise ValueError("output_attentions=True can not be supported with BetterTransformer.")
# Transform [batch_size, num_heads, seq_length, head_dim] to [batch_size, seq_length, num_heads * head_dim]
context_layer = context_layer.transpose(1, 2)
context_layer = context_layer.reshape(batch_size, q_length, self.hidden_size)

# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
if self.pretraining_tp > 1 and self.slow_but_exact:
slices = self.hidden_size / self.pretraining_tp
output_tensor = torch.zeros_like(context_layer)
for i in range(self.pretraining_tp):
output_tensor = output_tensor + F.linear(
context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
)
else:
output_tensor = self.dense(context_layer)

fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training)

# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
outputs = (output_tensor, layer_past)

batch_size, q_length, _, _ = query_layer.shape
return outputs

# Permute to [batch_size, num_heads, seq_length, head_dim]
query_layer = query_layer.transpose(1, 2)
else:
# Adapted from transformers.models.bloom.modeling_bloom.BloomAttention.forward
def bloom_forward(
self,
hidden_states: torch.Tensor,
residual: torch.Tensor,
alibi: torch.Tensor,
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
**kwargs,
):
raise_on_head_mask(head_mask)

if layer_past is not None:
past_key, past_value = layer_past
past_key = past_key.transpose(1, 2)
if output_attentions is True:
raise ValueError("output_attentions=True can not be supported with BetterTransformer.")

key_layer = key_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
# [batch_size, seq_length, 3 x hidden_size]
fused_qkv = self.query_key_value(hidden_states)

# concatenate along seq_length dimension
key_layer = torch.cat((past_key, key_layer), dim=1)
value_layer = torch.cat((past_value, value_layer), dim=1)
# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)

# untangle batch_size from self.num_heads
key_layer = key_layer.reshape(batch_size, self.num_heads, *key_layer.shape[1:])
value_layer = value_layer.reshape(batch_size, self.num_heads, *value_layer.shape[1:])
else:
key_layer = key_layer.transpose(1, 2)
value_layer = value_layer.transpose(1, 2)

alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:])
alibi = torch.masked_fill(alibi, attention_mask, torch.finfo(alibi.dtype).min)

context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
attn_mask=alibi,
dropout_p=self.dropout_prob_attn if self.training else 0.0,
)
batch_size, q_length, _, _ = query_layer.shape

# Transform [batch_size, num_heads, seq_length, head_dim] to [batch_size, seq_length, num_heads * head_dim]
context_layer = context_layer.transpose(1, 2)
context_layer = context_layer.reshape(*context_layer.shape[:2], -1)

# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
if self.pretraining_tp > 1 and self.slow_but_exact:
slices = self.hidden_size / self.pretraining_tp
output_tensor = torch.zeros_like(context_layer)
for i in range(self.pretraining_tp):
output_tensor = output_tensor + torch.nn.functional.linear(
context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
)
else:
output_tensor = self.dense(context_layer)
# Permute to [batch_size, num_heads, seq_length, head_dim]
query_layer = query_layer.transpose(1, 2)

if layer_past is not None:
past_key, past_value = layer_past
past_key = past_key.transpose(1, 2)

output_tensor = torch.nn.functional.dropout(output_tensor, p=self.hidden_dropout, training=self.training)
output_tensor = residual + output_tensor
key_layer = key_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)

if use_cache is True:
present = (
key_layer.reshape(-1, *key_layer.shape[2:]).transpose(1, 2),
value_layer.reshape(-1, *value_layer.shape[2:]),
# concatenate along seq_length dimension
key_layer = torch.cat((past_key, key_layer), dim=1)
value_layer = torch.cat((past_value, value_layer), dim=1)

# untangle batch_size from self.num_heads
key_layer = key_layer.reshape(batch_size, self.num_heads, *key_layer.shape[1:])
value_layer = value_layer.reshape(batch_size, self.num_heads, *value_layer.shape[1:])
else:
key_layer = key_layer.transpose(1, 2)
value_layer = value_layer.transpose(1, 2)

alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:])
alibi = torch.masked_fill(alibi, attention_mask, torch.finfo(alibi.dtype).min)

context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
attn_mask=alibi,
dropout_p=self.dropout_prob_attn if self.training else 0.0,
)
else:
present = None

return (output_tensor, present)
# Transform [batch_size, num_heads, seq_length, head_dim] to [batch_size, seq_length, num_heads * head_dim]
context_layer = context_layer.transpose(1, 2)
context_layer = context_layer.reshape(*context_layer.shape[:2], -1)

# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
if self.pretraining_tp > 1 and self.slow_but_exact:
slices = self.hidden_size / self.pretraining_tp
output_tensor = torch.zeros_like(context_layer)
for i in range(self.pretraining_tp):
output_tensor = output_tensor + torch.nn.functional.linear(
context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
)
else:
output_tensor = self.dense(context_layer)

output_tensor = torch.nn.functional.dropout(output_tensor, p=self.hidden_dropout, training=self.training)
output_tensor = residual + output_tensor

if use_cache is True:
present = (
key_layer.reshape(-1, *key_layer.shape[2:]).transpose(1, 2),
value_layer.reshape(-1, *value_layer.shape[2:]),
)
else:
present = None

return (output_tensor, present)
2 changes: 2 additions & 0 deletions optimum/bettertransformer/models/decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,8 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"):
self.dropout_prob_attn = config.attention_dropout

self.module_mapping = None
self.layer_idx = getattr(layer, "layer_idx", None)

submodules = ["query_key_value", "dense", "attention_dropout"]
for attr in submodules:
setattr(self, attr, getattr(layer, attr))
Expand Down
38 changes: 21 additions & 17 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,27 +338,31 @@ class BloomOnnxConfig(TextDecoderOnnxConfig):
) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES
DUMMY_PKV_GENERATOR_CLASS = BloomDummyPastKeyValuesGenerator
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_layers="n_layer", num_attention_heads="n_head")
DEFAULT_ONNX_OPSET = 14 # Bloom uses aten::triu that requires opset>=14, and F.scaled_dot_product_attention

def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str):
if direction not in ["inputs", "outputs"]:
raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given')

if direction == "inputs":
decoder_sequence_name = "past_sequence_length"
name = "past_key_values"
if check_if_transformers_greater("4.44"):
super().add_past_key_values(inputs_or_outputs, direction)
else:
decoder_sequence_name = "past_sequence_length + 1"
name = "present"
if direction not in ["inputs", "outputs"]:
raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given')

for i in range(self._normalized_config.num_layers):
inputs_or_outputs[f"{name}.{i}.key"] = {
0: "batch_size x num_heads",
2: decoder_sequence_name,
}
inputs_or_outputs[f"{name}.{i}.value"] = {
0: "batch_size x num_heads",
1: decoder_sequence_name,
}
if direction == "inputs":
decoder_sequence_name = "past_sequence_length"
name = "past_key_values"
else:
decoder_sequence_name = "past_sequence_length + 1"
name = "present"

for i in range(self._normalized_config.num_layers):
inputs_or_outputs[f"{name}.{i}.key"] = {
0: "batch_size x num_heads",
2: decoder_sequence_name,
}
inputs_or_outputs[f"{name}.{i}.value"] = {
0: "batch_size x num_heads",
1: decoder_sequence_name,
}


class GPTBigCodeOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
Expand Down
21 changes: 13 additions & 8 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,8 +336,7 @@ def prepare_past_key_values(
dtype = constructor.float16 if self.use_fp16 else constructor.float32

# TODO: find a way to better handle this controlflow, this is EXTREMELY UGLY.
# "1" is the dummy sequence length
if self.model_type == "bloom":
if self.__class__.__name__ == "ORTBloomForCausalLM":
shape_value = (batch_size * num_attention_heads, 0, embed_size_per_head)
shape_key = (batch_size * num_attention_heads, embed_size_per_head, 0)
key = constructor.zeros(shape_key, dtype=dtype)
Expand All @@ -354,9 +353,9 @@ def prepare_past_key_values(
for name, value in zip(self.key_value_output_names, past_key_values):
shape = [*value.shape]
index = 1 if "value" in name else 2

shape[index] += sequence_length
pkv_output_shape[name] = shape

elif self.model_type == "gpt_bigcode":
# GPT BigCode uses muti-query attention, and has the specificity of putting both key and value in the same cache tensor.
shape_key_and_value = (batch_size, 0, embed_size_per_head * 2)
Expand All @@ -371,9 +370,9 @@ def prepare_past_key_values(
shape = [*value.shape]
shape[1] += sequence_length
pkv_output_shape[name] = shape

else:
num_key_value_heads = self.num_key_value_heads if self.model_type == "falcon" else num_attention_heads

shape = (batch_size, num_key_value_heads, 0, embed_size_per_head)
key_or_value = constructor.zeros(shape, dtype=dtype)

Expand Down Expand Up @@ -534,9 +533,9 @@ def _from_pretrained(

# Since https://github.com/huggingface/optimum/pull/871/
# changed axis notation/naming during export, we need to update the dims
for dim in input_dims.keys():
if "past" in dim and input_dims[dim][2] == "past_sequence_length + sequence_length":
input_dims[dim][2] = "past_sequence_length"
for input_name in input_dims.keys():
if "past" in input_name and input_dims[input_name][2] == "past_sequence_length + sequence_length":
input_dims[input_name][2] = "past_sequence_length"
override_dims = True

if override_dims:
Expand All @@ -559,6 +558,12 @@ def _from_pretrained(
size_threshold=0,
)

# Since transformers 4.44, the bloom model has been updated to use the standard cache format
use_old_bloom_modeling = not check_if_transformers_greater("4.44")
for input_name in input_dims.keys():
if input_dims[input_name][0] == "batch_size x num_heads":
use_old_bloom_modeling = True

del onnx_model

model = ORTModel.load_model(
Expand All @@ -568,7 +573,7 @@ def _from_pretrained(
provider_options=provider_options,
)

if config.model_type == "bloom":
if config.model_type == "bloom" and use_old_bloom_modeling:
init_cls = ORTBloomForCausalLM
elif config.model_type == "falcon":
init_cls = ORTFalconForCausalLM
Expand Down
Loading
Loading