Skip to content

Commit

Permalink
Add moonshine ONNX config
Browse files Browse the repository at this point in the history
  • Loading branch information
xenova committed Dec 16, 2024
1 parent 4daa408 commit 3269458
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
27 changes: 27 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1628,6 +1628,33 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
# def inputs(self) -> Dict[str, Dict[int, str]]:
# return {"input_features": {0: "batch_size", 1: "sequence_classification"}}

class MoonshineOnnxConfig(AudioToTextOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig

# torch.onnx.errors.UnsupportedOperatorError: Exporting the operator 'aten::triu' to ONNX opset version 11 is not supported.
# Support for this operator was added in version 14, try exporting with this version.
DEFAULT_ONNX_OPSET = 14

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
common_inputs = {}

if self._behavior is not ConfigBehavior.DECODER:
common_inputs["input_values"] = {0: "batch_size", 1: "num_samples"}

if self._behavior is not ConfigBehavior.ENCODER:
if self.use_past_in_inputs:
common_inputs["decoder_input_ids"] = {0: "batch_size"}
self.add_past_key_values(common_inputs, direction="inputs")
else:
common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"}

if self._behavior is ConfigBehavior.DECODER:
common_inputs["encoder_outputs"] = {0: "batch_size", 1: "encoder_sequence_length"}

return common_inputs



class WhisperOnnxConfig(AudioToTextOnnxConfig):
DEFAULT_ONNX_OPSET = 14 # Whisper now uses F.scaled_dot_product_attention by default for torch>=2.1.1.
Expand Down
7 changes: 7 additions & 0 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,13 @@ class TasksManager:
"image-classification",
onnx="MobileNetV2OnnxConfig",
),
"moonshine": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
"automatic-speech-recognition",
"automatic-speech-recognition-with-past",
onnx="MoonshineOnnxConfig",
),
"mpnet": supported_tasks_mapping(
"feature-extraction",
"fill-mask",
Expand Down

0 comments on commit 3269458

Please sign in to comment.