-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 6c223f7
Showing
4 changed files
with
193 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
wandb | ||
data | ||
__pycache__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import torch | ||
import torchvision | ||
from torch.utils.data import Dataset | ||
from torchvision import transforms | ||
|
||
|
||
class ShuffledCIFAR10(Dataset): | ||
def __init__(self, train=True): | ||
self.dataset = torchvision.datasets.CIFAR10( | ||
root="./data", train=train, download=True, transform=transforms.ToTensor() | ||
) | ||
self.permutation = torch.randperm(32 * 32) | ||
|
||
def __len__(self): | ||
return len(self.dataset) | ||
|
||
def __getitem__(self, idx): | ||
img, label = self.dataset[idx] | ||
img_flat = img.view(-1, 32 * 32) | ||
shuffled_img_flat = img_flat[:, self.permutation] | ||
shuffled_img = shuffled_img_flat.view_as(img) | ||
return shuffled_img, img |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torchvision.models.resnet import BasicBlock, ResNet | ||
|
||
|
||
def vae_loss(recon_x, x, mu, logvar): | ||
# Reconstruction loss (assuming Bernoulli distribution) | ||
recon_loss = F.binary_cross_entropy(recon_x, x, reduction="sum") | ||
# KL divergence | ||
kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) | ||
return recon_loss + kl_div | ||
|
||
|
||
class UnFlatten(nn.Module): | ||
def forward(self, input): | ||
return input.view(input.size(0), 512, 1, 1) | ||
|
||
|
||
class VAE(nn.Module): | ||
def __init__(self, image_channels=3, latent_dim=128): | ||
super(VAE, self).__init__() | ||
# Encoder (ResNet-18) | ||
self.encoder = nn.Sequential( | ||
*list(ResNet(BasicBlock, [2, 2, 2, 2]).children())[:-1], nn.Flatten() | ||
) | ||
self.fc_mu = nn.Linear(512, latent_dim) | ||
self.fc_logvar = nn.Linear(512, latent_dim) | ||
|
||
# Decoder | ||
self.decoder_input = nn.Linear(latent_dim, 512) | ||
self.decoder = nn.Sequential( | ||
UnFlatten(), # Output: 512x1x1 | ||
nn.ConvTranspose2d( | ||
512, 256, kernel_size=4, stride=2, padding=1 | ||
), # Output: 256x2x2 | ||
nn.ReLU(), | ||
nn.ConvTranspose2d( | ||
256, 128, kernel_size=4, stride=2, padding=1 | ||
), # Output: 128x4x4 | ||
nn.ReLU(), | ||
nn.ConvTranspose2d( | ||
128, 64, kernel_size=4, stride=2, padding=1 | ||
), # Output: 64x8x8 | ||
nn.ReLU(), | ||
nn.ConvTranspose2d( | ||
64, 32, kernel_size=4, stride=2, padding=1 | ||
), # Output: 32x16x16 | ||
nn.ReLU(), | ||
nn.ConvTranspose2d( | ||
32, image_channels, kernel_size=4, stride=2, padding=1 | ||
), # Output: 3x32x32 | ||
nn.Sigmoid(), | ||
) | ||
|
||
def reparameterize(self, mu, logvar): | ||
std = torch.exp(0.5 * logvar) | ||
eps = torch.randn_like(std) | ||
return mu + eps * std | ||
|
||
def forward(self, x): | ||
# Encode | ||
x_encoded = self.encoder(x) | ||
mu = self.fc_mu(x_encoded) | ||
logvar = self.fc_logvar(x_encoded) | ||
z = self.reparameterize(mu, logvar) | ||
# Decode | ||
x_reconstructed = self.decoder(self.decoder_input(z)) | ||
return x_reconstructed, mu, logvar |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import torch | ||
from torch import optim | ||
from torch.utils.data import DataLoader | ||
from tqdm import tqdm | ||
|
||
import wandb | ||
from data import ShuffledCIFAR10 | ||
from model import VAE, vae_loss | ||
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
EPOCHS = 100000 # can we have grokking kinda effect with this insane number??? | ||
BATCH_SIZE = 64 | ||
train_dataset = ShuffledCIFAR10(train=True) | ||
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) | ||
|
||
val_dataset = ShuffledCIFAR10(train=False) | ||
val_loader = DataLoader( | ||
val_dataset, batch_size=BATCH_SIZE, shuffle=False | ||
) # No need to shuffle the validation dataset | ||
|
||
|
||
model = VAE().to(device) | ||
optimizer = optim.Adam(model.parameters(), lr=0.001) | ||
|
||
|
||
def show_images(shuffled, original, reconstructed): | ||
fig, axs = plt.subplots(1, 3, figsize=(9, 3)) | ||
for ax, img, title in zip( | ||
axs, | ||
[shuffled, original, reconstructed], | ||
["Shuffled", "Original", "Reconstructed"], | ||
): | ||
ax.imshow(np.transpose(img.numpy(), (1, 2, 0))) | ||
ax.set_title(title) | ||
ax.axis("off") | ||
plt.show() | ||
|
||
|
||
wandb.init( | ||
project="image_reconstruction_vae", | ||
config={ | ||
"epochs": EPOCHS, | ||
"batch_size": BATCH_SIZE, | ||
"image_channels": 3, | ||
"CUDA_LAUNCH_BLOCKING=1": True, | ||
}, | ||
) | ||
|
||
for epoch in tqdm(range(EPOCHS)): | ||
model.train() | ||
train_loss = 0 | ||
for data, target in train_loader: | ||
data, target = data.to(device), target.to(device) | ||
optimizer.zero_grad() | ||
recon_batch, mu, logvar = model(data) | ||
loss = vae_loss(recon_batch, target, mu, logvar) | ||
loss.backward() | ||
train_loss += loss.item() | ||
optimizer.step() | ||
|
||
avg_train_loss = train_loss / len(train_loader) | ||
print(f"Epoch {epoch}, Training Loss: {avg_train_loss}") | ||
wandb.log({"epoch": epoch, "train_loss": avg_train_loss}) | ||
|
||
model.eval() | ||
val_loss = 0 | ||
|
||
with torch.no_grad(): | ||
for data, target in val_loader: | ||
data, target = data.to(device), target.to(device) | ||
recon_batch, _, _ = model(data) | ||
|
||
loss = vae_loss(recon_batch, target, mu, logvar) | ||
val_loss += loss.item() | ||
avg_val_loss = val_loss / len(val_loader) | ||
wandb.log({"epoch": epoch, "val_loss": avg_val_loss}) | ||
|
||
shuffled_img, original_img, reconstructed_img = ( | ||
data[0].cpu(), | ||
target[0].cpu(), | ||
recon_batch[0].cpu(), | ||
) | ||
# print(len(data)) | ||
wandb.log( | ||
{ | ||
"reconstructed_images": [ | ||
wandb.Image(recon_batch[i].cpu(), caption="Reconstructed Image") | ||
for i in range(5) | ||
], | ||
"original_images": [ | ||
wandb.Image(target[i].cpu(), caption="Original Image") for i in range(5) | ||
], | ||
"shuffled_images": [ | ||
wandb.Image(data[i].cpu(), caption="Shuffled Image") for i in range(5) | ||
], | ||
} | ||
) |