Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Webgl support #5

Open
wants to merge 5 commits into
base: to-onnxruntime-web-webgl
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,9 @@ Cyhton が便利です。
5. import して[このように](https://github.com/Hiroshiba/voicevox_core/blob/f4844efc65b1a4875442091955af84f671e16887/example/python/run.py#L21-L25)つなぎこむ

## モデルをonnxに変換
* `python run.py --yukarin_s_model_dir "model/yukarin_s" --yukarin_sa_model_dir "model/yukarin_sa" --yukarin_sosoa_model_dir "model/yukarin_sosoa" --hifigan_model_dir "model/hifigan" --speaker_ids 5 --method=convert` でonnxへの変換が可能。modelフォルダ内のyukarin_s, yukarin_sa, yukarin_sosoaフォルダにonnxが保存される
* `python run.py --yukarin_s_model_dir "model/yukarin_s" --yukarin_sa_model_dir "model/yukarin_sa" --yukarin_sosoa_model_dir "model/yukarin_sosoa" --hifigan_model_dir "model/hifigan" --speaker_ids 5 --method=convert` でonnxへの変換が可能。modelフォルダ内のyukarin_s, yukarin_sa, yukarin_sosoa, hifiganフォルダにonnxが保存される
- `speaker_ids`オプションに指定する数値は自由。どの数値を指定しても生成されるonnxモデルは全ての`speaker_id`に対応しており、値を変えて実行しなおしたり、複数のidを指定したりする必要は無い。
- yukarin_sosoaフォルダにはhifi_ganと合わせた`decode.onnx`が保存される
- int64 to int32 `python onnx-typecast/convert.py model/yukarin_sosoa/decode.onnx model/yukarin_sosoa/decode.onnx`
- int64 to int32 `python onnx-typecast/convert.py model/hifigan/hifigan.onnx model/hifigan/hifigan.onnx`

* onnxで実行したい場合は`--method=onnx`とする; `python run.py --yukarin_s_model_dir "model/yukarin_s" --yukarin_sa_model_dir "model/yukarin_sa" --yukarin_sosoa_model_dir "model/yukarin_sosoa" --hifigan_model_dir "model/hifigan" --speaker_ids 5 --method=onnx`
- `speaker_ids`に複数の数値を指定すれば、通常実行と同様に各話者の音声が保存される。
Expand Down
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,6 @@ git+https://github.com/Hiroshiba/pyopenjtalk@69e5f354634f98098113f9cac5a6ea73644
onnx
typer<0.4
colorlog==4.7.2

--extra-index-url https://pypi.ngc.nvidia.com
onnx_graphsurgeon
79 changes: 45 additions & 34 deletions vv_core_inference/make_decode_forwarder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from vv_core_inference.make_yukarin_sosoa_forwarder import make_yukarin_sosoa_wrapper
from vv_core_inference.utility import to_tensor, OPSET
from vv_core_inference.surgeon import surgeon


class AttrDict(dict):
Expand All @@ -17,33 +18,37 @@ def __init__(self, *args, **kwargs):
self.__dict__ = self


class WrapperDecodeForwarder(nn.Module):
class WrapperHifiGan(nn.Module):
def __init__(
self,
yukarin_sosoa_forwarder: nn.Module,
hifi_gan_forwarder: nn.Module,
):
super().__init__()
self.yukarin_sosoa_forwarder = yukarin_sosoa_forwarder
self.hifi_gan_forwarder = hifi_gan_forwarder

@torch.no_grad()
def forward(
self,
f0: torch.Tensor,
phoneme: torch.Tensor,
speaker_id: torch.Tensor,
spec: torch.Tensor,
):
# forward sosoa
spec = self.yukarin_sosoa_forwarder(
f0=f0, phoneme=phoneme, speaker_id=speaker_id
)

# forward hifi gan
x = spec.transpose(1, 0)
wave = self.hifi_gan_forwarder(x.unsqueeze(0))[0, 0]
return wave

def make_hifi_gan_wrapper(hifigan_model_dir: Path, device) -> nn.Module:
config = AttrDict(json.load(hifigan_model_dir.joinpath("config.json").open()))
predictor = HifiGanPredictor(config).to(device)
checkpoint_dict = torch.load(
hifigan_model_dir.joinpath("model.pth"),
map_location=device,
)
predictor.load_state_dict(checkpoint_dict["generator"])
predictor.eval()
predictor.remove_weight_norm()
print("hifi-gan loaded!")
return WrapperHifiGan(predictor)


def make_decode_forwarder(
yukarin_sosoa_model_dir: Path, hifigan_model_dir: Path, device, convert=False
Expand All @@ -54,23 +59,8 @@ def make_decode_forwarder(
)

# hifi-gan
vocoder_model_config = AttrDict(
json.loads((hifigan_model_dir / "config.json").read_text())
)

hifi_gan_predictor = HifiGanPredictor(vocoder_model_config).to(device)
checkpoint_dict = torch.load(
hifigan_model_dir.joinpath("model.pth"),
map_location=device,
)
hifi_gan_predictor.load_state_dict(checkpoint_dict["generator"])
hifi_gan_predictor.eval()
hifi_gan_predictor.remove_weight_norm()
print("hifi-gan loaded!")

decode_forwarder = WrapperDecodeForwarder(
yukarin_sosoa_forwarder=yukarin_sosoa_wrapper,
hifi_gan_forwarder=hifi_gan_predictor,
hifi_gan_wrapper = make_hifi_gan_wrapper(
hifigan_model_dir=hifigan_model_dir, device=device
)

def _dispatcher(
Expand All @@ -84,21 +74,42 @@ def _dispatcher(
phoneme = to_tensor(phoneme, device=device)
if speaker_id is not None:
speaker_id = to_tensor(speaker_id, device=device)

spec = yukarin_sosoa_wrapper(
f0=f0, phoneme=phoneme, speaker_id=speaker_id
)
wave = hifi_gan_wrapper(spec)

if convert:
torch.onnx.export(
decode_forwarder,
yukarin_sosoa_wrapper,
(f0, phoneme, speaker_id),
yukarin_sosoa_model_dir.joinpath("decode.onnx"),
yukarin_sosoa_model_dir.joinpath("yukarin_sosoa.onnx"),
opset_version=OPSET,
do_constant_folding=True,
input_names=["f0", "phoneme", "speaker_id"],
output_names=["wave"],
output_names=["spec"],
dynamic_axes={
"f0": {0: "length"},
"phoneme": {0: "length"},
"spec": {0: "length"}
})
print("decode/yukarin_sosoa has been converted to ONNX")
torch.onnx.export(
hifi_gan_wrapper,
(spec,),
hifigan_model_dir.joinpath("hifigan.onnx"),
opset_version=OPSET,
do_constant_folding=True,
input_names=["spec"],
output_names=["wave"],
dynamic_axes={
"spec": {0: "length"},
"wave": {0: "outlength"}
})
print("decode has been converted to ONNX")
return decode_forwarder(f0, phoneme, speaker_id).cpu().numpy()
fname = str(hifigan_model_dir.joinpath("hifigan.onnx"))
surgeon(fname, fname)
print("decode/hifigan has been converted to ONNX")
return wave.cpu().numpy()

return _dispatcher
5 changes: 3 additions & 2 deletions vv_core_inference/make_yukarin_sosoa_forwarder.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,9 @@ def forward(

h = self.pre(h)

mask = torch.ones_like(f0).squeeze()
h, _ = self.encoder(h, mask)
# mask = torch.ones_like(f0).squeeze()
# h, _ = self.encoder(h, mask)
h, _ = self.encoder(h, None)

output1 = self.post(h)
output2 = output1 + self.postnet(output1.transpose(1, 2)).transpose(1, 2)
Expand Down
8 changes: 6 additions & 2 deletions vv_core_inference/onnx_decode_forwarder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import onnxruntime

def make_decode_forwarder(yukarin_sosoa_model_dir: Path, hifigan_model_dir: Path, device, convert=False):
session = onnxruntime.InferenceSession(str(yukarin_sosoa_model_dir.joinpath("decode.onnx")))
session_sosoa = onnxruntime.InferenceSession(str(yukarin_sosoa_model_dir.joinpath("yukarin_sosoa.onnx")))
session_hifi = onnxruntime.InferenceSession(str(hifigan_model_dir.joinpath("hifigan.onnx")))

def _dispatcher(
length: int,
Expand All @@ -20,9 +21,12 @@ def _dispatcher(
if speaker_id is not None:
speaker_id = np.asarray(speaker_id)
speaker_id = speaker_id.reshape((1,)).astype(np.int64)
return session.run(["wave"], {
spec = session_sosoa.run(["spec"], {
"f0": f0,
"phoneme": phoneme,
"speaker_id": speaker_id,
})[0]
return session_hifi.run(["wave"], {
"spec": spec,
})[0]
return _dispatcher
116 changes: 116 additions & 0 deletions vv_core_inference/surgeon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import numpy as np
import onnx
import onnx_graphsurgeon as gs

@gs.Graph.register()
def replace_ConvTranspose(self, node):
assert node.op == "ConvTranspose"
in_tensor, weight, bias = node.inputs
in_tensor.outputs.clear()
weight.outputs.clear()
bias.outputs.clear()
out_tensor = node.outputs[0]
out_tensor.inputs.clear()

kernel_size = node.attrs["kernel_shape"]

assert len(kernel_size) == 1, "only supports conv_transpose1d"
kernel_size = kernel_size[0]
groups = node.attrs["group"]
dilation = node.attrs["dilations"][0]
padding = node.attrs["pads"]
stride = node.attrs["strides"][0]

assert groups == 1
assert dilation == 1
assert padding[0] == padding[1]
padding = padding[0]

weight_numpy = weight.values
weight_numpy_conv = np.ascontiguousarray(weight_numpy.transpose(1,0,2)[:,:,::-1])

print("replace", node.name)
h1 = self.layer(op="Unsqueeze", inputs=[in_tensor], outputs=["expanded"], attrs={"axes": [3]})[0]
h2 = self.layer(op="Pad", inputs=[h1, [0, 0, 0, 0, 0, 0, 0, stride-1]], outputs=["pad_inner"])[0]
h3 = self.layer(op="Reshape", inputs=[h2, [0, 0, -1]], outputs=["unpooled"])[0]
h4 = self.layer(op="Pad", inputs=[h3, np.array([0, 0, kernel_size - padding - 1, 0, 0, kernel_size - padding - stride], np.int64)], outputs=["pad_outer"])[0]
return self.layer(op="Conv", inputs=[h4, weight_numpy_conv, bias.values], outputs=[out_tensor], attrs={
"dilations": [1],
"group": 1,
"kernel_shape": [kernel_size],
"pads": [0, 0],
"strides": [1]
})

@gs.Graph.register()
def replace_Conv(self, node):
# 1d -> 2d (webgl only supports conv2d)
assert node.op == "Conv"
in_tensor, weight, bias = node.inputs
in_tensor.outputs.clear()
weight.outputs.clear()
bias.outputs.clear()
out_tensor = node.outputs[0]
out_tensor.inputs.clear()

kernel_size = node.attrs["kernel_shape"]
assert len(kernel_size) == 1, "only supports conv1d"
kernel_size = kernel_size[0]
groups = node.attrs["group"]
dilation = node.attrs["dilations"][0]
padding = node.attrs["pads"]
stride = node.attrs["strides"][0]

print("replace", node.name)
h1 = self.layer(op="Unsqueeze", inputs=[in_tensor], outputs=["in_2d"], attrs={"axes": [3]})[0]
h2 = self.layer(op="Conv", inputs=[h1, weight.values[:, :, :, None], bias], outputs=["out_2d"], attrs={
"dilations": [dilation, dilation],
"group": groups,
"kernel_shape": [kernel_size, 1],
"pads": [padding[0], 0, padding[1], 0],
"strides": [stride, stride],
})[0]
return self.layer(op="Squeeze", inputs=[h2], outputs=[out_tensor], attrs={"axes": [3]})

def fold_unsqueeze(node):
if node.op != "Squeeze":
return
squeeze = node
axes = node.attrs["axes"]
if not (len(node.outputs[0].outputs) == 1 and node.o().op == "LeakyRelu"):
return
relu = node.o()
if not (len(relu.outputs[0].outputs) == 1 and relu.o().op == "Unsqueeze" and relu.o().attrs["axes"] == axes):
return
unsqueeze = relu.o()

in_node = squeeze.i()
in_node.outputs = squeeze.outputs
squeeze.outputs.clear()

relu.outputs = unsqueeze.outputs
unsqueeze.outputs.clear()
print("eliminate", node.name)


def surgeon(filename, outname):
graph = gs.import_onnx(onnx.load(filename))
# ConvTranspose -> Conv
targets = [node for node in graph.nodes if node.op == "ConvTranspose"]
for node in targets:
graph.replace_ConvTranspose(node)
graph.cleanup().toposort()
# Conv1d -> Conv2d
targets = [node for node in graph.nodes if node.op == "Conv"]
for node in targets:
graph.replace_Conv(node)
graph.cleanup().toposort()
# fold --Squeeze--LeakyRelu--Unsqueeze-- into --LeakyRelu--
targets = [node for node in graph.nodes if node.op == "Squeeze"]
for node in targets:
fold_unsqueeze(node)
graph.cleanup()
onnx.save(gs.export_onnx(graph), outname)

# surgeon("model/hifigan/hifigan.onnx", "model/hifigan/hifigan_modified.onnx")
# surgeon("model/hifigan/hifigan.onnx", "../vv_check_web/public/hifigan.onnx")