diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a5d41652..65f3f166 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,7 +21,7 @@ repos: # Python - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.8.1 + rev: v0.8.2 hooks: - id: ruff args: ["--fix"] diff --git a/dpdata/stat.py b/dpdata/stat.py index 5ec39570..ed74c258 100644 --- a/dpdata/stat.py +++ b/dpdata/stat.py @@ -2,13 +2,14 @@ from abc import ABCMeta, abstractmethod from functools import lru_cache +from typing import Any import numpy as np from dpdata.system import LabeledSystem, MultiSystems -def mae(errors: np.ndarray) -> np.float64: +def mae(errors: np.ndarray) -> np.floating[Any]: """Compute the mean absolute error (MAE). Parameters @@ -18,13 +19,13 @@ def mae(errors: np.ndarray) -> np.float64: Returns ------- - np.float64 + floating[Any] mean absolute error (MAE) """ return np.mean(np.abs(errors)) -def rmse(errors: np.ndarray) -> np.float64: +def rmse(errors: np.ndarray) -> np.floating[Any]: """Compute the root mean squared error (RMSE). Parameters @@ -34,7 +35,7 @@ def rmse(errors: np.ndarray) -> np.float64: Returns ------- - np.float64 + floating[Any] root mean squared error (RMSE) """ return np.sqrt(np.mean(np.square(errors))) @@ -74,22 +75,22 @@ def f_errors(self) -> np.ndarray: """Force errors.""" @property - def e_mae(self) -> np.float64: + def e_mae(self) -> np.floating[Any]: """Energy MAE.""" return mae(self.e_errors) @property - def e_rmse(self) -> np.float64: + def e_rmse(self) -> np.floating[Any]: """Energy RMSE.""" return rmse(self.e_errors) @property - def f_mae(self) -> np.float64: + def f_mae(self) -> np.floating[Any]: """Force MAE.""" return mae(self.f_errors) @property - def f_rmse(self) -> np.float64: + def f_rmse(self) -> np.floating[Any]: """Force RMSE.""" return rmse(self.f_errors) diff --git a/dpdata/system.py b/dpdata/system.py index abe0a755..00172602 100644 --- a/dpdata/system.py +++ b/dpdata/system.py @@ -1049,6 +1049,7 @@ def remove_atom_names(self, atom_names: str | list[str]): atom_idx = self.data["atom_types"] == idx removed_atom_idx.append(atom_idx) picked_atom_idx = ~np.any(removed_atom_idx, axis=0) + assert not isinstance(picked_atom_idx, np.bool_) new_sys = self.pick_atom_idx(picked_atom_idx) # let's remove atom_names # firstly, rearrange atom_names and put these atom_names in the end