diff --git a/egs/ljspeech/TTS/vits/export-onnx.py b/egs/ljspeech/TTS/vits/export-onnx.py index 2068adeea0..bca6aec99e 100755 --- a/egs/ljspeech/TTS/vits/export-onnx.py +++ b/egs/ljspeech/TTS/vits/export-onnx.py @@ -176,7 +176,7 @@ def export_model_onnx( torch.onnx.export( model, - (tokens, tokens_lens, noise_scale, noise_scale_dur, alpha), + (tokens, tokens_lens, noise_scale, alpha, noise_scale_dur), model_filename, verbose=False, opset_version=opset_version, @@ -184,8 +184,8 @@ def export_model_onnx( "tokens", "tokens_lens", "noise_scale", - "noise_scale_dur", "alpha", + "noise_scale_dur", ], output_names=["audio"], dynamic_axes={ diff --git a/egs/ljspeech/TTS/vits/test_onnx.py b/egs/ljspeech/TTS/vits/test_onnx.py index 686fee2a03..fcbc1d6632 100755 --- a/egs/ljspeech/TTS/vits/test_onnx.py +++ b/egs/ljspeech/TTS/vits/test_onnx.py @@ -92,8 +92,8 @@ def __call__(self, tokens: torch.Tensor, tokens_lens: torch.Tensor) -> torch.Ten self.model.get_inputs()[0].name: tokens.numpy(), self.model.get_inputs()[1].name: tokens_lens.numpy(), self.model.get_inputs()[2].name: noise_scale.numpy(), - self.model.get_inputs()[3].name: noise_scale_dur.numpy(), - self.model.get_inputs()[4].name: alpha.numpy(), + self.model.get_inputs()[3].name: alpha.numpy(), + self.model.get_inputs()[4].name: noise_scale_dur.numpy(), }, )[0] return torch.from_numpy(out) diff --git a/egs/vctk/TTS/vits/export-onnx.py b/egs/vctk/TTS/vits/export-onnx.py index 7c9664cc14..cfc74fd0ac 100755 --- a/egs/vctk/TTS/vits/export-onnx.py +++ b/egs/vctk/TTS/vits/export-onnx.py @@ -187,7 +187,7 @@ def export_model_onnx( torch.onnx.export( model, - (tokens, tokens_lens, noise_scale, noise_scale_dur, speaker, alpha), + (tokens, tokens_lens, noise_scale, alpha, noise_scale_dur, speaker), model_filename, verbose=False, opset_version=opset_version, @@ -195,9 +195,9 @@ def export_model_onnx( "tokens", "tokens_lens", "noise_scale", + "alpha", "noise_scale_dur", "speaker", - "alpha", ], output_names=["audio"], dynamic_axes={ diff --git a/egs/vctk/TTS/vits/test_onnx.py b/egs/vctk/TTS/vits/test_onnx.py index 757e67fc1c..d85c0a27bd 100755 --- a/egs/vctk/TTS/vits/test_onnx.py +++ b/egs/vctk/TTS/vits/test_onnx.py @@ -101,9 +101,9 @@ def __call__( self.model.get_inputs()[0].name: tokens.numpy(), self.model.get_inputs()[1].name: tokens_lens.numpy(), self.model.get_inputs()[2].name: noise_scale.numpy(), - self.model.get_inputs()[3].name: noise_scale_dur.numpy(), - self.model.get_inputs()[4].name: speaker.numpy(), - self.model.get_inputs()[5].name: alpha.numpy(), + self.model.get_inputs()[3].name: alpha.numpy(), + self.model.get_inputs()[4].name: noise_scale_dur.numpy(), + self.model.get_inputs()[5].name: speaker.numpy(), }, )[0] return torch.from_numpy(out)