From 02259a82eb5972559bfee58101dfa554609eff7f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?F=C3=A9lix=20Marty?= <9808326+fxmarty@users.noreply.github.com> Date: Thu, 21 Sep 2023 14:52:34 +0200 Subject: [PATCH] nit --- optimum/exporters/onnx/model_configs.py | 10 +++++----- optimum/exporters/onnx/model_patcher.py | 3 ++- optimum/exporters/onnx/utils.py | 6 +++--- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index febb4c40073..51d5823774f 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -1194,10 +1194,10 @@ class SpeechT5OnnxConfig(OnnxSeq2SeqConfigWithPast): # TODO: DO NOT CUT OUTPUT_SEQUENCE LENGTH WITH PAST!!!!! VARIANTS = { - "transformers-like": "The following components are exported following Transformers implementation:\n\t - encoder_model.onnx: corresponds to the encoding part in https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/speecht5/modeling_speecht5.py#L2544-L2556.\n\t - decoder_model.onnx: corresponds to the decoder part in https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/speecht5/modeling_speecht5.py#L2572-L2602.\n\t - decoder_with_past_model.onnx: same as the above, with past_key_values input (KV cache filled).\n\t - decoder_postnet_and_vocoder.onnx: Decoder speech postnet and vocoder (e.g. a SpeechT5HifiGan) to generate speech from the spectrogram, as in https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/speecht5/modeling_speecht5.py#L2605-L2614.", - "without-cache": "The same as `transformers-like`, without KV cache support. This is not a recommende export as slower than `transformers-like`.", + "with-past": "The export follows the Transformers implementation using the KV cache, with the following components exported:\n\t - encoder_model.onnx: corresponds to the encoding part in https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/speecht5/modeling_speecht5.py#L2544-L2556.\n\t - decoder_model.onnx: corresponds to the decoder part in https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/speecht5/modeling_speecht5.py#L2572-L2602.\n\t - decoder_with_past_model.onnx: same as the above, with past_key_values input (KV cache filled).\n\t - decoder_postnet_and_vocoder.onnx: Decoder speech postnet and vocoder (e.g. a SpeechT5HifiGan) to generate speech from the spectrogram, as in https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/speecht5/modeling_speecht5.py#L2605-L2614.", + "without-past": "The same as `with-past`, just without KV cache support. This is not a recommended export as slower than `with-past`.", } - DEFAULT_VARIANT = "transformers-like" + DEFAULT_VARIANT = "with-past" @property def inputs(self) -> Dict[str, Dict[int, str]]: @@ -1214,7 +1214,7 @@ def inputs(self) -> Dict[str, Dict[int, str]]: common_inputs["encoder_hidden_states"] = {1: "encoder_sequence_length"} common_inputs["encoder_attention_mask"] = {1: "encoder_sequence_length"} - if self.variant == "transformers-like" and self.use_past_in_inputs: + if self.variant == "with-past" and self.use_past_in_inputs: # TODO: check PKV shape self.add_past_key_values(common_inputs, direction="inputs") elif self.is_postnet_and_vocoder: @@ -1237,7 +1237,7 @@ def outputs(self) -> Dict[str, Dict[int, str]]: common_outputs["prob"] = {} # No dynamic shape here. common_outputs["spectrum"] = {} # No dynamic shape here. - if self.variant == "transformers-like" and self.use_past: + if self.variant == "with-past" and self.use_past: # When exporting decoder models with use_cache=True, both the decoder without past and with past have the KV cache as an output. self.add_past_key_values(common_outputs, direction="outputs") elif self.is_postnet_and_vocoder: diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index b9abe29421a..33d92ebb6b6 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -372,7 +372,7 @@ def patched_forward( output_sequence=None, spectrogram=None, ): - use_cache = self.real_config.use_past and self.real_config.variant == "transformers-like" + use_cache = self.real_config.use_past and self.real_config.variant == "with-past" if self.real_config._behavior == "encoder": encoder_attention_mask = torch.ones_like(input_ids) @@ -432,6 +432,7 @@ def patched_forward( # TODO: PKV here } elif self.real_config.is_postnet_and_vocoder: + # NOTE: the following concatenation is expected to be handled outside of the ONNX: # spectrogram = torch.cat(spectrogram, dim=0).unsqueeze(0) spectrogram = spectrogram.unsqueeze(0) spectrogram = model.speech_decoder_postnet.postnet(spectrogram) diff --git a/optimum/exporters/onnx/utils.py b/optimum/exporters/onnx/utils.py index a24cde52135..1ae682cce9f 100644 --- a/optimum/exporters/onnx/utils.py +++ b/optimum/exporters/onnx/utils.py @@ -403,7 +403,7 @@ def get_speecht5_models_for_export( models_for_export["encoder_model"] = model models_for_export["decoder_model"] = model - if config.variant == "transformers-like": + if config.variant == "with-past": models_for_export["decoder_with_past_model"] = model vocoder = SpeechT5HifiGan.from_pretrained(model_kwargs["vocoder"]) @@ -413,12 +413,12 @@ def get_speecht5_models_for_export( encoder_onnx_config = config.with_behavior("encoder") - use_past = config.variant == "transformers-like" + use_past = config.variant == "with-past" decoder_onnx_config = config.with_behavior("decoder", use_past=use_past, use_past_in_inputs=False) models_for_export[ONNX_ENCODER_NAME] = (models_for_export[ONNX_ENCODER_NAME], encoder_onnx_config) models_for_export[ONNX_DECODER_NAME] = (models_for_export[ONNX_DECODER_NAME], decoder_onnx_config) - if config.variant == "transformers-like": + if config.variant == "with-past": decoder_onnx_config_with_past = config.with_behavior("decoder", use_past=True, use_past_in_inputs=True) models_for_export[ONNX_DECODER_WITH_PAST_NAME] = ( models_for_export[ONNX_DECODER_WITH_PAST_NAME],