Skip to content

Commit

Permalink
Use return type typing_extensions.Self for class methods (#179)
Browse files Browse the repository at this point in the history
* use return type typing_extensions.Self for all class methods

* improve make_graphs doc str

* bump pre-commit hooks
  • Loading branch information
janosh authored Jul 7, 2024
1 parent b819ef5 commit 9556c3e
Show file tree
Hide file tree
Showing 9 changed files with 37 additions and 25 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ default_install_hook_types: [pre-commit, commit-msg]

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.10
rev: v0.5.1
hooks:
- id: ruff
args: [--fix]
Expand Down Expand Up @@ -48,7 +48,7 @@ repos:
- svelte

- repo: https://github.com/pre-commit/mirrors-eslint
rev: v9.5.0
rev: v9.6.0
hooks:
- id: eslint
types: [file]
Expand Down
4 changes: 3 additions & 1 deletion chgnet/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
if TYPE_CHECKING:
from collections.abc import Sequence

from typing_extensions import Self

from chgnet import TrainTask

warnings.filterwarnings("ignore")
Expand Down Expand Up @@ -97,7 +99,7 @@ def from_vasp(
save_path: str | None = None,
graph_converter: CrystalGraphConverter | None = None,
shuffle: bool = True,
) -> StructureData:
) -> Self:
"""Parse VASP output files into structures and labels and feed into the dataset.
Args:
Expand Down
5 changes: 3 additions & 2 deletions chgnet/graph/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

if TYPE_CHECKING:
from pymatgen.core import Structure
from typing_extensions import Self

try:
from chgnet.graph.cygraph import make_graph
Expand Down Expand Up @@ -285,6 +286,6 @@ def as_dict(self) -> dict[str, str | float]:
}

@classmethod
def from_dict(cls, dct: dict) -> CrystalGraphConverter:
def from_dict(cls, dct: dict) -> Self:
"""Create converter from dictionary."""
return CrystalGraphConverter(**dct)
return cls(**dct)
11 changes: 7 additions & 4 deletions chgnet/graph/crystalgraph.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from __future__ import annotations

import os
from typing import Any
from typing import TYPE_CHECKING, Any

import torch
from torch import Tensor

if TYPE_CHECKING:
from typing_extensions import Self

datatype = torch.float32


Expand Down Expand Up @@ -152,7 +155,7 @@ def save(self, fname: str | None = None, save_dir: str = ".") -> str:
return save_name

@classmethod
def from_file(cls, file_name: str) -> CrystalGraph:
def from_file(cls, file_name: str) -> Self:
"""Load a crystal graph from a file.
Args:
Expand All @@ -164,9 +167,9 @@ def from_file(cls, file_name: str) -> CrystalGraph:
return torch.load(file_name)

@classmethod
def from_dict(cls, dic: dict[str, Any]) -> CrystalGraph:
def from_dict(cls, dic: dict[str, Any]) -> Self:
"""Load a CrystalGraph from a dictionary."""
return CrystalGraph(**dic)
return cls(**dic)

def __repr__(self) -> str:
"""String representation of the graph."""
Expand Down
7 changes: 3 additions & 4 deletions chgnet/model/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
if TYPE_CHECKING:
from ase.io import Trajectory
from ase.optimize.optimize import Optimizer
from typing_extensions import Self

