Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ONNX and ORT support for Falcon #1391

Merged
merged 7 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/exporters/onnx/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ Supported architectures:
- Donut-Swin
- Electra
- Encoder Decoder
- Falcon
- Flaubert
- GPT-2
- GPT-BigCode
Expand Down
69 changes: 65 additions & 4 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
DummyVisionEmbeddingsGenerator,
DummyVisionInputGenerator,
GPTBigCodeDummyPastKeyValuesGenerator,
MultiQueryPastKeyValuesGenerator,
NormalizedConfig,
NormalizedEncoderDecoderConfig,
NormalizedSeq2SeqConfig,
Expand All @@ -54,7 +55,7 @@
TextSeq2SeqOnnxConfig,
VisionOnnxConfig,
)
from .model_patcher import SAMModelPatcher, WavLMModelPatcher
from .model_patcher import FalconModelPatcher, SAMModelPatcher, WavLMModelPatcher


if TYPE_CHECKING:
Expand Down Expand Up @@ -234,9 +235,6 @@ class BloomOnnxConfig(TextDecoderOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_layers="n_layer", num_attention_heads="n_head")

def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str):
"""
Refer to OnnxConfigWithPast in base.py
"""
if direction not in ["inputs", "outputs"]:
raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given')

Expand Down Expand Up @@ -287,6 +285,69 @@ def flatten_past_key_values(self, flattened_output, name, idx, t):
flattened_output[f"{name}.{idx}.key_value"] = t


class FalconOnnxConfig(TextDecoderOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (
MultiQueryPastKeyValuesGenerator,
) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES
DEFAULT_ONNX_OPSET = 14 # Falcon uses aten::triu that requires opset>=14
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
DUMMY_PKV_GENERATOR_CLASS = MultiQueryPastKeyValuesGenerator

def __init__(
self,
config: "PretrainedConfig",
task: str = "feature-extraction",
int_dtype: str = "int64",
float_dtype: str = "fp32",
use_past: bool = False,
use_past_in_inputs: bool = False,
preprocessors: Optional[List[Any]] = None,
):
super().__init__(
config=config,
task=task,
int_dtype=int_dtype,
float_dtype=float_dtype,
use_past=use_past,
use_past_in_inputs=use_past_in_inputs,
preprocessors=preprocessors,
)
# For some reason Falcon config.num_kv_heads can not be trusted, see modeling_falcon.py in transformers
self._normalized_config.num_kv_heads = (
self._normalized_config.num_kv_heads
if (self._normalized_config.new_decoder_architecture or not self._normalized_config.multi_query)
else 1
)

# we need to set output_attentions=True in the model input to avoid calling
# torch.nn.functional.scaled_dot_product_attention that is not supported by the ONNX export
Comment on lines +391 to +392
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I would move this comment inside the method.

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return FalconModelPatcher(self, model, model_kwargs=model_kwargs)

def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str):
if direction not in ["inputs", "outputs"]:
raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given')

if direction == "inputs":
decoder_sequence_name = "past_sequence_length"
name = "past_key_values"
else:
decoder_sequence_name = "past_sequence_length + 1"
name = "present"

for i in range(self._normalized_config.num_layers):
inputs_or_outputs[f"{name}.{i}.key"] = {
0: "batch_size x num_heads",
1: decoder_sequence_name,
}
inputs_or_outputs[f"{name}.{i}.value"] = {
0: "batch_size x num_heads",
1: decoder_sequence_name,
}


