Skip to content

Commit

Permalink
Fix MLP error.
Browse files Browse the repository at this point in the history
  • Loading branch information
rousseab committed Jan 8, 2025
1 parent b7677ee commit cc5a2be
Showing 1 changed file with 8 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -150,13 +150,19 @@ def _forward_unchecked(
computed_scores : the scores computed by the model in an AXL namedtuple.
"""
if self.use_permutation_invariance:
# An equivariant vectorial score network takes the form
#
# s_{sym}(x) = 1/|G| \sum_{g \in G} g^{-1}.s(g.x)
#
# The atom type predictions need only be invariant since they are scalars.
#
list_model_outputs = []
for permutation, inverse_permutation in zip(self.perm_indices, self.inverse_perm_indices):
permuted_batch = self.get_permuted_batch(batch, permutation)
model_output = self._forward_unchecked_single_permutation(permuted_batch, conditional)
permuted_model_output = AXL(A=model_output.A[:, inverse_permutation],
permuted_model_output = AXL(A=model_output.A,
X=model_output.X[:, inverse_permutation],
L=model_output.X[:, inverse_permutation])
L=model_output.L)

list_model_outputs.append(permuted_model_output)

Expand Down

0 comments on commit cc5a2be

Please sign in to comment.