From 9556c3ec5c5d11e66e3082b99adf8b46e0952bae Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Sun, 7 Jul 2024 14:20:38 -0400 Subject: [PATCH] Use return type `typing_extensions.Self` for class methods (#179) * use return type typing_extensions.Self for all class methods * improve make_graphs doc str * bump pre-commit hooks --- .pre-commit-config.yaml | 4 ++-- chgnet/data/dataset.py | 4 +++- chgnet/graph/converter.py | 5 +++-- chgnet/graph/crystalgraph.py | 11 +++++++---- chgnet/model/dynamics.py | 7 +++---- chgnet/model/model.py | 14 ++++++++------ chgnet/trainer/trainer.py | 7 +++++-- examples/make_graphs.py | 9 +++++---- pyproject.toml | 1 + 9 files changed, 37 insertions(+), 25 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 68f35ae0..28f80341 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,7 +4,7 @@ default_install_hook_types: [pre-commit, commit-msg] repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.4.10 + rev: v0.5.1 hooks: - id: ruff args: [--fix] @@ -48,7 +48,7 @@ repos: - svelte - repo: https://github.com/pre-commit/mirrors-eslint - rev: v9.5.0 + rev: v9.6.0 hooks: - id: eslint types: [file] diff --git a/chgnet/data/dataset.py b/chgnet/data/dataset.py index 01296479..2bf74ee2 100644 --- a/chgnet/data/dataset.py +++ b/chgnet/data/dataset.py @@ -19,6 +19,8 @@ if TYPE_CHECKING: from collections.abc import Sequence + from typing_extensions import Self + from chgnet import TrainTask warnings.filterwarnings("ignore") @@ -97,7 +99,7 @@ def from_vasp( save_path: str | None = None, graph_converter: CrystalGraphConverter | None = None, shuffle: bool = True, - ) -> StructureData: + ) -> Self: """Parse VASP output files into structures and labels and feed into the dataset. Args: diff --git a/chgnet/graph/converter.py b/chgnet/graph/converter.py index 12a07441..3fedf3c5 100644 --- a/chgnet/graph/converter.py +++ b/chgnet/graph/converter.py @@ -14,6 +14,7 @@ if TYPE_CHECKING: from pymatgen.core import Structure + from typing_extensions import Self try: from chgnet.graph.cygraph import make_graph @@ -285,6 +286,6 @@ def as_dict(self) -> dict[str, str | float]: } @classmethod - def from_dict(cls, dct: dict) -> CrystalGraphConverter: + def from_dict(cls, dct: dict) -> Self: """Create converter from dictionary.""" - return CrystalGraphConverter(**dct) + return cls(**dct) diff --git a/chgnet/graph/crystalgraph.py b/chgnet/graph/crystalgraph.py index 566df036..637b359a 100644 --- a/chgnet/graph/crystalgraph.py +++ b/chgnet/graph/crystalgraph.py @@ -1,11 +1,14 @@ from __future__ import annotations import os -from typing import Any +from typing import TYPE_CHECKING, Any import torch from torch import Tensor +if TYPE_CHECKING: + from typing_extensions import Self + datatype = torch.float32 @@ -152,7 +155,7 @@ def save(self, fname: str | None = None, save_dir: str = ".") -> str: return save_name @classmethod - def from_file(cls, file_name: str) -> CrystalGraph: + def from_file(cls, file_name: str) -> Self: """Load a crystal graph from a file. Args: @@ -164,9 +167,9 @@ def from_file(cls, file_name: str) -> CrystalGraph: return torch.load(file_name) @classmethod - def from_dict(cls, dic: dict[str, Any]) -> CrystalGraph: + def from_dict(cls, dic: dict[str, Any]) -> Self: """Load a CrystalGraph from a dictionary.""" - return CrystalGraph(**dic) + return cls(**dic) def __repr__(self) -> str: """String representation of the graph.""" diff --git a/chgnet/model/dynamics.py b/chgnet/model/dynamics.py index 5e1eeb91..a7855e94 100644 --- a/chgnet/model/dynamics.py +++ b/chgnet/model/dynamics.py @@ -31,6 +31,7 @@ if TYPE_CHECKING: from ase.io import Trajectory from ase.optimize.optimize import Optimizer + from typing_extensions import Self # We would like to thank M3GNet develop team for this module # source: https://github.com/materialsvirtuallab/m3gnet @@ -94,11 +95,9 @@ def __init__( print(f"CHGNet will run on {self.device}") @classmethod - def from_file( - cls, path: str, use_device: str | None = None, **kwargs - ) -> CHGNetCalculator: + def from_file(cls, path: str, use_device: str | None = None, **kwargs) -> Self: """Load a user's CHGNet model and initialize the Calculator.""" - return CHGNetCalculator( + return cls( model=CHGNet.from_file(path), use_device=use_device, **kwargs, diff --git a/chgnet/model/model.py b/chgnet/model/model.py index bfa4ac09..d2edb92a 100644 --- a/chgnet/model/model.py +++ b/chgnet/model/model.py @@ -25,6 +25,8 @@ from chgnet.utils import determine_device if TYPE_CHECKING: + from typing_extensions import Self + from chgnet import PredTask module_dir = os.path.dirname(os.path.abspath(__file__)) @@ -661,17 +663,17 @@ def todict(self) -> dict: return {"model_name": type(self).__name__, "model_args": self.model_args} @classmethod - def from_dict(cls, dct: dict, **kwargs) -> CHGNet: + def from_dict(cls, dct: dict, **kwargs) -> Self: """Build a CHGNet from a saved dictionary.""" - chgnet = CHGNet(**dct["model_args"], **kwargs) + chgnet = cls(**dct["model_args"], **kwargs) chgnet.load_state_dict(dct["state_dict"]) return chgnet @classmethod - def from_file(cls, path: str, **kwargs) -> CHGNet: + def from_file(cls, path: str, **kwargs) -> Self: """Build a CHGNet from a saved file.""" state = torch.load(path, map_location=torch.device("cpu")) - return CHGNet.from_dict(state["model"], **kwargs) + return cls.from_dict(state["model"], **kwargs) @classmethod def load( @@ -681,7 +683,7 @@ def load( use_device: str | None = None, check_cuda_mem: bool = False, verbose: bool = True, - ) -> CHGNet: + ) -> Self: """Load pretrained CHGNet model. Args: @@ -777,7 +779,7 @@ def from_graphs( angle_basis_expansion: nn.Module, *, compute_stress: bool = False, - ) -> BatchedGraph: + ) -> Self: """Featurize and assemble a list of graphs. Args: diff --git a/chgnet/trainer/trainer.py b/chgnet/trainer/trainer.py index 82cb71d8..87ec99eb 100644 --- a/chgnet/trainer/trainer.py +++ b/chgnet/trainer/trainer.py @@ -29,6 +29,7 @@ if TYPE_CHECKING: from torch.utils.data import DataLoader + from typing_extensions import Self from chgnet import TrainTask @@ -645,14 +646,14 @@ def save_checkpoint(self, epoch: int, mae_error: dict, save_dir: str) -> None: ) @classmethod - def load(cls, path: str) -> Trainer: + def load(cls, path: str) -> Self: """Load trainer state_dict.""" state = torch.load(path, map_location=torch.device("cpu")) model = CHGNet.from_dict(state["model"]) print(f"Loaded model params = {sum(p.numel() for p in model.parameters()):,}") # drop model from trainer_args if present state["trainer_args"].pop("model", None) - trainer = Trainer(model=model, **state["trainer_args"]) + trainer = cls(model=model, **state["trainer_args"]) trainer.model.to(trainer.device) trainer.optimizer.load_state_dict(state["optimizer"]) trainer.scheduler.load_state_dict(state["scheduler"]) @@ -791,6 +792,8 @@ def forward( out["s_MAE_size"] = stress_target.shape[0] # Mag + print(f"{list(prediction)=}") + print(f"{list(targets)=}") if "m" in self.target_str: mag_preds, mag_targets = [], [] m_mae_size = 0 diff --git a/examples/make_graphs.py b/examples/make_graphs.py index ba209f43..6093c4f6 100644 --- a/examples/make_graphs.py +++ b/examples/make_graphs.py @@ -29,10 +29,11 @@ def make_graphs( """Make graphs from a StructureJsonData dataset. Args: - data (StructureJsonData): a StructureJsonData - graph_dir (str): a directory to save the graphs - train_ratio (float): train ratio - val_ratio (float): val ratio + data (StructureJsonData | StructureData): Input structures to convert to graphs. + graph_dir (str): a directory to save the graphs and labels. + train_ratio (float): train ratio. Default = 0.8 + val_ratio (float): val ratio. Default = 0.1. The test ratio is + 1 - train_ratio - val_ratio """ os.makedirs(graph_dir, exist_ok=True) random.shuffle(data.keys) diff --git a/pyproject.toml b/pyproject.toml index 2b20a1d0..3a46fdde 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ dependencies = [ "nvidia-ml-py3>=7.352.0", "pymatgen>=2023.10.11", "torch>=1.11.0", + "typing-extensions>=4.12", ] classifiers = [ "Intended Audience :: Science/Research",