Skip to content

Commit

Permalink
bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
junjun3518 committed Apr 25, 2023
1 parent d940c48 commit 07d1c74
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 5 deletions.
5 changes: 3 additions & 2 deletions phaseaug.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,10 @@ def stft_rot_istft(self, x, phi):
# phi: [B,nfft//2+1]
# also possible for x :[B,C,T] but we did not generalize it.
def forward(self, x, phi=None):
B = x.shape[0]
x = x.squeeze(1) #[B,t]
if phi is None:
phi = self.sample_phi(self, X.shape[0])
phi = self.sample_phi(self, B)
phi[:, 0] = 0. # we are multiplying phi_ref to mu, so it is always zero in our scheme
phi = phi.unsqueeze(-1) #[B,F,1]
x_aug = self.stft_rot_istft(self, x, phi)
Expand All @@ -114,7 +115,7 @@ def forward_sync(self, x, x_hat, phi=None):
B = x.shape[0]
x = torch.cat([x, x_hat], dim=0) #[2B,1,t]
if phi is None:
phi = self.sample_phi(self, X.shape[0] // 2) #[2B, nfft//2+1]
phi = self.sample_phi(self, B) #[2B, nfft//2+1]
phi = torch.cat([phi, phi], dim=0)
x_augs = self.forward(x, phi).split(B, dim=0)
return x_augs
5 changes: 3 additions & 2 deletions phaseaug/phaseaug.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,10 @@ def stft_rot_istft(self, x, phi):
# phi: [B,nfft//2+1]
# also possible for x :[B,C,T] but we did not generalize it.
def forward(self, x, phi=None):
B = x.shape[0]
x = x.squeeze(1) #[B,t]
if phi is None:
phi = self.sample_phi(self, X.shape[0])
phi = self.sample_phi(self, B)
phi[:, 0] = 0. # we are multiplying phi_ref to mu, so it is always zero in our scheme
phi = phi.unsqueeze(-1) #[B,F,1]
x_aug = self.stft_rot_istft(self, x, phi)
Expand All @@ -114,7 +115,7 @@ def forward_sync(self, x, x_hat, phi=None):
B = x.shape[0]
x = torch.cat([x, x_hat], dim=0) #[2B,1,t]
if phi is None:
phi = self.sample_phi(self, X.shape[0] // 2) #[2B, nfft//2+1]
phi = self.sample_phi(self, B) #[2B, nfft//2+1]
phi = torch.cat([phi, phi], dim=0)
x_augs = self.forward(x, phi).split(B, dim=0)
return x_augs
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

setup(
name = 'phaseaug',
version = '1.0.0',
version = '1.0.1',
description = 'PhaseAug: A Differentiable Augmentation for Speech Synthesis to Simulate One-to-Many Mapping',
long_description=long_description,
long_description_content_type="text/markdown",
Expand Down

0 comments on commit 07d1c74

Please sign in to comment.