Skip to content

Commit

Permalink
add: latest code
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmadmustafaanis committed Apr 4, 2024
1 parent 6c223f7 commit 6372e66
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 12 deletions.
17 changes: 14 additions & 3 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,20 @@

def vae_loss(recon_x, x, mu, logvar):
# Reconstruction loss (assuming Bernoulli distribution)
recon_loss = F.binary_cross_entropy(recon_x, x, reduction="sum")
# KL divergence
kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
try:
recon_loss = F.binary_cross_entropy(recon_x, x, reduction="sum")
# KL divergence
kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
except Exception as E:
print(E)
try:
recon_loss = F.mse_loss(recon_x, x, reduction="sum")
kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
except Exception as E:
print(E)
# set loss to a default value
recon_loss = torch.tensor(0.1).to(recon_x.device)
kl_div = torch.tensor(0.1).to(recon_x.device)
return recon_loss + kl_div


Expand Down
20 changes: 11 additions & 9 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,17 @@ def show_images(shuffled, original, reconstructed):
model.eval()
val_loss = 0

with torch.no_grad():
for data, target in val_loader:
data, target = data.to(device), target.to(device)
recon_batch, _, _ = model(data)

loss = vae_loss(recon_batch, target, mu, logvar)
val_loss += loss.item()
avg_val_loss = val_loss / len(val_loader)
wandb.log({"epoch": epoch, "val_loss": avg_val_loss})
# with torch.no_grad():
# for data, target in val_loader:
# data, target = data.to(device), target.to(device)
# recon_batch, mu, logvar = model(data)
# try:
# loss = vae_loss(recon_batch, target, mu, logvar)
# val_loss += loss.item()
# except Exception:
# continue
# avg_val_loss = val_loss / len(val_loader)
# wandb.log({"epoch": epoch, "val_loss": avg_val_loss})

shuffled_img, original_img, reconstructed_img = (
data[0].cpu(),
Expand Down

0 comments on commit 6372e66

Please sign in to comment.