Skip to content

Commit

Permalink
Merge branch 'main' into orca-update
Browse files Browse the repository at this point in the history
  • Loading branch information
bernstei committed Aug 29, 2024
2 parents b456cd9 + 3705632 commit f34f0d6
Show file tree
Hide file tree
Showing 6 changed files with 145 additions and 11 deletions.
22 changes: 17 additions & 5 deletions .github/workflows/pytests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ jobs:
- name: Install dependencies from pip
run: |
echo "numpy<2" >> $PIP_CONSTRAINT
python3 -m pip install wheel setuptools numpy scipy click matplotlib pyyaml spglib rdkit flake8 pytest pytest-cov requests
python3 -m pip install wheel setuptools numpy scipy click matplotlib pyyaml spglib rdkit==2024.3.3 flake8 pytest pytest-cov requests
- name: Install latest ASE from pypi
run: |
echo PIP_CONSTRAINT $PIP_CONSTRAINT
python3 -m pip install ase
python3 -m pip install ase
echo -n "ASE VERSION "
python3 -c "import ase; print(ase.__file__, ase.__version__)"
Expand Down Expand Up @@ -105,15 +105,27 @@ jobs:
run: |
echo "search for torch version"
set +o pipefail
# echo "torch versions"
# python3 -m pip install torch==
# echo "torch versions to search"
# python3 -m pip install torch== 2>&1 | fgrep 'from versions' |
# sed -e 's/.*from versions: //' -e 's/)//' -e 's/,[ ]*/\n/g' | tac
# search for available torch version with +cpu support
for torch_version_test in $( python3 -m pip install torch== 2>&1 | fgrep 'from versions' |
sed -e 's/.*from versions: //' -e 's/)//' -e 's/,[ ]*/\n/g' | tac ); do
# for torch_version_test in $( python3 -m pip install torch== 2>&1 | fgrep 'from versions' |
# sed -e 's/.*from versions: //' -e 's/)//' -e 's/,[ ]*/\n/g' | tac ); do
wget https://pypi.org/pypi/torch/json -O torch_versions
for torch_version_test in $( python3 -c "import json; print(' '.join(json.load(open('torch_versions'))['releases'].keys()))" | sed 's/ /\n/g' | tac ); do
echo "check torch_version_test $torch_version_test"
set +e
python3 -m pip install --dry-run torch==${torch_version_test}+cpu \
-f https://download.pytorch.org/whl/torch_stable.html > /dev/null 2>&1
-f https://download.pytorch.org/whl/torch_stable.html 2>&1
search_stat=$?
echo "got search_stat $search_stat"
set -e
if [ $search_stat == 0 ]; then
echo "got valid +cpu version, exiting"
torch_version=${torch_version_test}
break
fi
Expand Down
78 changes: 78 additions & 0 deletions tests/calculators/test_ase_fileio_caching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import os

import pytest

from ase.atoms import Atoms


########################
# test Vasp calculator

from tests.calculators.test_vasp import test_vasp_mark
@pytest.mark.skipif(test_vasp_mark, reason='Vasp testing env vars missing')
def test_vasp_cache_timing(tmp_path, monkeypatch):
from ase.calculators.vasp import Vasp as Vasp_ase
from wfl.calculators.vasp import Vasp as Vasp_wrap

config = Atoms('Si', positions=[[0, 0, 9]], cell=[2, 2, 2], pbc=[True, True, True])
kwargs_ase = {'encut': 200, 'pp': os.environ['PYTEST_VASP_POTCAR_DIR']}
kwargs_wrapper = {'workdir': tmp_path}
# make sure 'pp' is relative to correct dir (see wfl.calculators.vasp)
if os.environ['PYTEST_VASP_POTCAR_DIR'].startswith('/'):
monkeypatch.setenv("VASP_PP_PATH", "/.")
else:
monkeypatch.setenv("VASP_PP_PATH", ".")
cache_timing(config, Vasp_ase, kwargs_ase, Vasp_wrap, kwargs_wrapper, tmp_path, monkeypatch)

########################
# test quantum espresso calculator
from tests.calculators.test_qe import espresso_avail, qe_pseudo
@pytest.mark.skipif(not espresso_avail, reason='qe testing env vars missing')
def test_qe_cache_timing(tmp_path, monkeypatch, qe_pseudo):
from ase.calculators.espresso import Espresso as Espresso_ASE
from wfl.calculators.espresso import Espresso as Espresso_wrap

config = Atoms('Si', positions=[[0, 0, 9]], cell=[2, 2, 2], pbc=[True, True, True])

pspot = qe_pseudo
kwargs_ase = dict(
pseudopotentials=dict(Si=pspot.name),
pseudo_dir=pspot.parent,
input_data={"SYSTEM": {"ecutwfc": 40, "input_dft": "LDA",}},
kpts=(2, 3, 4),
conv_thr=0.0001,
workdir=tmp_path
)

kwargs_wrapper = {}
cache_timing(config, Espresso_ASE, kwargs_ase, Espresso_wrap, kwargs_wrapper, tmp_path, monkeypatch)


########################
# generic code used by all calculators

import time

from wfl.configset import ConfigSet, OutputSpec
from wfl.calculators import generic

def cache_timing(config, calc_ase, kwargs_ase, calc_wfl, kwargs_wrapper, rundir, monkeypatch):
(rundir / "run_calc_ase").mkdir()

calc = calc_ase(**kwargs_ase)
config.calc = calc

