Skip to content

Commit

Permalink
bump pre-commit hooks and unignore ruff PT011 PT013
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh committed Jun 5, 2024
1 parent fad20ef commit bd589e8
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 32 deletions.
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ default_install_hook_types: [pre-commit, commit-msg]

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.4
rev: v0.4.7
hooks:
- id: ruff
args: [--fix]
Expand All @@ -23,7 +23,7 @@ repos:
- id: trailing-whitespace

- repo: https://github.com/codespell-project/codespell
rev: v2.2.6
rev: v2.3.0
hooks:
- id: codespell
stages: [commit, commit-msg]
Expand All @@ -46,7 +46,7 @@ repos:
- svelte

- repo: https://github.com/pre-commit/mirrors-eslint
rev: v9.3.0
rev: v9.4.0
hooks:
- id: eslint
types: [file]
Expand Down
3 changes: 0 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,6 @@ ignore = [
"PLR", # pylint refactor
"PLW2901", # Outer for loop variable overwritten by inner assignment target
"PT006", # pytest-parametrize-names-wrong-type
"PT011", # pytest-raises-too-broad
"PT013", # pytest-incorrect-pytest-import
"PT019", # pytest-fixture-param-without-value
"PTH", # prefer Path to os.path
"S108",
"S301", # pickle can be unsafe
Expand Down
16 changes: 10 additions & 6 deletions tests/test_converter.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
from __future__ import annotations

from typing import Literal
from typing import TYPE_CHECKING, Literal

import pytest
from pymatgen.core import Lattice, Structure
from pytest import CaptureFixture

from chgnet.graph import CrystalGraph
from chgnet.graph.converter import CrystalGraphConverter

if TYPE_CHECKING:
from collections.abc import Generator

lattice = Lattice.cubic(4)
species = ["Na", "Cl"]
coords = [[0, 0, 0], [0.5, 0.5, 0.5]]
NaCl = Structure(lattice, species, coords)


@pytest.fixture()
def _set_make_graph() -> None:
def _set_make_graph() -> Generator[None, None, None]:
# fixture to force make_graph to be None and then restore it after test
from chgnet.graph import converter

Expand Down Expand Up @@ -63,7 +65,7 @@ def test_crystal_graph_converter_warns():

@pytest.mark.parametrize("on_isolated_atoms", ["ignore", "warn", "error"])
def test_crystal_graph_converter_forward(
on_isolated_atoms, capsys: CaptureFixture[str]
on_isolated_atoms, capsys: pytest.CaptureFixture[str]
):
atom_graph_cutoff = 5
converter = CrystalGraphConverter(
Expand All @@ -75,13 +77,15 @@ def test_crystal_graph_converter_forward(
strained.apply_strain(5)
graph_id = "strained"
err_msg = (
f"Structure {graph_id=} has 2 isolated atom(s) with "
f"Structure {graph_id=} has {len(NaCl)} isolated atom(s) with "
f"{atom_graph_cutoff=}. "
f"CHGNet calculation will likely go wrong"
)

if on_isolated_atoms == "error":
with pytest.raises(ValueError) as exc_info:
with pytest.raises(
ValueError, match=f"Structure {graph_id=} has {len(NaCl)} isolated atom"
) as exc_info:
converter.forward(strained, graph_id=graph_id)
assert err_msg in str(exc_info.value)
else:
Expand Down
4 changes: 1 addition & 3 deletions tests/test_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,9 @@ def test_angle_encoder(num_angular: int, learnable: bool) -> None:

@pytest.mark.parametrize("num_angular", [-2, 8])
def test_angle_encoder_num_angular(num_angular: int) -> None:
with pytest.raises(ValueError) as exc_info:
with pytest.raises(ValueError, match=f"{num_angular=} must be an odd integer"):
AngleEncoder(num_angular=num_angular)

assert f"{num_angular=} must be an odd integer" in str(exc_info.value)


@pytest.mark.parametrize("learnable", [True, False])
def test_bond_encoder_learnable(learnable: bool) -> None:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from pymatgen.analysis.structure_matcher import StructureMatcher
from pymatgen.core import Structure
from pymatgen.io.ase import AseAtomsAdaptor
from pytest import MonkeyPatch, approx
from pytest import MonkeyPatch, approx # noqa: PT013

from chgnet import ROOT
from chgnet.graph import CrystalGraphConverter
Expand Down
19 changes: 9 additions & 10 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import numpy as np
import pytest
from pymatgen.core import Structure
from pytest import mark

from chgnet import ROOT
from chgnet.graph import CrystalGraphConverter
Expand All @@ -14,13 +13,13 @@
model = CHGNet.load()


@mark.parametrize("atom_fea_dim", [1, 64])
@mark.parametrize("bond_fea_dim", [1, 64])
@mark.parametrize("angle_fea_dim", [1, 64])
@mark.parametrize("num_radial", [1, 9])
@mark.parametrize("num_angular", [1, 9])
@mark.parametrize("n_conv", [1, 4])
@mark.parametrize("composition_model", ["MPtrj", "MPtrj_e", "MPF"])
@pytest.mark.parametrize("atom_fea_dim", [1, 64])
@pytest.mark.parametrize("bond_fea_dim", [1, 64])
@pytest.mark.parametrize("angle_fea_dim", [1, 64])
@pytest.mark.parametrize("num_radial", [1, 9])
@pytest.mark.parametrize("num_angular", [1, 9])
@pytest.mark.parametrize("n_conv", [1, 4])
@pytest.mark.parametrize("composition_model", ["MPtrj", "MPtrj_e", "MPF"])
def test_model(
atom_fea_dim: int,
bond_fea_dim: int,
Expand Down Expand Up @@ -118,8 +117,8 @@ def test_predict_structure() -> None:
assert out["atom_fea"].shape == (8, 64)


@mark.parametrize("axis", [[0, 0, 1], [1, 1, 0], [-2, 3, 1]])
@mark.parametrize("rotation_angle", [5, 30, 45, 120])
@pytest.mark.parametrize("axis", [[0, 0, 1], [1, 1, 0], [-2, 3, 1]])
@pytest.mark.parametrize("rotation_angle", [5, 30, 45, 120])
def test_predict_structure_rotated(rotation_angle: float, axis: list) -> None:
from pymatgen.transformations.standard_transformations import RotationTransformation

Expand Down
12 changes: 6 additions & 6 deletions tests/test_relaxation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import torch
from ase.filters import ExpCellFilter, Filter, FrechetCellFilter
from pymatgen.core import Structure
from pytest import approx, mark, param

from chgnet.graph import CrystalGraphConverter
from chgnet.model import CHGNet, StructOptimizer
Expand Down Expand Up @@ -54,19 +53,20 @@ def test_relaxation(
assert len(traj) == 2 if algorithm == "legacy" else 4

# make sure final structure is more relaxed than initial one
assert traj.energies[-1] == approx(-58.94209, rel=1e-4)
assert traj.energies[-1] == pytest.approx(-58.94209, rel=1e-4)


no_cuda = mark.skipif(not torch.cuda.is_available(), reason="No CUDA device")
no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device")
# skip in macos-14 M1 CI due to OOM error (TODO investigate if
# PYTORCH_MPS_HIGH_WATERMARK_RATIO can fix)
no_mps = mark.skipif(
no_mps = pytest.mark.skipif(
not torch.backends.mps.is_available() or "CI" in os.environ, reason="No MPS device"
)


@mark.parametrize(
"use_device", ["cpu", param("cuda", marks=no_cuda), param("mps", marks=no_mps)]
@pytest.mark.parametrize(
"use_device",
["cpu", pytest.param("cuda", marks=no_cuda), pytest.param("mps", marks=no_mps)],
)
def test_structure_optimizer_passes_kwargs_to_model(use_device: str) -> None:
relaxer = StructOptimizer(use_device=use_device)
Expand Down

0 comments on commit bd589e8

Please sign in to comment.