Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conditional mace #57

Merged
merged 43 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
9f5c589
Use torch ode.
rousseab Jun 8, 2024
c8fd617
ODE sampling experiment.
rousseab Jun 9, 2024
b773948
Introducing a distinction between 'sampler' and 'generator'.
rousseab Jun 9, 2024
4fc3216
Refactoring to distinguish 'sampler' and 'generator'
rousseab Jun 9, 2024
7d65da4
Fix broken test.
rousseab Jun 9, 2024
e4a15fc
Fix typo.
rousseab Jun 9, 2024
29f86ac
One more level of abstraction in the generators.
rousseab Jun 9, 2024
db877a2
the ODE position generator.
rousseab Jun 9, 2024
5c17d6e
Mark some tests as slow so they can be skipped.
rousseab Jun 9, 2024
14786a0
Create sample trajectory object for ODE generator.
rousseab Jun 9, 2024
3ed1d6d
Updated the sampling callback to offer options.
rousseab Jun 9, 2024
a324c60
Fix test.
rousseab Jun 10, 2024
04daed3
Fix config.
rousseab Jun 10, 2024
721834c
Improved visuals.
rousseab Jun 10, 2024
bb78aa8
Experiment example.
rousseab Jun 10, 2024
45939f4
Looser atol on the ODE solver.
rousseab Jun 10, 2024
d02cbf2
Put times on the right device.
rousseab Jun 10, 2024
dad4e32
Improve device stuff.
rousseab Jun 10, 2024
71788d3
Improve device stuff.
rousseab Jun 10, 2024
9ddd6f9
Improve device stuff.
rousseab Jun 10, 2024
cc8831c
Improve device stuff.
rousseab Jun 10, 2024
962f79b
Improve device stuff.
rousseab Jun 10, 2024
c613948
Merge branch 'main' into ode_sampling
rousseab Jun 10, 2024
e30e358
Read from pickle.
rousseab Jun 10, 2024
328cd29
adding condition onf forces in diffiusion mace and cleaning the confi…
sblackburn-mila Jun 10, 2024
842e260
more example clean-up
sblackburn-mila Jun 10, 2024
0e0fc99
fixing a bug with biases in condition_embedding_layer
sblackburn-mila Jun 10, 2024
7e304aa
adding condition onf forces in diffiusion mace and cleaning the confi…
sblackburn-mila Jun 10, 2024
37ccedd
more example clean-up
sblackburn-mila Jun 10, 2024
225ea0d
fixing a bug with biases in condition_embedding_layer
sblackburn-mila Jun 10, 2024
df6ffb6
weird merge issue
sblackburn-mila Jun 10, 2024
18cb6a2
removing bias in conditional_layers
sblackburn-mila Jun 10, 2024
e74cbf2
fixing unit test
sblackburn-mila Jun 10, 2024
e418df8
Analysis scripts.
rousseab Jun 10, 2024
77eccc4
Different paths.
rousseab Jun 11, 2024
11d62f0
Yet another analysis script.
rousseab Jun 11, 2024
0b31dc2
Merge pull request #58 from mila-iqia/ode_sampling
sblackburn86 Jun 11, 2024
cbc5871
rm mila cluster config_diffusion_mace
sblackburn-mila Jun 11, 2024
3ddc3f5
more example clean-up
sblackburn-mila Jun 10, 2024
d887cd1
fixing a bug with biases in condition_embedding_layer
sblackburn-mila Jun 10, 2024
f77d70c
adding yaml config
sblackburn-mila Jun 11, 2024
3421ebb
fixing unit test
sblackburn-mila Jun 10, 2024
183eb07
weird git thing part 2
sblackburn-mila Jun 11, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
"""Perfect Score ODE sampling.

This little ad hoc experiment explores sampling with an ODE solver, using the 'analytic' score.
It works very well!
"""

import logging

import matplotlib.pyplot as plt
import torch
from einops import einops

from crystal_diffusion.analysis import PLEASANT_FIG_SIZE, PLOT_STYLE_PATH
from crystal_diffusion.analysis.analytic_score.utils import (get_exact_samples,
get_unit_cells)
from crystal_diffusion.generators.ode_position_generator import \
ExplodingVarianceODEPositionGenerator
from crystal_diffusion.models.score_networks.analytical_score_network import (
AnalyticalScoreNetwork, AnalyticalScoreNetworkParameters)
from crystal_diffusion.samplers.variance_sampler import NoiseParameters

