diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/generators/adaptative_corrector.py b/src/diffusion_for_multi_scale_molecular_dynamics/generators/adaptative_corrector.py index 14414082..2d46dcc3 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/generators/adaptative_corrector.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/generators/adaptative_corrector.py @@ -66,16 +66,31 @@ def predictor_step( composition_i, t_i, sigma_i, unit_cell, cartesian_forces ) + # Even if the global flag 'one_atom_type_transition_per_step' is set to True, a single atomic transition + # cannot be used at the last time step because it is necessary for all atoms to be unmasked at the end + # of the trajectory. Here, we use 'first' and 'last' with respect to a denoising trajectory, where + # the "first" time step is at index_i = T and the "last" time step is index_i = 1. + this_is_last_time_step = idx == 0 + one_atom_type_transition_per_step = ( + self.one_atom_type_transition_per_step and not this_is_last_time_step + ) + # atom types update - a_im1 = self.atom_types_update( + a_im1 = self._atom_types_update( model_predictions_i.A, composition_i.A, q_matrices_i, q_bar_matrices_i, q_bar_tm1_matrices_i, + atom_type_greedy_sampling=self.atom_type_greedy_sampling, + one_atom_type_transition_per_step=one_atom_type_transition_per_step, ) - # in this approach, there is no predictor step applied on the X component + if this_is_last_time_step: + assert (a_im1 != self.masked_atom_type_index).all(), \ + "There remains MASKED atoms at the last time step: review code, there must be a bug or invalid input." + + # in the adaptative corrector approach, there is no predictor step applied on the X component composition_im1 = AXL( A=a_im1, X=composition_i.X, L=unit_cell ) # TODO : Deal with L correctly @@ -184,7 +199,7 @@ def corrector_step( ) sqrt_2eps_i = torch.sqrt(2 * eps_i) - corrected_x_i = self.relative_coordinates_update( + corrected_x_i = self._relative_coordinates_update( composition_i.X, model_predictions_i.X, sigma_i, eps_i, sqrt_2eps_i, z=z ) @@ -193,12 +208,14 @@ def corrector_step( q_bar_matrices_i = self.noise.q_bar_matrix[idx].to(composition_i.X) q_bar_tm1_matrices_i = self.noise.q_bar_tm1_matrix[idx].to(composition_i.X) # atom types update - corrected_a_i = self.atom_types_update( + corrected_a_i = self._atom_types_update( model_predictions_i.A, composition_i.A, q_matrices_i, q_bar_matrices_i, q_bar_tm1_matrices_i, + atom_type_greedy_sampling=self.atom_type_greedy_sampling, + one_atom_type_transition_per_step=self.one_atom_type_transition_per_step, ) else: corrected_a_i = composition_i.A