Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

https://github.com/ziatdinovmax/pyroVED/issues/54 #55

Merged
merged 11 commits into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,5 @@ dmypy.json

# Pyre type checker
.pyre/

*tar.gz
533 changes: 533 additions & 0 deletions examples/VAE_gp.ipynb

Large diffs are not rendered by default.

61 changes: 58 additions & 3 deletions pyroved/models/ivae.py
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Almost there... Instead of returning z, we should be returning z_decoded, which is z passed through a trained decoder. In this particular case, we can obtain it simply as z_decoded = self.manifold2d(d, plot=False).

I'm curious if there's a specific reason you don't want to train GP inside utils/gp.py?

Copy link
Contributor Author

@utkarshp1161 utkarshp1161 Jan 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Almost there... Instead of returning z, we should be returning z_decoded, which is z passed through a trained decoder. In this particular case, we can obtain it simply as z_decoded = self.manifold2d(d, plot=False).

I was having something else in mind, basically having the latent coordinates rather than decoded ones. Changed as per this suggestion.

I'm curious if there's a specific reason you don't want to train GP inside utils/gp.py?

Changed training to utils. Right, this way its more modular.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was having something else in mind, basically having the latent coordinates rather than decoded ones. Changed as per this suggestion.

In principle, we can have all three. Something like

return (z, z_decoded), predictions

Let me know if you would like to add this. Other than that I'm ready to merge it.

Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,18 @@
"""

from typing import Optional, Tuple, Union, List

import pyro
import pyro.distributions as dist
import torch

import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from pyroved.models.base import baseVAE
from pyroved.nets import fcDecoderNet, fcEncoderNet, sDecoderNet
from pyroved.utils import (
generate_grid, generate_latent_grid, get_sampler,
plot_img_grid, plot_spect_grid, set_deterministic_mode,
to_onehot, transform_coordinates
to_onehot, transform_coordinates, gp_model
)


Expand Down Expand Up @@ -307,3 +308,57 @@ def manifold2d(self, d: int,
elif self.ndim == 1:
plot_spect_grid(loc, d, **kwargs)
return loc

def predict_on_latent(self, train_data: torch.Tensor, gp_labels: torch.Tensor, gp_iterations: int = 1, d: int = 12, plot: bool = False):
"""
Predicts on the latent grid using a trained GP
Args:
train_data: Training data used to train the VAE
gp_labels: Labels for training data
gp_iterations: Number of iterations for GP training
d: Grid size

Returns:
z: Latent grid
z_decoded: Decoded latent grid
predictions: Predictions on the latent grid
"""
# Convert X and y to torch tensors
X = torch.tensor(train_data, dtype=torch.float32)
y = torch.tensor(gp_labels, dtype=torch.float32)

# Use VAE's encoder to transform X into the latent space
encoded_X = self.encode(X)[0] # Assuming the encoder returns mean as the first element

gpr = gp_model(input_dim=encoded_X.shape[1], encoded_X=encoded_X, y=y, gp_iterations=gp_iterations)


# Generate the latent grid
z, (grid_x, grid_y) = generate_latent_grid(d)
z = torch.tensor(z, dtype=torch.float32)

# Predict on the latent grid using the trained GP
gpr.eval()
with torch.no_grad():
predictions, _ = gpr(z)
x, y = np.array(z).T
z_decoded = self.manifold2d(d, plot=False)
if plot:
self.manifold2d(d=d, cmap='viridis')
# Plot the second figure in the second subplot
plt.figure(figsize=(8, 8))
predictions_reshaped = predictions.reshape(d, d)

# Plot the 2D array using imshow
plt.figure(figsize=(8, 8))
heatmap = plt.imshow(predictions_reshaped, cmap='viridis', aspect='auto')
plt.colorbar(heatmap, label='Prediction Value')
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.xlabel("$z_1$", fontsize=14)
plt.ylabel("$z_2$", fontsize=14)
plt.title('Predictions Visualization')
plt.show()


return (z, z_decoded), predictions
1 change: 1 addition & 0 deletions pyroved/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Concat, _to_device)
from .prob import get_sampler
from .viz import plot_grid_traversal, plot_img_grid, plot_spect_grid
from .gp import gp_model

__all__ = ['generate_grid', 'transform_coordinates', 'generate_latent_grid',
'get_sampler', 'init_dataloader', 'init_ssvae_dataloaders',
Expand Down
28 changes: 28 additions & 0 deletions pyroved/utils/gp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import pyro
import pyro.contrib.gp as gp
import torch
from tqdm import tqdm
def gp_model(input_dim: int = None, encoded_X: torch.Tensor = None, y: torch.Tensor = None, gp_iterations: int = 1):
"""
Returns a GP model trained on the encoded data.
Args:
input_dim: Dimensionality of the input data.
encoded_X: Encoded data.
y: Target data.
Returns:
gpr: GP regression model.
"""
# Define and train the GP model
print("Training GP model...")
kernel = gp.kernels.RBF(input_dim=encoded_X.shape[1])
gpr = gp.models.GPRegression(encoded_X, y, kernel)
optimizer = torch.optim.Adam(gpr.parameters(), lr=0.005)
loss_fn = pyro.infer.Trace_ELBO().differentiable_loss
loss = loss_fn(gpr.model, gpr.guide)
for _ in tqdm(range(gp_iterations)):
optimizer.zero_grad()
loss.backward()
optimizer.step()
print("GP model trained.")

return gpr
23 changes: 22 additions & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,4 +643,25 @@ def test_save_load_basevae(invariances):
vae.save_weights("my_weights")
vae.load_weights("my_weights.pt")
weights_loaded = vae.state_dict()
assert_(assert_weights_equal(weights_loaded, weights_init))
assert_(assert_weights_equal(weights_loaded, weights_init))



def test_ivae_predict_on_latent():
num_samples = 10
train_data = torch.randn(num_samples, 5,5) # Example training data
gp_labels = torch.randint(0, 2, (num_samples,)) # Example GP labels
gp_iterations = 1
d = 12
in_dim = (5,5)

vae = models.iVAE(in_dim, latent_dim=2, invariances=None, seed=0)
(z, z_decoded) , predictions = vae.predict_on_latent(train_data, gp_labels, gp_iterations, d, plot=False)
assert isinstance(z, torch.Tensor), "z should be a torch.Tensor"
assert isinstance(predictions, torch.Tensor), "predictions should be a torch.Tensor"
assert z_decoded.dim() == 3, "z should be a 3-dimensional tensor"
assert predictions.dim() == 1, "predictions should be a 1-dimensional tensor"
# Check the shapes
expected_z_shape = (d * d, 5, 5) # Assuming this is the expected shape
assert z_decoded.shape == expected_z_shape, f"Shape of z should be {expected_z_shape}"
assert predictions.shape[0] == d * d, "Length of predictions should match number of points in grid"
16 changes: 16 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import pytest
import torch
import pyro
import pyro.contrib.gp as gp
from pyroved.utils import gp_model
from pyroved import models

def test_gp_model_output_shape():
input_dim = 3
num_samples = 5
encoded_X = torch.randn(num_samples, input_dim) # Random tensor for encoded_X
y = torch.randn(num_samples) # Random tensor for y
gpr = gp_model(input_dim, encoded_X, y)
with torch.no_grad():
predictions, _ = gpr(encoded_X)
assert predictions.shape == y.shape, "Output tensor shape mismatch"
Loading