diff --git a/phaseaug.py b/phaseaug.py index bd7a42d..309750a 100644 --- a/phaseaug.py +++ b/phaseaug.py @@ -100,9 +100,10 @@ def stft_rot_istft(self, x, phi): # phi: [B,nfft//2+1] # also possible for x :[B,C,T] but we did not generalize it. def forward(self, x, phi=None): + B = x.shape[0] x = x.squeeze(1) #[B,t] if phi is None: - phi = self.sample_phi(self, X.shape[0]) + phi = self.sample_phi(self, B) phi[:, 0] = 0. # we are multiplying phi_ref to mu, so it is always zero in our scheme phi = phi.unsqueeze(-1) #[B,F,1] x_aug = self.stft_rot_istft(self, x, phi) @@ -114,7 +115,7 @@ def forward_sync(self, x, x_hat, phi=None): B = x.shape[0] x = torch.cat([x, x_hat], dim=0) #[2B,1,t] if phi is None: - phi = self.sample_phi(self, X.shape[0] // 2) #[2B, nfft//2+1] + phi = self.sample_phi(self, B) #[2B, nfft//2+1] phi = torch.cat([phi, phi], dim=0) x_augs = self.forward(x, phi).split(B, dim=0) return x_augs diff --git a/phaseaug/phaseaug.py b/phaseaug/phaseaug.py index bd7a42d..309750a 100644 --- a/phaseaug/phaseaug.py +++ b/phaseaug/phaseaug.py @@ -100,9 +100,10 @@ def stft_rot_istft(self, x, phi): # phi: [B,nfft//2+1] # also possible for x :[B,C,T] but we did not generalize it. def forward(self, x, phi=None): + B = x.shape[0] x = x.squeeze(1) #[B,t] if phi is None: - phi = self.sample_phi(self, X.shape[0]) + phi = self.sample_phi(self, B) phi[:, 0] = 0. # we are multiplying phi_ref to mu, so it is always zero in our scheme phi = phi.unsqueeze(-1) #[B,F,1] x_aug = self.stft_rot_istft(self, x, phi) @@ -114,7 +115,7 @@ def forward_sync(self, x, x_hat, phi=None): B = x.shape[0] x = torch.cat([x, x_hat], dim=0) #[2B,1,t] if phi is None: - phi = self.sample_phi(self, X.shape[0] // 2) #[2B, nfft//2+1] + phi = self.sample_phi(self, B) #[2B, nfft//2+1] phi = torch.cat([phi, phi], dim=0) x_augs = self.forward(x, phi).split(B, dim=0) return x_augs diff --git a/setup.py b/setup.py index a7e517e..5dc85f4 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setup( name = 'phaseaug', - version = '1.0.0', + version = '1.0.1', description = 'PhaseAug: A Differentiable Augmentation for Speech Synthesis to Simulate One-to-Many Mapping', long_description=long_description, long_description_content_type="text/markdown",