logger = logging.getLogger(__name__)

plt.style.use(PLOT_STYLE_PATH)

if torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')

spatial_dimension = 3
number_of_atoms = 2
kmax = 1
spring_constant = 1000.
batch_size = 1000
total_time_steps = 41
if __name__ == '__main__':

noise_parameters = NoiseParameters(total_time_steps=total_time_steps,
sigma_min=0.001,
sigma_max=0.5)

equilibrium_relative_coordinates = torch.stack([0.25 * torch.ones(spatial_dimension),
0.75 * torch.ones(spatial_dimension)]).to(device)
inverse_covariance = torch.zeros(number_of_atoms, spatial_dimension, number_of_atoms, spatial_dimension).to(device)
for atom_i in range(number_of_atoms):
for alpha in range(spatial_dimension):
inverse_covariance[atom_i, alpha, atom_i, alpha] = spring_constant

score_network_parameters = AnalyticalScoreNetworkParameters(
number_of_atoms=number_of_atoms,
spatial_dimension=spatial_dimension,
kmax=kmax,
equilibrium_relative_coordinates=equilibrium_relative_coordinates,
inverse_covariance=inverse_covariance)

sigma_normalized_score_network = AnalyticalScoreNetwork(score_network_parameters)

position_generator = ExplodingVarianceODEPositionGenerator(noise_parameters,
number_of_atoms,
spatial_dimension,
sigma_normalized_score_network,
record_samples=True)

times = torch.linspace(0, 1, 1001)
sigmas = position_generator._get_exploding_variance_sigma(times)
ode_prefactor = position_generator._get_ode_prefactor(sigmas)

fig0 = plt.figure(figsize=PLEASANT_FIG_SIZE)
fig0.suptitle('ODE parameters')

ax1 = fig0.add_subplot(121)
ax2 = fig0.add_subplot(122)
ax1.set_title('$\\sigma$ Parameter')
ax2.set_title('$\\gamma$ Parameter')
ax1.plot(times, sigmas, '-')
ax2.plot(times, ode_prefactor, '-')

ax1.set_ylabel('$\\sigma(t)$')
ax2.set_ylabel('$\\gamma(t)$')
for ax in [ax1, ax2]:
ax.set_xlabel('Diffusion Time')
ax.set_xlim([-0.01, 1.01])

fig0.tight_layout()
plt.show()

unit_cell = get_unit_cells(acell=1., spatial_dimension=spatial_dimension, number_of_samples=batch_size).to(device)
relative_coordinates = position_generator.sample(number_of_samples=batch_size, device=device, unit_cell=unit_cell)

batch_times = position_generator.sample_trajectory_recorder.data['time'][0]
batch_relative_coordinates = position_generator.sample_trajectory_recorder.data['relative_coordinates'][0]
batch_flat_relative_coordinates = einops.rearrange(batch_relative_coordinates, "b t n d -> b t (n d)")

fig1 = plt.figure(figsize=PLEASANT_FIG_SIZE)
fig1.suptitle('ODE Trajectories')

ax = fig1.add_subplot(111)
ax.set_xlabel('Diffusion Time')
ax.set_ylabel('Raw Relative Coordinate')
ax.yaxis.tick_right()
ax.spines['top'].set_visible(True)
ax.spines['right'].set_visible(True)
ax.spines['bottom'].set_visible(True)
ax.spines['left'].set_visible(True)

time = batch_times[0] # all time arrays are the same
for flat_relative_coordinates in batch_flat_relative_coordinates[::20]:
for i in range(number_of_atoms * spatial_dimension):
coordinate = flat_relative_coordinates[:, i]
ax.plot(time.cpu(), coordinate.cpu(), '-', color='b', alpha=0.05)

ax.set_xlim([1.01, -0.01])
plt.show()

exact_samples = get_exact_samples(equilibrium_relative_coordinates,
inverse_covariance,
batch_size).cpu()

fig2 = plt.figure(figsize=PLEASANT_FIG_SIZE)
fig2.suptitle('Comparing ODE and Expected Marignal Distributions')
ax1 = fig2.add_subplot(131, aspect='equal')
ax2 = fig2.add_subplot(132, aspect='equal')
ax3 = fig2.add_subplot(133, aspect='equal')

