Skip to content

Commit

Permalink
fix legacy
Browse files Browse the repository at this point in the history
  • Loading branch information
fxmarty committed Nov 2, 2023
1 parent 3f11d1a commit 6c53370
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
1 change: 1 addition & 0 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,6 +830,7 @@ def with_behavior(
use_past_in_inputs=use_past_in_inputs,
behavior=behavior,
preprocessors=self._preprocessors,
legacy=legacy,
)
onnx_config.variant = self.variant
return onnx_config
Expand Down
9 changes: 5 additions & 4 deletions optimum/exporters/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def get_decoder_models_for_export(

models_for_export = _get_submodels_for_export_decoder(model, use_past=config.use_past, legacy=legacy)

onnx_kwargs = {"task": config.task, "float_dtype": config.float_dtype, "int_dtype": config.int_dtype}
onnx_kwargs = {"task": config.task, "float_dtype": config.float_dtype, "int_dtype": config.int_dtype, "legacy": legacy}

if legacy:
onnx_config = config.__class__(
Expand Down Expand Up @@ -386,14 +386,14 @@ def get_sam_models_for_export(model: Union["PreTrainedModel", "TFPreTrainedModel
models_for_export = _get_submodels_for_export_sam(model, config.variant)

if config.variant == "monolith":
onnx_config = config.__class__(model.config, task=config.task)
onnx_config = config.__class__(model.config, task=config.task, legacy=legacy)
models_for_export["model"] = (models_for_export["model"], onnx_config)
else:
vision_encoder_onnx_config = config.__class__(
model.config, task=config.task, variant=config.variant, vision_encoder=True
model.config, task=config.task, variant=config.variant, vision_encoder=True, legacy=legacy
)
prompt_encoder_mask_decoder_onnx_config = config.__class__(
model.config, task=config.task, variant=config.variant, vision_encoder=False
model.config, task=config.task, variant=config.variant, vision_encoder=False, legacy=legacy
)
models_for_export["vision_encoder"] = (models_for_export["vision_encoder"], vision_encoder_onnx_config)
models_for_export["prompt_encoder_mask_decoder"] = (
Expand Down Expand Up @@ -451,6 +451,7 @@ def get_speecht5_models_for_export(
behavior=config._behavior, # Irrelevant here.
preprocessors=config._preprocessors,
is_postnet_and_vocoder=True,
legacy=legacy,
)
postnet_and_vocoder_onnx_config.variant = config.variant
models_for_export["decoder_postnet_and_vocoder"] = (
Expand Down

0 comments on commit 6c53370

Please sign in to comment.