Skip to content

Commit

Permalink
nit
Browse files Browse the repository at this point in the history
  • Loading branch information
fxmarty committed Sep 21, 2023
1 parent be26f71 commit 02259a8
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 9 deletions.
10 changes: 5 additions & 5 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1194,10 +1194,10 @@ class SpeechT5OnnxConfig(OnnxSeq2SeqConfigWithPast):
# TODO: DO NOT CUT OUTPUT_SEQUENCE LENGTH WITH PAST!!!!!

VARIANTS = {
"transformers-like": "The following components are exported following Transformers implementation:\n\t - encoder_model.onnx: corresponds to the encoding part in https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/speecht5/modeling_speecht5.py#L2544-L2556.\n\t - decoder_model.onnx: corresponds to the decoder part in https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/speecht5/modeling_speecht5.py#L2572-L2602.\n\t - decoder_with_past_model.onnx: same as the above, with past_key_values input (KV cache filled).\n\t - decoder_postnet_and_vocoder.onnx: Decoder speech postnet and vocoder (e.g. a SpeechT5HifiGan) to generate speech from the spectrogram, as in https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/speecht5/modeling_speecht5.py#L2605-L2614.",
"without-cache": "The same as `transformers-like`, without KV cache support. This is not a recommende export as slower than `transformers-like`.",
"with-past": "The export follows the Transformers implementation using the KV cache, with the following components exported:\n\t - encoder_model.onnx: corresponds to the encoding part in https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/speecht5/modeling_speecht5.py#L2544-L2556.\n\t - decoder_model.onnx: corresponds to the decoder part in https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/speecht5/modeling_speecht5.py#L2572-L2602.\n\t - decoder_with_past_model.onnx: same as the above, with past_key_values input (KV cache filled).\n\t - decoder_postnet_and_vocoder.onnx: Decoder speech postnet and vocoder (e.g. a SpeechT5HifiGan) to generate speech from the spectrogram, as in https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/speecht5/modeling_speecht5.py#L2605-L2614.",
"without-past": "The same as `with-past`, just without KV cache support. This is not a recommended export as slower than `with-past`.",
}
DEFAULT_VARIANT = "transformers-like"
DEFAULT_VARIANT = "with-past"

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
Expand All @@ -1214,7 +1214,7 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
common_inputs["encoder_hidden_states"] = {1: "encoder_sequence_length"}
common_inputs["encoder_attention_mask"] = {1: "encoder_sequence_length"}

if self.variant == "transformers-like" and self.use_past_in_inputs:
if self.variant == "with-past" and self.use_past_in_inputs:
# TODO: check PKV shape
self.add_past_key_values(common_inputs, direction="inputs")
elif self.is_postnet_and_vocoder:
Expand All @@ -1237,7 +1237,7 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
common_outputs["prob"] = {} # No dynamic shape here.
common_outputs["spectrum"] = {} # No dynamic shape here.

if self.variant == "transformers-like" and self.use_past:
if self.variant == "with-past" and self.use_past:
# When exporting decoder models with use_cache=True, both the decoder without past and with past have the KV cache as an output.
self.add_past_key_values(common_outputs, direction="outputs")
elif self.is_postnet_and_vocoder:
Expand Down
3 changes: 2 additions & 1 deletion optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def patched_forward(
output_sequence=None,
spectrogram=None,
):
use_cache = self.real_config.use_past and self.real_config.variant == "transformers-like"
use_cache = self.real_config.use_past and self.real_config.variant == "with-past"
if self.real_config._behavior == "encoder":
encoder_attention_mask = torch.ones_like(input_ids)

Expand Down Expand Up @@ -432,6 +432,7 @@ def patched_forward(
# TODO: PKV here
}
elif self.real_config.is_postnet_and_vocoder:
# NOTE: the following concatenation is expected to be handled outside of the ONNX:
# spectrogram = torch.cat(spectrogram, dim=0).unsqueeze(0)
spectrogram = spectrogram.unsqueeze(0)
spectrogram = model.speech_decoder_postnet.postnet(spectrogram)
Expand Down
6 changes: 3 additions & 3 deletions optimum/exporters/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ def get_speecht5_models_for_export(
models_for_export["encoder_model"] = model
models_for_export["decoder_model"] = model

if config.variant == "transformers-like":
if config.variant == "with-past":
models_for_export["decoder_with_past_model"] = model

vocoder = SpeechT5HifiGan.from_pretrained(model_kwargs["vocoder"])
Expand All @@ -413,12 +413,12 @@ def get_speecht5_models_for_export(

encoder_onnx_config = config.with_behavior("encoder")

use_past = config.variant == "transformers-like"
use_past = config.variant == "with-past"
decoder_onnx_config = config.with_behavior("decoder", use_past=use_past, use_past_in_inputs=False)

models_for_export[ONNX_ENCODER_NAME] = (models_for_export[ONNX_ENCODER_NAME], encoder_onnx_config)
models_for_export[ONNX_DECODER_NAME] = (models_for_export[ONNX_DECODER_NAME], decoder_onnx_config)
if config.variant == "transformers-like":
if config.variant == "with-past":
decoder_onnx_config_with_past = config.with_behavior("decoder", use_past=True, use_past_in_inputs=True)
models_for_export[ONNX_DECODER_WITH_PAST_NAME] = (
models_for_export[ONNX_DECODER_WITH_PAST_NAME],
Expand Down

0 comments on commit 02259a8

Please sign in to comment.