xs = einops.rearrange(relative_coordinates, 'b n d -> (b n) d').cpu()
zs = einops.rearrange(exact_samples, 'b n d -> (b n) d').cpu()
ax1.set_title('XY Projection')
ax1.plot(xs[:, 0], xs[:, 1], 'ro', alpha=0.5, mew=0, label='ODE Solver')
ax1.plot(zs[:, 0], zs[:, 1], 'go', alpha=0.05, mew=0, label='Exact Samples')

ax2.set_title('XZ Projection')
ax2.plot(xs[:, 0], xs[:, 2], 'ro', alpha=0.5, mew=0, label='ODE Solver')
ax2.plot(zs[:, 0], zs[:, 2], 'go', alpha=0.05, mew=0, label='Exact Samples')

ax3.set_title('YZ Projection')
ax3.plot(xs[:, 1], xs[:, 2], 'ro', alpha=0.5, mew=0, label='ODE Solver')
ax3.plot(zs[:, 1], zs[:, 2], 'go', alpha=0.05, mew=0, label='Exact Samples')

for ax in [ax1, ax2, ax3]:
ax.set_xlim(-0.01, 1.01)
ax.set_ylim(-0.01, 1.01)
ax.vlines(x=[0, 1], ymin=0, ymax=1, color='k', lw=2)
ax.hlines(y=[0, 1], xmin=0, xmax=1, color='k', lw=2)

ax1.legend(loc=0)
fig2.tight_layout()
plt.show()

fig3 = plt.figure(figsize=PLEASANT_FIG_SIZE)
ax1 = fig3.add_subplot(131)
ax2 = fig3.add_subplot(132)
ax3 = fig3.add_subplot(133)
fig3.suptitle("Marginal Distributions of t=0 Samples")

common_params = dict(histtype='stepfilled', alpha=0.5, bins=50)

ax1.hist(xs[:, 0], **common_params, facecolor='r', label='ODE solver')
ax2.hist(xs[:, 1], **common_params, facecolor='r', label='ODE solver')
ax3.hist(xs[:, 2], **common_params, facecolor='r', label='ODE solver')

ax1.hist(zs[:, 0], **common_params, facecolor='g', label='Exact')
ax2.hist(zs[:, 1], **common_params, facecolor='g', label='Exact')
ax3.hist(zs[:, 2], **common_params, facecolor='g', label='Exact')

ax1.set_xlabel('X')
ax2.set_xlabel('Y')
ax3.set_xlabel('Z')

for ax in [ax1, ax2, ax3]:
ax.set_xlim(-0.01, 1.01)

ax1.legend(loc=0)
fig3.tight_layout()
plt.show()
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
from crystal_diffusion.analysis.analytic_score.utils import (
get_exact_samples, get_random_equilibrium_relative_coordinates,
get_random_inverse_covariance, get_samples_harmonic_energy, get_unit_cells)
from crystal_diffusion.generators.predictor_corrector_position_generator import \
AnnealedLangevinDynamicsGenerator
from crystal_diffusion.models.score_networks.analytical_score_network import (
AnalyticalScoreNetwork, AnalyticalScoreNetworkParameters)
from crystal_diffusion.samplers.predictor_corrector_position_sampler import \
AnnealedLangevinDynamicsSampler
from crystal_diffusion.samplers.variance_sampler import NoiseParameters

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -64,8 +64,8 @@
record_samples=False,
positions_require_grad=False)

pc_sampler = AnnealedLangevinDynamicsSampler(sigma_normalized_score_network=sigma_normalized_score_network,
**sampler_parameters)
pc_sampler = AnnealedLangevinDynamicsGenerator(sigma_normalized_score_network=sigma_normalized_score_network,
**sampler_parameters)

unit_cell = get_unit_cells(acell=1., spatial_dimension=spatial_dimension, number_of_samples=number_of_samples)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
from crystal_diffusion.analysis.analytic_score.utils import (
get_exact_samples, get_random_equilibrium_relative_coordinates,
get_random_inverse_covariance, get_unit_cells)
from crystal_diffusion.generators.predictor_corrector_position_generator import \
AnnealedLangevinDynamicsGenerator
from crystal_diffusion.models.score_networks.analytical_score_network import (
AnalyticalScoreNetwork, AnalyticalScoreNetworkParameters)
from crystal_diffusion.samplers.predictor_corrector_position_sampler import \
AnnealedLangevinDynamicsSampler
from crystal_diffusion.samplers.variance_sampler import NoiseParameters

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -60,8 +60,8 @@
record_samples=False,
positions_require_grad=False)

