From ea044dfced39d960d9d344882a0520f64a67d37d Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Mon, 2 Dec 2024 13:53:47 -0500 Subject: [PATCH 01/43] Finish comments with something more definitive. --- .../sample_diffusion.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py b/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py index 26260536..1fcf2e9e 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py @@ -211,6 +211,8 @@ def create_samples_and_write_to_disk( output_directory / "trajectories.pt" ) + logger.info("Done!") + if __name__ == "__main__": main() From 2ed2eac5d73c1ad3085f68524fa9fa642c29fa8c Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Mon, 2 Dec 2024 16:15:16 -0500 Subject: [PATCH 02/43] Cleaner repaint. --- .../constrained_langevin_generator.py | 190 ++++++++++------- .../test_constrained_langevin_generator.py | 195 +++++++++++++----- 2 files changed, 252 insertions(+), 133 deletions(-) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py b/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py index 99db71ce..06dc9455 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py @@ -1,91 +1,135 @@ -from dataclasses import dataclass - -import numpy as np +import einops import torch from tqdm import tqdm from diffusion_for_multi_scale_molecular_dynamics.generators.langevin_generator import \ LangevinGenerator -from diffusion_for_multi_scale_molecular_dynamics.generators.predictor_corrector_axl_generator import \ - PredictorCorrectorSamplingParameters -from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.score_network import \ - ScoreNetwork from diffusion_for_multi_scale_molecular_dynamics.namespace import AXL -from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ - NoiseParameters +from diffusion_for_multi_scale_molecular_dynamics.noisers.atom_types_noiser import \ + AtomTypesNoiser from diffusion_for_multi_scale_molecular_dynamics.noisers.relative_coordinates_noiser import \ RelativeCoordinatesNoiser +from diffusion_for_multi_scale_molecular_dynamics.utils.d3pm_utils import \ + class_index_to_onehot -@dataclass(kw_only=True) -class ConstrainedLangevinGeneratorParameters(PredictorCorrectorSamplingParameters): - """Hyper-parameters for diffusion sampling with the predictor-corrector algorithm.""" - - algorithm: str = "constrained_langevin" - constrained_relative_coordinates: ( - np.ndarray - ) # the positions that must be satisfied at the end of sampling. - +class ConstrainedPredictorCorrectorAXLGenerator: + """Constrained Predictor Corrector AXL Generator. -class ConstrainedLangevinGenerator(LangevinGenerator): - """Constrained Annealed Langevin Dynamics Generator. - - This generator implements a basic version of the inpainting algorithm presented in the - paper + This class constrains the input PC generator following a basic version of the inpainting algorithm + presented in the paper "RePaint: Inpainting using Denoising Diffusion Probabilistic Models". """ def __init__( self, - noise_parameters: NoiseParameters, - sampling_parameters: ConstrainedLangevinGeneratorParameters, - axl_network: ScoreNetwork, + generator: LangevinGenerator, + reference_composition: AXL, + constrained_atom_indices: torch.Tensor, ): """Init method.""" - super().__init__(noise_parameters, sampling_parameters, axl_network) + self.generator = generator + + self.number_of_atoms = self.generator.number_of_atoms + self.num_classes = self.generator.num_classes - self.constraint_relative_coordinates = torch.from_numpy( - sampling_parameters.constrained_relative_coordinates - ) # TODO constraint the atom type as well + self.reference_composition = reference_composition + self.constraint_indices = constrained_atom_indices assert ( - len(self.constraint_relative_coordinates.shape) == 2 + len(self.reference_composition.X.shape) == 2 ), "The constrained relative coordinates have the wrong shape" - number_of_constraints, spatial_dimensions = ( - self.constraint_relative_coordinates.shape - ) - assert ( - number_of_constraints <= self.number_of_atoms - ), "There are more constrained positions than atoms!" assert ( - spatial_dimensions <= self.spatial_dimension - ), "The spatial dimension of the constrained positions is inconsistent" + len(self.reference_composition.A.shape) == 1 + ), "The constrained atom types have the wrong shape" - # Without loss of generality, we impose that the first positions are constrained. - # This should have no consequence for a permutation equivariant model. - self.constraint_mask = torch.zeros(self.number_of_atoms, dtype=bool) - self.constraint_mask[:number_of_constraints] = True + assert ( + len(constrained_atom_indices.shape) == 1 + ), "The constrained_atom_indices array has the wrong shape" self.relative_coordinates_noiser = RelativeCoordinatesNoiser() + self.atom_type_noiser = AtomTypesNoiser() def _apply_constraint(self, composition: AXL, device: torch.device) -> AXL: """This method applies the coordinate constraint on the input configuration.""" - x = composition.X - x[:, self.constraint_mask] = self.constraint_relative_coordinates.to(device) - updated_axl = AXL( - A=composition.A, - X=x, + constrained_x = composition.X.clone() + constrained_x[:, self.constraint_indices] = self.reference_composition.X[ + self.constraint_indices + ].to(device) + + constrained_a = composition.A.clone() + constrained_a[:, self.constraint_indices] = self.reference_composition.A[ + self.constraint_indices + ].to(device) + + constrained_composition = AXL( + A=constrained_a, + X=constrained_x, L=composition.L, ) - return updated_axl + return constrained_composition + + def _get_noised_known_composition( + self, i: int, number_of_samples: int, device: torch.device + ) -> AXL: + """This method applies the noise to the known composition.""" + # Initialize compositions that satisfies the constraint, but is otherwise random. + # Since the noising process is 'atom-per-atom', the non-constrained position should have no impact. + composition0_known = self.generator.initialize(number_of_samples, device) + composition0_known = self._apply_constraint(composition0_known, device) + + q_bar_matrices_i = einops.repeat( + self.generator.noise.q_bar_matrix[i].to(device), + "n1 n2 -> nsamples natoms n1 n2", + nsamples=number_of_samples, + natoms=self.number_of_atoms, + ) + + sigma_i = self.generator.noise.sigma[i] + coordinates_broadcasting = torch.ones_like(composition0_known.X) + broadcast_sigmas_i = sigma_i * coordinates_broadcasting + + # Noise an example satisfying the constraints from t_0 to t_i + x_i_known = ( + self.relative_coordinates_noiser.get_noisy_relative_coordinates_sample( + composition0_known.X, broadcast_sigmas_i + ) + ) + + one_hot_a_i = class_index_to_onehot( + composition0_known.A, num_classes=self.num_classes + ) + a_i_known = self.atom_type_noiser.get_noisy_atom_types_sample( + one_hot_a_i, q_bar_matrices_i + ) + + noised_composition = AXL(A=a_i_known, X=x_i_known, L=composition0_known.L) + return noised_composition + + def _combine_noised_and_denoised_compositions( + self, noised_composition: AXL, denoised_composition: AXL + ) -> AXL: + + updated_x = denoised_composition.X.clone() + updated_a = denoised_composition.A.clone() + + updated_x[:, self.constraint_indices] = noised_composition.X[ + :, self.constraint_indices + ] + updated_a[:, self.constraint_indices] = noised_composition.A[ + :, self.constraint_indices + ] + + composition_i = AXL(A=updated_a, X=updated_x, L=denoised_composition.L) + return composition_i def sample( self, number_of_samples: int, device: torch.device, unit_cell: torch.Tensor ) -> AXL: """Sample. - This method draws samples, imposing the satisfaction of positional constraints. + This method draws samples, imposing the satisfaction of positional constraints. Args: number_of_samples : number of samples to draw. @@ -98,48 +142,36 @@ def sample( """ assert unit_cell.size() == ( number_of_samples, - self.spatial_dimension, - self.spatial_dimension, + self.generator.spatial_dimension, + self.generator.spatial_dimension, ), ( "Unit cell passed to sample should be of size (number of sample, spatial dimension, spatial dimension" + f"Got {unit_cell.size()}" ) - # Initialize a configuration that satisfy the constraint, but is otherwise random. - # Since the noising process is 'atom-per-atom', the non-constrained position should have no impact. - composition0_known = self.initialize(number_of_samples, device) - # this is an AXL objet - - composition0_known = self._apply_constraint(composition0_known, device) - - composition_ip1 = self.initialize(number_of_samples, device) + composition_ip1 = self.generator.initialize(number_of_samples, device) forces = torch.zeros_like(composition_ip1.X) - coordinates_broadcasting = torch.ones( - number_of_samples, self.number_of_atoms, self.spatial_dimension - ).to(device) - - for i in tqdm(range(self.number_of_discretization_steps - 1, -1, -1)): - sigma_i = self.noise.sigma[i] - broadcast_sigmas_i = sigma_i * coordinates_broadcasting - # Noise an example satisfying the constraints from t_0 to t_i - x_i_known = ( - self.relative_coordinates_noiser.get_noisy_relative_coordinates_sample( - composition0_known.X, broadcast_sigmas_i - ) + for i in tqdm(range(self.generator.number_of_discretization_steps - 1, -1, -1)): + + # Noise from t_0 to t_i + noised_composition_i = self._get_noised_known_composition( + i, number_of_samples, device ) + # Denoise from t_{i+1} to t_i - composition_i = self.predictor_step( + denoised_composition_i = self.generator.predictor_step( composition_ip1, i + 1, unit_cell, forces ) - # Combine the known and unknown - x_i = composition_i.X - x_i[:, self.constraint_mask] = x_i_known[:, self.constraint_mask] - composition_i = AXL(A=composition_i.A, X=x_i, L=composition_i.L) + composition_i = self._combine_noised_and_denoised_compositions( + noised_composition_i, denoised_composition_i + ) - for _ in range(self.number_of_corrector_steps): - composition_i = self.corrector_step(composition_i, i, unit_cell, forces) + for _ in range(self.generator.number_of_corrector_steps): + composition_i = self.generator.corrector_step( + composition_i, i, unit_cell, forces + ) composition_ip1 = composition_i diff --git a/tests/generators/test_constrained_langevin_generator.py b/tests/generators/test_constrained_langevin_generator.py index 59f2bb6d..0c898e5d 100644 --- a/tests/generators/test_constrained_langevin_generator.py +++ b/tests/generators/test_constrained_langevin_generator.py @@ -2,92 +2,179 @@ import pytest import torch -from diffusion_for_multi_scale_molecular_dynamics.generators.constrained_langevin_generator import ( - ConstrainedLangevinGenerator, ConstrainedLangevinGeneratorParameters) +from diffusion_for_multi_scale_molecular_dynamics.generators.constrained_langevin_generator import \ + ConstrainedPredictorCorrectorAXLGenerator from diffusion_for_multi_scale_molecular_dynamics.namespace import AXL +from diffusion_for_multi_scale_molecular_dynamics.utils.basis_transformations import \ + map_relative_coordinates_to_unit_cell from tests.generators.test_langevin_generator import TestLangevinGenerator class TestConstrainedLangevinGenerator(TestLangevinGenerator): @pytest.fixture() - def constrained_relative_coordinates(self, number_of_atoms, spatial_dimension): - number_of_constraints = number_of_atoms // 2 - return torch.rand(number_of_constraints, spatial_dimension).numpy() - - @pytest.fixture() - def sampling_parameters( + def reference_composition( self, number_of_atoms, spatial_dimension, - number_of_samples, - cell_dimensions, - number_of_corrector_steps, - unit_cell_size, - constrained_relative_coordinates, - num_atom_types, + num_atomic_classes, + device, ): - sampling_parameters = ConstrainedLangevinGeneratorParameters( - number_of_corrector_steps=number_of_corrector_steps, - number_of_atoms=number_of_atoms, - number_of_samples=number_of_samples, - cell_dimensions=cell_dimensions, - spatial_dimension=spatial_dimension, - constrained_relative_coordinates=constrained_relative_coordinates, - num_atom_types=num_atom_types, - ) - - return sampling_parameters - - @pytest.fixture() - def pc_generator(self, noise_parameters, sampling_parameters, axl_network): - generator = ConstrainedLangevinGenerator( - noise_parameters=noise_parameters, - sampling_parameters=sampling_parameters, - axl_network=axl_network, + return AXL( + A=torch.randint(0, num_atomic_classes, (number_of_atoms,)).to(device), + X=map_relative_coordinates_to_unit_cell( + torch.rand(number_of_atoms, spatial_dimension) + ).to(device), + L=torch.zeros(spatial_dimension * (spatial_dimension - 1)).to( + device + ), # TODO placeholder ) - return generator - @pytest.fixture() - def axl( + def random_compositions( self, number_of_samples, number_of_atoms, spatial_dimension, - num_atom_types, + num_atomic_classes, device, ): return AXL( A=torch.randint( - 0, num_atom_types + 1, (number_of_samples, number_of_atoms) + 0, + num_atomic_classes, + ( + number_of_samples, + number_of_atoms, + ), ).to(device), - X=torch.rand(number_of_samples, number_of_atoms, spatial_dimension).to( - device - ), - L=torch.rand( - number_of_samples, spatial_dimension * (spatial_dimension - 1) - ).to( + X=map_relative_coordinates_to_unit_cell( + torch.rand(number_of_samples, number_of_atoms, spatial_dimension) + ).to(device), + L=torch.zeros(spatial_dimension * (spatial_dimension - 1)).to( device ), # TODO placeholder ) + @pytest.fixture() + def constrained_atom_indices(self, number_of_atoms, device): + number_of_constraints = number_of_atoms // 2 + return torch.randperm(number_of_atoms)[:number_of_constraints].to(device) + + @pytest.fixture() + def constrained_pc_generator( + self, pc_generator, reference_composition, constrained_atom_indices + ): + constrained_generator = ConstrainedPredictorCorrectorAXLGenerator( + generator=pc_generator, + reference_composition=reference_composition, + constrained_atom_indices=constrained_atom_indices, + ) + + return constrained_generator + + @pytest.fixture() + def constrained_samples( + self, constrained_pc_generator, number_of_samples, device, unit_cell_sample + ): + samples = constrained_pc_generator.sample( + number_of_samples, device, unit_cell_sample + ) + return samples + + def test_constraints( + self, + constrained_samples, + reference_composition, + constrained_atom_indices, + number_of_samples, + ): + reference_x = einops.repeat( + reference_composition.X[constrained_atom_indices], + "... -> n ...", + n=number_of_samples, + ) + reference_a = einops.repeat( + reference_composition.A[constrained_atom_indices], + "... -> n ...", + n=number_of_samples, + ) + + torch.testing.assert_close( + constrained_samples.X[:, constrained_atom_indices], reference_x + ) + torch.testing.assert_close( + constrained_samples.A[:, constrained_atom_indices], reference_a + ) + def test_apply_constraint( - self, pc_generator, axl, constrained_relative_coordinates, device + self, + constrained_pc_generator, + number_of_samples, + random_compositions, + reference_composition, + constrained_atom_indices, + device, ): - batch_size = axl.X.shape[0] - original_x = torch.clone(axl.X) - pc_generator._apply_constraint(axl, device) - number_of_constraints = len(constrained_relative_coordinates) + constrained_compositions = constrained_pc_generator._apply_constraint( + random_compositions, device + ) - constrained_x = einops.repeat( - torch.from_numpy(constrained_relative_coordinates).to(device), - "n d -> b n d", - b=batch_size, + reference_x = einops.repeat( + reference_composition.X[constrained_atom_indices], + "... -> n ...", + n=number_of_samples, + ) + reference_a = einops.repeat( + reference_composition.A[constrained_atom_indices], + "... -> n ...", + n=number_of_samples, ) - torch.testing.assert_close(axl.X[:, :number_of_constraints], constrained_x) torch.testing.assert_close( - axl.X[:, number_of_constraints:], original_x[:, number_of_constraints:] + constrained_compositions.X[:, constrained_atom_indices], reference_x + ) + torch.testing.assert_close( + constrained_compositions.A[:, constrained_atom_indices], reference_a + ) + + def test_combine_noised_and_denoised_compositions( + self, + constrained_pc_generator, + constrained_atom_indices, + number_of_samples, + number_of_atoms, + spatial_dimension, + device, + ) -> AXL: + + noised_mask = torch.zeros(number_of_atoms, dtype=torch.bool).to(device) + noised_mask[constrained_atom_indices] = True + + noised_compositions = AXL( + A=torch.zeros(number_of_samples, number_of_atoms).to(device), + X=torch.zeros(number_of_samples, number_of_atoms, spatial_dimension).to( + device + ), + L=0.0, + ) + + denoised_compositions = AXL( + A=torch.ones(number_of_samples, number_of_atoms).to(device), + X=torch.ones(number_of_samples, number_of_atoms, spatial_dimension).to( + device + ), + L=0.0, + ) + + combined_compositions = ( + constrained_pc_generator._combine_noised_and_denoised_compositions( + noised_compositions, denoised_compositions + ) ) + + assert (combined_compositions.X[:, noised_mask] == 0.0).all() + assert (combined_compositions.X[:, ~noised_mask] == 1.0).all() + assert (combined_compositions.A[:, noised_mask] == 0.0).all() + assert (combined_compositions.A[:, ~noised_mask] == 1.0).all() From f09d68c5188d890283c11392568a7a355e86c796 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Mon, 2 Dec 2024 16:16:01 -0500 Subject: [PATCH 03/43] Modified code in experiments. Probably broken. --- .../analytic_score/repaint/repaint_with_analytic_score.py | 5 +++-- experiments/sampling_sota_model/repaint_with_sota_score.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/experiments/analysis/analytic_score/repaint/repaint_with_analytic_score.py b/experiments/analysis/analytic_score/repaint/repaint_with_analytic_score.py index 2ea37960..e75550c6 100644 --- a/experiments/analysis/analytic_score/repaint/repaint_with_analytic_score.py +++ b/experiments/analysis/analytic_score/repaint/repaint_with_analytic_score.py @@ -8,7 +8,8 @@ from diffusion_for_multi_scale_molecular_dynamics.analysis import \ PLOT_STYLE_PATH from diffusion_for_multi_scale_molecular_dynamics.generators.constrained_langevin_generator import ( - ConstrainedLangevinGenerator, ConstrainedLangevinGeneratorParameters) + ConstrainedLangevinGeneratorParameters, + ConstrainedPredictorCorrectorAXLGenerator) from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.analytical_score_network import ( AnalyticalScoreNetwork, AnalyticalScoreNetworkParameters) from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ @@ -98,7 +99,7 @@ record_samples=True, ) - position_generator = ConstrainedLangevinGenerator( + position_generator = ConstrainedPredictorCorrectorAXLGenerator( noise_parameters=noise_parameters, sampling_parameters=sampling_parameters, sigma_normalized_score_network=sigma_normalized_score_network, diff --git a/experiments/sampling_sota_model/repaint_with_sota_score.py b/experiments/sampling_sota_model/repaint_with_sota_score.py index 8e8d7dad..0c45547b 100644 --- a/experiments/sampling_sota_model/repaint_with_sota_score.py +++ b/experiments/sampling_sota_model/repaint_with_sota_score.py @@ -11,7 +11,8 @@ from diffusion_for_multi_scale_molecular_dynamics.analysis import ( PLEASANT_FIG_SIZE, PLOT_STYLE_PATH) from diffusion_for_multi_scale_molecular_dynamics.generators.constrained_langevin_generator import ( - ConstrainedLangevinGenerator, ConstrainedLangevinGeneratorParameters) + ConstrainedLangevinGeneratorParameters, + ConstrainedPredictorCorrectorAXLGenerator) from diffusion_for_multi_scale_molecular_dynamics.models.instantiate_diffusion_model import \ load_diffusion_model from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ @@ -84,7 +85,7 @@ record_samples=True, ) - position_generator = ConstrainedLangevinGenerator( + position_generator = ConstrainedPredictorCorrectorAXLGenerator( noise_parameters=noise_parameters, sampling_parameters=sampling_parameters, sigma_normalized_score_network=sigma_normalized_score_network, From f9c35956855d08879f303be870a7b60fef35e520 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Mon, 2 Dec 2024 19:18:54 -0500 Subject: [PATCH 04/43] Let's try to sample with constraints! --- .../constrained_langevin_generator.py | 3 ++ .../sample_diffusion.py | 54 +++++++++++++------ .../utils/ovito_utils.py | 13 +++++ 3 files changed, 55 insertions(+), 15 deletions(-) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py b/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py index 06dc9455..924010e7 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py @@ -30,6 +30,9 @@ def __init__( """Init method.""" self.generator = generator + if hasattr(self.generator, "sample_trajectory_recorder"): + self.sample_trajectory_recorder = self.generator.sample_trajectory_recorder + self.number_of_atoms = self.generator.number_of_atoms self.num_classes = self.generator.num_classes diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py b/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py index 1fcf2e9e..a296749c 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py @@ -16,8 +16,12 @@ ElementTypes from diffusion_for_multi_scale_molecular_dynamics.generators.axl_generator import \ SamplingParameters +from diffusion_for_multi_scale_molecular_dynamics.generators.constrained_langevin_generator import \ + ConstrainedPredictorCorrectorAXLGenerator from diffusion_for_multi_scale_molecular_dynamics.generators.instantiate_generator import \ instantiate_generator +from diffusion_for_multi_scale_molecular_dynamics.generators.langevin_generator import \ + LangevinGenerator from diffusion_for_multi_scale_molecular_dynamics.generators.load_sampling_parameters import \ load_sampling_parameters from diffusion_for_multi_scale_molecular_dynamics.models.axl_diffusion_lightning_model import \ @@ -36,6 +40,8 @@ get_git_hash, setup_console_logger) from diffusion_for_multi_scale_molecular_dynamics.utils.main_utils import \ load_and_backup_hyperparameters +from diffusion_for_multi_scale_molecular_dynamics.utils.ovito_utils import \ + get_composition_from_cif_file logger = logging.getLogger(__name__) @@ -51,6 +57,11 @@ def main(args: Optional[Any] = None): required=True, help="config file with sampling parameters in yaml format.", ) + parser.add_argument( + "--path_to_constraint_cif_file", required=False, + help="path to a cif file with constrained positions." + ) + parser.add_argument( "--checkpoint", required=True, help="path to checkpoint model to be loaded." ) @@ -101,12 +112,35 @@ def main(args: Optional[Any] = None): elements = hyper_params["elements"] oracle_parameters = create_energy_oracle_parameters(hyper_params["oracle"], elements) - create_samples_and_write_to_disk( + axl_network = get_axl_network(args.checkpoint) + + logger.info("Instantiate generator...") + raw_generator = instantiate_generator( + sampling_parameters=sampling_parameters, noise_parameters=noise_parameters, + axl_network=axl_network, + ) + + if 'constrained_sampling' in hyper_params: + constrained_atom_indices = torch.Tensor(hyper_params['constrained_atom_indices']) + cif_file_path = Path(args.path_to_constraint_cif_file) + assert cif_file_path.is_file(), "The constraint cif file does not exist." + + reference_composition = get_composition_from_cif_file(cif_file_path=cif_file_path, + elements=elements, + device=device) + + generator = ConstrainedPredictorCorrectorAXLGenerator(raw_generator, + reference_composition, + constrained_atom_indices) + else: + generator = raw_generator + + create_samples_and_write_to_disk( + generator=generator, sampling_parameters=sampling_parameters, oracle_parameters=oracle_parameters, device=device, - checkpoint_path=args.checkpoint, output_path=args.output, ) @@ -152,11 +186,10 @@ def get_axl_network(checkpoint_path: Union[str, Path]) -> ScoreNetwork: def create_samples_and_write_to_disk( - noise_parameters: NoiseParameters, + generator: LangevinGenerator, sampling_parameters: SamplingParameters, oracle_parameters: Union[OracleParameters, None], device: torch.device, - checkpoint_path: Union[str, Path], output_path: Union[str, Path], ): """Create Samples and write to disk. @@ -173,19 +206,10 @@ def create_samples_and_write_to_disk( Returns: None """ - axl_network = get_axl_network(checkpoint_path) - - logger.info("Instantiate generator...") - position_generator = instantiate_generator( - sampling_parameters=sampling_parameters, - noise_parameters=noise_parameters, - axl_network=axl_network, - ) - logger.info("Generating samples...") with torch.no_grad(): samples_batch = create_batch_of_samples( - generator=position_generator, + generator=generator, sampling_parameters=sampling_parameters, device=device, ) @@ -207,7 +231,7 @@ def create_samples_and_write_to_disk( if sampling_parameters.record_samples: logger.info("Writing sampling trajectories to disk...") - position_generator.sample_trajectory_recorder.write_to_pickle( + generator.sample_trajectory_recorder.write_to_pickle( output_directory / "trajectories.pt" ) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/utils/ovito_utils.py b/src/diffusion_for_multi_scale_molecular_dynamics/utils/ovito_utils.py index b58722c6..5d6adb6e 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/utils/ovito_utils.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/utils/ovito_utils.py @@ -10,6 +10,7 @@ import numpy as np import ovito +import torch from ovito.io import import_file from ovito.modifiers import (AffineTransformationModifier, CombineDatasetsModifier, CreateBondsModifier) @@ -26,6 +27,18 @@ _cif_file_name_template = "diffusion_positions_step_{time_index}.cif" +def get_composition_from_cif_file(cif_file_path: Path, elements: list[str], device): + """Get composition from a cif file.""" + structure = Structure.from_file(cif_file_path) + element_types = ElementTypes(elements) + + a = torch.Tensor([element_types.get_element_id(s.name) for s in structure.species]).to(torch.int64).to(device) + x = torch.from_numpy(structure.frac_coords).to(torch.float32).to(device) + lattice = torch.from_numpy(structure.lattice.matrix).to(torch.float32).to(device) + composition = AXL(A=a, X=x, L=lattice) + return composition + + def create_cif_files( elements: list[str], visualization_artifacts_path: Path, From ab87c0e08360f2a437e1d1b6416542acf6210673 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Mon, 2 Dec 2024 19:36:03 -0500 Subject: [PATCH 05/43] a bit more logging. --- .../sample_diffusion.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py b/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py index a296749c..f0a4c3fe 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py @@ -122,9 +122,12 @@ def main(args: Optional[Any] = None): ) if 'constrained_sampling' in hyper_params: + logger.info("Constrained Sampling is activated") constrained_atom_indices = torch.Tensor(hyper_params['constrained_atom_indices']) cif_file_path = Path(args.path_to_constraint_cif_file) assert cif_file_path.is_file(), "The constraint cif file does not exist." + logger.info(f"Constrained cif file is {cif_file_path}") + logger.info(f"Constrained atom indices are {constrained_atom_indices}") reference_composition = get_composition_from_cif_file(cif_file_path=cif_file_path, elements=elements, From 2c2bf8d5125f6cf9893131ad646823641099ebfd Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Mon, 2 Dec 2024 19:38:22 -0500 Subject: [PATCH 06/43] Fix bjork. --- .../sample_diffusion.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py b/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py index f0a4c3fe..e35e1b39 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py @@ -123,7 +123,8 @@ def main(args: Optional[Any] = None): if 'constrained_sampling' in hyper_params: logger.info("Constrained Sampling is activated") - constrained_atom_indices = torch.Tensor(hyper_params['constrained_atom_indices']) + constraint_dict = hyper_params['constrained_sampling'] + constrained_atom_indices = torch.Tensor(constraint_dict['constrained_atom_indices']) cif_file_path = Path(args.path_to_constraint_cif_file) assert cif_file_path.is_file(), "The constraint cif file does not exist." logger.info(f"Constrained cif file is {cif_file_path}") From 56289525f21e031da63f107a786ad7f6c1d46206 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Mon, 2 Dec 2024 19:52:48 -0500 Subject: [PATCH 07/43] Fix bjork. --- .../sample_diffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py b/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py index e35e1b39..2d575907 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py @@ -124,7 +124,7 @@ def main(args: Optional[Any] = None): if 'constrained_sampling' in hyper_params: logger.info("Constrained Sampling is activated") constraint_dict = hyper_params['constrained_sampling'] - constrained_atom_indices = torch.Tensor(constraint_dict['constrained_atom_indices']) + constrained_atom_indices = torch.Tensor(constraint_dict['constrained_atom_indices']).to(torch.int64) cif_file_path = Path(args.path_to_constraint_cif_file) assert cif_file_path.is_file(), "The constraint cif file does not exist." logger.info(f"Constrained cif file is {cif_file_path}") From 0b44e275a0200fff5f9eeb328d52a6f77dfd4594 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Mon, 2 Dec 2024 21:32:19 -0500 Subject: [PATCH 08/43] Use a pickle for constraints. --- .../sample_diffusion.py | 22 ++++++++----------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py b/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py index 2d575907..713d5133 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py @@ -40,8 +40,6 @@ get_git_hash, setup_console_logger) from diffusion_for_multi_scale_molecular_dynamics.utils.main_utils import \ load_and_backup_hyperparameters -from diffusion_for_multi_scale_molecular_dynamics.utils.ovito_utils import \ - get_composition_from_cif_file logger = logging.getLogger(__name__) @@ -58,8 +56,8 @@ def main(args: Optional[Any] = None): help="config file with sampling parameters in yaml format.", ) parser.add_argument( - "--path_to_constraint_cif_file", required=False, - help="path to a cif file with constrained positions." + "--path_to_constraint_data_pickle", required=False, + help="path to a pickle that contains a reference compositions and fixed atom indices." ) parser.add_argument( @@ -121,18 +119,16 @@ def main(args: Optional[Any] = None): axl_network=axl_network, ) - if 'constrained_sampling' in hyper_params: + if args.path_to_constraint_data_pickle: logger.info("Constrained Sampling is activated") - constraint_dict = hyper_params['constrained_sampling'] - constrained_atom_indices = torch.Tensor(constraint_dict['constrained_atom_indices']).to(torch.int64) - cif_file_path = Path(args.path_to_constraint_cif_file) - assert cif_file_path.is_file(), "The constraint cif file does not exist." - logger.info(f"Constrained cif file is {cif_file_path}") + constraint_data_pickle_path = Path(args.path_to_constraint_data_pickle) + assert constraint_data_pickle_path.is_file(), "The constraint data pickle does not exist." + + constraint_data = torch.load(constraint_data_pickle_path) + constrained_atom_indices = constraint_data["constrained_atom_indices"] logger.info(f"Constrained atom indices are {constrained_atom_indices}") - reference_composition = get_composition_from_cif_file(cif_file_path=cif_file_path, - elements=elements, - device=device) + reference_composition = constraint_data["reference_composition"] generator = ConstrainedPredictorCorrectorAXLGenerator(raw_generator, reference_composition, From b7d6b67e09477448743a003800a0baab3471768e Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Tue, 3 Dec 2024 07:40:36 -0500 Subject: [PATCH 09/43] Combine noised and denoised during corrector steps. --- .../generators/constrained_langevin_generator.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py b/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py index 924010e7..0c597c44 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py @@ -132,7 +132,7 @@ def sample( ) -> AXL: """Sample. - This method draws samples, imposing the satisfaction of positional constraints. + This method draws samples, imposing the satisfaction of atomic constraints. Args: number_of_samples : number of samples to draw. @@ -172,9 +172,12 @@ def sample( ) for _ in range(self.generator.number_of_corrector_steps): - composition_i = self.generator.corrector_step( + corrected_composition_i = self.generator.corrector_step( composition_i, i, unit_cell, forces ) + composition_i = self._combine_noised_and_denoised_compositions( + corrected_composition_i, denoised_composition_i + ) composition_ip1 = composition_i From b2d77b4e71004ff154b4678d15e03cb540624e13 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Tue, 3 Dec 2024 08:08:58 -0500 Subject: [PATCH 10/43] Turn off repaint in corrector step. --- .../generators/constrained_langevin_generator.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py b/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py index 0c597c44..127cfd4f 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py @@ -172,12 +172,9 @@ def sample( ) for _ in range(self.generator.number_of_corrector_steps): - corrected_composition_i = self.generator.corrector_step( + composition_i = self.generator.corrector_step( composition_i, i, unit_cell, forces ) - composition_i = self._combine_noised_and_denoised_compositions( - corrected_composition_i, denoised_composition_i - ) composition_ip1 = composition_i From bdc3952ae73680133915b84b56a090a25692aeea Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Mon, 2 Dec 2024 13:53:47 -0500 Subject: [PATCH 11/43] Finish comments with something more definitive. --- .../sample_diffusion.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py b/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py index 26260536..1fcf2e9e 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py @@ -211,6 +211,8 @@ def create_samples_and_write_to_disk( output_directory / "trajectories.pt" ) + logger.info("Done!") + if __name__ == "__main__": main() From 7b5a9f8640b539d8a3af24f0411b18fcdfe62730 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Mon, 2 Dec 2024 16:15:16 -0500 Subject: [PATCH 12/43] Cleaner repaint. --- .../constrained_langevin_generator.py | 190 ++++++++++------- .../test_constrained_langevin_generator.py | 195 +++++++++++++----- 2 files changed, 252 insertions(+), 133 deletions(-) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py b/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py index 99db71ce..06dc9455 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py @@ -1,91 +1,135 @@ -from dataclasses import dataclass - -import numpy as np +import einops import torch from tqdm import tqdm from diffusion_for_multi_scale_molecular_dynamics.generators.langevin_generator import \ LangevinGenerator -from diffusion_for_multi_scale_molecular_dynamics.generators.predictor_corrector_axl_generator import \ - PredictorCorrectorSamplingParameters -from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.score_network import \ - ScoreNetwork from diffusion_for_multi_scale_molecular_dynamics.namespace import AXL -from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ - NoiseParameters +from diffusion_for_multi_scale_molecular_dynamics.noisers.atom_types_noiser import \ + AtomTypesNoiser from diffusion_for_multi_scale_molecular_dynamics.noisers.relative_coordinates_noiser import \ RelativeCoordinatesNoiser +from diffusion_for_multi_scale_molecular_dynamics.utils.d3pm_utils import \ + class_index_to_onehot -@dataclass(kw_only=True) -class ConstrainedLangevinGeneratorParameters(PredictorCorrectorSamplingParameters): - """Hyper-parameters for diffusion sampling with the predictor-corrector algorithm.""" - - algorithm: str = "constrained_langevin" - constrained_relative_coordinates: ( - np.ndarray - ) # the positions that must be satisfied at the end of sampling. - +class ConstrainedPredictorCorrectorAXLGenerator: + """Constrained Predictor Corrector AXL Generator. -class ConstrainedLangevinGenerator(LangevinGenerator): - """Constrained Annealed Langevin Dynamics Generator. - - This generator implements a basic version of the inpainting algorithm presented in the - paper + This class constrains the input PC generator following a basic version of the inpainting algorithm + presented in the paper "RePaint: Inpainting using Denoising Diffusion Probabilistic Models". """ def __init__( self, - noise_parameters: NoiseParameters, - sampling_parameters: ConstrainedLangevinGeneratorParameters, - axl_network: ScoreNetwork, + generator: LangevinGenerator, + reference_composition: AXL, + constrained_atom_indices: torch.Tensor, ): """Init method.""" - super().__init__(noise_parameters, sampling_parameters, axl_network) + self.generator = generator + + self.number_of_atoms = self.generator.number_of_atoms + self.num_classes = self.generator.num_classes - self.constraint_relative_coordinates = torch.from_numpy( - sampling_parameters.constrained_relative_coordinates - ) # TODO constraint the atom type as well + self.reference_composition = reference_composition + self.constraint_indices = constrained_atom_indices assert ( - len(self.constraint_relative_coordinates.shape) == 2 + len(self.reference_composition.X.shape) == 2 ), "The constrained relative coordinates have the wrong shape" - number_of_constraints, spatial_dimensions = ( - self.constraint_relative_coordinates.shape - ) - assert ( - number_of_constraints <= self.number_of_atoms - ), "There are more constrained positions than atoms!" assert ( - spatial_dimensions <= self.spatial_dimension - ), "The spatial dimension of the constrained positions is inconsistent" + len(self.reference_composition.A.shape) == 1 + ), "The constrained atom types have the wrong shape" - # Without loss of generality, we impose that the first positions are constrained. - # This should have no consequence for a permutation equivariant model. - self.constraint_mask = torch.zeros(self.number_of_atoms, dtype=bool) - self.constraint_mask[:number_of_constraints] = True + assert ( + len(constrained_atom_indices.shape) == 1 + ), "The constrained_atom_indices array has the wrong shape" self.relative_coordinates_noiser = RelativeCoordinatesNoiser() + self.atom_type_noiser = AtomTypesNoiser() def _apply_constraint(self, composition: AXL, device: torch.device) -> AXL: """This method applies the coordinate constraint on the input configuration.""" - x = composition.X - x[:, self.constraint_mask] = self.constraint_relative_coordinates.to(device) - updated_axl = AXL( - A=composition.A, - X=x, + constrained_x = composition.X.clone() + constrained_x[:, self.constraint_indices] = self.reference_composition.X[ + self.constraint_indices + ].to(device) + + constrained_a = composition.A.clone() + constrained_a[:, self.constraint_indices] = self.reference_composition.A[ + self.constraint_indices + ].to(device) + + constrained_composition = AXL( + A=constrained_a, + X=constrained_x, L=composition.L, ) - return updated_axl + return constrained_composition + + def _get_noised_known_composition( + self, i: int, number_of_samples: int, device: torch.device + ) -> AXL: + """This method applies the noise to the known composition.""" + # Initialize compositions that satisfies the constraint, but is otherwise random. + # Since the noising process is 'atom-per-atom', the non-constrained position should have no impact. + composition0_known = self.generator.initialize(number_of_samples, device) + composition0_known = self._apply_constraint(composition0_known, device) + + q_bar_matrices_i = einops.repeat( + self.generator.noise.q_bar_matrix[i].to(device), + "n1 n2 -> nsamples natoms n1 n2", + nsamples=number_of_samples, + natoms=self.number_of_atoms, + ) + + sigma_i = self.generator.noise.sigma[i] + coordinates_broadcasting = torch.ones_like(composition0_known.X) + broadcast_sigmas_i = sigma_i * coordinates_broadcasting + + # Noise an example satisfying the constraints from t_0 to t_i + x_i_known = ( + self.relative_coordinates_noiser.get_noisy_relative_coordinates_sample( + composition0_known.X, broadcast_sigmas_i + ) + ) + + one_hot_a_i = class_index_to_onehot( + composition0_known.A, num_classes=self.num_classes + ) + a_i_known = self.atom_type_noiser.get_noisy_atom_types_sample( + one_hot_a_i, q_bar_matrices_i + ) + + noised_composition = AXL(A=a_i_known, X=x_i_known, L=composition0_known.L) + return noised_composition + + def _combine_noised_and_denoised_compositions( + self, noised_composition: AXL, denoised_composition: AXL + ) -> AXL: + + updated_x = denoised_composition.X.clone() + updated_a = denoised_composition.A.clone() + + updated_x[:, self.constraint_indices] = noised_composition.X[ + :, self.constraint_indices + ] + updated_a[:, self.constraint_indices] = noised_composition.A[ + :, self.constraint_indices + ] + + composition_i = AXL(A=updated_a, X=updated_x, L=denoised_composition.L) + return composition_i def sample( self, number_of_samples: int, device: torch.device, unit_cell: torch.Tensor ) -> AXL: """Sample. - This method draws samples, imposing the satisfaction of positional constraints. + This method draws samples, imposing the satisfaction of positional constraints. Args: number_of_samples : number of samples to draw. @@ -98,48 +142,36 @@ def sample( """ assert unit_cell.size() == ( number_of_samples, - self.spatial_dimension, - self.spatial_dimension, + self.generator.spatial_dimension, + self.generator.spatial_dimension, ), ( "Unit cell passed to sample should be of size (number of sample, spatial dimension, spatial dimension" + f"Got {unit_cell.size()}" ) - # Initialize a configuration that satisfy the constraint, but is otherwise random. - # Since the noising process is 'atom-per-atom', the non-constrained position should have no impact. - composition0_known = self.initialize(number_of_samples, device) - # this is an AXL objet - - composition0_known = self._apply_constraint(composition0_known, device) - - composition_ip1 = self.initialize(number_of_samples, device) + composition_ip1 = self.generator.initialize(number_of_samples, device) forces = torch.zeros_like(composition_ip1.X) - coordinates_broadcasting = torch.ones( - number_of_samples, self.number_of_atoms, self.spatial_dimension - ).to(device) - - for i in tqdm(range(self.number_of_discretization_steps - 1, -1, -1)): - sigma_i = self.noise.sigma[i] - broadcast_sigmas_i = sigma_i * coordinates_broadcasting - # Noise an example satisfying the constraints from t_0 to t_i - x_i_known = ( - self.relative_coordinates_noiser.get_noisy_relative_coordinates_sample( - composition0_known.X, broadcast_sigmas_i - ) + for i in tqdm(range(self.generator.number_of_discretization_steps - 1, -1, -1)): + + # Noise from t_0 to t_i + noised_composition_i = self._get_noised_known_composition( + i, number_of_samples, device ) + # Denoise from t_{i+1} to t_i - composition_i = self.predictor_step( + denoised_composition_i = self.generator.predictor_step( composition_ip1, i + 1, unit_cell, forces ) - # Combine the known and unknown - x_i = composition_i.X - x_i[:, self.constraint_mask] = x_i_known[:, self.constraint_mask] - composition_i = AXL(A=composition_i.A, X=x_i, L=composition_i.L) + composition_i = self._combine_noised_and_denoised_compositions( + noised_composition_i, denoised_composition_i + ) - for _ in range(self.number_of_corrector_steps): - composition_i = self.corrector_step(composition_i, i, unit_cell, forces) + for _ in range(self.generator.number_of_corrector_steps): + composition_i = self.generator.corrector_step( + composition_i, i, unit_cell, forces + ) composition_ip1 = composition_i diff --git a/tests/generators/test_constrained_langevin_generator.py b/tests/generators/test_constrained_langevin_generator.py index 59f2bb6d..0c898e5d 100644 --- a/tests/generators/test_constrained_langevin_generator.py +++ b/tests/generators/test_constrained_langevin_generator.py @@ -2,92 +2,179 @@ import pytest import torch -from diffusion_for_multi_scale_molecular_dynamics.generators.constrained_langevin_generator import ( - ConstrainedLangevinGenerator, ConstrainedLangevinGeneratorParameters) +from diffusion_for_multi_scale_molecular_dynamics.generators.constrained_langevin_generator import \ + ConstrainedPredictorCorrectorAXLGenerator from diffusion_for_multi_scale_molecular_dynamics.namespace import AXL +from diffusion_for_multi_scale_molecular_dynamics.utils.basis_transformations import \ + map_relative_coordinates_to_unit_cell from tests.generators.test_langevin_generator import TestLangevinGenerator class TestConstrainedLangevinGenerator(TestLangevinGenerator): @pytest.fixture() - def constrained_relative_coordinates(self, number_of_atoms, spatial_dimension): - number_of_constraints = number_of_atoms // 2 - return torch.rand(number_of_constraints, spatial_dimension).numpy() - - @pytest.fixture() - def sampling_parameters( + def reference_composition( self, number_of_atoms, spatial_dimension, - number_of_samples, - cell_dimensions, - number_of_corrector_steps, - unit_cell_size, - constrained_relative_coordinates, - num_atom_types, + num_atomic_classes, + device, ): - sampling_parameters = ConstrainedLangevinGeneratorParameters( - number_of_corrector_steps=number_of_corrector_steps, - number_of_atoms=number_of_atoms, - number_of_samples=number_of_samples, - cell_dimensions=cell_dimensions, - spatial_dimension=spatial_dimension, - constrained_relative_coordinates=constrained_relative_coordinates, - num_atom_types=num_atom_types, - ) - - return sampling_parameters - - @pytest.fixture() - def pc_generator(self, noise_parameters, sampling_parameters, axl_network): - generator = ConstrainedLangevinGenerator( - noise_parameters=noise_parameters, - sampling_parameters=sampling_parameters, - axl_network=axl_network, + return AXL( + A=torch.randint(0, num_atomic_classes, (number_of_atoms,)).to(device), + X=map_relative_coordinates_to_unit_cell( + torch.rand(number_of_atoms, spatial_dimension) + ).to(device), + L=torch.zeros(spatial_dimension * (spatial_dimension - 1)).to( + device + ), # TODO placeholder ) - return generator - @pytest.fixture() - def axl( + def random_compositions( self, number_of_samples, number_of_atoms, spatial_dimension, - num_atom_types, + num_atomic_classes, device, ): return AXL( A=torch.randint( - 0, num_atom_types + 1, (number_of_samples, number_of_atoms) + 0, + num_atomic_classes, + ( + number_of_samples, + number_of_atoms, + ), ).to(device), - X=torch.rand(number_of_samples, number_of_atoms, spatial_dimension).to( - device - ), - L=torch.rand( - number_of_samples, spatial_dimension * (spatial_dimension - 1) - ).to( + X=map_relative_coordinates_to_unit_cell( + torch.rand(number_of_samples, number_of_atoms, spatial_dimension) + ).to(device), + L=torch.zeros(spatial_dimension * (spatial_dimension - 1)).to( device ), # TODO placeholder ) + @pytest.fixture() + def constrained_atom_indices(self, number_of_atoms, device): + number_of_constraints = number_of_atoms // 2 + return torch.randperm(number_of_atoms)[:number_of_constraints].to(device) + + @pytest.fixture() + def constrained_pc_generator( + self, pc_generator, reference_composition, constrained_atom_indices + ): + constrained_generator = ConstrainedPredictorCorrectorAXLGenerator( + generator=pc_generator, + reference_composition=reference_composition, + constrained_atom_indices=constrained_atom_indices, + ) + + return constrained_generator + + @pytest.fixture() + def constrained_samples( + self, constrained_pc_generator, number_of_samples, device, unit_cell_sample + ): + samples = constrained_pc_generator.sample( + number_of_samples, device, unit_cell_sample + ) + return samples + + def test_constraints( + self, + constrained_samples, + reference_composition, + constrained_atom_indices, + number_of_samples, + ): + reference_x = einops.repeat( + reference_composition.X[constrained_atom_indices], + "... -> n ...", + n=number_of_samples, + ) + reference_a = einops.repeat( + reference_composition.A[constrained_atom_indices], + "... -> n ...", + n=number_of_samples, + ) + + torch.testing.assert_close( + constrained_samples.X[:, constrained_atom_indices], reference_x + ) + torch.testing.assert_close( + constrained_samples.A[:, constrained_atom_indices], reference_a + ) + def test_apply_constraint( - self, pc_generator, axl, constrained_relative_coordinates, device + self, + constrained_pc_generator, + number_of_samples, + random_compositions, + reference_composition, + constrained_atom_indices, + device, ): - batch_size = axl.X.shape[0] - original_x = torch.clone(axl.X) - pc_generator._apply_constraint(axl, device) - number_of_constraints = len(constrained_relative_coordinates) + constrained_compositions = constrained_pc_generator._apply_constraint( + random_compositions, device + ) - constrained_x = einops.repeat( - torch.from_numpy(constrained_relative_coordinates).to(device), - "n d -> b n d", - b=batch_size, + reference_x = einops.repeat( + reference_composition.X[constrained_atom_indices], + "... -> n ...", + n=number_of_samples, + ) + reference_a = einops.repeat( + reference_composition.A[constrained_atom_indices], + "... -> n ...", + n=number_of_samples, ) - torch.testing.assert_close(axl.X[:, :number_of_constraints], constrained_x) torch.testing.assert_close( - axl.X[:, number_of_constraints:], original_x[:, number_of_constraints:] + constrained_compositions.X[:, constrained_atom_indices], reference_x + ) + torch.testing.assert_close( + constrained_compositions.A[:, constrained_atom_indices], reference_a + ) + + def test_combine_noised_and_denoised_compositions( + self, + constrained_pc_generator, + constrained_atom_indices, + number_of_samples, + number_of_atoms, + spatial_dimension, + device, + ) -> AXL: + + noised_mask = torch.zeros(number_of_atoms, dtype=torch.bool).to(device) + noised_mask[constrained_atom_indices] = True + + noised_compositions = AXL( + A=torch.zeros(number_of_samples, number_of_atoms).to(device), + X=torch.zeros(number_of_samples, number_of_atoms, spatial_dimension).to( + device + ), + L=0.0, + ) + + denoised_compositions = AXL( + A=torch.ones(number_of_samples, number_of_atoms).to(device), + X=torch.ones(number_of_samples, number_of_atoms, spatial_dimension).to( + device + ), + L=0.0, + ) + + combined_compositions = ( + constrained_pc_generator._combine_noised_and_denoised_compositions( + noised_compositions, denoised_compositions + ) ) + + assert (combined_compositions.X[:, noised_mask] == 0.0).all() + assert (combined_compositions.X[:, ~noised_mask] == 1.0).all() + assert (combined_compositions.A[:, noised_mask] == 0.0).all() + assert (combined_compositions.A[:, ~noised_mask] == 1.0).all() From 381ecc6153070c82a03f070e72f07107894cab45 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Mon, 2 Dec 2024 16:16:01 -0500 Subject: [PATCH 13/43] Modified code in experiments. Probably broken. --- .../analytic_score/repaint/repaint_with_analytic_score.py | 5 +++-- experiments/sampling_sota_model/repaint_with_sota_score.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/experiments/analysis/analytic_score/repaint/repaint_with_analytic_score.py b/experiments/analysis/analytic_score/repaint/repaint_with_analytic_score.py index 2ea37960..e75550c6 100644 --- a/experiments/analysis/analytic_score/repaint/repaint_with_analytic_score.py +++ b/experiments/analysis/analytic_score/repaint/repaint_with_analytic_score.py @@ -8,7 +8,8 @@ from diffusion_for_multi_scale_molecular_dynamics.analysis import \ PLOT_STYLE_PATH from diffusion_for_multi_scale_molecular_dynamics.generators.constrained_langevin_generator import ( - ConstrainedLangevinGenerator, ConstrainedLangevinGeneratorParameters) + ConstrainedLangevinGeneratorParameters, + ConstrainedPredictorCorrectorAXLGenerator) from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.analytical_score_network import ( AnalyticalScoreNetwork, AnalyticalScoreNetworkParameters) from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ @@ -98,7 +99,7 @@ record_samples=True, ) - position_generator = ConstrainedLangevinGenerator( + position_generator = ConstrainedPredictorCorrectorAXLGenerator( noise_parameters=noise_parameters, sampling_parameters=sampling_parameters, sigma_normalized_score_network=sigma_normalized_score_network, diff --git a/experiments/sampling_sota_model/repaint_with_sota_score.py b/experiments/sampling_sota_model/repaint_with_sota_score.py index 8e8d7dad..0c45547b 100644 --- a/experiments/sampling_sota_model/repaint_with_sota_score.py +++ b/experiments/sampling_sota_model/repaint_with_sota_score.py @@ -11,7 +11,8 @@ from diffusion_for_multi_scale_molecular_dynamics.analysis import ( PLEASANT_FIG_SIZE, PLOT_STYLE_PATH) from diffusion_for_multi_scale_molecular_dynamics.generators.constrained_langevin_generator import ( - ConstrainedLangevinGenerator, ConstrainedLangevinGeneratorParameters) + ConstrainedLangevinGeneratorParameters, + ConstrainedPredictorCorrectorAXLGenerator) from diffusion_for_multi_scale_molecular_dynamics.models.instantiate_diffusion_model import \ load_diffusion_model from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ @@ -84,7 +85,7 @@ record_samples=True, ) - position_generator = ConstrainedLangevinGenerator( + position_generator = ConstrainedPredictorCorrectorAXLGenerator( noise_parameters=noise_parameters, sampling_parameters=sampling_parameters, sigma_normalized_score_network=sigma_normalized_score_network, From b4a5a25f1cd16a8ffc3cbfceac5be767d3864d79 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Mon, 2 Dec 2024 19:18:54 -0500 Subject: [PATCH 14/43] Let's try to sample with constraints! --- .../constrained_langevin_generator.py | 3 ++ .../sample_diffusion.py | 54 +++++++++++++------ .../utils/ovito_utils.py | 13 +++++ 3 files changed, 55 insertions(+), 15 deletions(-) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py b/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py index 06dc9455..924010e7 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py @@ -30,6 +30,9 @@ def __init__( """Init method.""" self.generator = generator + if hasattr(self.generator, "sample_trajectory_recorder"): + self.sample_trajectory_recorder = self.generator.sample_trajectory_recorder + self.number_of_atoms = self.generator.number_of_atoms self.num_classes = self.generator.num_classes diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py b/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py index 1fcf2e9e..a296749c 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py @@ -16,8 +16,12 @@ ElementTypes from diffusion_for_multi_scale_molecular_dynamics.generators.axl_generator import \ SamplingParameters +from diffusion_for_multi_scale_molecular_dynamics.generators.constrained_langevin_generator import \ + ConstrainedPredictorCorrectorAXLGenerator from diffusion_for_multi_scale_molecular_dynamics.generators.instantiate_generator import \ instantiate_generator +from diffusion_for_multi_scale_molecular_dynamics.generators.langevin_generator import \ + LangevinGenerator from diffusion_for_multi_scale_molecular_dynamics.generators.load_sampling_parameters import \ load_sampling_parameters from diffusion_for_multi_scale_molecular_dynamics.models.axl_diffusion_lightning_model import \ @@ -36,6 +40,8 @@ get_git_hash, setup_console_logger) from diffusion_for_multi_scale_molecular_dynamics.utils.main_utils import \ load_and_backup_hyperparameters +from diffusion_for_multi_scale_molecular_dynamics.utils.ovito_utils import \ + get_composition_from_cif_file logger = logging.getLogger(__name__) @@ -51,6 +57,11 @@ def main(args: Optional[Any] = None): required=True, help="config file with sampling parameters in yaml format.", ) + parser.add_argument( + "--path_to_constraint_cif_file", required=False, + help="path to a cif file with constrained positions." + ) + parser.add_argument( "--checkpoint", required=True, help="path to checkpoint model to be loaded." ) @@ -101,12 +112,35 @@ def main(args: Optional[Any] = None): elements = hyper_params["elements"] oracle_parameters = create_energy_oracle_parameters(hyper_params["oracle"], elements) - create_samples_and_write_to_disk( + axl_network = get_axl_network(args.checkpoint) + + logger.info("Instantiate generator...") + raw_generator = instantiate_generator( + sampling_parameters=sampling_parameters, noise_parameters=noise_parameters, + axl_network=axl_network, + ) + + if 'constrained_sampling' in hyper_params: + constrained_atom_indices = torch.Tensor(hyper_params['constrained_atom_indices']) + cif_file_path = Path(args.path_to_constraint_cif_file) + assert cif_file_path.is_file(), "The constraint cif file does not exist." + + reference_composition = get_composition_from_cif_file(cif_file_path=cif_file_path, + elements=elements, + device=device) + + generator = ConstrainedPredictorCorrectorAXLGenerator(raw_generator, + reference_composition, + constrained_atom_indices) + else: + generator = raw_generator + + create_samples_and_write_to_disk( + generator=generator, sampling_parameters=sampling_parameters, oracle_parameters=oracle_parameters, device=device, - checkpoint_path=args.checkpoint, output_path=args.output, ) @@ -152,11 +186,10 @@ def get_axl_network(checkpoint_path: Union[str, Path]) -> ScoreNetwork: def create_samples_and_write_to_disk( - noise_parameters: NoiseParameters, + generator: LangevinGenerator, sampling_parameters: SamplingParameters, oracle_parameters: Union[OracleParameters, None], device: torch.device, - checkpoint_path: Union[str, Path], output_path: Union[str, Path], ): """Create Samples and write to disk. @@ -173,19 +206,10 @@ def create_samples_and_write_to_disk( Returns: None """ - axl_network = get_axl_network(checkpoint_path) - - logger.info("Instantiate generator...") - position_generator = instantiate_generator( - sampling_parameters=sampling_parameters, - noise_parameters=noise_parameters, - axl_network=axl_network, - ) - logger.info("Generating samples...") with torch.no_grad(): samples_batch = create_batch_of_samples( - generator=position_generator, + generator=generator, sampling_parameters=sampling_parameters, device=device, ) @@ -207,7 +231,7 @@ def create_samples_and_write_to_disk( if sampling_parameters.record_samples: logger.info("Writing sampling trajectories to disk...") - position_generator.sample_trajectory_recorder.write_to_pickle( + generator.sample_trajectory_recorder.write_to_pickle( output_directory / "trajectories.pt" ) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/utils/ovito_utils.py b/src/diffusion_for_multi_scale_molecular_dynamics/utils/ovito_utils.py index b58722c6..5d6adb6e 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/utils/ovito_utils.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/utils/ovito_utils.py @@ -10,6 +10,7 @@ import numpy as np import ovito +import torch from ovito.io import import_file from ovito.modifiers import (AffineTransformationModifier, CombineDatasetsModifier, CreateBondsModifier) @@ -26,6 +27,18 @@ _cif_file_name_template = "diffusion_positions_step_{time_index}.cif" +def get_composition_from_cif_file(cif_file_path: Path, elements: list[str], device): + """Get composition from a cif file.""" + structure = Structure.from_file(cif_file_path) + element_types = ElementTypes(elements) + + a = torch.Tensor([element_types.get_element_id(s.name) for s in structure.species]).to(torch.int64).to(device) + x = torch.from_numpy(structure.frac_coords).to(torch.float32).to(device) + lattice = torch.from_numpy(structure.lattice.matrix).to(torch.float32).to(device) + composition = AXL(A=a, X=x, L=lattice) + return composition + + def create_cif_files( elements: list[str], visualization_artifacts_path: Path, From 642c8dcf34aef1e390d1c1ccc94e13191940b168 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Mon, 2 Dec 2024 19:36:03 -0500 Subject: [PATCH 15/43] a bit more logging. --- .../sample_diffusion.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py b/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py index a296749c..f0a4c3fe 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py @@ -122,9 +122,12 @@ def main(args: Optional[Any] = None): ) if 'constrained_sampling' in hyper_params: + logger.info("Constrained Sampling is activated") constrained_atom_indices = torch.Tensor(hyper_params['constrained_atom_indices']) cif_file_path = Path(args.path_to_constraint_cif_file) assert cif_file_path.is_file(), "The constraint cif file does not exist." + logger.info(f"Constrained cif file is {cif_file_path}") + logger.info(f"Constrained atom indices are {constrained_atom_indices}") reference_composition = get_composition_from_cif_file(cif_file_path=cif_file_path, elements=elements, From 9d747e0f1526d8a435edd606d006c9c605313973 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Mon, 2 Dec 2024 19:38:22 -0500 Subject: [PATCH 16/43] Fix bjork. --- .../sample_diffusion.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py b/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py index f0a4c3fe..e35e1b39 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py @@ -123,7 +123,8 @@ def main(args: Optional[Any] = None): if 'constrained_sampling' in hyper_params: logger.info("Constrained Sampling is activated") - constrained_atom_indices = torch.Tensor(hyper_params['constrained_atom_indices']) + constraint_dict = hyper_params['constrained_sampling'] + constrained_atom_indices = torch.Tensor(constraint_dict['constrained_atom_indices']) cif_file_path = Path(args.path_to_constraint_cif_file) assert cif_file_path.is_file(), "The constraint cif file does not exist." logger.info(f"Constrained cif file is {cif_file_path}") From 06c3efaccc255568ef964a18151b5da237bd02b2 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Mon, 2 Dec 2024 19:52:48 -0500 Subject: [PATCH 17/43] Fix bjork. --- .../sample_diffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py b/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py index e35e1b39..2d575907 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py @@ -124,7 +124,7 @@ def main(args: Optional[Any] = None): if 'constrained_sampling' in hyper_params: logger.info("Constrained Sampling is activated") constraint_dict = hyper_params['constrained_sampling'] - constrained_atom_indices = torch.Tensor(constraint_dict['constrained_atom_indices']) + constrained_atom_indices = torch.Tensor(constraint_dict['constrained_atom_indices']).to(torch.int64) cif_file_path = Path(args.path_to_constraint_cif_file) assert cif_file_path.is_file(), "The constraint cif file does not exist." logger.info(f"Constrained cif file is {cif_file_path}") From 8184606377028c6d71326f38c2d5ede400775585 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Mon, 2 Dec 2024 21:32:19 -0500 Subject: [PATCH 18/43] Use a pickle for constraints. --- .../sample_diffusion.py | 22 ++++++++----------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py b/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py index 2d575907..713d5133 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py @@ -40,8 +40,6 @@ get_git_hash, setup_console_logger) from diffusion_for_multi_scale_molecular_dynamics.utils.main_utils import \ load_and_backup_hyperparameters -from diffusion_for_multi_scale_molecular_dynamics.utils.ovito_utils import \ - get_composition_from_cif_file logger = logging.getLogger(__name__) @@ -58,8 +56,8 @@ def main(args: Optional[Any] = None): help="config file with sampling parameters in yaml format.", ) parser.add_argument( - "--path_to_constraint_cif_file", required=False, - help="path to a cif file with constrained positions." + "--path_to_constraint_data_pickle", required=False, + help="path to a pickle that contains a reference compositions and fixed atom indices." ) parser.add_argument( @@ -121,18 +119,16 @@ def main(args: Optional[Any] = None): axl_network=axl_network, ) - if 'constrained_sampling' in hyper_params: + if args.path_to_constraint_data_pickle: logger.info("Constrained Sampling is activated") - constraint_dict = hyper_params['constrained_sampling'] - constrained_atom_indices = torch.Tensor(constraint_dict['constrained_atom_indices']).to(torch.int64) - cif_file_path = Path(args.path_to_constraint_cif_file) - assert cif_file_path.is_file(), "The constraint cif file does not exist." - logger.info(f"Constrained cif file is {cif_file_path}") + constraint_data_pickle_path = Path(args.path_to_constraint_data_pickle) + assert constraint_data_pickle_path.is_file(), "The constraint data pickle does not exist." + + constraint_data = torch.load(constraint_data_pickle_path) + constrained_atom_indices = constraint_data["constrained_atom_indices"] logger.info(f"Constrained atom indices are {constrained_atom_indices}") - reference_composition = get_composition_from_cif_file(cif_file_path=cif_file_path, - elements=elements, - device=device) + reference_composition = constraint_data["reference_composition"] generator = ConstrainedPredictorCorrectorAXLGenerator(raw_generator, reference_composition, From 74a3990b099efe6d520acd494337d94fc69d3e5c Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Tue, 3 Dec 2024 07:40:36 -0500 Subject: [PATCH 19/43] Combine noised and denoised during corrector steps. --- .../generators/constrained_langevin_generator.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py b/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py index 924010e7..0c597c44 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py @@ -132,7 +132,7 @@ def sample( ) -> AXL: """Sample. - This method draws samples, imposing the satisfaction of positional constraints. + This method draws samples, imposing the satisfaction of atomic constraints. Args: number_of_samples : number of samples to draw. @@ -172,9 +172,12 @@ def sample( ) for _ in range(self.generator.number_of_corrector_steps): - composition_i = self.generator.corrector_step( + corrected_composition_i = self.generator.corrector_step( composition_i, i, unit_cell, forces ) + composition_i = self._combine_noised_and_denoised_compositions( + corrected_composition_i, denoised_composition_i + ) composition_ip1 = composition_i From 061ea289a5c94b25ac3565c6b236bf7128dad96b Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Tue, 3 Dec 2024 08:08:58 -0500 Subject: [PATCH 20/43] Turn off repaint in corrector step. --- .../generators/constrained_langevin_generator.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py b/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py index 0c597c44..127cfd4f 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py @@ -172,12 +172,9 @@ def sample( ) for _ in range(self.generator.number_of_corrector_steps): - corrected_composition_i = self.generator.corrector_step( + composition_i = self.generator.corrector_step( composition_i, i, unit_cell, forces ) - composition_i = self._combine_noised_and_denoised_compositions( - corrected_composition_i, denoised_composition_i - ) composition_ip1 = composition_i From a7fcf7119057b126be24bb7b30f0b5f585270b9f Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Mon, 2 Dec 2024 16:15:16 -0500 Subject: [PATCH 21/43] Cleaner repaint. --- .../generators/constrained_langevin_generator.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py b/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py index 127cfd4f..06dc9455 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py @@ -30,9 +30,6 @@ def __init__( """Init method.""" self.generator = generator - if hasattr(self.generator, "sample_trajectory_recorder"): - self.sample_trajectory_recorder = self.generator.sample_trajectory_recorder - self.number_of_atoms = self.generator.number_of_atoms self.num_classes = self.generator.num_classes @@ -132,7 +129,7 @@ def sample( ) -> AXL: """Sample. - This method draws samples, imposing the satisfaction of atomic constraints. + This method draws samples, imposing the satisfaction of positional constraints. Args: number_of_samples : number of samples to draw. From 4598d3f7ae1b96cd1a0f34cc6c001e698148496a Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Mon, 2 Dec 2024 19:18:54 -0500 Subject: [PATCH 22/43] Let's try to sample with constraints! --- .../constrained_langevin_generator.py | 3 +++ .../sample_diffusion.py | 22 +++++++++---------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py b/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py index 06dc9455..924010e7 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py @@ -30,6 +30,9 @@ def __init__( """Init method.""" self.generator = generator + if hasattr(self.generator, "sample_trajectory_recorder"): + self.sample_trajectory_recorder = self.generator.sample_trajectory_recorder + self.number_of_atoms = self.generator.number_of_atoms self.num_classes = self.generator.num_classes diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py b/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py index 713d5133..a296749c 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py @@ -40,6 +40,8 @@ get_git_hash, setup_console_logger) from diffusion_for_multi_scale_molecular_dynamics.utils.main_utils import \ load_and_backup_hyperparameters +from diffusion_for_multi_scale_molecular_dynamics.utils.ovito_utils import \ + get_composition_from_cif_file logger = logging.getLogger(__name__) @@ -56,8 +58,8 @@ def main(args: Optional[Any] = None): help="config file with sampling parameters in yaml format.", ) parser.add_argument( - "--path_to_constraint_data_pickle", required=False, - help="path to a pickle that contains a reference compositions and fixed atom indices." + "--path_to_constraint_cif_file", required=False, + help="path to a cif file with constrained positions." ) parser.add_argument( @@ -119,16 +121,14 @@ def main(args: Optional[Any] = None): axl_network=axl_network, ) - if args.path_to_constraint_data_pickle: - logger.info("Constrained Sampling is activated") - constraint_data_pickle_path = Path(args.path_to_constraint_data_pickle) - assert constraint_data_pickle_path.is_file(), "The constraint data pickle does not exist." + if 'constrained_sampling' in hyper_params: + constrained_atom_indices = torch.Tensor(hyper_params['constrained_atom_indices']) + cif_file_path = Path(args.path_to_constraint_cif_file) + assert cif_file_path.is_file(), "The constraint cif file does not exist." - constraint_data = torch.load(constraint_data_pickle_path) - constrained_atom_indices = constraint_data["constrained_atom_indices"] - logger.info(f"Constrained atom indices are {constrained_atom_indices}") - - reference_composition = constraint_data["reference_composition"] + reference_composition = get_composition_from_cif_file(cif_file_path=cif_file_path, + elements=elements, + device=device) generator = ConstrainedPredictorCorrectorAXLGenerator(raw_generator, reference_composition, From f076a4b0d84acf2a0a63aa7868f4ea9c258cec30 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Tue, 3 Dec 2024 11:53:47 -0500 Subject: [PATCH 23/43] Fix rebase bjork. --- .../sample_diffusion.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py b/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py index a296749c..713d5133 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py @@ -40,8 +40,6 @@ get_git_hash, setup_console_logger) from diffusion_for_multi_scale_molecular_dynamics.utils.main_utils import \ load_and_backup_hyperparameters -from diffusion_for_multi_scale_molecular_dynamics.utils.ovito_utils import \ - get_composition_from_cif_file logger = logging.getLogger(__name__) @@ -58,8 +56,8 @@ def main(args: Optional[Any] = None): help="config file with sampling parameters in yaml format.", ) parser.add_argument( - "--path_to_constraint_cif_file", required=False, - help="path to a cif file with constrained positions." + "--path_to_constraint_data_pickle", required=False, + help="path to a pickle that contains a reference compositions and fixed atom indices." ) parser.add_argument( @@ -121,14 +119,16 @@ def main(args: Optional[Any] = None): axl_network=axl_network, ) - if 'constrained_sampling' in hyper_params: - constrained_atom_indices = torch.Tensor(hyper_params['constrained_atom_indices']) - cif_file_path = Path(args.path_to_constraint_cif_file) - assert cif_file_path.is_file(), "The constraint cif file does not exist." + if args.path_to_constraint_data_pickle: + logger.info("Constrained Sampling is activated") + constraint_data_pickle_path = Path(args.path_to_constraint_data_pickle) + assert constraint_data_pickle_path.is_file(), "The constraint data pickle does not exist." - reference_composition = get_composition_from_cif_file(cif_file_path=cif_file_path, - elements=elements, - device=device) + constraint_data = torch.load(constraint_data_pickle_path) + constrained_atom_indices = constraint_data["constrained_atom_indices"] + logger.info(f"Constrained atom indices are {constrained_atom_indices}") + + reference_composition = constraint_data["reference_composition"] generator = ConstrainedPredictorCorrectorAXLGenerator(raw_generator, reference_composition, From 6b088c063286ea47f946561ba5a5d55971f4b669 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Tue, 3 Dec 2024 11:58:44 -0500 Subject: [PATCH 24/43] Remove needless function. --- .../utils/ovito_utils.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/utils/ovito_utils.py b/src/diffusion_for_multi_scale_molecular_dynamics/utils/ovito_utils.py index 5d6adb6e..b58722c6 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/utils/ovito_utils.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/utils/ovito_utils.py @@ -10,7 +10,6 @@ import numpy as np import ovito -import torch from ovito.io import import_file from ovito.modifiers import (AffineTransformationModifier, CombineDatasetsModifier, CreateBondsModifier) @@ -27,18 +26,6 @@ _cif_file_name_template = "diffusion_positions_step_{time_index}.cif" -def get_composition_from_cif_file(cif_file_path: Path, elements: list[str], device): - """Get composition from a cif file.""" - structure = Structure.from_file(cif_file_path) - element_types = ElementTypes(elements) - - a = torch.Tensor([element_types.get_element_id(s.name) for s in structure.species]).to(torch.int64).to(device) - x = torch.from_numpy(structure.frac_coords).to(torch.float32).to(device) - lattice = torch.from_numpy(structure.lattice.matrix).to(torch.float32).to(device) - composition = AXL(A=a, X=x, L=lattice) - return composition - - def create_cif_files( elements: list[str], visualization_artifacts_path: Path, From f15578b58700d232ea436d1204a62527e0134e4f Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Mon, 2 Dec 2024 13:53:47 -0500 Subject: [PATCH 25/43] Finish comments with something more definitive. --- .../sample_diffusion.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py b/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py index 26260536..1fcf2e9e 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py @@ -211,6 +211,8 @@ def create_samples_and_write_to_disk( output_directory / "trajectories.pt" ) + logger.info("Done!") + if __name__ == "__main__": main() From 079e1088284d4283565dd6fa6a426ed50f90c2ec Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Mon, 2 Dec 2024 16:15:16 -0500 Subject: [PATCH 26/43] Cleaner repaint. --- .../constrained_langevin_generator.py | 190 ++++++++++------- .../test_constrained_langevin_generator.py | 195 +++++++++++++----- 2 files changed, 252 insertions(+), 133 deletions(-) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py b/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py index 99db71ce..06dc9455 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py @@ -1,91 +1,135 @@ -from dataclasses import dataclass - -import numpy as np +import einops import torch from tqdm import tqdm from diffusion_for_multi_scale_molecular_dynamics.generators.langevin_generator import \ LangevinGenerator -from diffusion_for_multi_scale_molecular_dynamics.generators.predictor_corrector_axl_generator import \ - PredictorCorrectorSamplingParameters -from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.score_network import \ - ScoreNetwork from diffusion_for_multi_scale_molecular_dynamics.namespace import AXL -from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ - NoiseParameters +from diffusion_for_multi_scale_molecular_dynamics.noisers.atom_types_noiser import \ + AtomTypesNoiser from diffusion_for_multi_scale_molecular_dynamics.noisers.relative_coordinates_noiser import \ RelativeCoordinatesNoiser +from diffusion_for_multi_scale_molecular_dynamics.utils.d3pm_utils import \ + class_index_to_onehot -@dataclass(kw_only=True) -class ConstrainedLangevinGeneratorParameters(PredictorCorrectorSamplingParameters): - """Hyper-parameters for diffusion sampling with the predictor-corrector algorithm.""" - - algorithm: str = "constrained_langevin" - constrained_relative_coordinates: ( - np.ndarray - ) # the positions that must be satisfied at the end of sampling. - +class ConstrainedPredictorCorrectorAXLGenerator: + """Constrained Predictor Corrector AXL Generator. -class ConstrainedLangevinGenerator(LangevinGenerator): - """Constrained Annealed Langevin Dynamics Generator. - - This generator implements a basic version of the inpainting algorithm presented in the - paper + This class constrains the input PC generator following a basic version of the inpainting algorithm + presented in the paper "RePaint: Inpainting using Denoising Diffusion Probabilistic Models". """ def __init__( self, - noise_parameters: NoiseParameters, - sampling_parameters: ConstrainedLangevinGeneratorParameters, - axl_network: ScoreNetwork, + generator: LangevinGenerator, + reference_composition: AXL, + constrained_atom_indices: torch.Tensor, ): """Init method.""" - super().__init__(noise_parameters, sampling_parameters, axl_network) + self.generator = generator + + self.number_of_atoms = self.generator.number_of_atoms + self.num_classes = self.generator.num_classes - self.constraint_relative_coordinates = torch.from_numpy( - sampling_parameters.constrained_relative_coordinates - ) # TODO constraint the atom type as well + self.reference_composition = reference_composition + self.constraint_indices = constrained_atom_indices assert ( - len(self.constraint_relative_coordinates.shape) == 2 + len(self.reference_composition.X.shape) == 2 ), "The constrained relative coordinates have the wrong shape" - number_of_constraints, spatial_dimensions = ( - self.constraint_relative_coordinates.shape - ) - assert ( - number_of_constraints <= self.number_of_atoms - ), "There are more constrained positions than atoms!" assert ( - spatial_dimensions <= self.spatial_dimension - ), "The spatial dimension of the constrained positions is inconsistent" + len(self.reference_composition.A.shape) == 1 + ), "The constrained atom types have the wrong shape" - # Without loss of generality, we impose that the first positions are constrained. - # This should have no consequence for a permutation equivariant model. - self.constraint_mask = torch.zeros(self.number_of_atoms, dtype=bool) - self.constraint_mask[:number_of_constraints] = True + assert ( + len(constrained_atom_indices.shape) == 1 + ), "The constrained_atom_indices array has the wrong shape" self.relative_coordinates_noiser = RelativeCoordinatesNoiser() + self.atom_type_noiser = AtomTypesNoiser() def _apply_constraint(self, composition: AXL, device: torch.device) -> AXL: """This method applies the coordinate constraint on the input configuration.""" - x = composition.X - x[:, self.constraint_mask] = self.constraint_relative_coordinates.to(device) - updated_axl = AXL( - A=composition.A, - X=x, + constrained_x = composition.X.clone() + constrained_x[:, self.constraint_indices] = self.reference_composition.X[ + self.constraint_indices + ].to(device) + + constrained_a = composition.A.clone() + constrained_a[:, self.constraint_indices] = self.reference_composition.A[ + self.constraint_indices + ].to(device) + + constrained_composition = AXL( + A=constrained_a, + X=constrained_x, L=composition.L, ) - return updated_axl + return constrained_composition + + def _get_noised_known_composition( + self, i: int, number_of_samples: int, device: torch.device + ) -> AXL: + """This method applies the noise to the known composition.""" + # Initialize compositions that satisfies the constraint, but is otherwise random. + # Since the noising process is 'atom-per-atom', the non-constrained position should have no impact. + composition0_known = self.generator.initialize(number_of_samples, device) + composition0_known = self._apply_constraint(composition0_known, device) + + q_bar_matrices_i = einops.repeat( + self.generator.noise.q_bar_matrix[i].to(device), + "n1 n2 -> nsamples natoms n1 n2", + nsamples=number_of_samples, + natoms=self.number_of_atoms, + ) + + sigma_i = self.generator.noise.sigma[i] + coordinates_broadcasting = torch.ones_like(composition0_known.X) + broadcast_sigmas_i = sigma_i * coordinates_broadcasting + + # Noise an example satisfying the constraints from t_0 to t_i + x_i_known = ( + self.relative_coordinates_noiser.get_noisy_relative_coordinates_sample( + composition0_known.X, broadcast_sigmas_i + ) + ) + + one_hot_a_i = class_index_to_onehot( + composition0_known.A, num_classes=self.num_classes + ) + a_i_known = self.atom_type_noiser.get_noisy_atom_types_sample( + one_hot_a_i, q_bar_matrices_i + ) + + noised_composition = AXL(A=a_i_known, X=x_i_known, L=composition0_known.L) + return noised_composition + + def _combine_noised_and_denoised_compositions( + self, noised_composition: AXL, denoised_composition: AXL + ) -> AXL: + + updated_x = denoised_composition.X.clone() + updated_a = denoised_composition.A.clone() + + updated_x[:, self.constraint_indices] = noised_composition.X[ + :, self.constraint_indices + ] + updated_a[:, self.constraint_indices] = noised_composition.A[ + :, self.constraint_indices + ] + + composition_i = AXL(A=updated_a, X=updated_x, L=denoised_composition.L) + return composition_i def sample( self, number_of_samples: int, device: torch.device, unit_cell: torch.Tensor ) -> AXL: """Sample. - This method draws samples, imposing the satisfaction of positional constraints. + This method draws samples, imposing the satisfaction of positional constraints. Args: number_of_samples : number of samples to draw. @@ -98,48 +142,36 @@ def sample( """ assert unit_cell.size() == ( number_of_samples, - self.spatial_dimension, - self.spatial_dimension, + self.generator.spatial_dimension, + self.generator.spatial_dimension, ), ( "Unit cell passed to sample should be of size (number of sample, spatial dimension, spatial dimension" + f"Got {unit_cell.size()}" ) - # Initialize a configuration that satisfy the constraint, but is otherwise random. - # Since the noising process is 'atom-per-atom', the non-constrained position should have no impact. - composition0_known = self.initialize(number_of_samples, device) - # this is an AXL objet - - composition0_known = self._apply_constraint(composition0_known, device) - - composition_ip1 = self.initialize(number_of_samples, device) + composition_ip1 = self.generator.initialize(number_of_samples, device) forces = torch.zeros_like(composition_ip1.X) - coordinates_broadcasting = torch.ones( - number_of_samples, self.number_of_atoms, self.spatial_dimension - ).to(device) - - for i in tqdm(range(self.number_of_discretization_steps - 1, -1, -1)): - sigma_i = self.noise.sigma[i] - broadcast_sigmas_i = sigma_i * coordinates_broadcasting - # Noise an example satisfying the constraints from t_0 to t_i - x_i_known = ( - self.relative_coordinates_noiser.get_noisy_relative_coordinates_sample( - composition0_known.X, broadcast_sigmas_i - ) + for i in tqdm(range(self.generator.number_of_discretization_steps - 1, -1, -1)): + + # Noise from t_0 to t_i + noised_composition_i = self._get_noised_known_composition( + i, number_of_samples, device ) + # Denoise from t_{i+1} to t_i - composition_i = self.predictor_step( + denoised_composition_i = self.generator.predictor_step( composition_ip1, i + 1, unit_cell, forces ) - # Combine the known and unknown - x_i = composition_i.X - x_i[:, self.constraint_mask] = x_i_known[:, self.constraint_mask] - composition_i = AXL(A=composition_i.A, X=x_i, L=composition_i.L) + composition_i = self._combine_noised_and_denoised_compositions( + noised_composition_i, denoised_composition_i + ) - for _ in range(self.number_of_corrector_steps): - composition_i = self.corrector_step(composition_i, i, unit_cell, forces) + for _ in range(self.generator.number_of_corrector_steps): + composition_i = self.generator.corrector_step( + composition_i, i, unit_cell, forces + ) composition_ip1 = composition_i diff --git a/tests/generators/test_constrained_langevin_generator.py b/tests/generators/test_constrained_langevin_generator.py index 59f2bb6d..0c898e5d 100644 --- a/tests/generators/test_constrained_langevin_generator.py +++ b/tests/generators/test_constrained_langevin_generator.py @@ -2,92 +2,179 @@ import pytest import torch -from diffusion_for_multi_scale_molecular_dynamics.generators.constrained_langevin_generator import ( - ConstrainedLangevinGenerator, ConstrainedLangevinGeneratorParameters) +from diffusion_for_multi_scale_molecular_dynamics.generators.constrained_langevin_generator import \ + ConstrainedPredictorCorrectorAXLGenerator from diffusion_for_multi_scale_molecular_dynamics.namespace import AXL +from diffusion_for_multi_scale_molecular_dynamics.utils.basis_transformations import \ + map_relative_coordinates_to_unit_cell from tests.generators.test_langevin_generator import TestLangevinGenerator class TestConstrainedLangevinGenerator(TestLangevinGenerator): @pytest.fixture() - def constrained_relative_coordinates(self, number_of_atoms, spatial_dimension): - number_of_constraints = number_of_atoms // 2 - return torch.rand(number_of_constraints, spatial_dimension).numpy() - - @pytest.fixture() - def sampling_parameters( + def reference_composition( self, number_of_atoms, spatial_dimension, - number_of_samples, - cell_dimensions, - number_of_corrector_steps, - unit_cell_size, - constrained_relative_coordinates, - num_atom_types, + num_atomic_classes, + device, ): - sampling_parameters = ConstrainedLangevinGeneratorParameters( - number_of_corrector_steps=number_of_corrector_steps, - number_of_atoms=number_of_atoms, - number_of_samples=number_of_samples, - cell_dimensions=cell_dimensions, - spatial_dimension=spatial_dimension, - constrained_relative_coordinates=constrained_relative_coordinates, - num_atom_types=num_atom_types, - ) - - return sampling_parameters - - @pytest.fixture() - def pc_generator(self, noise_parameters, sampling_parameters, axl_network): - generator = ConstrainedLangevinGenerator( - noise_parameters=noise_parameters, - sampling_parameters=sampling_parameters, - axl_network=axl_network, + return AXL( + A=torch.randint(0, num_atomic_classes, (number_of_atoms,)).to(device), + X=map_relative_coordinates_to_unit_cell( + torch.rand(number_of_atoms, spatial_dimension) + ).to(device), + L=torch.zeros(spatial_dimension * (spatial_dimension - 1)).to( + device + ), # TODO placeholder ) - return generator - @pytest.fixture() - def axl( + def random_compositions( self, number_of_samples, number_of_atoms, spatial_dimension, - num_atom_types, + num_atomic_classes, device, ): return AXL( A=torch.randint( - 0, num_atom_types + 1, (number_of_samples, number_of_atoms) + 0, + num_atomic_classes, + ( + number_of_samples, + number_of_atoms, + ), ).to(device), - X=torch.rand(number_of_samples, number_of_atoms, spatial_dimension).to( - device - ), - L=torch.rand( - number_of_samples, spatial_dimension * (spatial_dimension - 1) - ).to( + X=map_relative_coordinates_to_unit_cell( + torch.rand(number_of_samples, number_of_atoms, spatial_dimension) + ).to(device), + L=torch.zeros(spatial_dimension * (spatial_dimension - 1)).to( device ), # TODO placeholder ) + @pytest.fixture() + def constrained_atom_indices(self, number_of_atoms, device): + number_of_constraints = number_of_atoms // 2 + return torch.randperm(number_of_atoms)[:number_of_constraints].to(device) + + @pytest.fixture() + def constrained_pc_generator( + self, pc_generator, reference_composition, constrained_atom_indices + ): + constrained_generator = ConstrainedPredictorCorrectorAXLGenerator( + generator=pc_generator, + reference_composition=reference_composition, + constrained_atom_indices=constrained_atom_indices, + ) + + return constrained_generator + + @pytest.fixture() + def constrained_samples( + self, constrained_pc_generator, number_of_samples, device, unit_cell_sample + ): + samples = constrained_pc_generator.sample( + number_of_samples, device, unit_cell_sample + ) + return samples + + def test_constraints( + self, + constrained_samples, + reference_composition, + constrained_atom_indices, + number_of_samples, + ): + reference_x = einops.repeat( + reference_composition.X[constrained_atom_indices], + "... -> n ...", + n=number_of_samples, + ) + reference_a = einops.repeat( + reference_composition.A[constrained_atom_indices], + "... -> n ...", + n=number_of_samples, + ) + + torch.testing.assert_close( + constrained_samples.X[:, constrained_atom_indices], reference_x + ) + torch.testing.assert_close( + constrained_samples.A[:, constrained_atom_indices], reference_a + ) + def test_apply_constraint( - self, pc_generator, axl, constrained_relative_coordinates, device + self, + constrained_pc_generator, + number_of_samples, + random_compositions, + reference_composition, + constrained_atom_indices, + device, ): - batch_size = axl.X.shape[0] - original_x = torch.clone(axl.X) - pc_generator._apply_constraint(axl, device) - number_of_constraints = len(constrained_relative_coordinates) + constrained_compositions = constrained_pc_generator._apply_constraint( + random_compositions, device + ) - constrained_x = einops.repeat( - torch.from_numpy(constrained_relative_coordinates).to(device), - "n d -> b n d", - b=batch_size, + reference_x = einops.repeat( + reference_composition.X[constrained_atom_indices], + "... -> n ...", + n=number_of_samples, + ) + reference_a = einops.repeat( + reference_composition.A[constrained_atom_indices], + "... -> n ...", + n=number_of_samples, ) - torch.testing.assert_close(axl.X[:, :number_of_constraints], constrained_x) torch.testing.assert_close( - axl.X[:, number_of_constraints:], original_x[:, number_of_constraints:] + constrained_compositions.X[:, constrained_atom_indices], reference_x + ) + torch.testing.assert_close( + constrained_compositions.A[:, constrained_atom_indices], reference_a + ) + + def test_combine_noised_and_denoised_compositions( + self, + constrained_pc_generator, + constrained_atom_indices, + number_of_samples, + number_of_atoms, + spatial_dimension, + device, + ) -> AXL: + + noised_mask = torch.zeros(number_of_atoms, dtype=torch.bool).to(device) + noised_mask[constrained_atom_indices] = True + + noised_compositions = AXL( + A=torch.zeros(number_of_samples, number_of_atoms).to(device), + X=torch.zeros(number_of_samples, number_of_atoms, spatial_dimension).to( + device + ), + L=0.0, + ) + + denoised_compositions = AXL( + A=torch.ones(number_of_samples, number_of_atoms).to(device), + X=torch.ones(number_of_samples, number_of_atoms, spatial_dimension).to( + device + ), + L=0.0, + ) + + combined_compositions = ( + constrained_pc_generator._combine_noised_and_denoised_compositions( + noised_compositions, denoised_compositions + ) ) + + assert (combined_compositions.X[:, noised_mask] == 0.0).all() + assert (combined_compositions.X[:, ~noised_mask] == 1.0).all() + assert (combined_compositions.A[:, noised_mask] == 0.0).all() + assert (combined_compositions.A[:, ~noised_mask] == 1.0).all() From a2a62316276f392dc6f89ac4d78fc2ad1fa760ac Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Mon, 2 Dec 2024 16:16:01 -0500 Subject: [PATCH 27/43] Modified code in experiments. Probably broken. --- .../analytic_score/repaint/repaint_with_analytic_score.py | 5 +++-- experiments/sampling_sota_model/repaint_with_sota_score.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/experiments/analysis/analytic_score/repaint/repaint_with_analytic_score.py b/experiments/analysis/analytic_score/repaint/repaint_with_analytic_score.py index 2ea37960..e75550c6 100644 --- a/experiments/analysis/analytic_score/repaint/repaint_with_analytic_score.py +++ b/experiments/analysis/analytic_score/repaint/repaint_with_analytic_score.py @@ -8,7 +8,8 @@ from diffusion_for_multi_scale_molecular_dynamics.analysis import \ PLOT_STYLE_PATH from diffusion_for_multi_scale_molecular_dynamics.generators.constrained_langevin_generator import ( - ConstrainedLangevinGenerator, ConstrainedLangevinGeneratorParameters) + ConstrainedLangevinGeneratorParameters, + ConstrainedPredictorCorrectorAXLGenerator) from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.analytical_score_network import ( AnalyticalScoreNetwork, AnalyticalScoreNetworkParameters) from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ @@ -98,7 +99,7 @@ record_samples=True, ) - position_generator = ConstrainedLangevinGenerator( + position_generator = ConstrainedPredictorCorrectorAXLGenerator( noise_parameters=noise_parameters, sampling_parameters=sampling_parameters, sigma_normalized_score_network=sigma_normalized_score_network, diff --git a/experiments/sampling_sota_model/repaint_with_sota_score.py b/experiments/sampling_sota_model/repaint_with_sota_score.py index 8e8d7dad..0c45547b 100644 --- a/experiments/sampling_sota_model/repaint_with_sota_score.py +++ b/experiments/sampling_sota_model/repaint_with_sota_score.py @@ -11,7 +11,8 @@ from diffusion_for_multi_scale_molecular_dynamics.analysis import ( PLEASANT_FIG_SIZE, PLOT_STYLE_PATH) from diffusion_for_multi_scale_molecular_dynamics.generators.constrained_langevin_generator import ( - ConstrainedLangevinGenerator, ConstrainedLangevinGeneratorParameters) + ConstrainedLangevinGeneratorParameters, + ConstrainedPredictorCorrectorAXLGenerator) from diffusion_for_multi_scale_molecular_dynamics.models.instantiate_diffusion_model import \ load_diffusion_model from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ @@ -84,7 +85,7 @@ record_samples=True, ) - position_generator = ConstrainedLangevinGenerator( + position_generator = ConstrainedPredictorCorrectorAXLGenerator( noise_parameters=noise_parameters, sampling_parameters=sampling_parameters, sigma_normalized_score_network=sigma_normalized_score_network, From bb9aed79c05d56084cb4485917a5cf51275bbb08 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Mon, 2 Dec 2024 19:18:54 -0500 Subject: [PATCH 28/43] Let's try to sample with constraints! --- .../constrained_langevin_generator.py | 3 ++ .../sample_diffusion.py | 54 +++++++++++++------ .../utils/ovito_utils.py | 13 +++++ 3 files changed, 55 insertions(+), 15 deletions(-) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py b/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py index 06dc9455..924010e7 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py @@ -30,6 +30,9 @@ def __init__( """Init method.""" self.generator = generator + if hasattr(self.generator, "sample_trajectory_recorder"): + self.sample_trajectory_recorder = self.generator.sample_trajectory_recorder + self.number_of_atoms = self.generator.number_of_atoms self.num_classes = self.generator.num_classes diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py b/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py index 1fcf2e9e..a296749c 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py @@ -16,8 +16,12 @@ ElementTypes from diffusion_for_multi_scale_molecular_dynamics.generators.axl_generator import \ SamplingParameters +from diffusion_for_multi_scale_molecular_dynamics.generators.constrained_langevin_generator import \ + ConstrainedPredictorCorrectorAXLGenerator from diffusion_for_multi_scale_molecular_dynamics.generators.instantiate_generator import \ instantiate_generator +from diffusion_for_multi_scale_molecular_dynamics.generators.langevin_generator import \ + LangevinGenerator from diffusion_for_multi_scale_molecular_dynamics.generators.load_sampling_parameters import \ load_sampling_parameters from diffusion_for_multi_scale_molecular_dynamics.models.axl_diffusion_lightning_model import \ @@ -36,6 +40,8 @@ get_git_hash, setup_console_logger) from diffusion_for_multi_scale_molecular_dynamics.utils.main_utils import \ load_and_backup_hyperparameters +from diffusion_for_multi_scale_molecular_dynamics.utils.ovito_utils import \ + get_composition_from_cif_file logger = logging.getLogger(__name__) @@ -51,6 +57,11 @@ def main(args: Optional[Any] = None): required=True, help="config file with sampling parameters in yaml format.", ) + parser.add_argument( + "--path_to_constraint_cif_file", required=False, + help="path to a cif file with constrained positions." + ) + parser.add_argument( "--checkpoint", required=True, help="path to checkpoint model to be loaded." ) @@ -101,12 +112,35 @@ def main(args: Optional[Any] = None): elements = hyper_params["elements"] oracle_parameters = create_energy_oracle_parameters(hyper_params["oracle"], elements) - create_samples_and_write_to_disk( + axl_network = get_axl_network(args.checkpoint) + + logger.info("Instantiate generator...") + raw_generator = instantiate_generator( + sampling_parameters=sampling_parameters, noise_parameters=noise_parameters, + axl_network=axl_network, + ) + + if 'constrained_sampling' in hyper_params: + constrained_atom_indices = torch.Tensor(hyper_params['constrained_atom_indices']) + cif_file_path = Path(args.path_to_constraint_cif_file) + assert cif_file_path.is_file(), "The constraint cif file does not exist." + + reference_composition = get_composition_from_cif_file(cif_file_path=cif_file_path, + elements=elements, + device=device) + + generator = ConstrainedPredictorCorrectorAXLGenerator(raw_generator, + reference_composition, + constrained_atom_indices) + else: + generator = raw_generator + + create_samples_and_write_to_disk( + generator=generator, sampling_parameters=sampling_parameters, oracle_parameters=oracle_parameters, device=device, - checkpoint_path=args.checkpoint, output_path=args.output, ) @@ -152,11 +186,10 @@ def get_axl_network(checkpoint_path: Union[str, Path]) -> ScoreNetwork: def create_samples_and_write_to_disk( - noise_parameters: NoiseParameters, + generator: LangevinGenerator, sampling_parameters: SamplingParameters, oracle_parameters: Union[OracleParameters, None], device: torch.device, - checkpoint_path: Union[str, Path], output_path: Union[str, Path], ): """Create Samples and write to disk. @@ -173,19 +206,10 @@ def create_samples_and_write_to_disk( Returns: None """ - axl_network = get_axl_network(checkpoint_path) - - logger.info("Instantiate generator...") - position_generator = instantiate_generator( - sampling_parameters=sampling_parameters, - noise_parameters=noise_parameters, - axl_network=axl_network, - ) - logger.info("Generating samples...") with torch.no_grad(): samples_batch = create_batch_of_samples( - generator=position_generator, + generator=generator, sampling_parameters=sampling_parameters, device=device, ) @@ -207,7 +231,7 @@ def create_samples_and_write_to_disk( if sampling_parameters.record_samples: logger.info("Writing sampling trajectories to disk...") - position_generator.sample_trajectory_recorder.write_to_pickle( + generator.sample_trajectory_recorder.write_to_pickle( output_directory / "trajectories.pt" ) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/utils/ovito_utils.py b/src/diffusion_for_multi_scale_molecular_dynamics/utils/ovito_utils.py index b58722c6..5d6adb6e 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/utils/ovito_utils.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/utils/ovito_utils.py @@ -10,6 +10,7 @@ import numpy as np import ovito +import torch from ovito.io import import_file from ovito.modifiers import (AffineTransformationModifier, CombineDatasetsModifier, CreateBondsModifier) @@ -26,6 +27,18 @@ _cif_file_name_template = "diffusion_positions_step_{time_index}.cif" +def get_composition_from_cif_file(cif_file_path: Path, elements: list[str], device): + """Get composition from a cif file.""" + structure = Structure.from_file(cif_file_path) + element_types = ElementTypes(elements) + + a = torch.Tensor([element_types.get_element_id(s.name) for s in structure.species]).to(torch.int64).to(device) + x = torch.from_numpy(structure.frac_coords).to(torch.float32).to(device) + lattice = torch.from_numpy(structure.lattice.matrix).to(torch.float32).to(device) + composition = AXL(A=a, X=x, L=lattice) + return composition + + def create_cif_files( elements: list[str], visualization_artifacts_path: Path, From 43303b124a6c033df7ec975ee299d436887983f7 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Mon, 2 Dec 2024 19:36:03 -0500 Subject: [PATCH 29/43] a bit more logging. --- .../sample_diffusion.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py b/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py index a296749c..f0a4c3fe 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py @@ -122,9 +122,12 @@ def main(args: Optional[Any] = None): ) if 'constrained_sampling' in hyper_params: + logger.info("Constrained Sampling is activated") constrained_atom_indices = torch.Tensor(hyper_params['constrained_atom_indices']) cif_file_path = Path(args.path_to_constraint_cif_file) assert cif_file_path.is_file(), "The constraint cif file does not exist." + logger.info(f"Constrained cif file is {cif_file_path}") + logger.info(f"Constrained atom indices are {constrained_atom_indices}") reference_composition = get_composition_from_cif_file(cif_file_path=cif_file_path, elements=elements, From dfe3cd82ffbc060464318cf0c4013c1ff50aa165 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Mon, 2 Dec 2024 19:38:22 -0500 Subject: [PATCH 30/43] Fix bjork. --- .../sample_diffusion.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py b/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py index f0a4c3fe..e35e1b39 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py @@ -123,7 +123,8 @@ def main(args: Optional[Any] = None): if 'constrained_sampling' in hyper_params: logger.info("Constrained Sampling is activated") - constrained_atom_indices = torch.Tensor(hyper_params['constrained_atom_indices']) + constraint_dict = hyper_params['constrained_sampling'] + constrained_atom_indices = torch.Tensor(constraint_dict['constrained_atom_indices']) cif_file_path = Path(args.path_to_constraint_cif_file) assert cif_file_path.is_file(), "The constraint cif file does not exist." logger.info(f"Constrained cif file is {cif_file_path}") From c4cf40e734e49f454c1c27d40e7c3c4585d12f37 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Mon, 2 Dec 2024 19:52:48 -0500 Subject: [PATCH 31/43] Fix bjork. --- .../sample_diffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py b/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py index e35e1b39..2d575907 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py @@ -124,7 +124,7 @@ def main(args: Optional[Any] = None): if 'constrained_sampling' in hyper_params: logger.info("Constrained Sampling is activated") constraint_dict = hyper_params['constrained_sampling'] - constrained_atom_indices = torch.Tensor(constraint_dict['constrained_atom_indices']) + constrained_atom_indices = torch.Tensor(constraint_dict['constrained_atom_indices']).to(torch.int64) cif_file_path = Path(args.path_to_constraint_cif_file) assert cif_file_path.is_file(), "The constraint cif file does not exist." logger.info(f"Constrained cif file is {cif_file_path}") From 16bc27d1dcff4f6544c1fc6cbdca7cdf683ce15c Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Mon, 2 Dec 2024 21:32:19 -0500 Subject: [PATCH 32/43] Use a pickle for constraints. --- .../sample_diffusion.py | 22 ++++++++----------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py b/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py index 2d575907..713d5133 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py @@ -40,8 +40,6 @@ get_git_hash, setup_console_logger) from diffusion_for_multi_scale_molecular_dynamics.utils.main_utils import \ load_and_backup_hyperparameters -from diffusion_for_multi_scale_molecular_dynamics.utils.ovito_utils import \ - get_composition_from_cif_file logger = logging.getLogger(__name__) @@ -58,8 +56,8 @@ def main(args: Optional[Any] = None): help="config file with sampling parameters in yaml format.", ) parser.add_argument( - "--path_to_constraint_cif_file", required=False, - help="path to a cif file with constrained positions." + "--path_to_constraint_data_pickle", required=False, + help="path to a pickle that contains a reference compositions and fixed atom indices." ) parser.add_argument( @@ -121,18 +119,16 @@ def main(args: Optional[Any] = None): axl_network=axl_network, ) - if 'constrained_sampling' in hyper_params: + if args.path_to_constraint_data_pickle: logger.info("Constrained Sampling is activated") - constraint_dict = hyper_params['constrained_sampling'] - constrained_atom_indices = torch.Tensor(constraint_dict['constrained_atom_indices']).to(torch.int64) - cif_file_path = Path(args.path_to_constraint_cif_file) - assert cif_file_path.is_file(), "The constraint cif file does not exist." - logger.info(f"Constrained cif file is {cif_file_path}") + constraint_data_pickle_path = Path(args.path_to_constraint_data_pickle) + assert constraint_data_pickle_path.is_file(), "The constraint data pickle does not exist." + + constraint_data = torch.load(constraint_data_pickle_path) + constrained_atom_indices = constraint_data["constrained_atom_indices"] logger.info(f"Constrained atom indices are {constrained_atom_indices}") - reference_composition = get_composition_from_cif_file(cif_file_path=cif_file_path, - elements=elements, - device=device) + reference_composition = constraint_data["reference_composition"] generator = ConstrainedPredictorCorrectorAXLGenerator(raw_generator, reference_composition, From c745b04130994a2c19673e6b53f0c5e5945f7415 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Tue, 3 Dec 2024 07:40:36 -0500 Subject: [PATCH 33/43] Combine noised and denoised during corrector steps. --- .../generators/constrained_langevin_generator.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py b/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py index 924010e7..0c597c44 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py @@ -132,7 +132,7 @@ def sample( ) -> AXL: """Sample. - This method draws samples, imposing the satisfaction of positional constraints. + This method draws samples, imposing the satisfaction of atomic constraints. Args: number_of_samples : number of samples to draw. @@ -172,9 +172,12 @@ def sample( ) for _ in range(self.generator.number_of_corrector_steps): - composition_i = self.generator.corrector_step( + corrected_composition_i = self.generator.corrector_step( composition_i, i, unit_cell, forces ) + composition_i = self._combine_noised_and_denoised_compositions( + corrected_composition_i, denoised_composition_i + ) composition_ip1 = composition_i From 8d57ed86db28e222258341e2b5cff8cca748b620 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Tue, 3 Dec 2024 08:08:58 -0500 Subject: [PATCH 34/43] Turn off repaint in corrector step. --- .../generators/constrained_langevin_generator.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py b/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py index 0c597c44..127cfd4f 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py @@ -172,12 +172,9 @@ def sample( ) for _ in range(self.generator.number_of_corrector_steps): - corrected_composition_i = self.generator.corrector_step( + composition_i = self.generator.corrector_step( composition_i, i, unit_cell, forces ) - composition_i = self._combine_noised_and_denoised_compositions( - corrected_composition_i, denoised_composition_i - ) composition_ip1 = composition_i From 918c50bebef918e4428b4f91ce3fba53960a198a Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Tue, 3 Dec 2024 07:40:36 -0500 Subject: [PATCH 35/43] Combine noised and denoised during corrector steps. --- .../generators/constrained_langevin_generator.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py b/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py index 127cfd4f..0c597c44 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py @@ -172,9 +172,12 @@ def sample( ) for _ in range(self.generator.number_of_corrector_steps): - composition_i = self.generator.corrector_step( + corrected_composition_i = self.generator.corrector_step( composition_i, i, unit_cell, forces ) + composition_i = self._combine_noised_and_denoised_compositions( + corrected_composition_i, denoised_composition_i + ) composition_ip1 = composition_i From a71c4cf8929e9ed74b49e6fcd29aefd0b69aee59 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Tue, 3 Dec 2024 08:08:58 -0500 Subject: [PATCH 36/43] Turn off repaint in corrector step. --- .../generators/constrained_langevin_generator.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py b/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py index 0c597c44..127cfd4f 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py @@ -172,12 +172,9 @@ def sample( ) for _ in range(self.generator.number_of_corrector_steps): - corrected_composition_i = self.generator.corrector_step( + composition_i = self.generator.corrector_step( composition_i, i, unit_cell, forces ) - composition_i = self._combine_noised_and_denoised_compositions( - corrected_composition_i, denoised_composition_i - ) composition_ip1 = composition_i From c1d2ad123d674969ccd2174e54164d18ebc62309 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Mon, 2 Dec 2024 16:15:16 -0500 Subject: [PATCH 37/43] Cleaner repaint. --- .../generators/constrained_langevin_generator.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py b/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py index 127cfd4f..06dc9455 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py @@ -30,9 +30,6 @@ def __init__( """Init method.""" self.generator = generator - if hasattr(self.generator, "sample_trajectory_recorder"): - self.sample_trajectory_recorder = self.generator.sample_trajectory_recorder - self.number_of_atoms = self.generator.number_of_atoms self.num_classes = self.generator.num_classes @@ -132,7 +129,7 @@ def sample( ) -> AXL: """Sample. - This method draws samples, imposing the satisfaction of atomic constraints. + This method draws samples, imposing the satisfaction of positional constraints. Args: number_of_samples : number of samples to draw. From 891416a68f6c9c869c785ef68ef1f4ca56723841 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Mon, 2 Dec 2024 19:18:54 -0500 Subject: [PATCH 38/43] Let's try to sample with constraints! --- .../constrained_langevin_generator.py | 3 +++ .../sample_diffusion.py | 22 +++++++++---------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py b/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py index 06dc9455..924010e7 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py @@ -30,6 +30,9 @@ def __init__( """Init method.""" self.generator = generator + if hasattr(self.generator, "sample_trajectory_recorder"): + self.sample_trajectory_recorder = self.generator.sample_trajectory_recorder + self.number_of_atoms = self.generator.number_of_atoms self.num_classes = self.generator.num_classes diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py b/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py index 713d5133..a296749c 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py @@ -40,6 +40,8 @@ get_git_hash, setup_console_logger) from diffusion_for_multi_scale_molecular_dynamics.utils.main_utils import \ load_and_backup_hyperparameters +from diffusion_for_multi_scale_molecular_dynamics.utils.ovito_utils import \ + get_composition_from_cif_file logger = logging.getLogger(__name__) @@ -56,8 +58,8 @@ def main(args: Optional[Any] = None): help="config file with sampling parameters in yaml format.", ) parser.add_argument( - "--path_to_constraint_data_pickle", required=False, - help="path to a pickle that contains a reference compositions and fixed atom indices." + "--path_to_constraint_cif_file", required=False, + help="path to a cif file with constrained positions." ) parser.add_argument( @@ -119,16 +121,14 @@ def main(args: Optional[Any] = None): axl_network=axl_network, ) - if args.path_to_constraint_data_pickle: - logger.info("Constrained Sampling is activated") - constraint_data_pickle_path = Path(args.path_to_constraint_data_pickle) - assert constraint_data_pickle_path.is_file(), "The constraint data pickle does not exist." + if 'constrained_sampling' in hyper_params: + constrained_atom_indices = torch.Tensor(hyper_params['constrained_atom_indices']) + cif_file_path = Path(args.path_to_constraint_cif_file) + assert cif_file_path.is_file(), "The constraint cif file does not exist." - constraint_data = torch.load(constraint_data_pickle_path) - constrained_atom_indices = constraint_data["constrained_atom_indices"] - logger.info(f"Constrained atom indices are {constrained_atom_indices}") - - reference_composition = constraint_data["reference_composition"] + reference_composition = get_composition_from_cif_file(cif_file_path=cif_file_path, + elements=elements, + device=device) generator = ConstrainedPredictorCorrectorAXLGenerator(raw_generator, reference_composition, From a2bb4be79f7a13083c065434c1377be9aca743bb Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Tue, 3 Dec 2024 11:53:47 -0500 Subject: [PATCH 39/43] Fix rebase bjork. --- .../sample_diffusion.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py b/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py index a296749c..713d5133 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py @@ -40,8 +40,6 @@ get_git_hash, setup_console_logger) from diffusion_for_multi_scale_molecular_dynamics.utils.main_utils import \ load_and_backup_hyperparameters -from diffusion_for_multi_scale_molecular_dynamics.utils.ovito_utils import \ - get_composition_from_cif_file logger = logging.getLogger(__name__) @@ -58,8 +56,8 @@ def main(args: Optional[Any] = None): help="config file with sampling parameters in yaml format.", ) parser.add_argument( - "--path_to_constraint_cif_file", required=False, - help="path to a cif file with constrained positions." + "--path_to_constraint_data_pickle", required=False, + help="path to a pickle that contains a reference compositions and fixed atom indices." ) parser.add_argument( @@ -121,14 +119,16 @@ def main(args: Optional[Any] = None): axl_network=axl_network, ) - if 'constrained_sampling' in hyper_params: - constrained_atom_indices = torch.Tensor(hyper_params['constrained_atom_indices']) - cif_file_path = Path(args.path_to_constraint_cif_file) - assert cif_file_path.is_file(), "The constraint cif file does not exist." + if args.path_to_constraint_data_pickle: + logger.info("Constrained Sampling is activated") + constraint_data_pickle_path = Path(args.path_to_constraint_data_pickle) + assert constraint_data_pickle_path.is_file(), "The constraint data pickle does not exist." - reference_composition = get_composition_from_cif_file(cif_file_path=cif_file_path, - elements=elements, - device=device) + constraint_data = torch.load(constraint_data_pickle_path) + constrained_atom_indices = constraint_data["constrained_atom_indices"] + logger.info(f"Constrained atom indices are {constrained_atom_indices}") + + reference_composition = constraint_data["reference_composition"] generator = ConstrainedPredictorCorrectorAXLGenerator(raw_generator, reference_composition, From c687fe250a2227ccd18173bab7288c077ee3a70e Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Tue, 3 Dec 2024 11:58:44 -0500 Subject: [PATCH 40/43] Remove needless function. --- .../utils/ovito_utils.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/utils/ovito_utils.py b/src/diffusion_for_multi_scale_molecular_dynamics/utils/ovito_utils.py index 5d6adb6e..b58722c6 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/utils/ovito_utils.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/utils/ovito_utils.py @@ -10,7 +10,6 @@ import numpy as np import ovito -import torch from ovito.io import import_file from ovito.modifiers import (AffineTransformationModifier, CombineDatasetsModifier, CreateBondsModifier) @@ -27,18 +26,6 @@ _cif_file_name_template = "diffusion_positions_step_{time_index}.cif" -def get_composition_from_cif_file(cif_file_path: Path, elements: list[str], device): - """Get composition from a cif file.""" - structure = Structure.from_file(cif_file_path) - element_types = ElementTypes(elements) - - a = torch.Tensor([element_types.get_element_id(s.name) for s in structure.species]).to(torch.int64).to(device) - x = torch.from_numpy(structure.frac_coords).to(torch.float32).to(device) - lattice = torch.from_numpy(structure.lattice.matrix).to(torch.float32).to(device) - composition = AXL(A=a, X=x, L=lattice) - return composition - - def create_cif_files( elements: list[str], visualization_artifacts_path: Path, From 99bf118202407cd276b759f1be98e7557ca9cb90 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Tue, 3 Dec 2024 13:46:34 -0500 Subject: [PATCH 41/43] Remove broken experiment. --- .../repaint/repaint_with_analytic_score.py | 149 ------------------ 1 file changed, 149 deletions(-) delete mode 100644 experiments/analysis/analytic_score/repaint/repaint_with_analytic_score.py diff --git a/experiments/analysis/analytic_score/repaint/repaint_with_analytic_score.py b/experiments/analysis/analytic_score/repaint/repaint_with_analytic_score.py deleted file mode 100644 index e75550c6..00000000 --- a/experiments/analysis/analytic_score/repaint/repaint_with_analytic_score.py +++ /dev/null @@ -1,149 +0,0 @@ -import logging - -import matplotlib.pyplot as plt -import numpy as np -import torch - -from diffusion_for_multi_scale_molecular_dynamics import ANALYSIS_RESULTS_DIR -from diffusion_for_multi_scale_molecular_dynamics.analysis import \ - PLOT_STYLE_PATH -from diffusion_for_multi_scale_molecular_dynamics.generators.constrained_langevin_generator import ( - ConstrainedLangevinGeneratorParameters, - ConstrainedPredictorCorrectorAXLGenerator) -from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.analytical_score_network import ( - AnalyticalScoreNetwork, AnalyticalScoreNetworkParameters) -from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ - NoiseParameters -from diffusion_for_multi_scale_molecular_dynamics.utils.logging_utils import \ - setup_analysis_logger -from diffusion_for_multi_scale_molecular_dynamics.utils.structure_utils import \ - create_structure -from experiments.analysis.analytic_score.utils import ( - get_samples_harmonic_energy, get_silicon_supercell, get_unit_cells) - -logger = logging.getLogger(__name__) -setup_analysis_logger() - -repaint_dir = ANALYSIS_RESULTS_DIR / "ANALYTIC_SCORE/REPAINT" -repaint_dir.mkdir(exist_ok=True) - -plt.style.use(PLOT_STYLE_PATH) - -if torch.cuda.is_available(): - device = torch.device("cuda") -else: - device = torch.device("cpu") - -kmax = 1 -supercell_factor = 1 -variance_parameter = 0.001 / supercell_factor -number_of_samples = 1000 -total_time_steps = 100 -number_of_corrector_steps = 1 - -acell = 5.43 # Angstroms. - -constrained_relative_coordinates = np.array( - [[0.5, 0.5, 0.25], [0.5, 0.5, 0.5], [0.5, 0.5, 0.75]], dtype=np.float32 -) - -translation = np.array([0.125, 0.125, 0.125]).astype(np.float32) -if __name__ == "__main__": - logger.info("Setting up parameters") - - equilibrium_relative_coordinates = get_silicon_supercell( - supercell_factor=supercell_factor - ).astype(np.float32) - # Translate to avoid atoms right on the cell boundary - equilibrium_relative_coordinates = equilibrium_relative_coordinates + translation - - number_of_atoms, spatial_dimension = equilibrium_relative_coordinates.shape - - logger.info("Creating samples from the exact distribution") - inverse_covariance = ( - torch.diag(torch.ones(number_of_atoms * spatial_dimension)) / variance_parameter - ).to(device) - inverse_covariance = inverse_covariance.reshape( - number_of_atoms, spatial_dimension, number_of_atoms, spatial_dimension - ) - - unit_cells = get_unit_cells( - acell=acell, - spatial_dimension=spatial_dimension, - number_of_samples=number_of_samples, - ) - - noise_parameters = NoiseParameters( - total_time_steps=total_time_steps, sigma_min=0.001, sigma_max=0.5 - ) - - score_network_parameters = AnalyticalScoreNetworkParameters( - number_of_atoms=number_of_atoms, - spatial_dimension=spatial_dimension, - kmax=kmax, - equilibrium_relative_coordinates=torch.from_numpy( - equilibrium_relative_coordinates - ).to(device), - variance_parameter=variance_parameter, - ) - - sigma_normalized_score_network = AnalyticalScoreNetwork(score_network_parameters) - - sampling_parameters = ConstrainedLangevinGeneratorParameters( - number_of_corrector_steps=number_of_corrector_steps, - spatial_dimension=spatial_dimension, - number_of_atoms=number_of_atoms, - number_of_samples=number_of_samples, - cell_dimensions=3 * [acell], - constrained_relative_coordinates=constrained_relative_coordinates, - record_samples=True, - ) - - position_generator = ConstrainedPredictorCorrectorAXLGenerator( - noise_parameters=noise_parameters, - sampling_parameters=sampling_parameters, - sigma_normalized_score_network=sigma_normalized_score_network, - ) - - logger.info("Drawing constrained samples") - samples = position_generator.sample( - number_of_samples=number_of_samples, device=device, unit_cell=unit_cells - ).detach() - - position_generator.sample_trajectory_recorder.write_to_pickle( - repaint_dir / "repaint_trajectories.pkl" - ) - - logger.info("Computing harmonic energies") - sampled_harmonic_energies = get_samples_harmonic_energy( - equilibrium_relative_coordinates, inverse_covariance, samples - ) - - with open(repaint_dir / "harmonic_energies.pt", "wb") as fd: - torch.save(sampled_harmonic_energies, fd) - - logger.info("Creating CIF files") - - lattice_basis_vectors = np.diag([acell, acell, acell]) - species = number_of_atoms * ["Si"] - - relative_coordinates = equilibrium_relative_coordinates - equilibrium_structure = create_structure( - lattice_basis_vectors, relative_coordinates, species - ) - equilibrium_structure.to_file(str(repaint_dir / "equilibrium_positions.cif")) - - relative_coordinates[: len(constrained_relative_coordinates)] = ( - constrained_relative_coordinates - ) - forced_structure = create_structure( - lattice_basis_vectors, relative_coordinates, species - ) - forced_structure.to_file(str(repaint_dir / "forced_constraint_positions.cif")) - - samples_dir = repaint_dir / "samples" - samples_dir.mkdir(exist_ok=True) - - for idx, sample in enumerate(samples.cpu().numpy()): - structure = create_structure(lattice_basis_vectors, sample, species) - structure.to_file(str(samples_dir / f"sample_{idx}.cif")) From cfb6ae93cecc0b2c5e69b8858cab8b7b6dc4cee7 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Tue, 3 Dec 2024 13:47:17 -0500 Subject: [PATCH 42/43] Remove broken experiment. --- .../repaint_with_sota_score.py | 146 ------------------ 1 file changed, 146 deletions(-) delete mode 100644 experiments/sampling_sota_model/repaint_with_sota_score.py diff --git a/experiments/sampling_sota_model/repaint_with_sota_score.py b/experiments/sampling_sota_model/repaint_with_sota_score.py deleted file mode 100644 index 0c45547b..00000000 --- a/experiments/sampling_sota_model/repaint_with_sota_score.py +++ /dev/null @@ -1,146 +0,0 @@ -import logging -import os -from pathlib import Path - -import matplotlib.pyplot as plt -import numpy as np -import torch -import yaml - -from diffusion_for_multi_scale_molecular_dynamics import DATA_DIR -from diffusion_for_multi_scale_molecular_dynamics.analysis import ( - PLEASANT_FIG_SIZE, PLOT_STYLE_PATH) -from diffusion_for_multi_scale_molecular_dynamics.generators.constrained_langevin_generator import ( - ConstrainedLangevinGeneratorParameters, - ConstrainedPredictorCorrectorAXLGenerator) -from diffusion_for_multi_scale_molecular_dynamics.models.instantiate_diffusion_model import \ - load_diffusion_model -from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ - NoiseParameters -from diffusion_for_multi_scale_molecular_dynamics.oracle.lammps import \ - get_energy_and_forces_from_lammps -from diffusion_for_multi_scale_molecular_dynamics.utils.logging_utils import \ - setup_analysis_logger - -logger = logging.getLogger(__name__) -setup_analysis_logger() - - -experiments_dir = Path("/home/mila/r/rousseab/experiments/") -model_dir = experiments_dir / "checkpoints/sota_model/" -state_dict_path = model_dir / "last_model-epoch=199-step=019600_state_dict.ckpt" -config_path = model_dir / "config_backup.yaml" - -repaint_dir = Path("/home/mila/r/rousseab/experiments/draw_sota_samples/repaint") -repaint_dir.mkdir(exist_ok=True) - -plt.style.use(PLOT_STYLE_PATH) - -device = torch.device("cuda") - -number_of_samples = 1000 -total_time_steps = 100 -number_of_corrector_steps = 1 - -acell = 5.43 # Angstroms. -box = np.diag([acell, acell, acell]) - -number_of_atoms, spatial_dimension = 8, 3 -atom_types = np.ones(number_of_atoms, dtype=int) - -constrained_relative_coordinates = np.array( - [[0.5, 0.5, 0.25], [0.5, 0.5, 0.5], [0.5, 0.5, 0.75]], dtype=np.float32 -) - -if __name__ == "__main__": - logger.info("Setting up parameters") - - unit_cells = torch.Tensor(box).repeat(number_of_samples, 1, 1).to(device) - - noise_parameters = NoiseParameters( - total_time_steps=total_time_steps, sigma_min=0.001, sigma_max=0.5 - ) - - logger.info("Loading state dict") - with open(str(state_dict_path), "rb") as fd: - state_dict = torch.load(fd) - - with open(str(config_path), "r") as fd: - hyper_params = yaml.load(fd, Loader=yaml.FullLoader) - logger.info("Instantiate model") - pl_model = load_diffusion_model(hyper_params) - pl_model.load_state_dict(state_dict=state_dict) - pl_model.to(device) - pl_model.eval() - - sigma_normalized_score_network = pl_model.sigma_normalized_score_network - - sampling_parameters = ConstrainedLangevinGeneratorParameters( - number_of_corrector_steps=number_of_corrector_steps, - spatial_dimension=spatial_dimension, - number_of_atoms=number_of_atoms, - number_of_samples=number_of_samples, - cell_dimensions=3 * [acell], - constrained_relative_coordinates=constrained_relative_coordinates, - record_samples=True, - ) - - position_generator = ConstrainedPredictorCorrectorAXLGenerator( - noise_parameters=noise_parameters, - sampling_parameters=sampling_parameters, - sigma_normalized_score_network=sigma_normalized_score_network, - ) - - logger.info("Drawing constrained samples") - - with torch.no_grad(): - samples = position_generator.sample( - number_of_samples=number_of_samples, device=device, unit_cell=unit_cells - ) - - batch_relative_positions = samples.cpu().numpy() - batch_positions = np.dot(batch_relative_positions, box) - - position_generator.sample_trajectory_recorder.write_to_pickle( - repaint_dir / "repaint_trajectories.pkl" - ) - - logger.info("Compute energy from Oracle") - list_energy = [] - - lammps_work_directory = repaint_dir / "samples" - lammps_work_directory.mkdir(exist_ok=True) - - for idx, positions in enumerate(batch_positions): - energy, forces = get_energy_and_forces_from_lammps( - positions, - box, - atom_types, - tmp_work_dir=str(lammps_work_directory), - pair_coeff_dir=DATA_DIR, - ) - list_energy.append(energy) - src = os.path.join(lammps_work_directory, "dump.yaml") - dst = os.path.join(lammps_work_directory, f"dump_{idx}.yaml") - os.rename(src, dst) - - energies = np.array(list_energy) - - with open(repaint_dir / "energies.pt", "wb") as fd: - torch.save(torch.from_numpy(energies), fd) - - logger.info("Plotting energy distributions") - fig = plt.figure(figsize=PLEASANT_FIG_SIZE) - fig.suptitle("Energy Distribution for Repaint Structures,") - - common_params = dict(density=True, bins=50, histtype="stepfilled", alpha=0.25) - - ax1 = fig.add_subplot(111) - ax1.hist(energies, **common_params, label="Sampled Energies", color="red") - - ax1.set_xlabel("Energy (eV)") - ax1.set_ylabel("Density") - ax1.legend(loc="upper right", fancybox=True, shadow=True, ncol=1, fontsize=12) - fig.tight_layout() - fig.savefig(repaint_dir / f"energy_samples_repaint_{number_of_atoms}_atoms.png") - plt.close(fig) From 5f858e6a708e233c4a605288605d822819fe5564 Mon Sep 17 00:00:00 2001 From: Bruno Rousseau Date: Tue, 3 Dec 2024 14:32:26 -0500 Subject: [PATCH 43/43] test constrained sampling. --- tests/test_sample_diffusion.py | 59 +++++++++++++++++++++++++++++----- 1 file changed, 51 insertions(+), 8 deletions(-) diff --git a/tests/test_sample_diffusion.py b/tests/test_sample_diffusion.py index 5f3560db..fe6fc3be 100644 --- a/tests/test_sample_diffusion.py +++ b/tests/test_sample_diffusion.py @@ -15,8 +15,8 @@ OptimizerParameters from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.mlp_score_network import \ MLPScoreNetworkParameters -from diffusion_for_multi_scale_molecular_dynamics.namespace import \ - AXL_COMPOSITION +from diffusion_for_multi_scale_molecular_dynamics.namespace import ( + AXL, AXL_COMPOSITION) from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ NoiseParameters @@ -46,6 +46,19 @@ def cell_dimensions(): return [5.1, 6.2, 7.3] +@pytest.fixture() +def reference_composition(num_atom_types, number_of_atoms, spatial_dimension): + a = torch.randint(0, num_atom_types, (number_of_atoms,)) + x = torch.rand(number_of_atoms, spatial_dimension) + lat = torch.rand(spatial_dimension, spatial_dimension) + return AXL(A=a, X=x, L=lat) + + +@pytest.fixture() +def constrained_atom_indices(number_of_atoms): + return torch.sort(torch.randperm(number_of_atoms)[:number_of_atoms // 2]).values + + @pytest.fixture(params=[True, False]) def record_samples(request): return request.param @@ -77,7 +90,7 @@ def sampling_parameters( @pytest.fixture() -def axl_network(number_of_atoms, noise_parameters, num_atom_types): +def axl_network(number_of_atoms, noise_parameters, num_atom_types, device): score_network_parameters = MLPScoreNetworkParameters( number_of_atoms=number_of_atoms, num_atom_types=num_atom_types, @@ -97,7 +110,7 @@ def axl_network(number_of_atoms, noise_parameters, num_atom_types): diffusion_sampling_parameters=None, ) - model = AXLDiffusionLightningModel(diffusion_params) + model = AXLDiffusionLightningModel(diffusion_params).to(device) return model.axl_network @@ -116,6 +129,20 @@ def config_path(tmp_path, noise_parameters, sampling_parameters): return config_path +@pytest.fixture(params=[True, False]) +def apply_constraint(request): + return request.param + + +@pytest.fixture() +def constraint_data_pickle_path(tmp_path, reference_composition, constrained_atom_indices): + path_to_pickle = tmp_path / "pickle_path.pkl" + data = dict(reference_composition=reference_composition, + constrained_atom_indices=constrained_atom_indices) + torch.save(data, path_to_pickle) + return path_to_pickle + + @pytest.fixture() def checkpoint_path(tmp_path): path_to_checkpoint = tmp_path / "fake_checkpoint.pt" @@ -131,20 +158,24 @@ def output_path(tmp_path): @pytest.fixture() -def args(config_path, checkpoint_path, output_path): +def args(config_path, checkpoint_path, output_path, constraint_data_pickle_path, apply_constraint, device): """Input arguments for main.""" input_args = [ f"--config={config_path}", f"--checkpoint={checkpoint_path}", f"--output={output_path}", - "--device=cpu", + f"--device={device}", ] + if apply_constraint: + input_args.append(f"--path_to_constraint_data_pickle={constraint_data_pickle_path}") + return input_args def test_sample_diffusion( mocker, + device, args, axl_network, output_path, @@ -152,6 +183,9 @@ def test_sample_diffusion( number_of_atoms, spatial_dimension, record_samples, + apply_constraint, + reference_composition, + constrained_atom_indices ): mocker.patch( "diffusion_for_multi_scale_molecular_dynamics.sample_diffusion.get_axl_network", @@ -162,14 +196,23 @@ def test_sample_diffusion( assert (output_path / "samples.pt").exists() samples = torch.load(output_path / "samples.pt") - assert samples[AXL_COMPOSITION].X.shape == ( + compositions = samples[AXL_COMPOSITION] + + assert compositions.X.shape == ( number_of_samples, number_of_atoms, spatial_dimension, ) - assert samples[AXL_COMPOSITION].A.shape == ( + assert compositions.A.shape == ( number_of_samples, number_of_atoms, ) assert (output_path / "trajectories.pt").exists() == record_samples + + if apply_constraint: + reference_x = reference_composition.X[constrained_atom_indices].to(device) + reference_a = reference_composition.A[constrained_atom_indices].to(device) + + assert (compositions.X[:, constrained_atom_indices] == reference_x).all() + assert (compositions.A[:, constrained_atom_indices] == reference_a).all()