Skip to content

Commit

Permalink
Merge pull request #55 from utkarshp1161/main
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax authored Jan 29, 2024
2 parents ce84746 + ce0b58d commit 7807ffb
Show file tree
Hide file tree
Showing 7 changed files with 660 additions and 4 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,5 @@ dmypy.json

# Pyre type checker
.pyre/

*tar.gz
533 changes: 533 additions & 0 deletions examples/VAE_gp.ipynb

Large diffs are not rendered by default.

61 changes: 58 additions & 3 deletions pyroved/models/ivae.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,18 @@
"""

from typing import Optional, Tuple, Union, List

import pyro
import pyro.distributions as dist
import torch

import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from pyroved.models.base import baseVAE
from pyroved.nets import fcDecoderNet, fcEncoderNet, sDecoderNet
from pyroved.utils import (
generate_grid, generate_latent_grid, get_sampler,
plot_img_grid, plot_spect_grid, set_deterministic_mode,
to_onehot, transform_coordinates
to_onehot, transform_coordinates, gp_model
)


Expand Down Expand Up @@ -307,3 +308,57 @@ def manifold2d(self, d: int,
elif self.ndim == 1:
plot_spect_grid(loc, d, **kwargs)
return loc

def predict_on_latent(self, train_data: torch.Tensor, gp_labels: torch.Tensor, gp_iterations: int = 1, d: int = 12, plot: bool = False):
"""
Predicts on the latent grid using a trained GP
Args:
train_data: Training data used to train the VAE
gp_labels: Labels for training data
gp_iterations: Number of iterations for GP training
d: Grid size
Returns:
z: Latent grid
z_decoded: Decoded latent grid
predictions: Predictions on the latent grid
"""
# Convert X and y to torch tensors
X = torch.tensor(train_data, dtype=torch.float32)
y = torch.tensor(gp_labels, dtype=torch.float32)

# Use VAE's encoder to transform X into the latent space
encoded_X = self.encode(X)[0] # Assuming the encoder returns mean as the first element

gpr = gp_model(input_dim=encoded_X.shape[1], encoded_X=encoded_X, y=y, gp_iterations=gp_iterations)


# Generate the latent grid
z, (grid_x, grid_y) = generate_latent_grid(d)
z = torch.tensor(z, dtype=torch.float32)

# Predict on the latent grid using the trained GP
gpr.eval()
with torch.no_grad():
predictions, _ = gpr(z)
x, y = np.array(z).T
z_decoded = self.manifold2d(d, plot=False)
if plot:
self.manifold2d(d=d, cmap='viridis')
# Plot the second figure in the second subplot
plt.figure(figsize=(8, 8))
predictions_reshaped = predictions.reshape(d, d)

# Plot the 2D array using imshow
plt.figure(figsize=(8, 8))
heatmap = plt.imshow(predictions_reshaped, cmap='viridis', aspect='auto')
plt.colorbar(heatmap, label='Prediction Value')
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.xlabel("$z_1$", fontsize=14)
plt.ylabel("$z_2$", fontsize=14)
plt.title('Predictions Visualization')
plt.show()


return (z, z_decoded), predictions
1 change: 1 addition & 0 deletions pyroved/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Concat, _to_device)
from .prob import get_sampler
from .viz import plot_grid_traversal, plot_img_grid, plot_spect_grid
from .gp import gp_model

__all__ = ['generate_grid', 'transform_coordinates', 'generate_latent_grid',
'get_sampler', 'init_dataloader', 'init_ssvae_dataloaders',
Expand Down
28 changes: 28 additions & 0 deletions pyroved/utils/gp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import pyro
import pyro.contrib.gp as gp
import torch
from tqdm import tqdm
def gp_model(input_dim: int = None, encoded_X: torch.Tensor = None, y: torch.Tensor = None, gp_iterations: int = 1):
"""
Returns a GP model trained on the encoded data.
Args:
input_dim: Dimensionality of the input data.
encoded_X: Encoded data.
y: Target data.
Returns:
gpr: GP regression model.
"""
# Define and train the GP model
print("Training GP model...")
kernel = gp.kernels.RBF(input_dim=encoded_X.shape[1])
gpr = gp.models.GPRegression(encoded_X, y, kernel)
optimizer = torch.optim.Adam(gpr.parameters(), lr=0.005)
loss_fn = pyro.infer.Trace_ELBO().differentiable_loss
loss = loss_fn(gpr.model, gpr.guide)
for _ in tqdm(range(gp_iterations)):
optimizer.zero_grad()
loss.backward()
optimizer.step()
print("GP model trained.")

return gpr
23 changes: 22 additions & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,4 +643,25 @@ def test_save_load_basevae(invariances):
vae.save_weights("my_weights")
vae.load_weights("my_weights.pt")
weights_loaded = vae.state_dict()
assert_(assert_weights_equal(weights_loaded, weights_init))
assert_(assert_weights_equal(weights_loaded, weights_init))



def test_ivae_predict_on_latent():
num_samples = 10
train_data = torch.randn(num_samples, 5,5) # Example training data
gp_labels = torch.randint(0, 2, (num_samples,)) # Example GP labels
gp_iterations = 1
d = 12
in_dim = (5,5)

vae = models.iVAE(in_dim, latent_dim=2, invariances=None, seed=0)
(z, z_decoded) , predictions = vae.predict_on_latent(train_data, gp_labels, gp_iterations, d, plot=False)
assert isinstance(z, torch.Tensor), "z should be a torch.Tensor"
assert isinstance(predictions, torch.Tensor), "predictions should be a torch.Tensor"
assert z_decoded.dim() == 3, "z should be a 3-dimensional tensor"
assert predictions.dim() == 1, "predictions should be a 1-dimensional tensor"
# Check the shapes
expected_z_shape = (d * d, 5, 5) # Assuming this is the expected shape
assert z_decoded.shape == expected_z_shape, f"Shape of z should be {expected_z_shape}"
assert predictions.shape[0] == d * d, "Length of predictions should match number of points in grid"
16 changes: 16 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import pytest
import torch
import pyro
import pyro.contrib.gp as gp
from pyroved.utils import gp_model
from pyroved import models

def test_gp_model_output_shape():
input_dim = 3
num_samples = 5
encoded_X = torch.randn(num_samples, input_dim) # Random tensor for encoded_X
y = torch.randn(num_samples) # Random tensor for y
gpr = gp_model(input_dim, encoded_X, y)
with torch.no_grad():
predictions, _ = gpr(encoded_X)
assert predictions.shape == y.shape, "Output tensor shape mismatch"

0 comments on commit 7807ffb

Please sign in to comment.