Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(train): support ipu #264

Draft
wants to merge 22 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 4 additions & 17 deletions .idea/workspace.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 9 additions & 11 deletions src/so_vits_svc_fork/f0.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ def normalize_f0(
factor = torch.ones(f0.shape[0], 1).to(f0.device)
# normalize f0 based on means and factor
f0_norm = (f0 - means.unsqueeze(-1)) * factor.unsqueeze(-1)
if torch.isnan(f0_norm).any():
exit(0)
return f0_norm * x_mask


Expand Down Expand Up @@ -218,17 +216,17 @@ def compute_f0(
def f0_to_coarse(f0: torch.Tensor | float):
is_torch = isinstance(f0, torch.Tensor)
f0_mel = 1127 * (1 + f0 / 700).log() if is_torch else 1127 * np.log(1 + f0 / 700)
f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * (f0_bin - 2) / (
f0_mel_max - f0_mel_min
) + 1
# f0_mel[f0_mel > 0] = ...
f0_mel = (f0_mel - f0_mel_min) * (f0_bin - 2) / (f0_mel_max - f0_mel_min) + 1

f0_mel[f0_mel <= 1] = 1
f0_mel[f0_mel > f0_bin - 1] = f0_bin - 1
# f0_mel[f0_mel <= 1] = 1
# f0_mel[f0_mel > f0_bin - 1] = f0_bin - 1
f0_mel = torch.clamp(f0_mel, 1, f0_bin - 1)
f0_coarse = (f0_mel + 0.5).long() if is_torch else np.rint(f0_mel).astype(np.int)
assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, (
f0_coarse.max(),
f0_coarse.min(),
)
# assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, (
# f0_coarse.max(),
# f0_coarse.min(),
# )
return f0_coarse


Expand Down
166 changes: 41 additions & 125 deletions src/so_vits_svc_fork/modules/commons.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,49 @@
import math

import torch
from torch.nn import functional as F
from torch import Tensor


def slice_segments(x: Tensor, starts: Tensor, length: int) -> Tensor:
x_slice = torch.zeros((x.size()[:-1] + (length,)), dtype=x.dtype, device=x.device)
ends = starts + length
for i, (start, end) in enumerate(zip(starts, ends)):
x_slice[i, ...] = x[i, ..., start:end]
return x_slice


def slice_pitch_segments(x, ids_str, segment_size=4):
ret = torch.zeros_like(x[:, :segment_size])
for i in range(x.size(0)):
idx_str = ids_str[i]
idx_end = idx_str + segment_size
ret[i] = x[i, idx_str:idx_end]
return ret
def slice_2d_segments(x: Tensor, starts: Tensor, length: int) -> Tensor:
batch_size, num_features, seq_len = x.shape
ends = starts + length
idxs = (
torch.arange(seq_len)
.unsqueeze(0)
.unsqueeze(1)
.repeat(batch_size, num_features, 1)
)
mask = (idxs >= starts.unsqueeze(-1).unsqueeze(-1)) & (
idxs < ends.unsqueeze(-1).unsqueeze(-1)
)
return x[mask].reshape(batch_size, num_features, length)


def rand_slice_segments_with_pitch(x, pitch, x_lengths=None, segment_size=4):
b, d, t = x.size()
if x_lengths is None:
x_lengths = t
ids_str_max = x_lengths - segment_size + 1
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
ret = slice_segments(x, ids_str, segment_size)
ret_pitch = slice_pitch_segments(pitch, ids_str, segment_size)
return ret, ret_pitch, ids_str
def slice_1d_segments(x: Tensor, starts: Tensor, length: int) -> Tensor:
batch_size, seq_len = x.shape
ends = starts + length
idxs = torch.arange(seq_len).unsqueeze(0).repeat(batch_size, 1)
mask = (idxs >= starts.unsqueeze(-1)) & (idxs < ends.unsqueeze(-1))
return x[mask].reshape(batch_size, length)


def _slice_segments_v3(x: Tensor, starts: Tensor, length: int) -> Tensor:
shape = x.shape[:-1] + (length,)
ends = starts + length
idxs = torch.arange(x.shape[-1], device=x.device).unsqueeze(0).unsqueeze(0)
unsqueeze_dims = len(shape) - len(
x.shape
) # calculate number of dimensions to unsqueeze
starts = starts.reshape(starts.shape + (1,) * unsqueeze_dims)
ends = ends.reshape(ends.shape + (1,) * unsqueeze_dims)
mask = (idxs >= starts) & (idxs < ends)
return x[mask].reshape(shape)


def init_weights(m, mean=0.0, std=0.01):
Expand All @@ -40,89 +62,6 @@ def convert_pad_shape(pad_shape):
return pad_shape


def intersperse(lst, item):
result = [item] * (len(lst) * 2 + 1)
result[1::2] = lst
return result


def kl_divergence(m_p, logs_p, m_q, logs_q):
"""KL(P||Q)"""
kl = (logs_q - logs_p) - 0.5
kl += (
0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
)
return kl


def rand_gumbel(shape):
"""Sample from the Gumbel distribution, protect from overflows."""
uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
return -torch.log(-torch.log(uniform_samples))


def rand_gumbel_like(x):
g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
return g


def slice_segments(x, ids_str, segment_size=4):
ret = torch.zeros_like(x[:, :, :segment_size])
for i in range(x.size(0)):
idx_str = ids_str[i]
idx_end = idx_str + segment_size
ret[i] = x[i, :, idx_str:idx_end]
return ret


def rand_slice_segments(x, x_lengths=None, segment_size=4):
b, d, t = x.size()
if x_lengths is None:
x_lengths = t
ids_str_max = x_lengths - segment_size + 1
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
ret = slice_segments(x, ids_str, segment_size)
return ret, ids_str


def rand_spec_segments(x, x_lengths=None, segment_size=4):
b, d, t = x.size()
if x_lengths is None:
x_lengths = t
ids_str_max = x_lengths - segment_size
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
ret = slice_segments(x, ids_str, segment_size)
return ret, ids_str


def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
position = torch.arange(length, dtype=torch.float)
num_timescales = channels // 2
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
num_timescales - 1
)
inv_timescales = min_timescale * torch.exp(
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
)
scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
signal = F.pad(signal, [0, 0, 0, channels % 2])
signal = signal.view(1, channels, length)
return signal


def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
b, channels, length = x.size()
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
return x + signal.to(dtype=x.dtype, device=x.device)


def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
b, channels, length = x.size()
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)


