From 069f2495e670bbb852f60f83c2982105e5e53b00 Mon Sep 17 00:00:00 2001 From: Yosshi999 Date: Wed, 22 Jun 2022 07:28:37 +0900 Subject: [PATCH 1/5] split decoder and make surgeon script --- README.md | 5 +- vv_core_inference/make_decode_forwarder.py | 76 ++++++++------- .../make_yukarin_sosoa_forwarder.py | 5 +- vv_core_inference/onnx_decode_forwarder.py | 8 +- vv_core_inference/surgeon.py | 92 +++++++++++++++++++ 5 files changed, 145 insertions(+), 41 deletions(-) create mode 100644 vv_core_inference/surgeon.py diff --git a/README.md b/README.md index 45b576b..1133538 100644 --- a/README.md +++ b/README.md @@ -53,10 +53,9 @@ Cyhton が便利です。 5. import して[このように](https://github.com/Hiroshiba/voicevox_core/blob/f4844efc65b1a4875442091955af84f671e16887/example/python/run.py#L21-L25)つなぎこむ ## モデルをonnxに変換 -* `python run.py --yukarin_s_model_dir "model/yukarin_s" --yukarin_sa_model_dir "model/yukarin_sa" --yukarin_sosoa_model_dir "model/yukarin_sosoa" --hifigan_model_dir "model/hifigan" --speaker_ids 5 --method=convert` でonnxへの変換が可能。modelフォルダ内のyukarin_s, yukarin_sa, yukarin_sosoaフォルダにonnxが保存される。 +* `python run.py --yukarin_s_model_dir "model/yukarin_s" --yukarin_sa_model_dir "model/yukarin_sa" --yukarin_sosoa_model_dir "model/yukarin_sosoa" --hifigan_model_dir "model/hifigan" --speaker_ids 5 --method=convert` でonnxへの変換が可能。modelフォルダ内のyukarin_s, yukarin_sa, yukarin_sosoa, hifiganフォルダにonnxが保存される。 - `speaker_ids`オプションに指定する数値は自由。どの数値を指定しても生成されるonnxモデルは全ての`speaker_id`に対応しており、値を変えて実行しなおしたり、複数のidを指定したりする必要は無い。 - - yukarin_sosoaフォルダにはhifi_ganと合わせた`decode.onnx`が保存される - - int64 to int32 `python onnx-typecast/convert.py model/yukarin_sosoa/decode.onnx model/yukarin_sosoa/decode.onnx` + - int64 to int32 `python onnx-typecast/convert.py model/hifigan/hifigan.onnx model/hifigan/hifigan.onnx` * onnxで実行したい場合は`--method=onnx`とする; `python run.py --yukarin_s_model_dir "model/yukarin_s" --yukarin_sa_model_dir "model/yukarin_sa" --yukarin_sosoa_model_dir "model/yukarin_sosoa" --hifigan_model_dir "model/hifigan" --speaker_ids 5 --method=onnx` - `speaker_ids`に複数の数値を指定すれば、通常実行と同様に各話者の音声が保存される。 diff --git a/vv_core_inference/make_decode_forwarder.py b/vv_core_inference/make_decode_forwarder.py index 7165e97..519c064 100644 --- a/vv_core_inference/make_decode_forwarder.py +++ b/vv_core_inference/make_decode_forwarder.py @@ -17,33 +17,37 @@ def __init__(self, *args, **kwargs): self.__dict__ = self -class WrapperDecodeForwarder(nn.Module): +class WrapperHifiGan(nn.Module): def __init__( self, - yukarin_sosoa_forwarder: nn.Module, hifi_gan_forwarder: nn.Module, ): super().__init__() - self.yukarin_sosoa_forwarder = yukarin_sosoa_forwarder self.hifi_gan_forwarder = hifi_gan_forwarder - + @torch.no_grad() def forward( self, - f0: torch.Tensor, - phoneme: torch.Tensor, - speaker_id: torch.Tensor, + spec: torch.Tensor, ): - # forward sosoa - spec = self.yukarin_sosoa_forwarder( - f0=f0, phoneme=phoneme, speaker_id=speaker_id - ) - # forward hifi gan x = spec.transpose(1, 0) wave = self.hifi_gan_forwarder(x.unsqueeze(0))[0, 0] return wave +def make_hifi_gan_wrapper(hifigan_model_dir: Path, device) -> nn.Module: + config = AttrDict(json.load(hifigan_model_dir.joinpath("config.json").open())) + predictor = HifiGanPredictor(config).to(device) + checkpoint_dict = torch.load( + hifigan_model_dir.joinpath("model.pth"), + map_location=device, + ) + predictor.load_state_dict(checkpoint_dict["generator"]) + predictor.eval() + predictor.remove_weight_norm() + print("hifi-gan loaded!") + return WrapperHifiGan(predictor) + def make_decode_forwarder( yukarin_sosoa_model_dir: Path, hifigan_model_dir: Path, device, convert=False @@ -54,23 +58,8 @@ def make_decode_forwarder( ) # hifi-gan - vocoder_model_config = AttrDict( - json.loads((hifigan_model_dir / "config.json").read_text()) - ) - - hifi_gan_predictor = HifiGanPredictor(vocoder_model_config).to(device) - checkpoint_dict = torch.load( - hifigan_model_dir.joinpath("model.pth"), - map_location=device, - ) - hifi_gan_predictor.load_state_dict(checkpoint_dict["generator"]) - hifi_gan_predictor.eval() - hifi_gan_predictor.remove_weight_norm() - print("hifi-gan loaded!") - - decode_forwarder = WrapperDecodeForwarder( - yukarin_sosoa_forwarder=yukarin_sosoa_wrapper, - hifi_gan_forwarder=hifi_gan_predictor, + hifi_gan_wrapper = make_hifi_gan_wrapper( + hifigan_model_dir=hifigan_model_dir, device=device ) def _dispatcher( @@ -84,21 +73,40 @@ def _dispatcher( phoneme = to_tensor(phoneme, device=device) if speaker_id is not None: speaker_id = to_tensor(speaker_id, device=device) + + spec = yukarin_sosoa_wrapper( + f0=f0, phoneme=phoneme, speaker_id=speaker_id + ) + wave = hifi_gan_wrapper(spec) + if convert: torch.onnx.export( - decode_forwarder, + yukarin_sosoa_wrapper, (f0, phoneme, speaker_id), - yukarin_sosoa_model_dir.joinpath("decode.onnx"), + yukarin_sosoa_model_dir.joinpath("yukarin_sosoa.onnx"), opset_version=OPSET, do_constant_folding=True, input_names=["f0", "phoneme", "speaker_id"], - output_names=["wave"], + output_names=["spec"], dynamic_axes={ "f0": {0: "length"}, "phoneme": {0: "length"}, + "spec": {0: "length"} + }) + print("decode/yukarin_sosoa has been converted to ONNX") + torch.onnx.export( + hifi_gan_wrapper, + (spec,), + hifigan_model_dir.joinpath("hifigan.onnx"), + opset_version=OPSET, + do_constant_folding=True, + input_names=["spec"], + output_names=["wave"], + dynamic_axes={ + "spec": {0: "length"}, "wave": {0: "outlength"} }) - print("decode has been converted to ONNX") - return decode_forwarder(f0, phoneme, speaker_id).cpu().numpy() + print("decode/hifigan has been converted to ONNX") + return wave.cpu().numpy() return _dispatcher \ No newline at end of file diff --git a/vv_core_inference/make_yukarin_sosoa_forwarder.py b/vv_core_inference/make_yukarin_sosoa_forwarder.py index 465dee4..eadbccc 100644 --- a/vv_core_inference/make_yukarin_sosoa_forwarder.py +++ b/vv_core_inference/make_yukarin_sosoa_forwarder.py @@ -126,8 +126,9 @@ def forward( h = self.pre(h) - mask = torch.ones_like(f0).squeeze() - h, _ = self.encoder(h, mask) + # mask = torch.ones_like(f0).squeeze() + # h, _ = self.encoder(h, mask) + h, _ = self.encoder(h, None) output1 = self.post(h) output2 = output1 + self.postnet(output1.transpose(1, 2)).transpose(1, 2) diff --git a/vv_core_inference/onnx_decode_forwarder.py b/vv_core_inference/onnx_decode_forwarder.py index 223f20f..5c7c8ad 100644 --- a/vv_core_inference/onnx_decode_forwarder.py +++ b/vv_core_inference/onnx_decode_forwarder.py @@ -6,7 +6,8 @@ import onnxruntime def make_decode_forwarder(yukarin_sosoa_model_dir: Path, hifigan_model_dir: Path, device, convert=False): - session = onnxruntime.InferenceSession(str(yukarin_sosoa_model_dir.joinpath("decode.onnx"))) + session_sosoa = onnxruntime.InferenceSession(str(yukarin_sosoa_model_dir.joinpath("yukarin_sosoa.onnx"))) + session_hifi = onnxruntime.InferenceSession(str(hifigan_model_dir.joinpath("hifigan_modified.onnx"))) def _dispatcher( length: int, @@ -20,9 +21,12 @@ def _dispatcher( if speaker_id is not None: speaker_id = np.asarray(speaker_id) speaker_id = speaker_id.reshape((1,)).astype(np.int64) - return session.run(["wave"], { + spec = session_sosoa.run(["spec"], { "f0": f0, "phoneme": phoneme, "speaker_id": speaker_id, })[0] + return session_hifi.run(["wave"], { + "spec": spec, + })[0] return _dispatcher diff --git a/vv_core_inference/surgeon.py b/vv_core_inference/surgeon.py new file mode 100644 index 0000000..52ed7da --- /dev/null +++ b/vv_core_inference/surgeon.py @@ -0,0 +1,92 @@ +import numpy as np +import onnx +import onnx_graphsurgeon as gs + +@gs.Graph.register() +def replace_ConvTranspose(self, node): + assert node.op == "ConvTranspose" + in_tensor, weight, bias = node.inputs + in_tensor.outputs.clear() + weight.outputs.clear() + bias.outputs.clear() + out_tensor = node.outputs[0] + out_tensor.inputs.clear() + + kernel_size = node.attrs["kernel_shape"] + + assert len(kernel_size) == 1, "only supports conv_transpose1d" + kernel_size = kernel_size[0] + groups = node.attrs["group"] + dilation = node.attrs["dilations"][0] + padding = node.attrs["pads"] + stride = node.attrs["strides"][0] + + assert groups == 1 + assert dilation == 1 + assert padding[0] == padding[1] + padding = padding[0] + + weight_numpy = weight.values + weight_numpy_conv = np.ascontiguousarray(weight_numpy.transpose(1,0,2)[:,:,::-1]) + + # h1 = self.layer(op="Unsqueeze", inputs=[in_tensor, np.array([-1], np.int64)], outputs=["expanded"], attrs={"axes": [-1]})[0] + h1 = self.layer(op="Unsqueeze", inputs=[in_tensor], outputs=["expanded"], attrs={"axes": [-1]})[0] + h2 = self.layer(op="Pad", inputs=[h1, [0, 0, 0, 0, 0, 0, 0, stride-1]], outputs=["pad_inner"])[0] + # shape = self.layer(op="Shape", inputs=[h2], outputs=["shape"])[0] + # shape_0_2 = self.layer(op="Slice", inputs=[shape, [0], [2]], outputs=["shape_slice"])[0] + # shape_flatten = self.layer(op="Concat", inputs=[shape_0_2, np.array([-1], np.int64)], outputs=["shape_flatten"], attrs={"axis": 0})[0] + # h3 = self.layer(op="Reshape", inputs=[h2, shape_flatten], outputs=["unpooled"])[0] + h3 = self.layer(op="Reshape", inputs=[h2, [0, 0, -1]], outputs=["unpooled"])[0] + h4 = self.layer(op="Pad", inputs=[h3, np.array([0, 0, kernel_size - padding - 1, 0, 0, kernel_size - padding - stride], np.int64)], outputs=["pad_outer"])[0] + return self.layer(op="Conv", inputs=[h4, weight_numpy_conv, bias.values], outputs=[out_tensor], attrs={ + "dilations": [1], + "group": 1, + "kernel_shape": [kernel_size], + "pads": [0, 0], + "strides": [1] + }) + +@gs.Graph.register() +def replace_Conv(self, node): + # 1d -> 2d (webgl only supports conv2d) + assert node.op == "Conv" + in_tensor, weight, bias = node.inputs + in_tensor.outputs.clear() + weight.outputs.clear() + bias.outputs.clear() + out_tensor = node.outputs[0] + out_tensor.inputs.clear() + + kernel_size = node.attrs["kernel_shape"] + assert len(kernel_size) == 1, "only supports conv1d" + kernel_size = kernel_size[0] + groups = node.attrs["group"] + dilation = node.attrs["dilations"][0] + padding = node.attrs["pads"] + stride = node.attrs["strides"][0] + + h1 = self.layer(op="Unsqueeze", inputs=[in_tensor], outputs=["in_2d"], attrs={"axes": [3]})[0] + h2 = self.layer(op="Conv", inputs=[h1, weight.values[:, :, :, None], bias], outputs=["out_2d"], attrs={ + "dilations": [dilation, dilation], + "group": groups, + "kernel_shape": [kernel_size, 1], + "pads": [padding[0], 0, padding[1], 0], + "strides": [stride, stride], + })[0] + return self.layer(op="Squeeze", inputs=[h2], outputs=[out_tensor], attrs={"axes": [-1]}) + + +def surgeon(filename): + graph = gs.import_onnx(onnx.load(filename)) + targets = [node for node in graph.nodes if node.op == "ConvTranspose"] + for node in targets: + graph.replace_ConvTranspose(node) + graph.cleanup().toposort() + targets = [node for node in graph.nodes if node.op == "Conv"] + for node in targets: + graph.replace_Conv(node) + graph.cleanup().toposort() + onnx.save(gs.export_onnx(graph), "model/hifigan/hifigan_modified.onnx") + # onnx.save(gs.export_onnx(graph), "../vv_check_web/public/hifigan.onnx") + +surgeon("model/hifigan/hifigan.onnx") \ No newline at end of file From 51636aaf8273d14a37c2bf5fa7ce2221200d7c85 Mon Sep 17 00:00:00 2001 From: Yosshi999 Date: Wed, 22 Jun 2022 07:29:56 +0900 Subject: [PATCH 2/5] remove comments, fix filename --- vv_core_inference/onnx_decode_forwarder.py | 2 +- vv_core_inference/surgeon.py | 5 ----- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/vv_core_inference/onnx_decode_forwarder.py b/vv_core_inference/onnx_decode_forwarder.py index 5c7c8ad..0b05a62 100644 --- a/vv_core_inference/onnx_decode_forwarder.py +++ b/vv_core_inference/onnx_decode_forwarder.py @@ -7,7 +7,7 @@ def make_decode_forwarder(yukarin_sosoa_model_dir: Path, hifigan_model_dir: Path, device, convert=False): session_sosoa = onnxruntime.InferenceSession(str(yukarin_sosoa_model_dir.joinpath("yukarin_sosoa.onnx"))) - session_hifi = onnxruntime.InferenceSession(str(hifigan_model_dir.joinpath("hifigan_modified.onnx"))) + session_hifi = onnxruntime.InferenceSession(str(hifigan_model_dir.joinpath("hifigan.onnx"))) def _dispatcher( length: int, diff --git a/vv_core_inference/surgeon.py b/vv_core_inference/surgeon.py index 52ed7da..4ea6ac4 100644 --- a/vv_core_inference/surgeon.py +++ b/vv_core_inference/surgeon.py @@ -29,13 +29,8 @@ def replace_ConvTranspose(self, node): weight_numpy = weight.values weight_numpy_conv = np.ascontiguousarray(weight_numpy.transpose(1,0,2)[:,:,::-1]) - # h1 = self.layer(op="Unsqueeze", inputs=[in_tensor, np.array([-1], np.int64)], outputs=["expanded"], attrs={"axes": [-1]})[0] h1 = self.layer(op="Unsqueeze", inputs=[in_tensor], outputs=["expanded"], attrs={"axes": [-1]})[0] h2 = self.layer(op="Pad", inputs=[h1, [0, 0, 0, 0, 0, 0, 0, stride-1]], outputs=["pad_inner"])[0] - # shape = self.layer(op="Shape", inputs=[h2], outputs=["shape"])[0] - # shape_0_2 = self.layer(op="Slice", inputs=[shape, [0], [2]], outputs=["shape_slice"])[0] - # shape_flatten = self.layer(op="Concat", inputs=[shape_0_2, np.array([-1], np.int64)], outputs=["shape_flatten"], attrs={"axis": 0})[0] - # h3 = self.layer(op="Reshape", inputs=[h2, shape_flatten], outputs=["unpooled"])[0] h3 = self.layer(op="Reshape", inputs=[h2, [0, 0, -1]], outputs=["unpooled"])[0] h4 = self.layer(op="Pad", inputs=[h3, np.array([0, 0, kernel_size - padding - 1, 0, 0, kernel_size - padding - stride], np.int64)], outputs=["pad_outer"])[0] return self.layer(op="Conv", inputs=[h4, weight_numpy_conv, bias.values], outputs=[out_tensor], attrs={ From 53f5813bbcbc915e22191c258739ff491cfc57dc Mon Sep 17 00:00:00 2001 From: Yosshi999 Date: Wed, 22 Jun 2022 07:56:11 +0900 Subject: [PATCH 3/5] process surgeon in converting --- vv_core_inference/make_decode_forwarder.py | 3 +++ vv_core_inference/surgeon.py | 8 ++++---- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/vv_core_inference/make_decode_forwarder.py b/vv_core_inference/make_decode_forwarder.py index 519c064..5a24459 100644 --- a/vv_core_inference/make_decode_forwarder.py +++ b/vv_core_inference/make_decode_forwarder.py @@ -9,6 +9,7 @@ from vv_core_inference.make_yukarin_sosoa_forwarder import make_yukarin_sosoa_wrapper from vv_core_inference.utility import to_tensor, OPSET +from vv_core_inference.surgeon import surgeon class AttrDict(dict): @@ -106,6 +107,8 @@ def _dispatcher( "spec": {0: "length"}, "wave": {0: "outlength"} }) + fname = str(hifigan_model_dir.joinpath("hifigan.onnx")) + surgeon(fname, fname) print("decode/hifigan has been converted to ONNX") return wave.cpu().numpy() diff --git a/vv_core_inference/surgeon.py b/vv_core_inference/surgeon.py index 4ea6ac4..9aab373 100644 --- a/vv_core_inference/surgeon.py +++ b/vv_core_inference/surgeon.py @@ -71,7 +71,7 @@ def replace_Conv(self, node): return self.layer(op="Squeeze", inputs=[h2], outputs=[out_tensor], attrs={"axes": [-1]}) -def surgeon(filename): +def surgeon(filename, outname): graph = gs.import_onnx(onnx.load(filename)) targets = [node for node in graph.nodes if node.op == "ConvTranspose"] for node in targets: @@ -81,7 +81,7 @@ def surgeon(filename): for node in targets: graph.replace_Conv(node) graph.cleanup().toposort() - onnx.save(gs.export_onnx(graph), "model/hifigan/hifigan_modified.onnx") - # onnx.save(gs.export_onnx(graph), "../vv_check_web/public/hifigan.onnx") + onnx.save(gs.export_onnx(graph), outname) -surgeon("model/hifigan/hifigan.onnx") \ No newline at end of file +# surgeon("model/hifigan/hifigan.onnx", "model/hifigan/hifigan_modified.onnx") +# surgeon("model/hifigan/hifigan.onnx", "../vv_check_web/public/hifigan.onnx") \ No newline at end of file From 8e0f194a3c8e2220f50747bd4ea7ee995f526399 Mon Sep 17 00:00:00 2001 From: Yosshi999 Date: Wed, 22 Jun 2022 23:55:29 +0900 Subject: [PATCH 4/5] update requirements --- requirements.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/requirements.txt b/requirements.txt index 76b3ca4..1f31266 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,3 +13,6 @@ git+https://github.com/Hiroshiba/pyopenjtalk@69e5f354634f98098113f9cac5a6ea73644 onnx typer<0.4 colorlog==4.7.2 + +--extra-index-url https://pypi.ngc.nvidia.com +onnx_graphsurgeon \ No newline at end of file From 05e6d512050cb72a4d6fd49278e289fa8dd542f6 Mon Sep 17 00:00:00 2001 From: Yosshi999 Date: Wed, 22 Jun 2022 23:55:59 +0900 Subject: [PATCH 5/5] refactoring surgeon script and add elimination algorithm --- vv_core_inference/surgeon.py | 33 +++++++++++++++++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/vv_core_inference/surgeon.py b/vv_core_inference/surgeon.py index 9aab373..0e00e33 100644 --- a/vv_core_inference/surgeon.py +++ b/vv_core_inference/surgeon.py @@ -29,7 +29,8 @@ def replace_ConvTranspose(self, node): weight_numpy = weight.values weight_numpy_conv = np.ascontiguousarray(weight_numpy.transpose(1,0,2)[:,:,::-1]) - h1 = self.layer(op="Unsqueeze", inputs=[in_tensor], outputs=["expanded"], attrs={"axes": [-1]})[0] + print("replace", node.name) + h1 = self.layer(op="Unsqueeze", inputs=[in_tensor], outputs=["expanded"], attrs={"axes": [3]})[0] h2 = self.layer(op="Pad", inputs=[h1, [0, 0, 0, 0, 0, 0, 0, stride-1]], outputs=["pad_inner"])[0] h3 = self.layer(op="Reshape", inputs=[h2, [0, 0, -1]], outputs=["unpooled"])[0] h4 = self.layer(op="Pad", inputs=[h3, np.array([0, 0, kernel_size - padding - 1, 0, 0, kernel_size - padding - stride], np.int64)], outputs=["pad_outer"])[0] @@ -60,6 +61,7 @@ def replace_Conv(self, node): padding = node.attrs["pads"] stride = node.attrs["strides"][0] + print("replace", node.name) h1 = self.layer(op="Unsqueeze", inputs=[in_tensor], outputs=["in_2d"], attrs={"axes": [3]})[0] h2 = self.layer(op="Conv", inputs=[h1, weight.values[:, :, :, None], bias], outputs=["out_2d"], attrs={ "dilations": [dilation, dilation], @@ -68,19 +70,46 @@ def replace_Conv(self, node): "pads": [padding[0], 0, padding[1], 0], "strides": [stride, stride], })[0] - return self.layer(op="Squeeze", inputs=[h2], outputs=[out_tensor], attrs={"axes": [-1]}) + return self.layer(op="Squeeze", inputs=[h2], outputs=[out_tensor], attrs={"axes": [3]}) + +def fold_unsqueeze(node): + if node.op != "Squeeze": + return + squeeze = node + axes = node.attrs["axes"] + if not (len(node.outputs[0].outputs) == 1 and node.o().op == "LeakyRelu"): + return + relu = node.o() + if not (len(relu.outputs[0].outputs) == 1 and relu.o().op == "Unsqueeze" and relu.o().attrs["axes"] == axes): + return + unsqueeze = relu.o() + + in_node = squeeze.i() + in_node.outputs = squeeze.outputs + squeeze.outputs.clear() + + relu.outputs = unsqueeze.outputs + unsqueeze.outputs.clear() + print("eliminate", node.name) def surgeon(filename, outname): graph = gs.import_onnx(onnx.load(filename)) + # ConvTranspose -> Conv targets = [node for node in graph.nodes if node.op == "ConvTranspose"] for node in targets: graph.replace_ConvTranspose(node) graph.cleanup().toposort() + # Conv1d -> Conv2d targets = [node for node in graph.nodes if node.op == "Conv"] for node in targets: graph.replace_Conv(node) graph.cleanup().toposort() + # fold --Squeeze--LeakyRelu--Unsqueeze-- into --LeakyRelu-- + targets = [node for node in graph.nodes if node.op == "Squeeze"] + for node in targets: + fold_unsqueeze(node) + graph.cleanup() onnx.save(gs.export_onnx(graph), outname) # surgeon("model/hifigan/hifigan.onnx", "model/hifigan/hifigan_modified.onnx")