Skip to content

Commit

Permalink
hopefully working ort inference
Browse files Browse the repository at this point in the history
  • Loading branch information
fxmarty committed Oct 17, 2023
1 parent 244a985 commit 14e2ad8
Show file tree
Hide file tree
Showing 8 changed files with 114 additions and 22 deletions.
2 changes: 1 addition & 1 deletion optimum/exporters/onnx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
# Decoders based on GPT2 require a position_ids input to avoid
# generating wrong position_ids in the model itself:
# https://github.com/huggingface/transformers/blob/v4.33.1/src/transformers/models/gpt2/modeling_gpt2.py#L802
if not self.no_position_ids and self.task == "text-generation":
if not self.no_position_ids and self.task in ["text-generation", "feature-extraction"]:
common_inputs["position_ids"] = {0: "batch_size", 1: "sequence_length"}

return common_inputs
Expand Down
20 changes: 19 additions & 1 deletion optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ def __init__(
use_past: bool = False,
use_past_in_inputs: bool = False,
preprocessors: Optional[List[Any]] = None,
no_position_ids: bool = False,
):
super().__init__(
config=config,
Expand All @@ -362,14 +363,31 @@ def __init__(
use_past=use_past,
use_past_in_inputs=use_past_in_inputs,
preprocessors=preprocessors,
no_position_ids=no_position_ids,
)
# For some reason Falcon config.num_kv_heads can not be trusted, see modeling_falcon.py in transformers
# For some reason Falcon config.num_kv_heads can not be trusted, see in Transformers:
# https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/models/falcon/modeling_falcon.py#L337
self._normalized_config.num_kv_heads = (
self._normalized_config.num_kv_heads
if (self._normalized_config.new_decoder_architecture or not self._normalized_config.multi_query)
else 1
)

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
common_inputs = super().inputs

if (
not self.no_position_ids
and not self._config.alibi
and self.task in ["text-generation", "feature-extraction"]
):
# When alibi is used, position_ids are not used in Falcon.
# Reference: https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/models/falcon/modeling_falcon.py#L1116
common_inputs["position_ids"] = {0: "batch_size", 1: "sequence_length"}

return common_inputs

# we need to set output_attentions=True in the model input to avoid calling
# torch.nn.functional.scaled_dot_product_attention that is not supported by the ONNX export
def patch_model_for_export(
Expand Down
11 changes: 10 additions & 1 deletion optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@

import transformers
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
from transformers.models.falcon.modeling_falcon import build_alibi_tensor
from transformers.models.falcon.modeling_falcon import FalconModel, build_alibi_tensor
from transformers.utils import is_torch_available

from ...utils.modeling_utils import (
_falcon_prepare_attn_mask,
_prepare_attn_mask,
_prepare_decoder_attention_mask,
_prepare_decoder_sliding_window_attention_mask,
Expand Down Expand Up @@ -390,7 +391,15 @@ def __init__(
model.transformer.__class__.forward = falcon_model_forward_without_kv_reformatting

self.original_make_causal = transformers.models.falcon.modeling_falcon._make_causal_mask

transformers.models.falcon.modeling_falcon._make_causal_mask = _make_causal_mask_falcon_patched

# In order to use a single decoder, we need to patch the _prepare_attn_mask function to behave independently of the sequence length.
if isinstance(model, FalconModel):
model._prepare_attn_mask = _falcon_prepare_attn_mask
else:
model.transformer._prepare_attn_mask = _falcon_prepare_attn_mask

self._model = model

self.orig_forward_name = "forward" if hasattr(self._model, "forward") else "call"
Expand Down
1 change: 1 addition & 0 deletions optimum/exporters/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@

MODEL_TYPES_REQUIRING_POSITION_IDS = {
"codegen",
"falcon",
"gpt2",
"gpt-bigcode",
"gpt-neo",
Expand Down
31 changes: 19 additions & 12 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,6 @@ def forward(

if "loss" in self.output_names:
loss = output_buffers["loss"].view(output_shapes["loss"])

else:
inputs["input_ids"] = input_ids.cpu().detach().numpy() if use_torch else input_ids

Expand Down Expand Up @@ -337,8 +336,9 @@ def prepare_past_key_values(
else:
num_attention_heads = self.normalized_config.num_attention_heads
embed_size_per_head = self.normalized_config.hidden_size // self.normalized_config.num_attention_heads

dtype = constructor.float16 if self.use_fp16 else constructor.float32
# TODO: find a way to better handle this controlflow
# TODO: find a way to better handle this controlflow, this is EXTREMELY UGLY.
# "1" is the dummy sequence length
if self.config.model_type == "bloom":
shape_value = (batch_size * num_attention_heads, 0, embed_size_per_head)
Expand All @@ -353,14 +353,23 @@ def prepare_past_key_values(
past_key_values = tuple(
key_or_value for _ in range(len(self.key_value_input_names) // 2) for key_or_value in [key, value]
)
elif self.config.model_type in MULTI_QUERY_ATTN_MODELS:
elif self.config.model_type == "gpt_bigcode":
# GPT BigCode uses muti-query attention, and has the specificity of putting both key and value in the same cache tensor.
shape_key_and_value = (batch_size, 0, embed_size_per_head * 2)
key_and_value = constructor.zeros(shape_key_and_value, dtype=dtype)

if use_torch:
key_and_value = key_and_value.to(self.device)

past_key_values = tuple(key_and_value for _ in range(len(self.key_value_input_names)))
elif self.config.model_type == "falcon":
shape = (batch_size * self.num_key_value_heads, 0, embed_size_per_head)
key_or_value = constructor.zeros(shape, dtype=dtype)

if use_torch:
key_or_value = key_or_value.to(self.device)

past_key_values = tuple(key_or_value for _ in range(len(self.key_value_input_names)))
else:
shape = (batch_size, num_attention_heads, 0, embed_size_per_head)
key_or_value = constructor.zeros(shape, dtype=dtype)
Expand Down Expand Up @@ -729,30 +738,28 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
class ORTFalconForCausalLM(ORTModelForCausalLM):
def __init__(
self,
decoder_session: onnxruntime.InferenceSession,
model: onnxruntime.InferenceSession,
config: "PretrainedConfig",
onnx_paths: List[str],
decoder_with_past_session: Optional[onnxruntime.InferenceSession] = None,
use_cache: bool = True,
use_io_binding: Optional[bool] = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
preprocessors: Optional[List] = None,
generation_config: Optional[GenerationConfig] = None,
use_cache: Optional[bool] = None,
**kwargs,
):
super().__init__(
decoder_session=decoder_session,
model=model,
config=config,
onnx_paths=onnx_paths,
decoder_with_past_session=decoder_with_past_session,
use_cache=use_cache,
use_io_binding=use_io_binding,
model_save_dir=model_save_dir,
preprocessors=preprocessors,
generation_config=generation_config,
use_cache=use_cache,
**kwargs,
)
# self.num_kv_heads = config.num_kv_heads if (config.new_decoder_architecture or not config.multi_query) else 1
self.num_key_value_heads = (
config.num_kv_heads if (config.new_decoder_architecture or not config.multi_query) else 1
)

# Copied from https://github.com/huggingface/transformers/pull/26199
def _reorder_cache(
Expand Down
38 changes: 38 additions & 0 deletions optimum/utils/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,41 @@ def _prepare_decoder_sliding_window_attention_mask(
)

return combined_attention_mask


def _falcon_prepare_attn_mask(
attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
) -> torch.BoolTensor:
from transformers.models.falcon.modeling_falcon import (
_expand_mask,
)

# NOTE: there is no "copied from" for falcon in transformers which makes no sense to me.

# Create a causal mask
# The attention mask we receive as input should cover the whole extended sequence, including any past
# cache, so its shape should be [batch_size, seq_length + past_key_values_length]
# The output shape will be [batch_size, 1, seq_length, seq_length + past_key_values_length]
if input_shape[1] + past_key_values_length != attention_mask.shape[1]:
raise ValueError(
"Attention mask shape should be (batch_size, seq_length + past_key_values_length)"
f" but is {attention_mask.shape} with input_ids shape {input_shape} and past length"
f" {past_key_values_length}."
)
combined_attention_mask = None
device = attention_mask.device
_, seq_length = input_shape

# if seq_length > 1:
# NOTE: we remove here the `if seq_length > 1` to allow to use a single decoder.
combined_attention_mask = _make_causal_mask(
input_shape, device=device, past_key_values_length=past_key_values_length
)

# [batch_size, seq_length + past_key_values_length] -> [batch_size, 1, seq_length, seq_length + past_key_values_length]
expanded_attn_mask = _expand_mask(attention_mask, past_key_values_length=past_key_values_length)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
)

return combined_attention_mask
1 change: 1 addition & 0 deletions optimum/utils/normalized_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,5 +268,6 @@ def check_supported_model(cls, model_type: str):

@classmethod
def get_normalized_config_class(cls, model_type: str) -> Type:
model_type = model_type.replace("_", "-")
cls.check_supported_model(model_type)
return cls._conf[model_type]
32 changes: 25 additions & 7 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2057,14 +2057,17 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach
position_ids = torch.arange(0, input_shape[-1], dtype=torch.long).unsqueeze(0).view(-1, input_shape[-1])
onnx_outputs = onnx_model(**tokens, position_ids=position_ids)

self.assertTrue("logits" in onnx_outputs)
self.assertIsInstance(onnx_outputs.logits, torch.Tensor)

with torch.no_grad():
transformers_outputs = transformers_model(**tokens)

self.assertTrue("logits" in onnx_outputs)
self.assertIsInstance(onnx_outputs.logits, torch.Tensor)

# compare tensor outputs
self.assertTrue(torch.allclose(onnx_outputs.logits, transformers_outputs.logits, atol=1e-4))
self.assertTrue(
torch.allclose(onnx_outputs.logits, transformers_outputs.logits, atol=1e-4),
f"Maxdiff: {(onnx_outputs.logits - transformers_outputs.logits).abs()}",
)

# Compare batched generation.
tokenizer.pad_token_id = tokenizer.eos_token_id
Expand All @@ -2075,11 +2078,25 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach
onnx_model.config.eos_token_id = None
transformers_model.config.eos_token_id = None

new_tokens = 30
if model_arch == "falcon":
# TODO: remove once https://github.com/huggingface/transformers/pull/26873 is released, falcon is broken in transformers
new_tokens = 5
onnx_outputs = onnx_model.generate(
**tokens, num_beams=1, do_sample=False, min_new_tokens=30, max_new_tokens=30, eos_token_id=None
**tokens,
num_beams=1,
do_sample=False,
min_new_tokens=new_tokens,
max_new_tokens=new_tokens,
eos_token_id=None,
)
transformers_outputs = transformers_model.generate(
**tokens, num_beams=1, do_sample=False, min_new_tokens=30, max_new_tokens=30, eos_token_id=None
**tokens,
num_beams=1,
do_sample=False,
min_new_tokens=new_tokens,
max_new_tokens=new_tokens,
eos_token_id=None,
)

self.assertTrue(torch.allclose(onnx_outputs, transformers_outputs))
Expand Down Expand Up @@ -2257,12 +2274,13 @@ def test_compare_merged_and_not_merged_models_outputs(self, test_name: str, mode
text = "My Name is Philipp and i live"
tokens = tokenizer(text, return_tensors="pt", return_token_type_ids=False if model_arch == "llama" else None)
model_not_merged_dir = self.onnx_model_dirs[test_name + "_False"]

model_not_merged = ORTModelForCausalLM.from_pretrained(model_not_merged_dir)
not_merged_onnx_path = Path(model_not_merged_dir, ONNX_WEIGHTS_NAME)
self.assertFalse(has_onnx_input(not_merged_onnx_path, "use_cache_branch"))
self.assertFalse(model_not_merged.use_merged)

model_merged_dir = Path(model_not_merged_dir) / "merged"
model_merged_dir = Path(Path(model_not_merged_dir).parents[0], "merged")
task = model_not_merged.export_feature
if use_cache:
task += "-with-past"
Expand Down

0 comments on commit 14e2ad8

Please sign in to comment.