class T5DummySeq2SeqPastKeyValuesGenerator(DummySeq2SeqPastKeyValuesGenerator):
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
encoder_shape = (
Expand Down
201 changes: 200 additions & 1 deletion optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@
import dataclasses
import functools
import inspect
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union

import transformers
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
from transformers.models.falcon.modeling_falcon import build_alibi_tensor
from transformers.utils import is_torch_available


Expand Down Expand Up @@ -223,6 +226,202 @@ def patched_forward(*args, **kwargs):
self.patched_forward = patched_forward


def _make_causal_mask_falcon_patched(
input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
) -> torch.BoolTensor:
"""
Make causal mask used for self-attention. This mask does not take the existing attention mask into account - it
just blocks tokens from attending forwards in the sequence. The output shape will be `[batch_size, 1,
target_length, target_length+past_key_values_length]`.
"""
batch_size, target_length = input_ids_shape

# NOTE: ONNX Runtime is not able to run ONNX Trilu node with bool input. As a workaround, we pass a float input
# and cast to bool here. Reference: https://github.com/microsoft/onnxruntime/issues/16189
mask = torch.triu(torch.ones((target_length, target_length), dtype=torch.float, device=device), diagonal=1).to(
torch.bool
)

# If past_key_values_length is 0 this is an empty tensor and the concatenation is a no-op.
# This code style is an unfortunate consequence of getting your TF engineer to port models; doing it this
# way avoids a data-dependent conditional, which will help me when I have to port this to XLA later.
past_mask = torch.zeros((target_length, past_key_values_length), dtype=torch.bool, device=device)
mask = torch.cat([past_mask, mask], dim=-1)
expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
return expanded_mask


def falcon_model_forward_without_kv_reformatting(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")

if past_key_values is None:
past_key_values = tuple([None] * len(self.h))

# NOTE: here we removed the _convert_to_rw_cache call

# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape batch_size x num_heads x N x N
# head_mask has shape n_layer x batch x num_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)

hidden_states = inputs_embeds

presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None

# Compute alibi tensor: check build_alibi_tensor documentation
past_key_values_length = 0
if past_key_values[0] is not None:
past_key_values_length = past_key_values[0][0].shape[1] # 1 because RW-cache, not standard format
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=hidden_states.device)
else:
attention_mask = attention_mask.to(hidden_states.device)

if self.use_alibi:
alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
else:
alibi = None
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
# NOTE: here we use expand(batch_size, -1) instead of transformers view(-1, seq_length) that is bugged
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
else:
position_ids = position_ids.view(-1, seq_length).long()

causal_mask = self._prepare_attn_mask(
attention_mask,
input_shape=(batch_size, seq_length),
past_key_values_length=past_key_values_length,
)

for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=causal_mask,
position_ids=position_ids,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
alibi=alibi,
)

hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)

if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)

# Add last hidden state
hidden_states = self.ln_f(hidden_states)

if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

# NOTE: here we removed the _convert_cache_to_standard_format call

if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)

return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)


class FalconModelPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(config, model, model_kwargs)

# This is kind of ugly and bug prone if other FalconModel are instantiated.
if config.task == "text-generation":
model.transformer.__class__.forward = falcon_model_forward_without_kv_reformatting
fxmarty marked this conversation as resolved.
Show resolved Hide resolved

self.original_make_causal = transformers.models.falcon.modeling_falcon._make_causal_mask
transformers.models.falcon.modeling_falcon._make_causal_mask = _make_causal_mask_falcon_patched
fxmarty marked this conversation as resolved.
Show resolved Hide resolved
self._model = model

self.orig_forward_name = "forward" if hasattr(self._model, "forward") else "call"
self.orig_forward = getattr(self._model, self.orig_forward_name)

allow_past_in_outputs = hasattr(self.real_config, "use_past") and self.real_config.use_past

@functools.wraps(self.orig_forward)
def patched_forward(*args, **kwargs):
model_kwargs = self.model_kwargs
# setting output_attentions=True in the model input to avoid calling torch.nn.functional.scaled_dot_product_attention
# in https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/falcon/modeling_falcon.py#L425
model_kwargs["output_attentions"] = True
signature = inspect.signature(self.orig_forward)
args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=model_kwargs)

outputs = self.orig_forward(*args, **kwargs)

filterd_outputs = {}
for name, value in outputs.items():
onnx_output_name = config.torch_to_onnx_output_map.get(name, name)
if (
onnx_output_name in config.outputs
or (allow_past_in_outputs and name.startswith("past_key_values"))
or any(key.startswith(onnx_output_name) for key in config.outputs.keys())
):
filterd_outputs[name] = value
return filterd_outputs

self.patched_forward = patched_forward

def __exit__(self, exc_type, exc_value, traceback):
self.restore_ops()
transformers.models.falcon.modeling_falcon._make_causal_mask = self.original_make_causal
setattr(self._model, self.orig_forward_name, self.orig_forward)


class WavLMModelPatcher(ModelPatcher):
def __init__(
self,
Expand Down
9 changes: 9 additions & 0 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,15 @@ class TasksManager:
"text2text-generation-with-past",
onnx="EncoderDecoderOnnxConfig",
),
"falcon": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
"question-answering",
"text-generation",
"text-generation-with-past",
"token-classification",
onnx="FalconOnnxConfig",
),
"flaubert": supported_tasks_mapping(
"feature-extraction",
"fill-mask",
Expand Down
Loading