diff --git a/i6_models/losses/nce.py b/i6_models/losses/nce.py index 355f0096..144b515c 100644 --- a/i6_models/losses/nce.py +++ b/i6_models/losses/nce.py @@ -6,7 +6,6 @@ from torch import nn from torch.nn import functional as F from typing import Optional -import math class NoiseContrastiveEstimationLossV1(nn.Module): @@ -53,10 +52,9 @@ def forward(self, data: torch.Tensor, target: torch.Tensor) -> torch.Tensor: samples = self.noise_distribution_sampler.sample(self.num_samples).cuda() # log-probabilities for the noise distribution k * q(w|h) - sampled_prob = math.log(self.num_samples) + self.noise_distribution_sampler.log_prob( - samples - ) # [num_samples] - true_sample_prob = math.log(self.num_samples) + self.noise_distribution_sampler.log_prob(target) # [B x T] + ws = torch.log(torch.Tensor([self.num_samples])) + sampled_prob = ws + self.noise_distribution_sampler.log_prob(samples) # [num_samples] + true_sample_prob = ws + self.noise_distribution_sampler.log_prob(target) # [B x T] all_classes = torch.cat((target, samples), 0) # [B x T + num_sampled]