Skip to content

Commit

Permalink
Update PET driver to work with current code (i-pi#334)
Browse files Browse the repository at this point in the history
* Update PET driver to work with current code

This makes the PET driver compatible with the
current iteration of the PET code. It also
adds some small amount of future-proofing by
allowing to pass through arbitrary keyword arguments.

* Update docstring to be more accurate
  • Loading branch information
sirmarcel authored May 15, 2024
1 parent eeaba4b commit 231d157
Showing 1 changed file with 17 additions and 9 deletions.
26 changes: 17 additions & 9 deletions drivers/py/pes/pet.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@
class PET_driver(Dummy_driver):
def __init__(self, args=None, verbose=False):
self.error_msg = """
The PET driver requires specification of a .json model file fitted with
the PET tools, and a template file that describes the chemical makeup of
the structure.
The PET driver requires (a) a path to the results/experiment_name folder emitted by pet_train
(b) a path to an ase.io.read-able file with a prototype structure
Example: python driver.py -m pet -u -o model.json,template.xyz
Other arguments to the pet.SingleStructCalculator class can be optionally
supplied in key=value form after the required arguments.
Example: python driver.py -m pet -u -o "path/to/results/name,template.xyz,device=cuda"
"""

super().__init__(args, verbose)
Expand All @@ -43,19 +45,25 @@ def check_arguments(self):
This loads the potential and atoms template in PET
"""
arglist = self.args
args = self.args

if len(args) >= 2:
self.model_path = args[0]
self.template = args[1]
kwargs = {}
if len(args) > 2:
for arg in args[2:]:
key, value = arg.split("=")
kwargs[key] = value

if len(arglist) == 2:
self.model_path = arglist[0]
self.template = arglist[1]
else:
sys.exit(self.error_msg)

self.template_ase = read(self.template)
self.template_ase.arrays["forces"] = np.zeros_like(self.template_ase.positions)
self.pet_calc = PETCalc(
self.model_path,
default_hypers_path=self.model_path + "/default_hypers.yaml",
**kwargs,
)

def __call__(self, cell, pos):
Expand Down

0 comments on commit 231d157

Please sign in to comment.