Skip to content

Commit

Permalink
more fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
fxmarty committed Nov 2, 2023
1 parent 6c53370 commit 967b8c9
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 10 deletions.
6 changes: 3 additions & 3 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,14 +667,14 @@ def overwrite_shape_and_generate_input(
# models from TextSeq2SeqOnnxConfig use decoder_input_ids as input name
# while models from TextDecoderOnnxConfig use input_ids, hence the check for both

# TODO: The check `self.task != "text-generation" and not self.legacy` is added following the use of a single ONNX for both without/with KV cache, without subgraphs.
# TODO: The check `self.task != "text-generation" and self.legacy` is added following the use of a single ONNX for both without/with KV cache, without subgraphs.
# This overwrite may be moved to OnnxSeq2SeqConfigWithPast, but I am afraid it would break encoder-decoder models.
if (
self.use_past
and self.use_past_in_inputs
and self.use_cache_branch is not False
and input_name in ["decoder_input_ids", "input_ids", "position_ids"]
and ((self.task == "text-generation" and not self.legacy) or self.task != "text-generation")
and ((self.task == "text-generation" and self.legacy) or self.task != "text-generation")
):
sequence_length = dummy_input_gen.sequence_length
# Use a sequence length of 1 when the KV cache is already populated.
Expand Down Expand Up @@ -830,7 +830,7 @@ def with_behavior(
use_past_in_inputs=use_past_in_inputs,
behavior=behavior,
preprocessors=self._preprocessors,
legacy=legacy,
legacy=self.legacy,
)
onnx_config.variant = self.variant
return onnx_config
Expand Down
9 changes: 7 additions & 2 deletions optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,8 +410,11 @@ def __enter__(self):

def __exit__(self, exc_type, exc_value, traceback):
# TODO: Remove this if once transformers if much above 4.35
if AttentionMaskConverter is not None:
AttentionMaskConverter._make_causal_mask = self.original_make_causal
# TODO: We should unpatch it - however `self._make_causal_mask` may still be called later which raises issues with this simple patch strategy.
# We need to find a proper solution.
# if AttentionMaskConverter is not None:
# AttentionMaskConverter._make_causal_mask = self.original_make_causal
pass

def __init__(
self,
Expand All @@ -428,6 +431,7 @@ def __init__(

class FalconModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
self.patch_ops()

if self.real_config.task == "text-generation":
Expand All @@ -436,6 +440,7 @@ def __enter__(self):
)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self.restore_ops()

setattr(self._model, self.orig_forward_name, self.orig_forward)
Expand Down
15 changes: 10 additions & 5 deletions optimum/exporters/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,12 @@ def get_decoder_models_for_export(

models_for_export = _get_submodels_for_export_decoder(model, use_past=config.use_past, legacy=legacy)

onnx_kwargs = {"task": config.task, "float_dtype": config.float_dtype, "int_dtype": config.int_dtype, "legacy": legacy}
onnx_kwargs = {
"task": config.task,
"float_dtype": config.float_dtype,
"int_dtype": config.int_dtype,
"legacy": legacy,
}

if legacy:
onnx_config = config.__class__(
Expand Down Expand Up @@ -386,14 +391,14 @@ def get_sam_models_for_export(model: Union["PreTrainedModel", "TFPreTrainedModel
models_for_export = _get_submodels_for_export_sam(model, config.variant)

if config.variant == "monolith":
onnx_config = config.__class__(model.config, task=config.task, legacy=legacy)
onnx_config = config.__class__(model.config, task=config.task, legacy=config.legacy)
models_for_export["model"] = (models_for_export["model"], onnx_config)
else:
vision_encoder_onnx_config = config.__class__(
model.config, task=config.task, variant=config.variant, vision_encoder=True, legacy=legacy
model.config, task=config.task, variant=config.variant, vision_encoder=True, legacy=config.legacy
)
prompt_encoder_mask_decoder_onnx_config = config.__class__(
model.config, task=config.task, variant=config.variant, vision_encoder=False, legacy=legacy
model.config, task=config.task, variant=config.variant, vision_encoder=False, legacy=config.legacy
)
models_for_export["vision_encoder"] = (models_for_export["vision_encoder"], vision_encoder_onnx_config)
models_for_export["prompt_encoder_mask_decoder"] = (
Expand Down Expand Up @@ -451,7 +456,7 @@ def get_speecht5_models_for_export(
behavior=config._behavior, # Irrelevant here.
preprocessors=config._preprocessors,
is_postnet_and_vocoder=True,
legacy=legacy,
legacy=config.legacy,
)
postnet_and_vocoder_onnx_config.variant = config.variant
models_for_export["decoder_postnet_and_vocoder"] = (
Expand Down

0 comments on commit 967b8c9

Please sign in to comment.