Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CHGNetCalculator.calculate add kwarg task: PredTask = "efsm" #215

Merged
merged 3 commits into from
Nov 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
default_stages: [commit]
default_stages: [pre-commit]

default_install_hook_types: [pre-commit, commit-msg]

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.9
rev: v0.7.4
hooks:
- id: ruff
args: [--fix]
Expand All @@ -28,11 +28,11 @@ repos:
rev: v2.3.0
hooks:
- id: codespell
stages: [commit, commit-msg]
stages: [pre-commit, commit-msg]
args: [--check-filenames]

- repo: https://github.com/kynan/nbstripout
rev: 0.7.1
rev: 0.8.0
hooks:
- id: nbstripout
args: [--drop-empty-cells, --keep-output]
Expand All @@ -48,7 +48,7 @@ repos:
- svelte

- repo: https://github.com/pre-commit/mirrors-eslint
rev: v9.12.0
rev: v9.15.0
hooks:
- id: eslint
types: [file]
Expand Down
36 changes: 23 additions & 13 deletions chgnet/model/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
from ase.optimize.optimize import Optimizer
from typing_extensions import Self

from chgnet import PredTask

# We would like to thank M3GNet develop team for this module
# source: https://github.com/materialsvirtuallab/m3gnet

Expand All @@ -59,7 +61,7 @@ def __init__(
*,
use_device: str | None = None,
check_cuda_mem: bool = False,
stress_weight: float | None = 1 / 160.21766208,
stress_weight: float = units.GPa, # GPa to eV/A^3
on_isolated_atoms: Literal["ignore", "warn", "error"] = "warn",
return_site_energies: bool = False,
**kwargs,
Expand Down Expand Up @@ -124,6 +126,7 @@ def calculate(
atoms: Atoms | None = None,
properties: list | None = None,
system_changes: list | None = None,
task: PredTask = "efsm",
) -> None:
"""Calculate various properties of the atoms using CHGNet.

Expand All @@ -133,6 +136,8 @@ def calculate(
Default is all properties.
system_changes (list | None): The changes made to the system.
Default is all changes.
task (PredTask): The task to perform. One of "e", "ef", "em", "efs", "efsm".
Default = "efsm"
"""
properties = properties or all_properties
system_changes = system_changes or all_changes
Expand All @@ -147,23 +152,28 @@ def calculate(
graph = self.model.graph_converter(structure)
model_prediction = self.model.predict_graph(
graph.to(self.device),
task="efsm",
task=task,
return_crystal_feas=True,
return_site_energies=self.return_site_energies,
)

# Convert Result
factor = 1 if not self.model.is_intensive else structure.composition.num_atoms
self.results.update(
energy=model_prediction["e"] * factor,
forces=model_prediction["f"],
free_energy=model_prediction["e"] * factor,
magmoms=model_prediction["m"],
stress=model_prediction["s"] * self.stress_weight,
crystal_fea=model_prediction["crystal_fea"],
extensive_factor = len(structure) if self.model.is_intensive else 1
key_map = dict(
e=("energy", extensive_factor),
f=("forces", 1),
m=("magmoms", 1),
s=("stress", self.stress_weight),
)
self.results |= {
long_key: model_prediction[key] * factor
for key, (long_key, factor) in key_map.items()
if key in model_prediction
}
self.results["free_energy"] = self.results["energy"]
self.results["crystal_fea"] = model_prediction["crystal_fea"]
if self.return_site_energies:
self.results.update(energies=model_prediction["site_energies"])
self.results["energies"] = model_prediction["site_energies"]


class StructOptimizer:
Expand All @@ -174,7 +184,7 @@ def __init__(
model: CHGNet | CHGNetCalculator | None = None,
optimizer_class: Optimizer | str | None = "FIRE",
use_device: str | None = None,
stress_weight: float = 1 / 160.21766208,
stress_weight: float = units.GPa,
on_isolated_atoms: Literal["ignore", "warn", "error"] = "warn",
) -> None:
"""Provide a trained CHGNet model and an optimizer to relax crystal structures.
Expand Down Expand Up @@ -773,7 +783,7 @@ def __init__(
model: CHGNet | CHGNetCalculator | None = None,
optimizer_class: Optimizer | str | None = "FIRE",
use_device: str | None = None,
stress_weight: float = 1 / 160.21766208,
stress_weight: float = units.GPa,
on_isolated_atoms: Literal["ignore", "warn", "error"] = "error",
) -> None:
"""Initialize a structure optimizer object for calculation of bulk modulus.
Expand Down
9 changes: 6 additions & 3 deletions chgnet/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
import os
from collections.abc import Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING, Literal, get_args

import torch
from pymatgen.core import Structure
from torch import Tensor, nn

from chgnet import PredTask
from chgnet.graph import CrystalGraph, CrystalGraphConverter
from chgnet.graph.crystalgraph import TORCH_DTYPE
from chgnet.model.composition_model import AtomRef
Expand All @@ -27,7 +28,6 @@
if TYPE_CHECKING:
from typing_extensions import Self

from chgnet import PredTask

module_dir = os.path.dirname(os.path.abspath(__file__))

Expand Down Expand Up @@ -603,7 +603,7 @@ def predict_graph(

Args:
graph (CrystalGraph | Sequence[CrystalGraph]): CrystalGraph(s) to predict.
task (str): can be 'e' 'ef', 'em', 'efs', 'efsm'
task (PredTask): one of 'e', 'ef', 'em', 'efs', 'efsm'
Default = "efsm"
return_site_energies (bool): whether to return per-site energies.
Default = False
Expand All @@ -626,6 +626,9 @@ def predict_graph(
raise TypeError(
f"{type(graph)=} must be CrystalGraph or list of CrystalGraphs"
)
valid_tasks = get_args(PredTask)
if task not in valid_tasks:
raise ValueError(f"Invalid {task=}. Must be one of {valid_tasks}.")

model_device = next(self.parameters()).device

Expand Down
2 changes: 1 addition & 1 deletion chgnet/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,7 +858,7 @@ def forward(
for mag_pred, mag_target in zip(prediction["m"], targets["m"], strict=True):
# exclude structures without magmom labels
if self.allow_missing_labels:
if mag_target is not None and not np.isnan(mag_target).any():
if mag_target is not None and not torch.isnan(mag_target).any():
mag_preds.append(mag_pred)
mag_targets.append(mag_target)
m_mae_size += mag_target.shape[0]
Expand Down
1 change: 0 additions & 1 deletion site/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,3 @@ node_modules
.svelte-kit
build
src/routes/api/*.md
src/MetricsTable.svelte
36 changes: 18 additions & 18 deletions site/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,28 @@
"changelog": "npx auto-changelog --package --output ../changelog.md --hide-credit --commit-limit false"
},
"devDependencies": {
"@sveltejs/adapter-static": "^3.0.2",
"@sveltejs/kit": "^2.5.17",
"@sveltejs/vite-plugin-svelte": "^3.1.1",
"eslint": "^9.5.0",
"eslint-plugin-svelte": "^2.41.0",
"@sveltejs/adapter-static": "^3.0.6",
"@sveltejs/kit": "^2.8.1",
"@sveltejs/vite-plugin-svelte": "^4.0.1",
"eslint": "^9.15.0",
"eslint-plugin-svelte": "^2.46.0",
"hastscript": "^9.0.0",
"mdsvex": "^0.11.2",
"prettier": "^3.3.2",
"prettier-plugin-svelte": "^3.2.5",
"mdsvex": "^0.12.3",
"prettier": "^3.3.3",
"prettier-plugin-svelte": "^3.2.8",
"rehype-autolink-headings": "^7.1.0",
"rehype-slug": "^6.0.0",
"svelte": "^4.2.18",
"svelte-check": "^3.8.4",
"svelte-multiselect": "^10.3.0",
"svelte-preprocess": "^6.0.1",
"svelte": "^5.2.1",
"svelte-check": "^4.0.8",
"svelte-multiselect": "11.0.0-rc.1",
"svelte-preprocess": "^6.0.3",
"svelte-toc": "^0.5.9",
"svelte-zoo": "^0.4.10",
"svelte2tsx": "^0.7.13",
"tslib": "^2.6.3",
"typescript": "^5.5.2",
"typescript-eslint": "^7.14.1",
"vite": "^5.3.1"
"svelte-zoo": "^0.4.13",
"svelte2tsx": "^0.7.25",
"tslib": "^2.8.1",
"typescript": "^5.6.3",
"typescript-eslint": "^8.14.0",
"vite": "^5.4.11"
},
"prettier": {
"semi": false,
Expand Down
5 changes: 1 addition & 4 deletions site/src/routes/+page.svelte
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
<script lang="ts">
import Readme from '$root/README.md'
import MetricsTable from '$src/MetricsTable.svelte'
</script>

<main>
<Readme>
<MetricsTable slot="metrics-table" />
</Readme>
<Readme />
</main>

<style>
Expand Down
10 changes: 0 additions & 10 deletions site/vite.config.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,6 @@
import { sveltekit } from '@sveltejs/kit/vite'
import * as fs from 'fs'
import type { UserConfig } from 'vite'

// fetch latest Matbench Discovery metrics table at build time and save to src/ dir
await fetch(
`https://github.com/janosh/matbench-discovery/raw/main/site/src/figs/metrics-table-uniq-protos.svelte`,
)
.then((res) => res.text())
.then((text) => {
fs.writeFileSync(`src/MetricsTable.svelte`, text)
})

export default {
plugins: [sveltekit()],

Expand Down
28 changes: 26 additions & 2 deletions tests/test_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import os
import pickle
from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING, Literal, get_args

import numpy as np
import pytest
Expand All @@ -22,7 +22,7 @@
from chgnet.graph import CrystalGraphConverter
from chgnet.model import StructOptimizer
from chgnet.model.dynamics import CHGNetCalculator, EquationOfState, MolecularDynamics
from chgnet.model.model import CHGNet
from chgnet.model.model import CHGNet, PredTask

if TYPE_CHECKING:
from pathlib import Path
Expand Down Expand Up @@ -314,3 +314,27 @@ def test_md_crystal_feas_log(tmp_path: Path, monkeypatch: MonkeyPatch):
assert crystal_feas[0][1] == approx(-1.4285042, abs=1e-5)
assert crystal_feas[10][0] == approx(-0.0020592688, abs=1e-5)
assert crystal_feas[10][1] == approx(-1.4284436, abs=1e-5)


@pytest.mark.parametrize("task", [*get_args(PredTask)])
def test_calculator_task_valid(task: PredTask):
"""Test that the task kwarg of CHGNetCalculator.calculate() works correctly."""
key_map = dict(e="energy", f="forces", m="magmoms", s="stress")
calculator = CHGNetCalculator()
atoms = AseAtomsAdaptor.get_atoms(structure)
atoms.calc = calculator

calculator.calculate(atoms=atoms, task=task)

for key, prop in key_map.items():
assert (prop in calculator.results) == (key in task)


def test_calculator_task_invalid():
"""Test that invalid task raises ValueError."""
calculator = CHGNetCalculator()
atoms = AseAtomsAdaptor.get_atoms(structure)
atoms.calc = calculator

with pytest.raises(ValueError, match="Invalid task='invalid'."):
calculator.calculate(atoms=atoms, task="invalid")
4 changes: 3 additions & 1 deletion tests/test_relaxation.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ def test_relaxation(
assert {*traj.__dict__} == {
*"atoms energies forces stresses magmoms atom_positions cells".split()
}
assert len(traj) == 2 if algorithm == "legacy" else 4
assert len(traj) == (
2 if algorithm == "legacy" else 4
), f"{len(traj)=}, {algorithm=}"

# make sure final structure is more relaxed than initial one
assert traj.energies[-1] == pytest.approx(-58.94209, rel=1e-4)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ def test_trainer(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
for param in chgnet.composition_model.parameters():
assert param.requires_grad is False
assert tmp_path.is_dir(), "Training dir was not created"
for target_str in ["e", "f", "s", "m"]:
assert ~np.isnan(trainer.training_history[target_str]["train"]).any()
assert ~np.isnan(trainer.training_history[target_str]["val"]).any()
for prop in "efsm":
assert ~np.isnan(trainer.training_history[prop]["train"]).any()
assert ~np.isnan(trainer.training_history[prop]["val"]).any()
output_files = [file.name for file in tmp_path.iterdir()]
for prefix in ("epoch", "bestE_", "bestF_"):
n_matches = sum(file.startswith(prefix) for file in output_files)
Expand Down
Loading