pc_sampler = AnnealedLangevinDynamicsSampler(sigma_normalized_score_network=sigma_normalized_score_network,
**sampler_parameters)
pc_sampler = AnnealedLangevinDynamicsGenerator(sigma_normalized_score_network=sigma_normalized_score_network,
**sampler_parameters)

unit_cell = get_unit_cells(acell=1., spatial_dimension=spatial_dimension, number_of_samples=number_of_samples)

Expand Down
3 changes: 2 additions & 1 deletion crystal_diffusion/analysis/analytic_score/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def get_random_inverse_covariance(spring_constant_scale: float, number_of_atoms:
def get_exact_samples(equilibrium_relative_coordinates: torch.Tensor, inverse_covariance: torch.Tensor,
number_of_samples: int) -> torch.Tensor:
"""Sample the exact harmonic energy."""
device = equilibrium_relative_coordinates.device
natom, spatial_dimension, _, _ = inverse_covariance.shape

flat_dim = natom * spatial_dimension
Expand All @@ -53,7 +54,7 @@ def get_exact_samples(equilibrium_relative_coordinates: torch.Tensor, inverse_co

sigmas = 1. / torch.sqrt(eigenvalues)

z_scores = torch.randn(number_of_samples, flat_dim)
z_scores = torch.randn(number_of_samples, flat_dim).to(device)

sigma_z_scores = z_scores * sigmas.unsqueeze(0)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

from crystal_diffusion import TOP_DIR
from crystal_diffusion.analysis import PLEASANT_FIG_SIZE, PLOT_STYLE_PATH
from crystal_diffusion.utils.sample_trajectory import SampleTrajectory
from crystal_diffusion.utils.sample_trajectory import \
PredictorCorrectorSampleTrajectory

plt.style.use(PLOT_STYLE_PATH)

Expand All @@ -34,10 +35,10 @@
energies = torch.load(energy_sample_directory / f"energies_sample_epoch={epoch}.pt")

pickle_path = trajectory_data_directory / f"diffusion_position_sample_epoch={epoch}.pt"
sample_trajectory = SampleTrajectory.read_from_pickle(pickle_path)
sample_trajectory = PredictorCorrectorSampleTrajectory.read_from_pickle(pickle_path)

pickle_path = trajectory_data_directory / f"diffusion_position_sample_epoch={epoch}.pt"
sample_trajectory = SampleTrajectory.read_from_pickle(pickle_path)
sample_trajectory = PredictorCorrectorSampleTrajectory.read_from_pickle(pickle_path)

list_predictor_coordinates = sample_trajectory.data['predictor_x_i']
float_datatype = list_predictor_coordinates[0].dtype
Expand Down
7 changes: 4 additions & 3 deletions crystal_diffusion/analysis/positions_to_cif_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from pymatgen.core import Lattice, Structure

from crystal_diffusion import TOP_DIR
from crystal_diffusion.utils.sample_trajectory import SampleTrajectory
from crystal_diffusion.utils.sample_trajectory import \
PredictorCorrectorSampleTrajectory

# Hard coding some paths to local results. Modify as needed...
epoch = 35
Expand All @@ -25,10 +26,10 @@

if __name__ == '__main__':
pickle_path = trajectory_data_directory / f"diffusion_position_sample_epoch={epoch}.pt"
sample_trajectory = SampleTrajectory.read_from_pickle(pickle_path)
sample_trajectory = PredictorCorrectorSampleTrajectory.read_from_pickle(pickle_path)

pickle_path = trajectory_data_directory / f"diffusion_position_sample_epoch={epoch}.pt"
sample_trajectory = SampleTrajectory.read_from_pickle(pickle_path)
sample_trajectory = PredictorCorrectorSampleTrajectory.read_from_pickle(pickle_path)

basis_vectors = sample_trajectory.data['unit_cell'][sample_idx].numpy()
lattice = Lattice(matrix=basis_vectors, pbc=(True, True, True))
Expand Down
Loading
Loading