monkeypatch.chdir(rundir / "run_calc_ase")
t0 = time.time()
E = config.get_potential_energy()
ase_time = time.time() - t0

monkeypatch.chdir(rundir)
t0 = time.time()
_ = generic.calculate(inputs=ConfigSet(config), outputs=OutputSpec(),
calculator=calc_wfl(**kwargs_wrapper, **kwargs_ase))
wfl_time = time.time() - t0

print("ASE", ase_time, "WFL", wfl_time)

assert wfl_time < ase_time * 1.25
3 changes: 1 addition & 2 deletions tests/calculators/test_calc_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ def test_generic_autopara_defaults():
sys.stderr = sys.__stderr__
assert "num_inputs_per_python_subprocess=3" in l_stderr.getvalue()

@pytest.mark.xfail(reason="Waiting for update to work with ASE3.23")
def test_generic_DFT_autopara_defaults(tmp_path, monkeypatch):
ats = [Atoms('Al2', positions=[[0,0,0], [1,1,1]], cell=[10]*3, pbc=[True]*3) for _ in range(50)]

Expand All @@ -151,6 +150,6 @@ def test_generic_DFT_autopara_defaults(tmp_path, monkeypatch):
# try with a calculator that overrides an autopara default, namely a DFT calculator
# that sets num_inputs_per_python_subprocess=1
sys.stderr = l_stderr
at_proc = generic.calculate(ci, os, Espresso(calculator_exec="_DUMMY_EXEC_", pseudo_dir="_DUMMY_DIR_", workdir=tmp_path))
at_proc = generic.calculate(ci, os, Espresso(workdir=tmp_path))
sys.stderr = sys.__stderr__
assert "num_inputs_per_python_subprocess=1" in l_stderr.getvalue()
9 changes: 5 additions & 4 deletions tests/calculators/test_vasp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
from wfl.calculators import generic
from wfl.configset import ConfigSet, OutputSpec

pytestmark = pytest.mark.skipif('ASE_VASP_COMMAND' not in os.environ or
'ASE_VASP_COMMAND_GAMMA' not in os.environ or
'PYTEST_VASP_POTCAR_DIR' not in os.environ,
reason='missing env var ASE_VASP_COMMAND or ASE_VASP_COMMAND_GAMMA or PYTEST_VASP_POTCAR_DIR')
test_vasp_mark = ('ASE_VASP_COMMAND' not in os.environ or
'ASE_VASP_COMMAND_GAMMA' not in os.environ or
'PYTEST_VASP_POTCAR_DIR' not in os.environ)
pytestmark = pytest.mark.skipif(test_vasp_mark, reason='missing env var ASE_VASP_COMMAND or ASE_VASP_COMMAND_GAMMA '
'or PYTEST_VASP_POTCAR_DIR')


def test_vasp_gamma(tmp_path, monkeypatch):
Expand Down
30 changes: 30 additions & 0 deletions tests/calculators/test_wrapped_calculator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import pytest
from ase.atoms import Atoms
from wfl.configset import ConfigSet, OutputSpec
from wfl.calculators import generic

########################
# test a RuntimeWarning is raised when using the Espresso Calculator directly from ase
from tests.calculators.test_qe import espresso_avail, qe_pseudo
@pytest.mark.skipif(not espresso_avail, reason='qe testing env vars missing')
def test_wrapped_qe(tmp_path, qe_pseudo):
from ase.calculators.espresso import Espresso as Espresso_ASE
from wfl.calculators.espresso import Espresso as Espresso_wrap

config = Atoms('Si', positions=[[0, 0, 9]], cell=[2, 2, 2], pbc=[True, True, True])

pspot = qe_pseudo
kwargs = dict(
pseudopotentials=dict(Si=pspot.name),
pseudo_dir=pspot.parent,
input_data={"SYSTEM": {"ecutwfc": 40, "input_dft": "LDA",}},
kpts=(2, 3, 4),
conv_thr=0.0001,
workdir=tmp_path,
tstress=True,
tprnfor=True
)

direct_calc = (Espresso_ASE, [], kwargs)
kwargs_generic = dict(inputs=ConfigSet(config), outputs=OutputSpec(), calculator=direct_calc)
pytest.warns(RuntimeWarning, generic.calculate, **kwargs_generic)
14 changes: 14 additions & 0 deletions wfl/calculators/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,20 @@ def calculate(*args, **kwargs):
if calculator is None:
calculator = args[2]

#check if calculator should be wrapped
if type(calculator) == tuple:
from ase.calculators.espresso import Espresso as ASE_Espresso
from ase.calculators.vasp.vasp import Vasp as ASE_Vasp
from ase.calculators.aims import Aims as ASE_Aims
from ase.calculators.castep import Castep as ASE_Castep
from ase.calculators.mopac import MOPAC as ASE_MOPAC
from ase.calculators.orca import ORCA as ASE_ORCA
wrapped_types = [ASE_Espresso, ASE_Vasp, ASE_Aims, ASE_Castep, ASE_MOPAC, ASE_ORCA]

calc = calculator[0]
if calc in wrapped_types:
warnings.warn(f"{calc} should be imported from wfl.calculators rather than ase. Using {calc} directly can lead to duplicated singlepoints", RuntimeWarning)

default_autopara_info = getattr(calculator, "wfl_generic_default_autopara_info", {"num_inputs_per_python_subprocess": 10})

return autoparallelize(_run_autopara_wrappable, *args, default_autopara_info=default_autopara_info, **kwargs)
Expand Down

0 comments on commit f34f0d6

Please sign in to comment.