Skip to content

Commit

Permalink
fixes to reflect updates in langevin generator
Browse files Browse the repository at this point in the history
  • Loading branch information
Simon Blackburn authored and Simon Blackburn committed Dec 2, 2024
1 parent accde32 commit 4e30870
Showing 1 changed file with 21 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,31 @@ def predictor_step(
composition_i, t_i, sigma_i, unit_cell, cartesian_forces
)

# Even if the global flag 'one_atom_type_transition_per_step' is set to True, a single atomic transition
# cannot be used at the last time step because it is necessary for all atoms to be unmasked at the end
# of the trajectory. Here, we use 'first' and 'last' with respect to a denoising trajectory, where
# the "first" time step is at index_i = T and the "last" time step is index_i = 1.
this_is_last_time_step = idx == 0
one_atom_type_transition_per_step = (
self.one_atom_type_transition_per_step and not this_is_last_time_step
)

# atom types update
a_im1 = self.atom_types_update(
a_im1 = self._atom_types_update(
model_predictions_i.A,
composition_i.A,
q_matrices_i,
q_bar_matrices_i,
q_bar_tm1_matrices_i,
atom_type_greedy_sampling=self.atom_type_greedy_sampling,
one_atom_type_transition_per_step=one_atom_type_transition_per_step,
)

# in this approach, there is no predictor step applied on the X component
if this_is_last_time_step:
assert (a_im1 != self.masked_atom_type_index).all(), \
"There remains MASKED atoms at the last time step: review code, there must be a bug or invalid input."

# in the adaptative corrector approach, there is no predictor step applied on the X component
composition_im1 = AXL(
A=a_im1, X=composition_i.X, L=unit_cell
) # TODO : Deal with L correctly
Expand Down Expand Up @@ -184,7 +199,7 @@ def corrector_step(
)
sqrt_2eps_i = torch.sqrt(2 * eps_i)

corrected_x_i = self.relative_coordinates_update(
corrected_x_i = self._relative_coordinates_update(
composition_i.X, model_predictions_i.X, sigma_i, eps_i, sqrt_2eps_i, z=z
)

Expand All @@ -193,12 +208,14 @@ def corrector_step(
q_bar_matrices_i = self.noise.q_bar_matrix[idx].to(composition_i.X)
q_bar_tm1_matrices_i = self.noise.q_bar_tm1_matrix[idx].to(composition_i.X)
# atom types update
corrected_a_i = self.atom_types_update(
corrected_a_i = self._atom_types_update(
model_predictions_i.A,
composition_i.A,
q_matrices_i,
q_bar_matrices_i,
q_bar_tm1_matrices_i,
atom_type_greedy_sampling=self.atom_type_greedy_sampling,
one_atom_type_transition_per_step=self.one_atom_type_transition_per_step,
)
else:
corrected_a_i = composition_i.A
Expand Down

0 comments on commit 4e30870

Please sign in to comment.