From cbbbf99a83dafb8bb18784b7677f3f9cf0aef8ed Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Sat, 16 Nov 2024 10:48:31 -0500 Subject: [PATCH 1/3] CHGNetCalculator add kwarg task: PredTask = "efsm" --- .pre-commit-config.yaml | 10 +++++----- chgnet/model/dynamics.py | 38 +++++++++++++++++++++++++------------- chgnet/model/model.py | 9 ++++++--- tests/test_md.py | 28 ++++++++++++++++++++++++++-- tests/test_relaxation.py | 4 +++- 5 files changed, 65 insertions(+), 24 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5f0a13d2..bc3acb2d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,10 +1,10 @@ -default_stages: [commit] +default_stages: [pre-commit] default_install_hook_types: [pre-commit, commit-msg] repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.9 + rev: v0.7.4 hooks: - id: ruff args: [--fix] @@ -28,11 +28,11 @@ repos: rev: v2.3.0 hooks: - id: codespell - stages: [commit, commit-msg] + stages: [pre-commit, commit-msg] args: [--check-filenames] - repo: https://github.com/kynan/nbstripout - rev: 0.7.1 + rev: 0.8.0 hooks: - id: nbstripout args: [--drop-empty-cells, --keep-output] @@ -48,7 +48,7 @@ repos: - svelte - repo: https://github.com/pre-commit/mirrors-eslint - rev: v9.12.0 + rev: v9.15.0 hooks: - id: eslint types: [file] diff --git a/chgnet/model/dynamics.py b/chgnet/model/dynamics.py index b5b01f97..1372fea5 100644 --- a/chgnet/model/dynamics.py +++ b/chgnet/model/dynamics.py @@ -33,6 +33,8 @@ from ase.optimize.optimize import Optimizer from typing_extensions import Self + from chgnet import PredTask + # We would like to thank M3GNet develop team for this module # source: https://github.com/materialsvirtuallab/m3gnet @@ -59,7 +61,7 @@ def __init__( *, use_device: str | None = None, check_cuda_mem: bool = False, - stress_weight: float | None = 1 / 160.21766208, + stress_weight: float = units.GPa, # GPa to eV/A^3 on_isolated_atoms: Literal["ignore", "warn", "error"] = "warn", return_site_energies: bool = False, **kwargs, @@ -124,6 +126,7 @@ def calculate( atoms: Atoms | None = None, properties: list | None = None, system_changes: list | None = None, + task: PredTask = "efsm", ) -> None: """Calculate various properties of the atoms using CHGNet. @@ -133,6 +136,8 @@ def calculate( Default is all properties. system_changes (list | None): The changes made to the system. Default is all changes. + task (PredTask): The task to perform. One of "e", "ef", "em", "efs", "efsm". + Default = "efsm" """ properties = properties or all_properties system_changes = system_changes or all_changes @@ -147,23 +152,30 @@ def calculate( graph = self.model.graph_converter(structure) model_prediction = self.model.predict_graph( graph.to(self.device), - task="efsm", + task=task, return_crystal_feas=True, return_site_energies=self.return_site_energies, ) # Convert Result - factor = 1 if not self.model.is_intensive else structure.composition.num_atoms - self.results.update( - energy=model_prediction["e"] * factor, - forces=model_prediction["f"], - free_energy=model_prediction["e"] * factor, - magmoms=model_prediction["m"], - stress=model_prediction["s"] * self.stress_weight, - crystal_fea=model_prediction["crystal_fea"], + extensive_factor = ( + 1 if not self.model.is_intensive else structure.composition.num_atoms + ) + key_map = dict( + e=("energy", extensive_factor), + f=("forces", 1), + m=("magmoms", 1), + s=("stress", self.stress_weight), ) + self.results |= { + long_key: model_prediction[key] * factor + for key, (long_key, factor) in key_map.items() + if key in model_prediction + } + self.results["free_energy"] = self.results["energy"] + self.results["crystal_fea"] = model_prediction["crystal_fea"] if self.return_site_energies: - self.results.update(energies=model_prediction["site_energies"]) + self.results["energies"] = model_prediction["site_energies"] class StructOptimizer: @@ -174,7 +186,7 @@ def __init__( model: CHGNet | CHGNetCalculator | None = None, optimizer_class: Optimizer | str | None = "FIRE", use_device: str | None = None, - stress_weight: float = 1 / 160.21766208, + stress_weight: float = units.GPa, on_isolated_atoms: Literal["ignore", "warn", "error"] = "warn", ) -> None: """Provide a trained CHGNet model and an optimizer to relax crystal structures. @@ -773,7 +785,7 @@ def __init__( model: CHGNet | CHGNetCalculator | None = None, optimizer_class: Optimizer | str | None = "FIRE", use_device: str | None = None, - stress_weight: float = 1 / 160.21766208, + stress_weight: float = units.GPa, on_isolated_atoms: Literal["ignore", "warn", "error"] = "error", ) -> None: """Initialize a structure optimizer object for calculation of bulk modulus. diff --git a/chgnet/model/model.py b/chgnet/model/model.py index d42c61c9..c1bd58f8 100644 --- a/chgnet/model/model.py +++ b/chgnet/model/model.py @@ -4,12 +4,13 @@ import os from collections.abc import Sequence from dataclasses import dataclass -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, Literal, get_args import torch from pymatgen.core import Structure from torch import Tensor, nn +from chgnet import PredTask from chgnet.graph import CrystalGraph, CrystalGraphConverter from chgnet.graph.crystalgraph import TORCH_DTYPE from chgnet.model.composition_model import AtomRef @@ -27,7 +28,6 @@ if TYPE_CHECKING: from typing_extensions import Self - from chgnet import PredTask module_dir = os.path.dirname(os.path.abspath(__file__)) @@ -603,7 +603,7 @@ def predict_graph( Args: graph (CrystalGraph | Sequence[CrystalGraph]): CrystalGraph(s) to predict. - task (str): can be 'e' 'ef', 'em', 'efs', 'efsm' + task (PredTask): one of 'e', 'ef', 'em', 'efs', 'efsm' Default = "efsm" return_site_energies (bool): whether to return per-site energies. Default = False @@ -626,6 +626,9 @@ def predict_graph( raise TypeError( f"{type(graph)=} must be CrystalGraph or list of CrystalGraphs" ) + valid_tasks = get_args(PredTask) + if task not in valid_tasks: + raise ValueError(f"Invalid {task=}. Must be one of {valid_tasks}.") model_device = next(self.parameters()).device diff --git a/tests/test_md.py b/tests/test_md.py index ec62f632..f44c21eb 100644 --- a/tests/test_md.py +++ b/tests/test_md.py @@ -2,7 +2,7 @@ import os import pickle -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, Literal, get_args import numpy as np import pytest @@ -22,7 +22,7 @@ from chgnet.graph import CrystalGraphConverter from chgnet.model import StructOptimizer from chgnet.model.dynamics import CHGNetCalculator, EquationOfState, MolecularDynamics -from chgnet.model.model import CHGNet +from chgnet.model.model import CHGNet, PredTask if TYPE_CHECKING: from pathlib import Path @@ -314,3 +314,27 @@ def test_md_crystal_feas_log(tmp_path: Path, monkeypatch: MonkeyPatch): assert crystal_feas[0][1] == approx(-1.4285042, abs=1e-5) assert crystal_feas[10][0] == approx(-0.0020592688, abs=1e-5) assert crystal_feas[10][1] == approx(-1.4284436, abs=1e-5) + + +@pytest.mark.parametrize("task", [*get_args(PredTask)]) +def test_calculator_task_valid(task: PredTask): + """Test that the task kwarg of CHGNetCalculator.calculate() works correctly.""" + key_map = dict(e="energy", f="forces", m="magmoms", s="stress") + calculator = CHGNetCalculator() + atoms = AseAtomsAdaptor.get_atoms(structure) + atoms.calc = calculator + + calculator.calculate(atoms=atoms, task=task) + + for key, prop in key_map.items(): + assert (prop in calculator.results) == (key in task) + + +def test_calculator_task_invalid(): + """Test that invalid task raises ValueError.""" + calculator = CHGNetCalculator() + atoms = AseAtomsAdaptor.get_atoms(structure) + atoms.calc = calculator + + with pytest.raises(ValueError, match="Invalid task='invalid'."): + calculator.calculate(atoms=atoms, task="invalid") diff --git a/tests/test_relaxation.py b/tests/test_relaxation.py index c23b675b..b5d39fff 100644 --- a/tests/test_relaxation.py +++ b/tests/test_relaxation.py @@ -50,7 +50,9 @@ def test_relaxation( assert {*traj.__dict__} == { *"atoms energies forces stresses magmoms atom_positions cells".split() } - assert len(traj) == 2 if algorithm == "legacy" else 4 + assert len(traj) == ( + 2 if algorithm == "legacy" else 4 + ), f"{len(traj)=}, {algorithm=}" # make sure final structure is more relaxed than initial one assert traj.energies[-1] == pytest.approx(-58.94209, rel=1e-4) From 0cb93e610297659d49a1c87c9f1d701666c17332 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Sat, 16 Nov 2024 10:56:56 -0500 Subject: [PATCH 2/3] fix site build: don't fetch MBD MetricsTable.svelte which no longer exists at that location --- site/.gitignore | 1 - site/package.json | 36 ++++++++++++++++++------------------ site/src/routes/+page.svelte | 5 +---- site/vite.config.ts | 10 ---------- 4 files changed, 19 insertions(+), 33 deletions(-) diff --git a/site/.gitignore b/site/.gitignore index 59078f29..bded1f72 100644 --- a/site/.gitignore +++ b/site/.gitignore @@ -5,4 +5,3 @@ node_modules .svelte-kit build src/routes/api/*.md -src/MetricsTable.svelte diff --git a/site/package.json b/site/package.json index 3474e4be..2f8156fc 100644 --- a/site/package.json +++ b/site/package.json @@ -15,28 +15,28 @@ "changelog": "npx auto-changelog --package --output ../changelog.md --hide-credit --commit-limit false" }, "devDependencies": { - "@sveltejs/adapter-static": "^3.0.2", - "@sveltejs/kit": "^2.5.17", - "@sveltejs/vite-plugin-svelte": "^3.1.1", - "eslint": "^9.5.0", - "eslint-plugin-svelte": "^2.41.0", + "@sveltejs/adapter-static": "^3.0.6", + "@sveltejs/kit": "^2.8.1", + "@sveltejs/vite-plugin-svelte": "^4.0.1", + "eslint": "^9.15.0", + "eslint-plugin-svelte": "^2.46.0", "hastscript": "^9.0.0", - "mdsvex": "^0.11.2", - "prettier": "^3.3.2", - "prettier-plugin-svelte": "^3.2.5", + "mdsvex": "^0.12.3", + "prettier": "^3.3.3", + "prettier-plugin-svelte": "^3.2.8", "rehype-autolink-headings": "^7.1.0", "rehype-slug": "^6.0.0", - "svelte": "^4.2.18", - "svelte-check": "^3.8.4", - "svelte-multiselect": "^10.3.0", - "svelte-preprocess": "^6.0.1", + "svelte": "^5.2.1", + "svelte-check": "^4.0.8", + "svelte-multiselect": "11.0.0-rc.1", + "svelte-preprocess": "^6.0.3", "svelte-toc": "^0.5.9", - "svelte-zoo": "^0.4.10", - "svelte2tsx": "^0.7.13", - "tslib": "^2.6.3", - "typescript": "^5.5.2", - "typescript-eslint": "^7.14.1", - "vite": "^5.3.1" + "svelte-zoo": "^0.4.13", + "svelte2tsx": "^0.7.25", + "tslib": "^2.8.1", + "typescript": "^5.6.3", + "typescript-eslint": "^8.14.0", + "vite": "^5.4.11" }, "prettier": { "semi": false, diff --git a/site/src/routes/+page.svelte b/site/src/routes/+page.svelte index 7e2c6975..201fe721 100644 --- a/site/src/routes/+page.svelte +++ b/site/src/routes/+page.svelte @@ -1,12 +1,9 @@
- - - +