Skip to content

Commit

Permalink
fix atom e0s being subtracted twice in mace calculator
Browse files Browse the repository at this point in the history
  • Loading branch information
RokasEl committed May 15, 2024
1 parent 81f4f8c commit 279e562
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 31 deletions.
52 changes: 23 additions & 29 deletions mace/calculators/mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,25 @@ def _create_result_tensors(
dipole = torch.zeros(num_models, 3, device=self.device)
dict_of_tensors.update({"dipole": dipole})
return dict_of_tensors


def _atoms_to_batch(self, atoms):
config = data.config_from_atoms(atoms, charges_key=self.charges_key)
data_loader = torch_geometric.dataloader.DataLoader(
dataset=[
data.AtomicData.from_config(
config, z_table=self.z_table, cutoff=self.r_max
)
],
batch_size=1,
shuffle=False,
drop_last=False,
)
batch = next(iter(data_loader)).to(self.device)
return batch

def _prepare_batch(self, batch):

def _clone_batch(self, batch):
batch_clone = batch.clone()
if self.use_compile:
batch_clone["node_attrs"].requires_grad_(True)
Expand All @@ -211,32 +228,20 @@ def calculate(self, atoms=None, properties=None, system_changes=all_changes):
# call to base-class to set atoms attribute
Calculator.calculate(self, atoms)

# prepare data
config = data.config_from_atoms(atoms, charges_key=self.charges_key)
data_loader = torch_geometric.dataloader.DataLoader(
dataset=[
data.AtomicData.from_config(
config, z_table=self.z_table, cutoff=self.r_max
)
],
batch_size=1,
shuffle=False,
drop_last=False,
)
batch_base = self._atoms_to_batch(atoms)

if self.model_type in ["MACE", "EnergyDipoleMACE"]:
batch = next(iter(data_loader)).to(self.device)
batch = self._clone_batch(batch_base)
node_e0 = self.models[0].atomic_energies_fn(batch["node_attrs"])
compute_stress = not self.use_compile
else:
compute_stress = False

batch_base = next(iter(data_loader)).to(self.device)
ret_tensors = self._create_result_tensors(
self.model_type, self.num_models, len(atoms)
)
for i, model in enumerate(self.models):
batch = self._prepare_batch(batch_base)
batch = self._clone_batch(batch_base)
out = model(
batch.to_dict(),
compute_stress=compute_stress,
Expand All @@ -259,7 +264,7 @@ def calculate(self, atoms=None, properties=None, system_changes=all_changes):
)
self.results["free_energy"] = self.results["energy"]
self.results["node_energy"] = (
torch.mean(ret_tensors["node_energy"] - node_e0, dim=0).cpu().numpy()
torch.mean(ret_tensors["node_energy"], dim=0).cpu().numpy()
)
self.results["forces"] = (
torch.mean(ret_tensors["forces"], dim=0).cpu().numpy()
Expand Down Expand Up @@ -321,18 +326,7 @@ def get_descriptors(self, atoms=None, invariants_only=True, num_layers=-1):
raise NotImplementedError("Only implemented for MACE models")
if num_layers == -1:
num_layers = int(self.models[0].num_interactions)
config = data.config_from_atoms(atoms, charges_key=self.charges_key)
data_loader = torch_geometric.dataloader.DataLoader(
dataset=[
data.AtomicData.from_config(
config, z_table=self.z_table, cutoff=self.r_max
)
],
batch_size=1,
shuffle=False,
drop_last=False,
)
batch = next(iter(data_loader)).to(self.device)
batch = self._atoms_to_batch(atoms)
descriptors = [model(batch.to_dict())["node_feats"] for model in self.models]
if invariants_only:
irreps_out = self.models[0].products[0].linear.__dict__["irreps_out"]
Expand Down
13 changes: 11 additions & 2 deletions tests/test_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ def fitting_configs_fixture():
Atoms(numbers=[8], positions=[[0, 0, 0]], cell=[6] * 3),
Atoms(numbers=[1], positions=[[0, 0, 0]], cell=[6] * 3),
]
fit_configs[0].info["REF_energy"] = 0.0
fit_configs[0].info["REF_energy"] = 1.0
fit_configs[0].info["config_type"] = "IsolatedAtom"
fit_configs[1].info["REF_energy"] = 0.0
fit_configs[1].info["REF_energy"] = -0.5
fit_configs[1].info["config_type"] = "IsolatedAtom"

np.random.seed(5)
Expand Down Expand Up @@ -370,6 +370,15 @@ def trained_committee_fixture(tmp_path_factory, fitting_configs):

return MACECalculator(_model_paths, device="cpu")

def test_calculator_node_energy(fitting_configs, trained_model):
for at in fitting_configs:
trained_model.calculate(at)
node_energies = trained_model.results["node_energy"]
batch = trained_model._atoms_to_batch(at)
node_e0 = trained_model.models[0].atomic_energies_fn(batch["node_attrs"]).detach().numpy()
energy_via_nodes = np.sum(node_energies+node_e0)
energy = trained_model.results["energy"]
np.testing.assert_allclose(energy, energy_via_nodes, atol=1e-6)

def test_calculator_forces(fitting_configs, trained_model):
at = fitting_configs[2].copy()
Expand Down

0 comments on commit 279e562

Please sign in to comment.