Skip to content

Commit

Permalink
my version of beta_vae model but not with keras, with torch
Browse files Browse the repository at this point in the history
  • Loading branch information
DanPetrovich committed May 14, 2024
1 parent 94887ee commit 43c12ef
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 0 deletions.
55 changes: 55 additions & 0 deletions beta_vae_torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import torch
import torch.nn as nn


class BetaVAE(nn.Module):
def __init__(self, input_dim, latent_dim, beta=1.0):
super(BetaVAE, self).__init__()
self.beta = beta

# Encoder
self.encoder = nn.Sequential(
nn.Linear(input_dim, 512),
nn.ReLU(),
nn.Linear(512, 256),
nn.ReLU()
)

self.fc_mu = nn.Linear(256, latent_dim)
self.fc_logvar = nn.Linear(256, latent_dim)

# Decoder
self.decoder = nn.Sequential(
nn.Linear(latent_dim, 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, input_dim),
nn.Sigmoid()
)

def encode(self, x):
h = self.encoder(x)
mu = self.fc_mu(h)
logvar = self.fc_logvar(h)
return mu, logvar

def decode(self, z):
return self.decoder(z)

def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std

def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
return self.decode(z), mu, logvar


def loss_function(recon_x, x, mu, logvar, beta):
BCE = nn.functional.binary_cross_entropy(recon_x, x, reduction='sum')
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + beta * KLD

55 changes: 55 additions & 0 deletions pyod/models/beta_vae_torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import torch
import torch.nn as nn


class BetaVAE(nn.Module):
def __init__(self, input_dim, latent_dim, beta=1.0):
super(BetaVAE, self).__init__()
self.beta = beta

# Encoder
self.encoder = nn.Sequential(
nn.Linear(input_dim, 512),
nn.ReLU(),
nn.Linear(512, 256),
nn.ReLU()
)

self.fc_mu = nn.Linear(256, latent_dim)
self.fc_logvar = nn.Linear(256, latent_dim)

# Decoder
self.decoder = nn.Sequential(
nn.Linear(latent_dim, 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, input_dim),
nn.Sigmoid()
)

def encode(self, x):
h = self.encoder(x)
mu = self.fc_mu(h)
logvar = self.fc_logvar(h)
return mu, logvar

def decode(self, z):
return self.decoder(z)

def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std

def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
return self.decode(z), mu, logvar


def loss_function(recon_x, x, mu, logvar, beta):
BCE = nn.functional.binary_cross_entropy(recon_x, x, reduction='sum')
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + beta * KLD

0 comments on commit 43c12ef

Please sign in to comment.