From 5d02b46221e4fcf83bc00abc2fe6a1979d2283b8 Mon Sep 17 00:00:00 2001 From: Lester Hedges Date: Thu, 4 Jul 2024 14:10:52 +0100 Subject: [PATCH] Pass cell and pbc to AEVCalculator as args not kwargs. [closes #648] --- torchani/models.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchani/models.py b/torchani/models.py index 117cb4a9e..522e6d0d0 100644 --- a/torchani/models.py +++ b/torchani/models.py @@ -103,7 +103,7 @@ def forward(self, species_coordinates: Tuple[Tensor, Tensor], if species_coordinates[0].ge(self.aev_computer.num_species).any(): raise ValueError(f'Unknown species found in {species_coordinates[0]}') - species_aevs = self.aev_computer(species_coordinates, cell=cell, pbc=pbc) + species_aevs = self.aev_computer(species_coordinates, cell, pbc) species_energies = self.neural_networks(species_aevs) return self.energy_shifter(species_energies) @@ -135,7 +135,7 @@ def atomic_energies(self, species_coordinates: Tuple[Tensor, Tensor], """ if self.periodic_table_index: species_coordinates = self.species_converter(species_coordinates) - species, aevs = self.aev_computer(species_coordinates, cell=cell, pbc=pbc) + species, aevs = self.aev_computer(species_coordinates, cell, pbc) atomic_energies = self.neural_networks._atomic_energies((species, aevs)) self_energies = self.energy_shifter.self_energies.clone().to(species.device) self_energies = self_energies[species] @@ -236,7 +236,7 @@ def atomic_energies(self, species_coordinates: Tuple[Tensor, Tensor], """ if self.periodic_table_index: species_coordinates = self.species_converter(species_coordinates) - species, aevs = self.aev_computer(species_coordinates, cell=cell, pbc=pbc) + species, aevs = self.aev_computer(species_coordinates, cell, pbc) members_list = [] for nnp in self.neural_networks: members_list.append(nnp._atomic_energies((species, aevs)).unsqueeze(0)) @@ -322,7 +322,7 @@ def members_energies(self, species_coordinates: Tuple[Tensor, Tensor], """ if self.periodic_table_index: species_coordinates = self.species_converter(species_coordinates) - species, aevs = self.aev_computer(species_coordinates, cell=cell, pbc=pbc) + species, aevs = self.aev_computer(species_coordinates, cell, pbc) member_outputs = [] for nnp in self.neural_networks: unshifted_energies = nnp((species, aevs)).energies