diff --git a/experiments/analysis/analytic_score/repaint/repaint_with_analytic_score.py b/experiments/analysis/analytic_score/repaint/repaint_with_analytic_score.py deleted file mode 100644 index 2ea37960..00000000 --- a/experiments/analysis/analytic_score/repaint/repaint_with_analytic_score.py +++ /dev/null @@ -1,148 +0,0 @@ -import logging - -import matplotlib.pyplot as plt -import numpy as np -import torch - -from diffusion_for_multi_scale_molecular_dynamics import ANALYSIS_RESULTS_DIR -from diffusion_for_multi_scale_molecular_dynamics.analysis import \ - PLOT_STYLE_PATH -from diffusion_for_multi_scale_molecular_dynamics.generators.constrained_langevin_generator import ( - ConstrainedLangevinGenerator, ConstrainedLangevinGeneratorParameters) -from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.analytical_score_network import ( - AnalyticalScoreNetwork, AnalyticalScoreNetworkParameters) -from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ - NoiseParameters -from diffusion_for_multi_scale_molecular_dynamics.utils.logging_utils import \ - setup_analysis_logger -from diffusion_for_multi_scale_molecular_dynamics.utils.structure_utils import \ - create_structure -from experiments.analysis.analytic_score.utils import ( - get_samples_harmonic_energy, get_silicon_supercell, get_unit_cells) - -logger = logging.getLogger(__name__) -setup_analysis_logger() - -repaint_dir = ANALYSIS_RESULTS_DIR / "ANALYTIC_SCORE/REPAINT" -repaint_dir.mkdir(exist_ok=True) - -plt.style.use(PLOT_STYLE_PATH) - -if torch.cuda.is_available(): - device = torch.device("cuda") -else: - device = torch.device("cpu") - -kmax = 1 -supercell_factor = 1 -variance_parameter = 0.001 / supercell_factor -number_of_samples = 1000 -total_time_steps = 100 -number_of_corrector_steps = 1 - -acell = 5.43 # Angstroms. - -constrained_relative_coordinates = np.array( - [[0.5, 0.5, 0.25], [0.5, 0.5, 0.5], [0.5, 0.5, 0.75]], dtype=np.float32 -) - -translation = np.array([0.125, 0.125, 0.125]).astype(np.float32) -if __name__ == "__main__": - logger.info("Setting up parameters") - - equilibrium_relative_coordinates = get_silicon_supercell( - supercell_factor=supercell_factor - ).astype(np.float32) - # Translate to avoid atoms right on the cell boundary - equilibrium_relative_coordinates = equilibrium_relative_coordinates + translation - - number_of_atoms, spatial_dimension = equilibrium_relative_coordinates.shape - - logger.info("Creating samples from the exact distribution") - inverse_covariance = ( - torch.diag(torch.ones(number_of_atoms * spatial_dimension)) / variance_parameter - ).to(device) - inverse_covariance = inverse_covariance.reshape( - number_of_atoms, spatial_dimension, number_of_atoms, spatial_dimension - ) - - unit_cells = get_unit_cells( - acell=acell, - spatial_dimension=spatial_dimension, - number_of_samples=number_of_samples, - ) - - noise_parameters = NoiseParameters( - total_time_steps=total_time_steps, sigma_min=0.001, sigma_max=0.5 - ) - - score_network_parameters = AnalyticalScoreNetworkParameters( - number_of_atoms=number_of_atoms, - spatial_dimension=spatial_dimension, - kmax=kmax, - equilibrium_relative_coordinates=torch.from_numpy( - equilibrium_relative_coordinates - ).to(device), - variance_parameter=variance_parameter, - ) - - sigma_normalized_score_network = AnalyticalScoreNetwork(score_network_parameters) - - sampling_parameters = ConstrainedLangevinGeneratorParameters( - number_of_corrector_steps=number_of_corrector_steps, - spatial_dimension=spatial_dimension, - number_of_atoms=number_of_atoms, - number_of_samples=number_of_samples, - cell_dimensions=3 * [acell], - constrained_relative_coordinates=constrained_relative_coordinates, - record_samples=True, - ) - - position_generator = ConstrainedLangevinGenerator( - noise_parameters=noise_parameters, - sampling_parameters=sampling_parameters, - sigma_normalized_score_network=sigma_normalized_score_network, - ) - - logger.info("Drawing constrained samples") - samples = position_generator.sample( - number_of_samples=number_of_samples, device=device, unit_cell=unit_cells - ).detach() - - position_generator.sample_trajectory_recorder.write_to_pickle( - repaint_dir / "repaint_trajectories.pkl" - ) - - logger.info("Computing harmonic energies") - sampled_harmonic_energies = get_samples_harmonic_energy( - equilibrium_relative_coordinates, inverse_covariance, samples - ) - - with open(repaint_dir / "harmonic_energies.pt", "wb") as fd: - torch.save(sampled_harmonic_energies, fd) - - logger.info("Creating CIF files") - - lattice_basis_vectors = np.diag([acell, acell, acell]) - species = number_of_atoms * ["Si"] - - relative_coordinates = equilibrium_relative_coordinates - equilibrium_structure = create_structure( - lattice_basis_vectors, relative_coordinates, species - ) - equilibrium_structure.to_file(str(repaint_dir / "equilibrium_positions.cif")) - - relative_coordinates[: len(constrained_relative_coordinates)] = ( - constrained_relative_coordinates - ) - forced_structure = create_structure( - lattice_basis_vectors, relative_coordinates, species - ) - forced_structure.to_file(str(repaint_dir / "forced_constraint_positions.cif")) - - samples_dir = repaint_dir / "samples" - samples_dir.mkdir(exist_ok=True) - - for idx, sample in enumerate(samples.cpu().numpy()): - structure = create_structure(lattice_basis_vectors, sample, species) - structure.to_file(str(samples_dir / f"sample_{idx}.cif")) diff --git a/experiments/sampling_sota_model/repaint_with_sota_score.py b/experiments/sampling_sota_model/repaint_with_sota_score.py deleted file mode 100644 index 8e8d7dad..00000000 --- a/experiments/sampling_sota_model/repaint_with_sota_score.py +++ /dev/null @@ -1,145 +0,0 @@ -import logging -import os -from pathlib import Path - -import matplotlib.pyplot as plt -import numpy as np -import torch -import yaml - -from diffusion_for_multi_scale_molecular_dynamics import DATA_DIR -from diffusion_for_multi_scale_molecular_dynamics.analysis import ( - PLEASANT_FIG_SIZE, PLOT_STYLE_PATH) -from diffusion_for_multi_scale_molecular_dynamics.generators.constrained_langevin_generator import ( - ConstrainedLangevinGenerator, ConstrainedLangevinGeneratorParameters) -from diffusion_for_multi_scale_molecular_dynamics.models.instantiate_diffusion_model import \ - load_diffusion_model -from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ - NoiseParameters -from diffusion_for_multi_scale_molecular_dynamics.oracle.lammps import \ - get_energy_and_forces_from_lammps -from diffusion_for_multi_scale_molecular_dynamics.utils.logging_utils import \ - setup_analysis_logger - -logger = logging.getLogger(__name__) -setup_analysis_logger() - - -experiments_dir = Path("/home/mila/r/rousseab/experiments/") -model_dir = experiments_dir / "checkpoints/sota_model/" -state_dict_path = model_dir / "last_model-epoch=199-step=019600_state_dict.ckpt" -config_path = model_dir / "config_backup.yaml" - -repaint_dir = Path("/home/mila/r/rousseab/experiments/draw_sota_samples/repaint") -repaint_dir.mkdir(exist_ok=True) - -plt.style.use(PLOT_STYLE_PATH) - -device = torch.device("cuda") - -number_of_samples = 1000 -total_time_steps = 100 -number_of_corrector_steps = 1 - -acell = 5.43 # Angstroms. -box = np.diag([acell, acell, acell]) - -number_of_atoms, spatial_dimension = 8, 3 -atom_types = np.ones(number_of_atoms, dtype=int) - -constrained_relative_coordinates = np.array( - [[0.5, 0.5, 0.25], [0.5, 0.5, 0.5], [0.5, 0.5, 0.75]], dtype=np.float32 -) - -if __name__ == "__main__": - logger.info("Setting up parameters") - - unit_cells = torch.Tensor(box).repeat(number_of_samples, 1, 1).to(device) - - noise_parameters = NoiseParameters( - total_time_steps=total_time_steps, sigma_min=0.001, sigma_max=0.5 - ) - - logger.info("Loading state dict") - with open(str(state_dict_path), "rb") as fd: - state_dict = torch.load(fd) - - with open(str(config_path), "r") as fd: - hyper_params = yaml.load(fd, Loader=yaml.FullLoader) - logger.info("Instantiate model") - pl_model = load_diffusion_model(hyper_params) - pl_model.load_state_dict(state_dict=state_dict) - pl_model.to(device) - pl_model.eval() - - sigma_normalized_score_network = pl_model.sigma_normalized_score_network - - sampling_parameters = ConstrainedLangevinGeneratorParameters( - number_of_corrector_steps=number_of_corrector_steps, - spatial_dimension=spatial_dimension, - number_of_atoms=number_of_atoms, - number_of_samples=number_of_samples, - cell_dimensions=3 * [acell], - constrained_relative_coordinates=constrained_relative_coordinates, - record_samples=True, - ) - - position_generator = ConstrainedLangevinGenerator( - noise_parameters=noise_parameters, - sampling_parameters=sampling_parameters, - sigma_normalized_score_network=sigma_normalized_score_network, - ) - - logger.info("Drawing constrained samples") - - with torch.no_grad(): - samples = position_generator.sample( - number_of_samples=number_of_samples, device=device, unit_cell=unit_cells - ) - - batch_relative_positions = samples.cpu().numpy() - batch_positions = np.dot(batch_relative_positions, box) - - position_generator.sample_trajectory_recorder.write_to_pickle( - repaint_dir / "repaint_trajectories.pkl" - ) - - logger.info("Compute energy from Oracle") - list_energy = [] - - lammps_work_directory = repaint_dir / "samples" - lammps_work_directory.mkdir(exist_ok=True) - - for idx, positions in enumerate(batch_positions): - energy, forces = get_energy_and_forces_from_lammps( - positions, - box, - atom_types, - tmp_work_dir=str(lammps_work_directory), - pair_coeff_dir=DATA_DIR, - ) - list_energy.append(energy) - src = os.path.join(lammps_work_directory, "dump.yaml") - dst = os.path.join(lammps_work_directory, f"dump_{idx}.yaml") - os.rename(src, dst) - - energies = np.array(list_energy) - - with open(repaint_dir / "energies.pt", "wb") as fd: - torch.save(torch.from_numpy(energies), fd) - - logger.info("Plotting energy distributions") - fig = plt.figure(figsize=PLEASANT_FIG_SIZE) - fig.suptitle("Energy Distribution for Repaint Structures,") - - common_params = dict(density=True, bins=50, histtype="stepfilled", alpha=0.25) - - ax1 = fig.add_subplot(111) - ax1.hist(energies, **common_params, label="Sampled Energies", color="red") - - ax1.set_xlabel("Energy (eV)") - ax1.set_ylabel("Density") - ax1.legend(loc="upper right", fancybox=True, shadow=True, ncol=1, fontsize=12) - fig.tight_layout() - fig.savefig(repaint_dir / f"energy_samples_repaint_{number_of_atoms}_atoms.png") - plt.close(fig) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py b/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py index 99db71ce..924010e7 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py @@ -1,91 +1,138 @@ -from dataclasses import dataclass - -import numpy as np +import einops import torch from tqdm import tqdm from diffusion_for_multi_scale_molecular_dynamics.generators.langevin_generator import \ LangevinGenerator -from diffusion_for_multi_scale_molecular_dynamics.generators.predictor_corrector_axl_generator import \ - PredictorCorrectorSamplingParameters -from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.score_network import \ - ScoreNetwork from diffusion_for_multi_scale_molecular_dynamics.namespace import AXL -from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ - NoiseParameters +from diffusion_for_multi_scale_molecular_dynamics.noisers.atom_types_noiser import \ + AtomTypesNoiser from diffusion_for_multi_scale_molecular_dynamics.noisers.relative_coordinates_noiser import \ RelativeCoordinatesNoiser +from diffusion_for_multi_scale_molecular_dynamics.utils.d3pm_utils import \ + class_index_to_onehot -@dataclass(kw_only=True) -class ConstrainedLangevinGeneratorParameters(PredictorCorrectorSamplingParameters): - """Hyper-parameters for diffusion sampling with the predictor-corrector algorithm.""" - - algorithm: str = "constrained_langevin" - constrained_relative_coordinates: ( - np.ndarray - ) # the positions that must be satisfied at the end of sampling. - +class ConstrainedPredictorCorrectorAXLGenerator: + """Constrained Predictor Corrector AXL Generator. -class ConstrainedLangevinGenerator(LangevinGenerator): - """Constrained Annealed Langevin Dynamics Generator. - - This generator implements a basic version of the inpainting algorithm presented in the - paper + This class constrains the input PC generator following a basic version of the inpainting algorithm + presented in the paper "RePaint: Inpainting using Denoising Diffusion Probabilistic Models". """ def __init__( self, - noise_parameters: NoiseParameters, - sampling_parameters: ConstrainedLangevinGeneratorParameters, - axl_network: ScoreNetwork, + generator: LangevinGenerator, + reference_composition: AXL, + constrained_atom_indices: torch.Tensor, ): """Init method.""" - super().__init__(noise_parameters, sampling_parameters, axl_network) + self.generator = generator + + if hasattr(self.generator, "sample_trajectory_recorder"): + self.sample_trajectory_recorder = self.generator.sample_trajectory_recorder - self.constraint_relative_coordinates = torch.from_numpy( - sampling_parameters.constrained_relative_coordinates - ) # TODO constraint the atom type as well + self.number_of_atoms = self.generator.number_of_atoms + self.num_classes = self.generator.num_classes + + self.reference_composition = reference_composition + self.constraint_indices = constrained_atom_indices assert ( - len(self.constraint_relative_coordinates.shape) == 2 + len(self.reference_composition.X.shape) == 2 ), "The constrained relative coordinates have the wrong shape" - number_of_constraints, spatial_dimensions = ( - self.constraint_relative_coordinates.shape - ) - assert ( - number_of_constraints <= self.number_of_atoms - ), "There are more constrained positions than atoms!" assert ( - spatial_dimensions <= self.spatial_dimension - ), "The spatial dimension of the constrained positions is inconsistent" + len(self.reference_composition.A.shape) == 1 + ), "The constrained atom types have the wrong shape" - # Without loss of generality, we impose that the first positions are constrained. - # This should have no consequence for a permutation equivariant model. - self.constraint_mask = torch.zeros(self.number_of_atoms, dtype=bool) - self.constraint_mask[:number_of_constraints] = True + assert ( + len(constrained_atom_indices.shape) == 1 + ), "The constrained_atom_indices array has the wrong shape" self.relative_coordinates_noiser = RelativeCoordinatesNoiser() + self.atom_type_noiser = AtomTypesNoiser() def _apply_constraint(self, composition: AXL, device: torch.device) -> AXL: """This method applies the coordinate constraint on the input configuration.""" - x = composition.X - x[:, self.constraint_mask] = self.constraint_relative_coordinates.to(device) - updated_axl = AXL( - A=composition.A, - X=x, + constrained_x = composition.X.clone() + constrained_x[:, self.constraint_indices] = self.reference_composition.X[ + self.constraint_indices + ].to(device) + + constrained_a = composition.A.clone() + constrained_a[:, self.constraint_indices] = self.reference_composition.A[ + self.constraint_indices + ].to(device) + + constrained_composition = AXL( + A=constrained_a, + X=constrained_x, L=composition.L, ) - return updated_axl + return constrained_composition + + def _get_noised_known_composition( + self, i: int, number_of_samples: int, device: torch.device + ) -> AXL: + """This method applies the noise to the known composition.""" + # Initialize compositions that satisfies the constraint, but is otherwise random. + # Since the noising process is 'atom-per-atom', the non-constrained position should have no impact. + composition0_known = self.generator.initialize(number_of_samples, device) + composition0_known = self._apply_constraint(composition0_known, device) + + q_bar_matrices_i = einops.repeat( + self.generator.noise.q_bar_matrix[i].to(device), + "n1 n2 -> nsamples natoms n1 n2", + nsamples=number_of_samples, + natoms=self.number_of_atoms, + ) + + sigma_i = self.generator.noise.sigma[i] + coordinates_broadcasting = torch.ones_like(composition0_known.X) + broadcast_sigmas_i = sigma_i * coordinates_broadcasting + + # Noise an example satisfying the constraints from t_0 to t_i + x_i_known = ( + self.relative_coordinates_noiser.get_noisy_relative_coordinates_sample( + composition0_known.X, broadcast_sigmas_i + ) + ) + + one_hot_a_i = class_index_to_onehot( + composition0_known.A, num_classes=self.num_classes + ) + a_i_known = self.atom_type_noiser.get_noisy_atom_types_sample( + one_hot_a_i, q_bar_matrices_i + ) + + noised_composition = AXL(A=a_i_known, X=x_i_known, L=composition0_known.L) + return noised_composition + + def _combine_noised_and_denoised_compositions( + self, noised_composition: AXL, denoised_composition: AXL + ) -> AXL: + + updated_x = denoised_composition.X.clone() + updated_a = denoised_composition.A.clone() + + updated_x[:, self.constraint_indices] = noised_composition.X[ + :, self.constraint_indices + ] + updated_a[:, self.constraint_indices] = noised_composition.A[ + :, self.constraint_indices + ] + + composition_i = AXL(A=updated_a, X=updated_x, L=denoised_composition.L) + return composition_i def sample( self, number_of_samples: int, device: torch.device, unit_cell: torch.Tensor ) -> AXL: """Sample. - This method draws samples, imposing the satisfaction of positional constraints. + This method draws samples, imposing the satisfaction of positional constraints. Args: number_of_samples : number of samples to draw. @@ -98,48 +145,36 @@ def sample( """ assert unit_cell.size() == ( number_of_samples, - self.spatial_dimension, - self.spatial_dimension, + self.generator.spatial_dimension, + self.generator.spatial_dimension, ), ( "Unit cell passed to sample should be of size (number of sample, spatial dimension, spatial dimension" + f"Got {unit_cell.size()}" ) - # Initialize a configuration that satisfy the constraint, but is otherwise random. - # Since the noising process is 'atom-per-atom', the non-constrained position should have no impact. - composition0_known = self.initialize(number_of_samples, device) - # this is an AXL objet - - composition0_known = self._apply_constraint(composition0_known, device) - - composition_ip1 = self.initialize(number_of_samples, device) + composition_ip1 = self.generator.initialize(number_of_samples, device) forces = torch.zeros_like(composition_ip1.X) - coordinates_broadcasting = torch.ones( - number_of_samples, self.number_of_atoms, self.spatial_dimension - ).to(device) - - for i in tqdm(range(self.number_of_discretization_steps - 1, -1, -1)): - sigma_i = self.noise.sigma[i] - broadcast_sigmas_i = sigma_i * coordinates_broadcasting - # Noise an example satisfying the constraints from t_0 to t_i - x_i_known = ( - self.relative_coordinates_noiser.get_noisy_relative_coordinates_sample( - composition0_known.X, broadcast_sigmas_i - ) + for i in tqdm(range(self.generator.number_of_discretization_steps - 1, -1, -1)): + + # Noise from t_0 to t_i + noised_composition_i = self._get_noised_known_composition( + i, number_of_samples, device ) + # Denoise from t_{i+1} to t_i - composition_i = self.predictor_step( + denoised_composition_i = self.generator.predictor_step( composition_ip1, i + 1, unit_cell, forces ) - # Combine the known and unknown - x_i = composition_i.X - x_i[:, self.constraint_mask] = x_i_known[:, self.constraint_mask] - composition_i = AXL(A=composition_i.A, X=x_i, L=composition_i.L) + composition_i = self._combine_noised_and_denoised_compositions( + noised_composition_i, denoised_composition_i + ) - for _ in range(self.number_of_corrector_steps): - composition_i = self.corrector_step(composition_i, i, unit_cell, forces) + for _ in range(self.generator.number_of_corrector_steps): + composition_i = self.generator.corrector_step( + composition_i, i, unit_cell, forces + ) composition_ip1 = composition_i diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py b/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py index 26260536..713d5133 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py @@ -16,8 +16,12 @@ ElementTypes from diffusion_for_multi_scale_molecular_dynamics.generators.axl_generator import \ SamplingParameters +from diffusion_for_multi_scale_molecular_dynamics.generators.constrained_langevin_generator import \ + ConstrainedPredictorCorrectorAXLGenerator from diffusion_for_multi_scale_molecular_dynamics.generators.instantiate_generator import \ instantiate_generator +from diffusion_for_multi_scale_molecular_dynamics.generators.langevin_generator import \ + LangevinGenerator from diffusion_for_multi_scale_molecular_dynamics.generators.load_sampling_parameters import \ load_sampling_parameters from diffusion_for_multi_scale_molecular_dynamics.models.axl_diffusion_lightning_model import \ @@ -51,6 +55,11 @@ def main(args: Optional[Any] = None): required=True, help="config file with sampling parameters in yaml format.", ) + parser.add_argument( + "--path_to_constraint_data_pickle", required=False, + help="path to a pickle that contains a reference compositions and fixed atom indices." + ) + parser.add_argument( "--checkpoint", required=True, help="path to checkpoint model to be loaded." ) @@ -101,12 +110,37 @@ def main(args: Optional[Any] = None): elements = hyper_params["elements"] oracle_parameters = create_energy_oracle_parameters(hyper_params["oracle"], elements) - create_samples_and_write_to_disk( + axl_network = get_axl_network(args.checkpoint) + + logger.info("Instantiate generator...") + raw_generator = instantiate_generator( + sampling_parameters=sampling_parameters, noise_parameters=noise_parameters, + axl_network=axl_network, + ) + + if args.path_to_constraint_data_pickle: + logger.info("Constrained Sampling is activated") + constraint_data_pickle_path = Path(args.path_to_constraint_data_pickle) + assert constraint_data_pickle_path.is_file(), "The constraint data pickle does not exist." + + constraint_data = torch.load(constraint_data_pickle_path) + constrained_atom_indices = constraint_data["constrained_atom_indices"] + logger.info(f"Constrained atom indices are {constrained_atom_indices}") + + reference_composition = constraint_data["reference_composition"] + + generator = ConstrainedPredictorCorrectorAXLGenerator(raw_generator, + reference_composition, + constrained_atom_indices) + else: + generator = raw_generator + + create_samples_and_write_to_disk( + generator=generator, sampling_parameters=sampling_parameters, oracle_parameters=oracle_parameters, device=device, - checkpoint_path=args.checkpoint, output_path=args.output, ) @@ -152,11 +186,10 @@ def get_axl_network(checkpoint_path: Union[str, Path]) -> ScoreNetwork: def create_samples_and_write_to_disk( - noise_parameters: NoiseParameters, + generator: LangevinGenerator, sampling_parameters: SamplingParameters, oracle_parameters: Union[OracleParameters, None], device: torch.device, - checkpoint_path: Union[str, Path], output_path: Union[str, Path], ): """Create Samples and write to disk. @@ -173,19 +206,10 @@ def create_samples_and_write_to_disk( Returns: None """ - axl_network = get_axl_network(checkpoint_path) - - logger.info("Instantiate generator...") - position_generator = instantiate_generator( - sampling_parameters=sampling_parameters, - noise_parameters=noise_parameters, - axl_network=axl_network, - ) - logger.info("Generating samples...") with torch.no_grad(): samples_batch = create_batch_of_samples( - generator=position_generator, + generator=generator, sampling_parameters=sampling_parameters, device=device, ) @@ -207,10 +231,12 @@ def create_samples_and_write_to_disk( if sampling_parameters.record_samples: logger.info("Writing sampling trajectories to disk...") - position_generator.sample_trajectory_recorder.write_to_pickle( + generator.sample_trajectory_recorder.write_to_pickle( output_directory / "trajectories.pt" ) + logger.info("Done!") + if __name__ == "__main__": main() diff --git a/tests/generators/test_constrained_langevin_generator.py b/tests/generators/test_constrained_langevin_generator.py index 59f2bb6d..0c898e5d 100644 --- a/tests/generators/test_constrained_langevin_generator.py +++ b/tests/generators/test_constrained_langevin_generator.py @@ -2,92 +2,179 @@ import pytest import torch -from diffusion_for_multi_scale_molecular_dynamics.generators.constrained_langevin_generator import ( - ConstrainedLangevinGenerator, ConstrainedLangevinGeneratorParameters) +from diffusion_for_multi_scale_molecular_dynamics.generators.constrained_langevin_generator import \ + ConstrainedPredictorCorrectorAXLGenerator from diffusion_for_multi_scale_molecular_dynamics.namespace import AXL +from diffusion_for_multi_scale_molecular_dynamics.utils.basis_transformations import \ + map_relative_coordinates_to_unit_cell from tests.generators.test_langevin_generator import TestLangevinGenerator class TestConstrainedLangevinGenerator(TestLangevinGenerator): @pytest.fixture() - def constrained_relative_coordinates(self, number_of_atoms, spatial_dimension): - number_of_constraints = number_of_atoms // 2 - return torch.rand(number_of_constraints, spatial_dimension).numpy() - - @pytest.fixture() - def sampling_parameters( + def reference_composition( self, number_of_atoms, spatial_dimension, - number_of_samples, - cell_dimensions, - number_of_corrector_steps, - unit_cell_size, - constrained_relative_coordinates, - num_atom_types, + num_atomic_classes, + device, ): - sampling_parameters = ConstrainedLangevinGeneratorParameters( - number_of_corrector_steps=number_of_corrector_steps, - number_of_atoms=number_of_atoms, - number_of_samples=number_of_samples, - cell_dimensions=cell_dimensions, - spatial_dimension=spatial_dimension, - constrained_relative_coordinates=constrained_relative_coordinates, - num_atom_types=num_atom_types, - ) - - return sampling_parameters - - @pytest.fixture() - def pc_generator(self, noise_parameters, sampling_parameters, axl_network): - generator = ConstrainedLangevinGenerator( - noise_parameters=noise_parameters, - sampling_parameters=sampling_parameters, - axl_network=axl_network, + return AXL( + A=torch.randint(0, num_atomic_classes, (number_of_atoms,)).to(device), + X=map_relative_coordinates_to_unit_cell( + torch.rand(number_of_atoms, spatial_dimension) + ).to(device), + L=torch.zeros(spatial_dimension * (spatial_dimension - 1)).to( + device + ), # TODO placeholder ) - return generator - @pytest.fixture() - def axl( + def random_compositions( self, number_of_samples, number_of_atoms, spatial_dimension, - num_atom_types, + num_atomic_classes, device, ): return AXL( A=torch.randint( - 0, num_atom_types + 1, (number_of_samples, number_of_atoms) + 0, + num_atomic_classes, + ( + number_of_samples, + number_of_atoms, + ), ).to(device), - X=torch.rand(number_of_samples, number_of_atoms, spatial_dimension).to( - device - ), - L=torch.rand( - number_of_samples, spatial_dimension * (spatial_dimension - 1) - ).to( + X=map_relative_coordinates_to_unit_cell( + torch.rand(number_of_samples, number_of_atoms, spatial_dimension) + ).to(device), + L=torch.zeros(spatial_dimension * (spatial_dimension - 1)).to( device ), # TODO placeholder ) + @pytest.fixture() + def constrained_atom_indices(self, number_of_atoms, device): + number_of_constraints = number_of_atoms // 2 + return torch.randperm(number_of_atoms)[:number_of_constraints].to(device) + + @pytest.fixture() + def constrained_pc_generator( + self, pc_generator, reference_composition, constrained_atom_indices + ): + constrained_generator = ConstrainedPredictorCorrectorAXLGenerator( + generator=pc_generator, + reference_composition=reference_composition, + constrained_atom_indices=constrained_atom_indices, + ) + + return constrained_generator + + @pytest.fixture() + def constrained_samples( + self, constrained_pc_generator, number_of_samples, device, unit_cell_sample + ): + samples = constrained_pc_generator.sample( + number_of_samples, device, unit_cell_sample + ) + return samples + + def test_constraints( + self, + constrained_samples, + reference_composition, + constrained_atom_indices, + number_of_samples, + ): + reference_x = einops.repeat( + reference_composition.X[constrained_atom_indices], + "... -> n ...", + n=number_of_samples, + ) + reference_a = einops.repeat( + reference_composition.A[constrained_atom_indices], + "... -> n ...", + n=number_of_samples, + ) + + torch.testing.assert_close( + constrained_samples.X[:, constrained_atom_indices], reference_x + ) + torch.testing.assert_close( + constrained_samples.A[:, constrained_atom_indices], reference_a + ) + def test_apply_constraint( - self, pc_generator, axl, constrained_relative_coordinates, device + self, + constrained_pc_generator, + number_of_samples, + random_compositions, + reference_composition, + constrained_atom_indices, + device, ): - batch_size = axl.X.shape[0] - original_x = torch.clone(axl.X) - pc_generator._apply_constraint(axl, device) - number_of_constraints = len(constrained_relative_coordinates) + constrained_compositions = constrained_pc_generator._apply_constraint( + random_compositions, device + ) - constrained_x = einops.repeat( - torch.from_numpy(constrained_relative_coordinates).to(device), - "n d -> b n d", - b=batch_size, + reference_x = einops.repeat( + reference_composition.X[constrained_atom_indices], + "... -> n ...", + n=number_of_samples, + ) + reference_a = einops.repeat( + reference_composition.A[constrained_atom_indices], + "... -> n ...", + n=number_of_samples, ) - torch.testing.assert_close(axl.X[:, :number_of_constraints], constrained_x) torch.testing.assert_close( - axl.X[:, number_of_constraints:], original_x[:, number_of_constraints:] + constrained_compositions.X[:, constrained_atom_indices], reference_x + ) + torch.testing.assert_close( + constrained_compositions.A[:, constrained_atom_indices], reference_a + ) + + def test_combine_noised_and_denoised_compositions( + self, + constrained_pc_generator, + constrained_atom_indices, + number_of_samples, + number_of_atoms, + spatial_dimension, + device, + ) -> AXL: + + noised_mask = torch.zeros(number_of_atoms, dtype=torch.bool).to(device) + noised_mask[constrained_atom_indices] = True + + noised_compositions = AXL( + A=torch.zeros(number_of_samples, number_of_atoms).to(device), + X=torch.zeros(number_of_samples, number_of_atoms, spatial_dimension).to( + device + ), + L=0.0, + ) + + denoised_compositions = AXL( + A=torch.ones(number_of_samples, number_of_atoms).to(device), + X=torch.ones(number_of_samples, number_of_atoms, spatial_dimension).to( + device + ), + L=0.0, + ) + + combined_compositions = ( + constrained_pc_generator._combine_noised_and_denoised_compositions( + noised_compositions, denoised_compositions + ) ) + + assert (combined_compositions.X[:, noised_mask] == 0.0).all() + assert (combined_compositions.X[:, ~noised_mask] == 1.0).all() + assert (combined_compositions.A[:, noised_mask] == 0.0).all() + assert (combined_compositions.A[:, ~noised_mask] == 1.0).all() diff --git a/tests/test_sample_diffusion.py b/tests/test_sample_diffusion.py index 5f3560db..fe6fc3be 100644 --- a/tests/test_sample_diffusion.py +++ b/tests/test_sample_diffusion.py @@ -15,8 +15,8 @@ OptimizerParameters from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.mlp_score_network import \ MLPScoreNetworkParameters -from diffusion_for_multi_scale_molecular_dynamics.namespace import \ - AXL_COMPOSITION +from diffusion_for_multi_scale_molecular_dynamics.namespace import ( + AXL, AXL_COMPOSITION) from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ NoiseParameters @@ -46,6 +46,19 @@ def cell_dimensions(): return [5.1, 6.2, 7.3] +@pytest.fixture() +def reference_composition(num_atom_types, number_of_atoms, spatial_dimension): + a = torch.randint(0, num_atom_types, (number_of_atoms,)) + x = torch.rand(number_of_atoms, spatial_dimension) + lat = torch.rand(spatial_dimension, spatial_dimension) + return AXL(A=a, X=x, L=lat) + + +@pytest.fixture() +def constrained_atom_indices(number_of_atoms): + return torch.sort(torch.randperm(number_of_atoms)[:number_of_atoms // 2]).values + + @pytest.fixture(params=[True, False]) def record_samples(request): return request.param @@ -77,7 +90,7 @@ def sampling_parameters( @pytest.fixture() -def axl_network(number_of_atoms, noise_parameters, num_atom_types): +def axl_network(number_of_atoms, noise_parameters, num_atom_types, device): score_network_parameters = MLPScoreNetworkParameters( number_of_atoms=number_of_atoms, num_atom_types=num_atom_types, @@ -97,7 +110,7 @@ def axl_network(number_of_atoms, noise_parameters, num_atom_types): diffusion_sampling_parameters=None, ) - model = AXLDiffusionLightningModel(diffusion_params) + model = AXLDiffusionLightningModel(diffusion_params).to(device) return model.axl_network @@ -116,6 +129,20 @@ def config_path(tmp_path, noise_parameters, sampling_parameters): return config_path +@pytest.fixture(params=[True, False]) +def apply_constraint(request): + return request.param + + +@pytest.fixture() +def constraint_data_pickle_path(tmp_path, reference_composition, constrained_atom_indices): + path_to_pickle = tmp_path / "pickle_path.pkl" + data = dict(reference_composition=reference_composition, + constrained_atom_indices=constrained_atom_indices) + torch.save(data, path_to_pickle) + return path_to_pickle + + @pytest.fixture() def checkpoint_path(tmp_path): path_to_checkpoint = tmp_path / "fake_checkpoint.pt" @@ -131,20 +158,24 @@ def output_path(tmp_path): @pytest.fixture() -def args(config_path, checkpoint_path, output_path): +def args(config_path, checkpoint_path, output_path, constraint_data_pickle_path, apply_constraint, device): """Input arguments for main.""" input_args = [ f"--config={config_path}", f"--checkpoint={checkpoint_path}", f"--output={output_path}", - "--device=cpu", + f"--device={device}", ] + if apply_constraint: + input_args.append(f"--path_to_constraint_data_pickle={constraint_data_pickle_path}") + return input_args def test_sample_diffusion( mocker, + device, args, axl_network, output_path, @@ -152,6 +183,9 @@ def test_sample_diffusion( number_of_atoms, spatial_dimension, record_samples, + apply_constraint, + reference_composition, + constrained_atom_indices ): mocker.patch( "diffusion_for_multi_scale_molecular_dynamics.sample_diffusion.get_axl_network", @@ -162,14 +196,23 @@ def test_sample_diffusion( assert (output_path / "samples.pt").exists() samples = torch.load(output_path / "samples.pt") - assert samples[AXL_COMPOSITION].X.shape == ( + compositions = samples[AXL_COMPOSITION] + + assert compositions.X.shape == ( number_of_samples, number_of_atoms, spatial_dimension, ) - assert samples[AXL_COMPOSITION].A.shape == ( + assert compositions.A.shape == ( number_of_samples, number_of_atoms, ) assert (output_path / "trajectories.pt").exists() == record_samples + + if apply_constraint: + reference_x = reference_composition.X[constrained_atom_indices].to(device) + reference_a = reference_composition.A[constrained_atom_indices].to(device) + + assert (compositions.X[:, constrained_atom_indices] == reference_x).all() + assert (compositions.A[:, constrained_atom_indices] == reference_a).all()