Skip to content

Commit

Permalink
remove n_atoms factor
Browse files Browse the repository at this point in the history
  • Loading branch information
JPDarby authored and bernstei committed Jun 5, 2024
1 parent 0842e7c commit a7b32da
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions mace/modules/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,10 @@ def weighted_mean_squared_stress(ref: Batch, pred: TensorDict) -> torch.Tensor:
# energy: [n_graphs, ]
configs_weight = ref.weight.view(-1, 1, 1) # [n_graphs, ]
configs_stress_weight = ref.stress_weight.view(-1, 1, 1) # [n_graphs, ]
num_atoms = (ref.ptr[1:] - ref.ptr[:-1]).view(-1, 1, 1) # [n_graphs,]
return torch.mean(
configs_weight
* configs_stress_weight
* torch.square((ref["stress"] - pred["stress"]) / num_atoms)
* torch.square(ref["stress"] - pred["stress"])
) # []


Expand Down

0 comments on commit a7b32da

Please sign in to comment.