From 03faf8f723556a074703ca914ceaa28f92289f81 Mon Sep 17 00:00:00 2001 From: Avik Basu Date: Tue, 16 Jan 2024 19:00:25 -0800 Subject: [PATCH] fix: vae nsamples Signed-off-by: Avik Basu --- numalogic/models/vae/variants/conv.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/numalogic/models/vae/variants/conv.py b/numalogic/models/vae/variants/conv.py index 62d7ad9a..4bf54a22 100644 --- a/numalogic/models/vae/variants/conv.py +++ b/numalogic/models/vae/variants/conv.py @@ -202,8 +202,7 @@ def forward(self, x: Tensor) -> tuple[MultivariateNormal, Tensor]: x = self.configure_shape(x) z_mu, z_logvar = self.encoder(x) p = MultivariateNormal(loc=z_mu, covariance_matrix=torch.diag_embed(z_logvar.exp())) - samples = p.rsample(sample_shape=torch.Size([self.nsamples])) - z = torch.mean(samples, dim=0) + z = p.rsample() x_recon = self.decoder(z) return p, x_recon