def subsequent_mask(length):
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
return mask
Expand All @@ -138,36 +77,13 @@ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
return acts


def shift_1d(x):
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
return x


def sequence_mask(length, max_length=None):
if max_length is None:
max_length = length.max()
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
return x.unsqueeze(0) < length.unsqueeze(1)


def generate_path(duration, mask):
"""
duration: [b, 1, t_x]
mask: [b, 1, t_y, t_x]
"""
duration.device

b, _, t_y, t_x = mask.shape
cum_duration = torch.cumsum(duration, -1)

cum_duration_flat = cum_duration.view(b * t_x)
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
path = path.view(b, t_x, t_y)
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
path = path.unsqueeze(1).transpose(2, 3) * mask
return path


def clip_grad_value_(parameters, clip_value, norm_type=2):
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
Expand Down
11 changes: 6 additions & 5 deletions src/so_vits_svc_fork/modules/mel_processing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""from logging import getLogger
from logging import getLogger

import torch
import torch.utils.data
Expand Down Expand Up @@ -42,7 +42,8 @@ def mel_spectrogram_torch(audio: torch.Tensor, hps: HParams) -> torch.Tensor:
power=1.0,
window_fn=torch.hann_window,
normalized=False,
).to(audio.device)(audio)"""
).to(audio.device)(audio)


from logging import getLogger

Expand Down Expand Up @@ -87,7 +88,7 @@ def spectral_de_normalize_torch(magnitudes):
hann_window = {}


def spectrogram_torch(y, hps, center=False):
def spectrogram_torch_old(y, hps, center=False):
if torch.min(y) < -1.0:
LOG.info("min value is ", torch.min(y))
if torch.max(y) > 1.0:
Expand Down Expand Up @@ -127,7 +128,7 @@ def spectrogram_torch(y, hps, center=False):
return spec


def spec_to_mel_torch(spec, hps):
def spec_to_mel_torch_old(spec, hps):
sampling_rate = hps.data.sampling_rate
n_fft = hps.data.filter_length
num_mels = hps.data.n_mel_channels
Expand All @@ -148,7 +149,7 @@ def spec_to_mel_torch(spec, hps):
return spec


def mel_spectrogram_torch(y, hps, center=False):
def mel_spectrogram_torch_old(y, hps, center=False):
sampling_rate = hps.data.sampling_rate
n_fft = hps.data.filter_length
num_mels = hps.data.n_mel_channels
Expand Down
Loading