Skip to content

Commit

Permalink
update excite
Browse files Browse the repository at this point in the history
  • Loading branch information
takenori-y committed Nov 14, 2024
1 parent a138358 commit 8aa14ea
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 10 deletions.
6 changes: 6 additions & 0 deletions diffsptk/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,9 +515,11 @@ def entropy(p, out_format="nat"):
def excite(
p,
frame_period=80,
*,
voiced_region="pulse",
unvoiced_region="gauss",
polarity="auto",
init_phase="zeros",
):
"""Generate a simple excitation signal.
Expand All @@ -539,6 +541,9 @@ def excite(
polarity : ['auto', 'unipolar', 'bipolar']
Polarity.
init_phase : ['zeros', 'random']
Initial phase.
Returns
-------
out : Tensor [shape=(..., NxP)]
Expand All @@ -551,6 +556,7 @@ def excite(
voiced_region=voiced_region,
unvoiced_region=unvoiced_region,
polarity=polarity,
init_phase=init_phase,
)


Expand Down
28 changes: 21 additions & 7 deletions diffsptk/modules/excite.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,25 @@ class ExcitationGeneration(nn.Module):
'square']
Value on voiced region.
unvoiced_region : ['gauss', 'zeros']
unvoiced_region : ['zeros', 'gauss']
Value on unvoiced region.
polarity : ['auto', 'unipolar', 'bipolar']
Polarity.
init_phase : ['zeros', 'random']
Initial phase.
"""

def __init__(
self,
frame_period,
*,
voiced_region="pulse",
unvoiced_region="gauss",
polarity="auto",
init_phase="zeros",
):
super().__init__()

Expand All @@ -62,13 +67,15 @@ def __init__(
"triangle",
"square",
)
assert unvoiced_region in ("gauss", "zeros")
assert unvoiced_region in ("zeros", "gauss")
assert polarity in ("auto", "unipolar", "bipolar")
assert init_phase in ("zeros", "random")

self.frame_period = frame_period
self.voiced_region = voiced_region
self.unvoiced_region = unvoiced_region
self.polarity = polarity
self.init_phase = init_phase

def forward(self, p):
"""Generate a simple excitation signal.
Expand Down Expand Up @@ -98,10 +105,11 @@ def forward(self, p):
self.voiced_region,
self.unvoiced_region,
self.polarity,
self.init_phase,
)

@staticmethod
def _forward(p, frame_period, voiced_region, unvoiced_region, polarity):
def _forward(p, frame_period, voiced_region, unvoiced_region, polarity, init_phase):
# Make mask represents voiced region.
base_mask = torch.clip(p, min=0, max=1)
mask = torch.ne(base_mask, UNVOICED_SYMBOL)
Expand All @@ -127,6 +135,12 @@ def _forward(p, frame_period, voiced_region, unvoiced_region, polarity):
s = torch.cumsum(q.double(), dim=-1)
bias, _ = torch.cummax(s * ~mask, dim=-1)
phase = (s - bias).to(p.dtype)
if init_phase == "zeros":
pass
elif init_phase == "random":
phase += torch.rand_like(p[..., :1])
else:
raise ValueError(f"init_phase {init_phase} is not supported.")

# Generate excitation signal using phase.
if polarity == "auto":
Expand Down Expand Up @@ -168,7 +182,7 @@ def get_pulse_pos(p):
if unipolar:
e[mask] = torch.abs(2 * torch.fmod(phase[mask] + 0.5, 1) - 1)
else:
e[mask] = 2 * torch.abs(2 * torch.fmod(phase[mask] + 1.75, 1) - 1) - 1
e[mask] = 2 * torch.abs(2 * torch.fmod(phase[mask] + 0.75, 1) - 1) - 1
elif voiced_region == "square":
if unipolar:
e[mask] = torch.le(torch.fmod(phase[mask], 1), 0.5).to(e.dtype)
Expand All @@ -177,10 +191,10 @@ def get_pulse_pos(p):
else:
raise ValueError(f"voiced_region {voiced_region} is not supported.")

if unvoiced_region == "gauss":
e[~mask] = torch.randn(torch.sum(~mask), device=e.device)
elif unvoiced_region == "zeros":
if unvoiced_region == "zeros":
pass
elif unvoiced_region == "gauss":
e[~mask] = torch.randn(torch.sum(~mask), device=e.device)
else:
raise ValueError(f"unvoiced_region {unvoiced_region} is not supported.")
return e
Expand Down
11 changes: 8 additions & 3 deletions tests/test_excite.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,14 @@ def compute_error(infile):
["pulse", "sinusoidal", "sawtooth", "inverted-sawtooth", "triangle", "square"],
)
@pytest.mark.parametrize("polarity", ["unipolar", "bipolar"])
def test_waveform(voiced_region, polarity, P=80, verbose=False):
@pytest.mark.parametrize("init_phase", ["zeros", "random"])
def test_waveform(voiced_region, polarity, init_phase, P=80, verbose=False):
excite = diffsptk.ExcitationGeneration(
P, voiced_region=voiced_region, unvoiced_region="zeros", polarity=polarity
P,
voiced_region=voiced_region,
unvoiced_region="zeros",
polarity=polarity,
init_phase=init_phase,
)
pitch = torch.from_numpy(
U.call(f"x2x +sd tools/SPTK/asset/data.short | pitch -s 16 -p {P} -o 0 -a 2")
Expand All @@ -101,4 +106,4 @@ def test_waveform(voiced_region, polarity, P=80, verbose=False):
if voiced_region == "pulse":
e = e / e.abs().max()
if verbose:
sf.write(f"excite_{voiced_region}_{polarity}.wav", e, 16000)
sf.write(f"excite_{voiced_region}_{polarity}_{init_phase}.wav", e, 16000)

0 comments on commit 8aa14ea

Please sign in to comment.