Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Oct 9, 2023
1 parent 52e0c69 commit 8883323
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 19 deletions.
47 changes: 29 additions & 18 deletions optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,10 +358,7 @@ def __init__(
model_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(config, model, model_kwargs)

self.patch = self.real_config.task == "text-generation" and self.real_config.use_past
if self.patch:
self._orig_func = getattr(self._model_to_patch, self._orig_func_name)

def __enter__(self):
super().__enter__()
Expand All @@ -381,10 +378,12 @@ def __init__(
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
self._model_to_patch = model.transformer
self._patch_func = _prepare_attn_mask
self._orig_func_name = "_prepare_attn_mask"
super().__init__(config, model, model_kwargs)
if self.patch:
self._model_to_patch = model.transformer
self._patch_func = _prepare_attn_mask
self._orig_func_name = "_prepare_attn_mask"
self._orig_func = self._model_to_patch._prepare_attn_mask


class OPTModelPatcher(CausalAttentionMaskModelPatcher):
Expand All @@ -394,11 +393,14 @@ def __init__(
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
self._model_to_patch = model.model.decoder
self._patch_func = _prepare_decoder_attention_mask
self._orig_func_name = "_prepare_decoder_attention_mask"
super().__init__(config, model, model_kwargs)

if self.patch:
self._model_to_patch = model.model.decoder
self._patch_func = _prepare_decoder_attention_mask
self._orig_func_name = "_prepare_decoder_attention_mask"
self._orig_func = self._model_to_patch._prepare_decoder_attention_mask


class LlamaModelPatcher(CausalAttentionMaskModelPatcher):
def __init__(
Expand All @@ -407,11 +409,14 @@ def __init__(
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
self._model_to_patch = model.model
self._patch_func = _prepare_decoder_attention_mask
self._orig_func_name = "_prepare_decoder_attention_mask"
super().__init__(config, model, model_kwargs)

if self.patch:
self._model_to_patch = model.model
self._patch_func = _prepare_decoder_attention_mask
self._orig_func_name = "_prepare_decoder_attention_mask"
self._orig_func = self._model_to_patch._prepare_decoder_attention_mask


class MistralModelPatcher(CausalAttentionMaskModelPatcher):
def __init__(
Expand All @@ -420,11 +425,14 @@ def __init__(
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
self._model_to_patch = model.model
self._patch_func = _prepare_decoder_sliding_window_attention_mask
self._orig_func_name = "_prepare_decoder_attention_mask"
super().__init__(config, model, model_kwargs)

if self.patch:
self._model_to_patch = model.model
self._patch_func = _prepare_decoder_sliding_window_attention_mask
self._orig_func_name = "_prepare_decoder_attention_mask"
self._orig_func = self._model_to_patch._prepare_decoder_attention_mask


class BartModelPatcher(CausalAttentionMaskModelPatcher, Seq2SeqModelPatcher):
def __init__(
Expand All @@ -433,7 +441,10 @@ def __init__(
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
self._model_to_patch = model.model.decoder
self._patch_func = _prepare_decoder_attention_mask
self._orig_func_name = "_prepare_decoder_attention_mask"
super().__init__(config, model, model_kwargs)

if self.patch:
self._model_to_patch = model.model.decoder
self._patch_func = _prepare_decoder_attention_mask
self._orig_func_name = "_prepare_decoder_attention_mask"
self._orig_func = self._model_to_patch._prepare_decoder_attention_mask
2 changes: 1 addition & 1 deletion optimum/utils/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def _prepare_decoder_sliding_window_attention_mask(
past_key_values_length: int,
sliding_window: int,
):
from transformers.models.mistral.modeling_mistral import _make_sliding_window_causal_mask, _expand_mask
from transformers.models.mistral.modeling_mistral import _expand_mask, _make_sliding_window_causal_mask

# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
Expand Down

0 comments on commit 8883323

Please sign in to comment.