# We would like to thank M3GNet develop team for this module
# source: https://github.com/materialsvirtuallab/m3gnet
Expand Down Expand Up @@ -94,11 +95,9 @@ def __init__(
print(f"CHGNet will run on {self.device}")

@classmethod
def from_file(
cls, path: str, use_device: str | None = None, **kwargs
) -> CHGNetCalculator:
def from_file(cls, path: str, use_device: str | None = None, **kwargs) -> Self:
"""Load a user's CHGNet model and initialize the Calculator."""
return CHGNetCalculator(
return cls(
model=CHGNet.from_file(path),
use_device=use_device,
**kwargs,
Expand Down
14 changes: 8 additions & 6 deletions chgnet/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from chgnet.utils import determine_device

if TYPE_CHECKING:
from typing_extensions import Self

from chgnet import PredTask

module_dir = os.path.dirname(os.path.abspath(__file__))
Expand Down Expand Up @@ -661,17 +663,17 @@ def todict(self) -> dict:
return {"model_name": type(self).__name__, "model_args": self.model_args}

@classmethod
def from_dict(cls, dct: dict, **kwargs) -> CHGNet:
def from_dict(cls, dct: dict, **kwargs) -> Self:
"""Build a CHGNet from a saved dictionary."""
chgnet = CHGNet(**dct["model_args"], **kwargs)
chgnet = cls(**dct["model_args"], **kwargs)
chgnet.load_state_dict(dct["state_dict"])
return chgnet

@classmethod
def from_file(cls, path: str, **kwargs) -> CHGNet:
def from_file(cls, path: str, **kwargs) -> Self:
"""Build a CHGNet from a saved file."""
state = torch.load(path, map_location=torch.device("cpu"))
return CHGNet.from_dict(state["model"], **kwargs)
return cls.from_dict(state["model"], **kwargs)

@classmethod
def load(
Expand All @@ -681,7 +683,7 @@ def load(
use_device: str | None = None,
check_cuda_mem: bool = False,
verbose: bool = True,
) -> CHGNet:
) -> Self:
"""Load pretrained CHGNet model.
Args:
Expand Down Expand Up @@ -777,7 +779,7 @@ def from_graphs(
angle_basis_expansion: nn.Module,
*,
compute_stress: bool = False,
) -> BatchedGraph:
) -> Self:
"""Featurize and assemble a list of graphs.
Args:
Expand Down
7 changes: 5 additions & 2 deletions chgnet/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

if TYPE_CHECKING:
from torch.utils.data import DataLoader
from typing_extensions import Self

from chgnet import TrainTask

Expand Down Expand Up @@ -645,14 +646,14 @@ def save_checkpoint(self, epoch: int, mae_error: dict, save_dir: str) -> None:
)

@classmethod
def load(cls, path: str) -> Trainer:
def load(cls, path: str) -> Self:
"""Load trainer state_dict."""
state = torch.load(path, map_location=torch.device("cpu"))
model = CHGNet.from_dict(state["model"])
print(f"Loaded model params = {sum(p.numel() for p in model.parameters()):,}")
# drop model from trainer_args if present
state["trainer_args"].pop("model", None)
trainer = Trainer(model=model, **state["trainer_args"])
trainer = cls(model=model, **state["trainer_args"])
trainer.model.to(trainer.device)
trainer.optimizer.load_state_dict(state["optimizer"])
trainer.scheduler.load_state_dict(state["scheduler"])
Expand Down Expand Up @@ -791,6 +792,8 @@ def forward(
out["s_MAE_size"] = stress_target.shape[0]

# Mag
print(f"{list(prediction)=}")
print(f"{list(targets)=}")
if "m" in self.target_str:
mag_preds, mag_targets = [], []
m_mae_size = 0
Expand Down
9 changes: 5 additions & 4 deletions examples/make_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@ def make_graphs(
"""Make graphs from a StructureJsonData dataset.
Args:
data (StructureJsonData): a StructureJsonData
graph_dir (str): a directory to save the graphs
train_ratio (float): train ratio
val_ratio (float): val ratio
data (StructureJsonData | StructureData): Input structures to convert to graphs.
graph_dir (str): a directory to save the graphs and labels.
train_ratio (float): train ratio. Default = 0.8
val_ratio (float): val ratio. Default = 0.1. The test ratio is
1 - train_ratio - val_ratio
"""
os.makedirs(graph_dir, exist_ok=True)
random.shuffle(data.keys)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ dependencies = [
"nvidia-ml-py3>=7.352.0",
"pymatgen>=2023.10.11",
"torch>=1.11.0",
"typing-extensions>=4.12",
]
classifiers = [
"Intended Audience :: Science/Research",
Expand Down

0 comments on commit 9556c3e

Please sign in to comment.