Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change default CHGnet.load(check_cuda_mem: bool) to False #164

Merged
merged 4 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ default_install_hook_types: [pre-commit, commit-msg]

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.7
rev: v0.4.8
hooks:
- id: ruff
args: [--fix]
Expand Down
4 changes: 2 additions & 2 deletions chgnet/model/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(
model: CHGNet | None = None,
*,
use_device: str | None = None,
check_cuda_mem: bool = True,
check_cuda_mem: bool = False,
stress_weight: float | None = 1 / 160.21766208,
on_isolated_atoms: Literal["ignore", "warn", "error"] = "warn",
**kwargs,
Expand All @@ -73,7 +73,7 @@ def __init__(
automatically selected based on the available options.
Default = None
check_cuda_mem (bool): Whether to use cuda with most available memory
Default = True
Default = False
stress_weight (float): the conversion factor to convert GPa to eV/A^3.
Default = 1/160.21
on_isolated_atoms ('ignore' | 'warn' | 'error'): how to handle Structures
Expand Down
4 changes: 2 additions & 2 deletions chgnet/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,7 @@ def load(
*,
model_name: str = "0.3.0",
use_device: str | None = None,
check_cuda_mem: bool = True,
check_cuda_mem: bool = False,
verbose: bool = True,
) -> CHGNet:
"""Load pretrained CHGNet model.
Expand All @@ -692,7 +692,7 @@ def load(
automatically selected based on the available options.
Default = None
check_cuda_mem (bool): Whether to use cuda with most available memory
Default = True
Default = False
verbose (bool): whether to print model device information
Default = True
Raises:
Expand Down
4 changes: 2 additions & 2 deletions chgnet/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(
torch_seed: int | None = None,
data_seed: int | None = None,
use_device: str | None = None,
check_cuda_mem: bool = True,
check_cuda_mem: bool = False,
**kwargs,
) -> None:
"""Initialize all hyper-parameters for trainer.
Expand Down Expand Up @@ -87,7 +87,7 @@ def __init__(
automatically selected based on the available options.
Default = None
check_cuda_mem (bool): Whether to use cuda with most available memory
Default = True
Default = False
**kwargs (dict): additional hyper-params for optimizer, scheduler, etc.
"""
# Store trainer args for reproducibility
Expand Down
3 changes: 2 additions & 1 deletion chgnet/utils/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@
def determine_device(
use_device: str | None = None,
*,
check_cuda_mem: bool = True,
check_cuda_mem: bool = False,
) -> str:
"""Determine the device to use for torch model.

Args:
use_device (str): User specify device name
check_cuda_mem (bool): Whether to return cuda with most available memory
Default = False

Returns:
device (str): device name to be passed to model.to(device)
Expand Down
48 changes: 24 additions & 24 deletions examples/QueryMPtrj.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,22 @@ opt_task_types = [
optimization_task_ids = {}
for doc in material_ids:
material_id = doc.material_id
tmp = mpr.materials.get_data_by_id(material_id)
mp_doc = mpr.materials.get_data_by_id(material_id)

for task_id, task_type in tmp.calc_types.items():
for task_id, task_type in mp_doc.calc_types.items():
if task_type in opt_task_types:
optimization_task_ids[material_id.string].append(task_id)
```

### Query Materials Project Thermodoc entry and the relaxation tasks
### Query Materials Project `ThermoDoc` entry and the relaxation tasks

The thermodoc entry is the entry you normally see on the MP website
The `ThermoDoc` entry is the entry you normally see on the MP website

```python
# ThermoDoc: Query MP main entries
main_entry = mpr.get_entry_by_material_id(material_id=material_id)[0]
# Query one relaxation task
taskdoc = mpr.tasks.get_data_by_id(task_id, fields=["input", "output", "calcs_reversed", 'task_id', "run_type"])
task_doc = mpr.tasks.get_data_by_id(task_id, fields=["input", "output", "calcs_reversed", 'task_id', "run_type"])
```

## Filtering the data
Expand All @@ -61,58 +61,58 @@ This is done in two steps:
Check whether a task is compatible to Materials Project main entry, by comparing its DFT settings
and converged results with MP main entry.

- Note this step can no longer work for the current MP data, since a lot of `thermodoc` entry (main entry) have changed to `r2SCAN`
- Note this step no longer works for the current MP data, since a lot of `ThermoDoc` entries (main entry) have changed to `r2SCAN`

```python
def calc_type_equal(
taskdoc,
task_doc,
main_entry,
trjdata
trj_data
) -> bool:
# Check the LDAU of task
try:
is_hubbard = taskdoc.calcs_reversed[0].input['parameters']['LDAU']
is_hubbard = task_doc.calcs_reversed[0].input['parameters']['LDAU']
except:
is_hubbard = taskdoc.calcs_reversed[0].input['incar']['LDAU']
is_hubbard = task_doc.calcs_reversed[0].input['incar']['LDAU']

# Make sure we don't include both GGA and GGA+U for the same mp_id
if main_entry.parameters['is_hubbard'] != is_hubbard:
print(f'{main_entry.entry_id}, {taskdoc.task_id} is_hubbard= {is_hubbard}')
trjdata.exception[taskdoc.task_id] = f'is_hubbard inconsistent task is_hubbard={is_hubbard}'
print(f'{main_entry.entry_id}, {task_doc.task_id} is_hubbard= {is_hubbard}')
trj_data.exception[task_doc.task_id] = f'is_hubbard inconsistent task is_hubbard={is_hubbard}'
return False
elif is_hubbard == True:
# If the task is calculated with GGA+U
# Make sure the +U values are the same for each element
composition = taskdoc.output.structure.composition
composition = task_doc.output.structure.composition
hubbards = {element.symbol: U for element, U in
zip(composition.elements,
taskdoc.calcs_reversed[0].input['incar']['LDAUU'])}
task_doc.calcs_reversed[0].input['incar']['LDAUU'])}
if main_entry.parameters['hubbards'] != hubbards:
thermo_hubbards = main_entry.parameters['hubbards']
trjdata.exception[taskdoc.task_id] = f'hubbards inconsistent task hubbards={hubbards}, thermo hubbards={thermo_hubbards}'
trj_data.exception[task_doc.task_id] = f'hubbards inconsistent task hubbards={hubbards}, thermo hubbards={thermo_hubbards}'
return False
else:
# Check the energy convergence of the task wrt. the main entry
return check_energy_convergence(
taskdoc,
task_doc,
main_entry.uncorrected_energy_per_atom,
trjdata=trjdata
trj_data=trj_data
)
else:
# Check energy convergence for pure GGA tasks
check_energy_convergence(
taskdoc,
task_doc,
main_entry.uncorrected_energy_per_atom,
trjdata=trjdata
trj_data=trj_data
)

def check_energy_convergence(
taskdoc,
task_doc,
relaxed_entry_uncorrected_energy_per_atom,
trjdata
trj_data
) -> bool:
task_energy = taskdoc.calcs_reversed[0].output['ionic_steps'][-1]['e_fr_energy']
n_atom = taskdoc.calcs_reversed[0].output['ionic_steps'][-1][
task_energy = task_doc.calcs_reversed[0].output['ionic_steps'][-1]['e_fr_energy']
n_atom = task_doc.calcs_reversed[0].output['ionic_steps'][-1][
'structure'].composition.num_atoms
e_per_atom = task_energy / n_atom
# This is the energy difference of the last frame of the task vs main_entry energy
Expand All @@ -125,7 +125,7 @@ def check_energy_convergence(
# The task is falsely relaxed, we will discard the whole task
# This step will filter out tasks that relaxed into different spin states
# that caused large energy discrepancies
trjdata.exception[taskdoc.task_id] =
trj_data.exception[task_doc.task_id] =
f'e_diff is too large, '
f'task last step energy_per_atom = {e_per_atom}, '
f'relaxed_entry_uncorrected_e_per_atom = {relaxed_entry_uncorrected_energy_per_atom}'
Expand Down
14 changes: 7 additions & 7 deletions examples/crystaltoolkit_relax_viewer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
" # https://github.com/materialsproject/crystaltoolkit\n",
" # (only needed on Google Colab or if you didn't install these packages yet)\n",
" !git clone --depth 1 https://github.com/CederGroupHub/chgnet\n",
" !pip install './chgnet[examples]'\n"
" !pip install './chgnet[examples]'"
]
},
{
Expand All @@ -47,7 +47,7 @@
"outputs": [],
"source": [
"import numpy as np\n",
"from pymatgen.core import Structure\n"
"from pymatgen.core import Structure"
]
},
{
Expand All @@ -66,7 +66,7 @@
"\n",
" url = \"https://github.com/CederGroupHub/chgnet/raw/-/examples/mp-18767-LiMnO2.cif\"\n",
" cif = urlopen(url).read().decode(\"utf-8\")\n",
" structure = Structure.from_str(cif, fmt=\"cif\")\n"
" structure = Structure.from_str(cif, fmt=\"cif\")"
]
},
{
Expand Down Expand Up @@ -94,7 +94,7 @@
"# stretch the cell by a small amount\n",
"structure.scale_lattice(structure.volume * 1.1)\n",
"\n",
"print(f\"perturbed: {structure.get_space_group_info()}\")\n"
"print(f\"perturbed: {structure.get_space_group_info()}\")"
]
},
{
Expand Down Expand Up @@ -212,7 +212,7 @@
"\n",
"from chgnet.model import StructOptimizer\n",
"\n",
"trajectory = StructOptimizer().relax(structure)[\"trajectory\"]\n"
"trajectory = StructOptimizer().relax(structure)[\"trajectory\"]"
]
},
{
Expand All @@ -229,7 +229,7 @@
" np.linalg.norm(force, axis=1).mean() # mean of norm of force on each atom\n",
" for force in trajectory.forces\n",
"]\n",
"df_traj.index.name = \"step\"\n"
"df_traj.index.name = \"step\""
]
},
{
Expand All @@ -250,7 +250,7 @@
"mp_id = \"mp-18767\"\n",
"\n",
"dft_energy = -59.09\n",
"print(f\"{dft_energy=:.2f} eV (see https://materialsproject.org/materials/{mp_id})\")\n"
"print(f\"{dft_energy=:.2f} eV (see https://materialsproject.org/materials/{mp_id})\")"
]
},
{
Expand Down
17 changes: 12 additions & 5 deletions tests/test_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import inspect

import numpy as np
import pytest
from pymatgen.core import Structure
Expand Down Expand Up @@ -249,10 +251,15 @@ def test_model_load_version_params(
with pytest.raises(ValueError, match=f"Unknown {model_name=}"):
CHGNet.load(model_name=model_name)

# # set CHGNET_DEVICE to "cuda" and test
monkeypatch.setenv("CHGNET_DEVICE", env_device := "foobar")
with pytest.raises(
RuntimeError,
match=f"Expected one of cpu, .+type at start of device string: {env_device}",
bad_env_device = "foobar"
err_msg = f"Expected one of cpu, .+type at start of device string: {bad_env_device}"
with ( # noqa: PT012
monkeypatch.context() as ctx,
pytest.raises(RuntimeError, match=err_msg),
):
ctx.setenv("CHGNET_DEVICE", bad_env_device)
CHGNet.load()

# check check_cuda_mem defaults to False
inspect_signature = inspect.signature(CHGNet.load)
assert inspect_signature.parameters["check_cuda_mem"].default is False
Loading