Skip to content

Commit

Permalink
Pass cell and pbc to AEVCalculator as args not kwargs. [closes aiqm#648]
Browse files Browse the repository at this point in the history
  • Loading branch information
lohedges committed Jul 4, 2024
1 parent 17204c6 commit 5d02b46
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions torchani/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def forward(self, species_coordinates: Tuple[Tensor, Tensor],
if species_coordinates[0].ge(self.aev_computer.num_species).any():
raise ValueError(f'Unknown species found in {species_coordinates[0]}')

species_aevs = self.aev_computer(species_coordinates, cell=cell, pbc=pbc)
species_aevs = self.aev_computer(species_coordinates, cell, pbc)
species_energies = self.neural_networks(species_aevs)
return self.energy_shifter(species_energies)

Expand Down Expand Up @@ -135,7 +135,7 @@ def atomic_energies(self, species_coordinates: Tuple[Tensor, Tensor],
"""
if self.periodic_table_index:
species_coordinates = self.species_converter(species_coordinates)
species, aevs = self.aev_computer(species_coordinates, cell=cell, pbc=pbc)
species, aevs = self.aev_computer(species_coordinates, cell, pbc)
atomic_energies = self.neural_networks._atomic_energies((species, aevs))
self_energies = self.energy_shifter.self_energies.clone().to(species.device)
self_energies = self_energies[species]
Expand Down Expand Up @@ -236,7 +236,7 @@ def atomic_energies(self, species_coordinates: Tuple[Tensor, Tensor],
"""
if self.periodic_table_index:
species_coordinates = self.species_converter(species_coordinates)
species, aevs = self.aev_computer(species_coordinates, cell=cell, pbc=pbc)
species, aevs = self.aev_computer(species_coordinates, cell, pbc)
members_list = []
for nnp in self.neural_networks:
members_list.append(nnp._atomic_energies((species, aevs)).unsqueeze(0))
Expand Down Expand Up @@ -322,7 +322,7 @@ def members_energies(self, species_coordinates: Tuple[Tensor, Tensor],
"""
if self.periodic_table_index:
species_coordinates = self.species_converter(species_coordinates)
species, aevs = self.aev_computer(species_coordinates, cell=cell, pbc=pbc)
species, aevs = self.aev_computer(species_coordinates, cell, pbc)
member_outputs = []
for nnp in self.neural_networks:
unshifted_energies = nnp((species, aevs)).energies
Expand Down

0 comments on commit 5d02b46

Please sign in to comment.