From 279e5628a14e5b1f7be3bcf3888081112462e749 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rokas=20Elijo=C5=A1ius?= Date: Wed, 15 May 2024 17:51:05 +0100 Subject: [PATCH] fix atom e0s being subtracted twice in mace calculator --- mace/calculators/mace.py | 52 ++++++++++++++++++---------------------- tests/test_calculator.py | 13 ++++++++-- 2 files changed, 34 insertions(+), 31 deletions(-) diff --git a/mace/calculators/mace.py b/mace/calculators/mace.py index c7a543a1..76fecb96 100644 --- a/mace/calculators/mace.py +++ b/mace/calculators/mace.py @@ -191,8 +191,25 @@ def _create_result_tensors( dipole = torch.zeros(num_models, 3, device=self.device) dict_of_tensors.update({"dipole": dipole}) return dict_of_tensors + + + def _atoms_to_batch(self, atoms): + config = data.config_from_atoms(atoms, charges_key=self.charges_key) + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[ + data.AtomicData.from_config( + config, z_table=self.z_table, cutoff=self.r_max + ) + ], + batch_size=1, + shuffle=False, + drop_last=False, + ) + batch = next(iter(data_loader)).to(self.device) + return batch - def _prepare_batch(self, batch): + + def _clone_batch(self, batch): batch_clone = batch.clone() if self.use_compile: batch_clone["node_attrs"].requires_grad_(True) @@ -211,32 +228,20 @@ def calculate(self, atoms=None, properties=None, system_changes=all_changes): # call to base-class to set atoms attribute Calculator.calculate(self, atoms) - # prepare data - config = data.config_from_atoms(atoms, charges_key=self.charges_key) - data_loader = torch_geometric.dataloader.DataLoader( - dataset=[ - data.AtomicData.from_config( - config, z_table=self.z_table, cutoff=self.r_max - ) - ], - batch_size=1, - shuffle=False, - drop_last=False, - ) + batch_base = self._atoms_to_batch(atoms) if self.model_type in ["MACE", "EnergyDipoleMACE"]: - batch = next(iter(data_loader)).to(self.device) + batch = self._clone_batch(batch_base) node_e0 = self.models[0].atomic_energies_fn(batch["node_attrs"]) compute_stress = not self.use_compile else: compute_stress = False - batch_base = next(iter(data_loader)).to(self.device) ret_tensors = self._create_result_tensors( self.model_type, self.num_models, len(atoms) ) for i, model in enumerate(self.models): - batch = self._prepare_batch(batch_base) + batch = self._clone_batch(batch_base) out = model( batch.to_dict(), compute_stress=compute_stress, @@ -259,7 +264,7 @@ def calculate(self, atoms=None, properties=None, system_changes=all_changes): ) self.results["free_energy"] = self.results["energy"] self.results["node_energy"] = ( - torch.mean(ret_tensors["node_energy"] - node_e0, dim=0).cpu().numpy() + torch.mean(ret_tensors["node_energy"], dim=0).cpu().numpy() ) self.results["forces"] = ( torch.mean(ret_tensors["forces"], dim=0).cpu().numpy() @@ -321,18 +326,7 @@ def get_descriptors(self, atoms=None, invariants_only=True, num_layers=-1): raise NotImplementedError("Only implemented for MACE models") if num_layers == -1: num_layers = int(self.models[0].num_interactions) - config = data.config_from_atoms(atoms, charges_key=self.charges_key) - data_loader = torch_geometric.dataloader.DataLoader( - dataset=[ - data.AtomicData.from_config( - config, z_table=self.z_table, cutoff=self.r_max - ) - ], - batch_size=1, - shuffle=False, - drop_last=False, - ) - batch = next(iter(data_loader)).to(self.device) + batch = self._atoms_to_batch(atoms) descriptors = [model(batch.to_dict())["node_feats"] for model in self.models] if invariants_only: irreps_out = self.models[0].products[0].linear.__dict__["irreps_out"] diff --git a/tests/test_calculator.py b/tests/test_calculator.py index 790590df..96ee2be2 100644 --- a/tests/test_calculator.py +++ b/tests/test_calculator.py @@ -31,9 +31,9 @@ def fitting_configs_fixture(): Atoms(numbers=[8], positions=[[0, 0, 0]], cell=[6] * 3), Atoms(numbers=[1], positions=[[0, 0, 0]], cell=[6] * 3), ] - fit_configs[0].info["REF_energy"] = 0.0 + fit_configs[0].info["REF_energy"] = 1.0 fit_configs[0].info["config_type"] = "IsolatedAtom" - fit_configs[1].info["REF_energy"] = 0.0 + fit_configs[1].info["REF_energy"] = -0.5 fit_configs[1].info["config_type"] = "IsolatedAtom" np.random.seed(5) @@ -370,6 +370,15 @@ def trained_committee_fixture(tmp_path_factory, fitting_configs): return MACECalculator(_model_paths, device="cpu") +def test_calculator_node_energy(fitting_configs, trained_model): + for at in fitting_configs: + trained_model.calculate(at) + node_energies = trained_model.results["node_energy"] + batch = trained_model._atoms_to_batch(at) + node_e0 = trained_model.models[0].atomic_energies_fn(batch["node_attrs"]).detach().numpy() + energy_via_nodes = np.sum(node_energies+node_e0) + energy = trained_model.results["energy"] + np.testing.assert_allclose(energy, energy_via_nodes, atol=1e-6) def test_calculator_forces(fitting_configs, trained_model): at = fitting_configs[2].copy()