From cb7b61925969cbcb986444f6e1d56e53cc3a19a6 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 16 May 2024 20:36:15 -0400 Subject: [PATCH 01/12] chore: improve type annotations Signed-off-by: Jinzhe Zeng --- dpdata/ase_calculator.py | 11 +- dpdata/bond_order_system.py | 7 +- dpdata/data_type.py | 8 +- dpdata/driver.py | 9 +- dpdata/format.py | 3 +- dpdata/stat.py | 16 ++- dpdata/system.py | 205 +++++++++++++++++++++--------------- dpdata/utils.py | 13 ++- pyproject.toml | 6 ++ 9 files changed, 174 insertions(+), 104 deletions(-) diff --git a/dpdata/ase_calculator.py b/dpdata/ase_calculator.py index c05799789..3b5868c08 100644 --- a/dpdata/ase_calculator.py +++ b/dpdata/ase_calculator.py @@ -23,7 +23,10 @@ class DPDataCalculator(Calculator): dpdata driver """ - name = "dpdata" + @property + def name(self) -> str: + return "dpdata" + implemented_properties = ["energy", "free_energy", "forces", "virial", "stress"] def __init__(self, driver: Driver, **kwargs) -> None: @@ -48,10 +51,10 @@ def calculate( system_changes : List[str], optional unused, only for function signature compatibility, by default all_changes """ - if atoms is not None: - self.atoms = atoms.copy() + assert atoms is not None + atoms = atoms.copy() - system = dpdata.System(self.atoms, fmt="ase/structure") + system = dpdata.System(atoms, fmt="ase/structure") data = system.predict(driver=self.driver).data self.results["energy"] = data["energies"][0] diff --git a/dpdata/bond_order_system.py b/dpdata/bond_order_system.py index 1b6f903d4..cd8dca059 100644 --- a/dpdata/bond_order_system.py +++ b/dpdata/bond_order_system.py @@ -96,13 +96,14 @@ def from_fmt_obj(self, fmtobj, file_name, **kwargs): mol = fmtobj.from_bond_order_system(file_name, **kwargs) self.from_rdkit_mol(mol) if hasattr(fmtobj.from_bond_order_system, "post_func"): - for post_f in fmtobj.from_bond_order_system.post_func: + for post_f in fmtobj.from_bond_order_system.post_func: # type: ignore self.post_funcs.get_plugin(post_f)(self) return self def to_fmt_obj(self, fmtobj, *args, **kwargs): from rdkit.Chem import Conformer + assert self.rdkit_mol is not None self.rdkit_mol.RemoveAllConformers() for ii in range(self.get_nframes()): conf = Conformer() @@ -145,9 +146,9 @@ def get_formal_charges(self): """Return the formal charges on each atom.""" return self.data["formal_charges"] - def copy(self): + def copy(self): # type: ignore new_mol = deepcopy(self.rdkit_mol) - self.__class__(data=deepcopy(self.data), rdkit_mol=new_mol) + return self.__class__(data=deepcopy(self.data), rdkit_mol=new_mol) def __add__(self, other): raise NotImplementedError( diff --git a/dpdata/data_type.py b/dpdata/data_type.py index 64d4c5b1e..fdee615e2 100644 --- a/dpdata/data_type.py +++ b/dpdata/data_type.py @@ -1,7 +1,9 @@ from enum import Enum, unique -from typing import TYPE_CHECKING, Tuple +from typing import TYPE_CHECKING, Optional, Tuple, Union, Union import numpy as np +from dpdata.bond_order_system import BondOrderSystem +from dpdata.bond_order_system import BondOrderSystem from dpdata.plugin import Plugin @@ -50,7 +52,7 @@ def __init__( self, name: str, dtype: type, - shape: Tuple[int, Axis] = None, + shape: Optional[Tuple[Union[int, Axis], ...]] = None, required: bool = True, ) -> None: self.name = name @@ -60,6 +62,7 @@ def __init__( def real_shape(self, system: "System") -> Tuple[int]: """Returns expected real shape of a system.""" + assert self.shape is not None shape = [] for ii in self.shape: if ii is Axis.NFRAMES: @@ -70,6 +73,7 @@ def real_shape(self, system: "System") -> Tuple[int]: shape.append(system.get_natoms()) elif ii is Axis.NBONDS: # BondOrderSystem + assert isinstance(system, BondOrderSystem) shape.append(system.get_nbonds()) elif ii == -1: shape.append(AnyInt(-1)) diff --git a/dpdata/driver.py b/dpdata/driver.py index 81d9a9ede..56ad94384 100644 --- a/dpdata/driver.py +++ b/dpdata/driver.py @@ -1,12 +1,12 @@ """Driver plugin system.""" from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Callable, List, Union +from typing import TYPE_CHECKING, Callable, List, Type, Union from .plugin import Plugin if TYPE_CHECKING: - import ase + import ase.calculators.calculator class Driver(ABC): @@ -43,7 +43,7 @@ def register(key: str) -> Callable: return Driver.__DriverPlugin.register(key) @staticmethod - def get_driver(key: str) -> "Driver": + def get_driver(key: str) -> Type["Driver"]: """Get a driver plugin. Parameters @@ -157,6 +157,7 @@ def label(self, data: dict) -> dict: dict labeled data with energies and forces """ + labeled_data = {} for ii, driver in enumerate(self.drivers): lb_data = driver.label(data.copy()) if ii == 0: @@ -199,7 +200,7 @@ def register(key: str) -> Callable: return Minimizer.__MinimizerPlugin.register(key) @staticmethod - def get_minimizer(key: str) -> "Minimizer": + def get_minimizer(key: str) -> Type["Minimizer"]: """Get a minimizer plugin. Parameters diff --git a/dpdata/format.py b/dpdata/format.py index cd77561a7..777c9780f 100644 --- a/dpdata/format.py +++ b/dpdata/format.py @@ -2,6 +2,7 @@ import os from abc import ABC +from typing import Callable from .plugin import Plugin @@ -163,7 +164,7 @@ def decorator(object): if not isinstance(func_name, (list, tuple, set)): object.post_func = (func_name,) else: - object.post_func = func_name + object.post_func = tuple(func_name) return object return decorator diff --git a/dpdata/stat.py b/dpdata/stat.py index 8de649829..62b10c468 100644 --- a/dpdata/stat.py +++ b/dpdata/stat.py @@ -1,4 +1,4 @@ -from abc import ABCMeta, abstractproperty +from abc import ABCMeta, abstractmethod, abstractmethod, abstractproperty from functools import lru_cache import numpy as np @@ -61,11 +61,13 @@ def __init__(self, system_1: SYSTEM_TYPE, system_2: SYSTEM_TYPE) -> None: self.system_1 = system_1 self.system_2 = system_2 - @abstractproperty + @property + @abstractmethod def e_errors(self) -> np.ndarray: """Energy errors.""" - @abstractproperty + @property + @abstractmethod def f_errors(self) -> np.ndarray: """Force errors.""" @@ -114,12 +116,16 @@ class Errors(ErrorsBase): @lru_cache() def e_errors(self) -> np.ndarray: """Energy errors.""" + assert isinstance(self.system_1, self.SYSTEM_TYPE) + assert isinstance(self.system_2, self.SYSTEM_TYPE) return self.system_1["energies"] - self.system_2["energies"] @property @lru_cache() def f_errors(self) -> np.ndarray: """Force errors.""" + assert isinstance(self.system_1, self.SYSTEM_TYPE) + assert isinstance(self.system_2, self.SYSTEM_TYPE) return (self.system_1["forces"] - self.system_2["forces"]).ravel() @@ -147,6 +153,8 @@ class MultiErrors(ErrorsBase): @lru_cache() def e_errors(self) -> np.ndarray: """Energy errors.""" + assert isinstance(self.system_1, self.SYSTEM_TYPE) + assert isinstance(self.system_2, self.SYSTEM_TYPE) errors = [] for nn in self.system_1.systems.keys(): ss1 = self.system_1[nn] @@ -158,6 +166,8 @@ def e_errors(self) -> np.ndarray: @lru_cache() def f_errors(self) -> np.ndarray: """Force errors.""" + assert isinstance(self.system_1, self.SYSTEM_TYPE) + assert isinstance(self.system_2, self.SYSTEM_TYPE) errors = [] for nn in self.system_1.systems.keys(): ss1 = self.system_1[nn] diff --git a/dpdata/system.py b/dpdata/system.py index 33b7e7cfa..198f023df 100644 --- a/dpdata/system.py +++ b/dpdata/system.py @@ -1,10 +1,16 @@ # %% import glob import hashlib +import numbers import os import warnings from copy import deepcopy -from typing import Any, Dict, Optional, Tuple, Union +import sys +from typing import TYPE_CHECKING, TYPE_CHECKING, Any, Dict, Iterable, List, Literal, Optional, Tuple, Type, Union, overload +if sys.version_info < (3, 8): + from typing_extensions import TypedDict +else: + from typing import TypedDict import numpy as np @@ -13,6 +19,7 @@ # ensure all plugins are loaded! import dpdata.plugins +import dpdata.plugins.deepmd from dpdata.amber.mask import load_param_file, pick_by_amber_mask from dpdata.data_type import Axis, DataError, DataType, get_data_types from dpdata.driver import Driver, Minimizer @@ -26,6 +33,9 @@ utf8len, ) +if TYPE_CHECKING: + import parmed + def load_format(fmt): fmt = fmt.lower() @@ -64,11 +74,11 @@ class System: Attributes ---------- - DTYPES : tuple[DataType] + DTYPES : tuple[DataType, ...] data types of this class """ - DTYPES = ( + DTYPES: Tuple[DataType, ...] = ( DataType("atom_numbs", list, (Axis.NTYPES,)), DataType("atom_names", list, (Axis.NTYPES,)), DataType("atom_types", np.ndarray, (Axis.NATOMS,)), @@ -84,13 +94,14 @@ class System: def __init__( self, - file_name=None, - fmt="auto", - type_map=None, - begin=0, - step=1, - data=None, - convergence_check=True, + # some formats do not use string as input + file_name: Any=None, + fmt: str="auto", + type_map:Optional[list[str]]=None, + begin:int=0, + step:int=1, + data: Optional[Dict[str, Any]]=None, + convergence_check: bool=True, **kwargs, ): """Constructor. @@ -211,13 +222,13 @@ def check_data(self): post_funcs = Plugin() - def from_fmt(self, file_name, fmt="auto", **kwargs): + def from_fmt(self, file_name: Any, fmt: str="auto", **kwargs: Any): fmt = fmt.lower() if fmt == "auto": fmt = os.path.basename(file_name).split(".")[-1].lower() return self.from_fmt_obj(load_format(fmt), file_name, **kwargs) - def from_fmt_obj(self, fmtobj, file_name, **kwargs): + def from_fmt_obj(self, fmtobj: Format, file_name: Any, **kwargs: Any): data = fmtobj.from_system(file_name, **kwargs) if data: if isinstance(data, (list, tuple)): @@ -227,11 +238,11 @@ def from_fmt_obj(self, fmtobj, file_name, **kwargs): self.data = {**self.data, **data} self.check_data() if hasattr(fmtobj.from_system, "post_func"): - for post_f in fmtobj.from_system.post_func: + for post_f in fmtobj.from_system.post_func: # type: ignore self.post_funcs.get_plugin(post_f)(self) return self - def to(self, fmt: str, *args, **kwargs) -> "System": + def to(self, fmt: str, *args: Any, **kwargs: Any) -> "System": """Dump systems to the specific format. Parameters @@ -250,7 +261,7 @@ def to(self, fmt: str, *args, **kwargs) -> "System": """ return self.to_fmt_obj(load_format(fmt), *args, **kwargs) - def to_fmt_obj(self, fmtobj, *args, **kwargs): + def to_fmt_obj(self, fmtobj: Format, *args: Any, **kwargs: Any): return fmtobj.to_system(self.data, *args, **kwargs) def __repr__(self): @@ -268,13 +279,33 @@ def __str__(self): ret += "\n" + " ".join(map(str, self.get_atom_numbs())) return ret + @overload + def __getitem__(self, key: Union[int, slice, list, np.ndarray]) -> "System": + ... + @overload + def __getitem__(self, key: Literal["atom_names", "real_atom_names"]) -> List[str]: + ... + @overload + def __getitem__(self, key: Literal["atom_numbs"]) -> List[int]: + ... + @overload + def __getitem__(self, key: Literal["nopbc"]) -> bool: + ... + @overload + def __getitem__(self, key: Literal["orig", "coords", "energies", "forces", "virials"]) -> np.ndarray: + ... + @overload + def __getitem__(self, key: str) -> Any: + # other cases, for example customized data + ... + def __getitem__(self, key): """Returns proerty stored in System by key or by idx.""" if isinstance(key, (int, slice, list, np.ndarray)): return self.sub_system(key) return self.data[key] - def __len__(self): + def __len__(self) -> int: """Returns number of frames in the system.""" return self.get_nframes() @@ -293,13 +324,13 @@ def __add__(self, others): raise RuntimeError("Unspported data structure") return self.__class__.from_dict({"data": self_copy.data}) - def dump(self, filename, indent=4): + def dump(self, filename: str, indent: int=4): """Dump .json or .yaml file.""" from monty.serialization import dumpfn dumpfn(self.as_dict(), filename, indent=indent) - def map_atom_types(self, type_map=None) -> np.ndarray: + def map_atom_types(self, type_map: Optional[Union[Dict[str, int], List[str]]]=None) -> np.ndarray: """Map the atom types of the system. Parameters @@ -338,7 +369,7 @@ def map_atom_types(self, type_map=None) -> np.ndarray: return new_atom_types @staticmethod - def load(filename): + def load(filename: str): """Rebuild System obj. from .json or .yaml file.""" from monty.serialization import loadfn @@ -347,7 +378,7 @@ def load(filename): @classmethod def from_dict(cls, data: dict): """Construct a System instance from a data dict.""" - from monty.serialization import MontyDecoder + from monty.serialization import MontyDecoder # type: ignore decoded = { k: MontyDecoder().process_decoded(v) @@ -356,7 +387,7 @@ def from_dict(cls, data: dict): } return cls(**decoded) - def as_dict(self): + def as_dict(self) -> dict: """Returns data dict of System instance.""" d = { "@module": self.__class__.__module__, @@ -365,23 +396,23 @@ def as_dict(self): } return d - def get_atom_names(self): + def get_atom_names(self) -> List[str]: """Returns name of atoms.""" return self.data["atom_names"] - def get_atom_types(self): + def get_atom_types(self) -> np.ndarray: """Returns type of atoms.""" return self.data["atom_types"] - def get_atom_numbs(self): + def get_atom_numbs(self) -> List[int]: """Returns number of atoms.""" return self.data["atom_numbs"] - def get_nframes(self): + def get_nframes(self) -> int: """Returns number of frames in the system.""" return len(self.data["cells"]) - def get_natoms(self): + def get_natoms(self) -> int: """Returns total number of atoms in the system.""" return len(self.data["atom_types"]) @@ -393,7 +424,7 @@ def copy(self): """Returns a copy of the system.""" return self.__class__.from_dict({"data": deepcopy(self.data)}) - def sub_system(self, f_idx): + def sub_system(self, f_idx: numbers.Integral) -> "System": """Construct a subsystem from the system. Parameters @@ -408,7 +439,7 @@ def sub_system(self, f_idx): """ tmp = self.__class__() # convert int to array_like - if isinstance(f_idx, (int, np.int64)): + if isinstance(f_idx, numbers.Integral): f_idx = np.array([f_idx]) for tt in self.DTYPES: if tt.name not in self.data: @@ -416,7 +447,7 @@ def sub_system(self, f_idx): continue if tt.shape is not None and Axis.NFRAMES in tt.shape: axis_nframes = tt.shape.index(Axis.NFRAMES) - new_shape = [slice(None) for _ in self.data[tt.name].shape] + new_shape: List[Union[slice, np.ndarray]] = [slice(None) for _ in self.data[tt.name].shape] new_shape[axis_nframes] = f_idx tmp.data[tt.name] = self.data[tt.name][tuple(new_shape)] else: @@ -424,7 +455,7 @@ def sub_system(self, f_idx): tmp.data[tt.name] = self.data[tt.name] return tmp - def append(self, system): + def append(self, system: "System") -> bool: """Append a system to this system. Parameters @@ -480,7 +511,7 @@ def append(self, system): self.data["nopbc"] = False return True - def convert_to_mixed_type(self, type_map=None): + def convert_to_mixed_type(self, type_map:Optional[List[str]]=None): """Convert the data dict to mixed type format structure, in order to append systems with different formula but the same number of atoms. Change the 'atom_names' to one placeholder type 'MIXED_TOKEN' and add 'real_atom_types' to store the real type @@ -506,7 +537,7 @@ def convert_to_mixed_type(self, type_map=None): self.data["atom_numbs"] = [natoms] self.data["atom_names"] = ["MIXED_TOKEN"] - def sort_atom_names(self, type_map=None): + def sort_atom_names(self, type_map:Optional[List[str]]=None): """Sort atom_names of the system and reorder atom_numbs and atom_types accoarding to atom_names. If type_map is not given, atom_names will be sorted by alphabetical order. If type_map is given, atom_names will be type_map. @@ -518,7 +549,7 @@ def sort_atom_names(self, type_map=None): """ self.data = sort_atom_names(self.data, type_map=type_map) - def check_type_map(self, type_map): + def check_type_map(self, type_map:Optional[List[str]]): """Assign atom_names to type_map if type_map is given and different from atom_names. @@ -530,7 +561,7 @@ def check_type_map(self, type_map): if type_map is not None and type_map != self.data["atom_names"]: self.sort_atom_names(type_map=type_map) - def apply_type_map(self, type_map): + def apply_type_map(self, type_map: List[str]): """Customize the element symbol order and it should maintain order consistency in dpgen or deepmd-kit. It is especially recommended for multiple complexsystems with multiple elements. @@ -560,13 +591,13 @@ def sort_atom_types(self) -> np.ndarray: continue if tt.shape is not None and Axis.NATOMS in tt.shape: axis_natoms = tt.shape.index(Axis.NATOMS) - new_shape = [slice(None) for _ in self.data[tt.name].shape] + new_shape: List[Union[slice, np.ndarray]] = [slice(None) for _ in self.data[tt.name].shape] new_shape[axis_natoms] = idx self.data[tt.name] = self.data[tt.name][tuple(new_shape)] return idx @property - def formula(self): + def formula(self) -> str: """Return the formula of this system, like C3H5O2.""" return "".join( [ @@ -578,7 +609,7 @@ def formula(self): ) @property - def uniq_formula(self): + def uniq_formula(self) -> str: """Return the uniq_formula of this system. The uniq_formula sort the elements in formula by names. Systems with the same uniq_formula can be append together. @@ -628,7 +659,7 @@ def short_name(self) -> str: return short_formula return self.formula_hash - def extend(self, systems): + def extend(self, systems: Iterable["System"]): """Extend a system list to this system. Parameters @@ -646,7 +677,7 @@ def apply_pbc(self): self.data["coords"] = np.matmul(ncoord, self.data["cells"]) @post_funcs.register("remove_pbc") - def remove_pbc(self, protect_layer=9): + def remove_pbc(self, protect_layer: int=9): """This method does NOT delete the definition of the cells, it (1) revises the cell to a cubic cell and ensures that the cell boundary to any atom in the system is no less than `protect_layer` @@ -661,7 +692,7 @@ def remove_pbc(self, protect_layer=9): assert protect_layer >= 0, "the protect_layer should be no less than 0" remove_pbc(self.data, protect_layer) - def affine_map(self, trans, f_idx=0): + def affine_map(self, trans, f_idx: numbers.Integral=0): assert np.linalg.det(trans) != 0 self.data["cells"][f_idx] = np.matmul(self.data["cells"][f_idx], trans) self.data["coords"][f_idx] = np.matmul(self.data["coords"][f_idx], trans) @@ -679,7 +710,7 @@ def rot_lower_triangular(self): for ii in range(self.get_nframes()): self.rot_frame_lower_triangular(ii) - def rot_frame_lower_triangular(self, f_idx=0): + def rot_frame_lower_triangular(self, f_idx: numbers.Integral=0): qq, rr = np.linalg.qr(self.data["cells"][f_idx].T) if np.linalg.det(qq) < 0: qq = -qq @@ -696,11 +727,11 @@ def rot_frame_lower_triangular(self, f_idx=0): self.affine_map(rot, f_idx=f_idx) return np.matmul(qq, rot) - def add_atom_names(self, atom_names): + def add_atom_names(self, atom_names: List[str]): """Add atom_names that do not exist.""" self.data = add_atom_names(self.data, atom_names) - def replicate(self, ncopy): + def replicate(self, ncopy: Union[List[int], Tuple[int, int, int]]): """Replicate the each frame in the system in 3 dimensions. Each frame in the system will become a supercell. @@ -752,7 +783,7 @@ def replicate(self, ncopy): ) return tmp - def replace(self, initial_atom_type, end_atom_type, replace_num): + def replace(self, initial_atom_type: str, end_atom_type: str, replace_num: int): if type(self) is not dpdata.System: raise RuntimeError( "Must use method replace() of the instance of class dpdata.System" @@ -797,7 +828,7 @@ def replace(self, initial_atom_type, end_atom_type, replace_num): self.sort_atom_types() def perturb( - self, pert_num, cell_pert_fraction, atom_pert_distance, atom_pert_style="normal" + self, pert_num: int, cell_pert_fraction: float, atom_pert_distance: float, atom_pert_style: str="normal" ): """Perturb each frame in the system randomly. The cell will be deformed randomly, and atoms will be displaced by a random distance in random direction. @@ -865,7 +896,7 @@ def nopbc(self): return False @nopbc.setter - def nopbc(self, value): + def nopbc(self, value: bool): self.data["nopbc"] = value def shuffle(self): @@ -874,7 +905,7 @@ def shuffle(self): self.data = self.sub_system(idx).data return idx - def predict(self, *args: Any, driver: str = "dp", **kwargs: Any) -> "LabeledSystem": + def predict(self, *args: Any, driver: Union[str, Driver] = "dp", **kwargs: Any) -> "LabeledSystem": """Predict energies and forces by a driver. Parameters @@ -926,7 +957,7 @@ def minimize( data = minimizer.minimize(self.data.copy()) return LabeledSystem(data=data) - def pick_atom_idx(self, idx, nopbc=None): + def pick_atom_idx(self, idx: numbers.Integral, nopbc: Optional[bool]=None): """Pick atom index. Parameters @@ -942,7 +973,7 @@ def pick_atom_idx(self, idx, nopbc=None): new system """ new_sys = self.copy() - if isinstance(idx, (int, np.int64)): + if isinstance(idx, numbers.Integral): idx = np.array([idx]) for tt in self.DTYPES: if tt.name not in self.data: @@ -950,7 +981,7 @@ def pick_atom_idx(self, idx, nopbc=None): continue if tt.shape is not None and Axis.NATOMS in tt.shape: axis_natoms = tt.shape.index(Axis.NATOMS) - new_shape = [slice(None) for _ in self.data[tt.name].shape] + new_shape: List[Union[slice, np.ndarray]] = [slice(None) for _ in self.data[tt.name].shape] new_shape[axis_natoms] = idx new_sys.data[tt.name] = self.data[tt.name][tuple(new_shape)] # recalculate atom_numbs according to atom_types @@ -962,7 +993,7 @@ def pick_atom_idx(self, idx, nopbc=None): new_sys.nopbc = nopbc return new_sys - def remove_atom_names(self, atom_names): + def remove_atom_names(self, atom_names: Union[str, Iterable[str]]): """Remove atom names and all such atoms. For example, you may not remove EP atoms in TIP4P/Ew water, which is not a real atom. @@ -988,7 +1019,7 @@ def remove_atom_names(self, atom_names): new_sys.data["atom_numbs"] = new_sys.data["atom_numbs"][: len(new_atom_names)] return new_sys - def pick_by_amber_mask(self, param, maskstr, pass_coords=False, nopbc=None): + def pick_by_amber_mask(self, param: Union[str, "parmed.Structure"], maskstr: str, pass_coords: bool=False, nopbc: Optional[bool]=None): """Pick atoms by amber mask. Parameters @@ -1018,7 +1049,7 @@ def pick_by_amber_mask(self, param, maskstr, pass_coords=False, nopbc=None): return self.pick_atom_idx(idx, nopbc=nopbc) @classmethod - def register_data_type(cls, *data_type: Tuple[DataType]): + def register_data_type(cls, *data_type: DataType): """Register data type. Parameters @@ -1038,7 +1069,7 @@ def register_data_type(cls, *data_type: Tuple[DataType]): cls.DTYPES = tuple(dtypes_dict.values()) -def get_cell_perturb_matrix(cell_pert_fraction): +def get_cell_perturb_matrix(cell_pert_fraction: float): if cell_pert_fraction < 0: raise RuntimeError("cell_pert_fraction can not be negative") e0 = np.random.rand(6) @@ -1053,7 +1084,7 @@ def get_cell_perturb_matrix(cell_pert_fraction): return cell_pert_matrix -def get_atom_perturb_vector(atom_pert_distance, atom_pert_style="normal"): +def get_atom_perturb_vector(atom_pert_distance: float, atom_pert_style: Literal["normal", "uniform", "const"]="normal"): random_vector = None if atom_pert_distance < 0: raise RuntimeError("atom_pert_distance can not be negative") @@ -1123,7 +1154,7 @@ class LabeledSystem(System): The number of skipped frames when loading MD trajectory. """ - DTYPES = System.DTYPES + ( + DTYPES: Tuple[DataType, ...] = System.DTYPES + ( DataType("energies", np.ndarray, (Axis.NFRAMES,)), DataType("forces", np.ndarray, (Axis.NFRAMES, Axis.NATOMS, 3)), DataType("virials", np.ndarray, (Axis.NFRAMES, 3, 3), required=False), @@ -1142,7 +1173,7 @@ def from_fmt_obj(self, fmtobj, file_name, **kwargs): self.data = {**self.data, **data} self.check_data() if hasattr(fmtobj.from_labeled_system, "post_func"): - for post_f in fmtobj.from_labeled_system.post_func: + for post_f in fmtobj.from_labeled_system.post_func: # type: ignore self.post_funcs.get_plugin(post_f)(self) return self @@ -1178,11 +1209,11 @@ def __add__(self, others): raise RuntimeError("Unspported data structure") return self.__class__.from_dict({"data": self_copy.data}) - def has_virial(self): + def has_virial(self) -> bool: # return ('virials' in self.data) and (len(self.data['virials']) > 0) return "virials" in self.data - def affine_map_fv(self, trans, f_idx): + def affine_map_fv(self, trans, f_idx: numbers.Integral): assert np.linalg.det(trans) != 0 self.data["forces"][f_idx] = np.matmul(self.data["forces"][f_idx], trans) if self.has_virial(): @@ -1190,12 +1221,12 @@ def affine_map_fv(self, trans, f_idx): trans.T, np.matmul(self.data["virials"][f_idx], trans) ) - def rot_frame_lower_triangular(self, f_idx=0): + def rot_frame_lower_triangular(self, f_idx: numbers.Integral=0): trans = System.rot_frame_lower_triangular(self, f_idx=f_idx) self.affine_map_fv(trans, f_idx=f_idx) return trans - def correction(self, hl_sys): + def correction(self, hl_sys: "LabeledSystem") -> "LabeledSystem": """Get energy and force correction between self and a high-level LabeledSystem. The self's coordinates will be kept, but energy and forces will be replaced by the correction between these two systems. @@ -1275,14 +1306,14 @@ def __init__(self, *systems, type_map=None): type_map : list of str Maps atom type to name """ - self.systems = {} + self.systems: Dict[str, System] = {} if type_map is not None: - self.atom_names = type_map + self.atom_names: List[str] = type_map else: - self.atom_names = [] + self.atom_names: List[str] = [] self.append(*systems) - def from_fmt_obj(self, fmtobj, directory, labeled=True, **kwargs): + def from_fmt_obj(self, fmtobj: Format, directory, labeled:bool=True, **kwargs: Any): if not isinstance(fmtobj, dpdata.plugins.deepmd.DeePMDMixedFormat): for dd in fmtobj.from_multi_systems(directory, **kwargs): if labeled: @@ -1306,7 +1337,7 @@ def from_fmt_obj(self, fmtobj, directory, labeled=True, **kwargs): self.append(*system_list) return self - def to_fmt_obj(self, fmtobj, directory, *args, **kwargs): + def to_fmt_obj(self, fmtobj: Format, directory, *args: Any, **kwargs: Any): if not isinstance(fmtobj, dpdata.plugins.deepmd.DeePMDMixedFormat): for fn, ss in zip( fmtobj.to_multi_systems( @@ -1325,7 +1356,7 @@ def to_fmt_obj(self, fmtobj, directory, *args, **kwargs): ) return self - def to(self, fmt: str, *args, **kwargs) -> "MultiSystems": + def to(self, fmt: str, *args: Any, **kwargs: Any) -> "MultiSystems": """Dump systems to the specific format. Parameters @@ -1369,13 +1400,13 @@ def __add__(self, others): raise RuntimeError("Unspported data structure") @classmethod - def from_file(cls, file_name, fmt, **kwargs): + def from_file(cls, file_name, fmt: str, **kwargs: Any): multi_systems = cls() multi_systems.load_systems_from_file(file_name=file_name, fmt=fmt, **kwargs) return multi_systems @classmethod - def from_dir(cls, dir_name, file_name, fmt="auto", type_map=None): + def from_dir(cls, dir_name: str, file_name: str, fmt: str="auto", type_map: Optional[List[str]]=None): multi_systems = cls() target_file_list = sorted( glob.glob(f"./{dir_name}/**/{file_name}", recursive=True) @@ -1386,15 +1417,16 @@ def from_dir(cls, dir_name, file_name, fmt="auto", type_map=None): ) return multi_systems - def load_systems_from_file(self, file_name=None, fmt=None, **kwargs): + def load_systems_from_file(self, file_name=None, fmt: Optional[str]=None, **kwargs): + assert fmt is not None fmt = fmt.lower() return self.from_fmt_obj(load_format(fmt), file_name, **kwargs) - def get_nframes(self): + def get_nframes(self) -> int: """Returns number of frames in all systems.""" return sum(len(system) for system in self.systems.values()) - def append(self, *systems): + def append(self, *systems: Union[System, "MultiSystems"]): """Append systems or MultiSystems to systems. Parameters @@ -1411,7 +1443,7 @@ def append(self, *systems): else: raise RuntimeError("Object must be System or MultiSystems!") - def __append(self, system): + def __append(self, system: System): if not system.formula: return # prevent changing the original system @@ -1423,7 +1455,7 @@ def __append(self, system): else: self.systems[formula] = system.copy() - def check_atom_names(self, system): + def check_atom_names(self, system: System): """Make atom_names in all systems equal, prevent inconsistent atom_types.""" # new_in_system = set(system["atom_names"]) - set(self.atom_names) # new_in_self = set(self.atom_names) - set(system["atom_names"]) @@ -1444,7 +1476,7 @@ def check_atom_names(self, system): system.add_atom_names(new_in_self) system.sort_atom_names(type_map=self.atom_names) - def predict(self, *args: Any, driver="dp", **kwargs: Any) -> "MultiSystems": + def predict(self, *args: Any, driver: Union[str, Driver]="dp", **kwargs: Any) -> "MultiSystems": """Predict energies and forces by a driver. Parameters @@ -1503,7 +1535,7 @@ def minimize( new_multisystems.append(ss.minimize(*args, minimizer=minimizer, **kwargs)) return new_multisystems - def pick_atom_idx(self, idx, nopbc=None): + def pick_atom_idx(self, idx: numbers.Integral, nopbc: Optional[bool]=None): """Pick atom index. Parameters @@ -1523,7 +1555,7 @@ def pick_atom_idx(self, idx, nopbc=None): new_sys.append(ss.pick_atom_idx(idx, nopbc=nopbc)) return new_sys - def correction(self, hl_sys: "MultiSystems"): + def correction(self, hl_sys: "MultiSystems") -> "MultiSystems": """Get energy and force correction between self (assumed low-level) and a high-level MultiSystems. The self's coordinates will be kept, but energy and forces will be replaced by the correction between these two systems. @@ -1558,6 +1590,7 @@ def correction(self, hl_sys: "MultiSystems"): for nn in self.systems.keys(): ll_ss = self[nn] hl_ss = hl_sys[nn] + assert isinstance(ll_ss, LabeledSystem) corrected_sys.append(ll_ss.correction(hl_ss)) return corrected_sys @@ -1619,7 +1652,7 @@ def train_test_split( return train_systems, test_systems, test_system_idx -def get_cls_name(cls: object) -> str: +def get_cls_name(cls: Type[Any]) -> str: """Returns the fully qualified name of a class, such as `np.ndarray`. Parameters @@ -1654,7 +1687,7 @@ def add_format_methods(): for method, formatcls in Format.get_from_methods().items(): - def get_func(ff): + def get_func_from(ff): # ff is not initized when defining from_format so cannot be polluted def from_format(self, file_name, **kwargs): return self.from_fmt_obj(ff(), file_name, **kwargs) @@ -1662,22 +1695,22 @@ def from_format(self, file_name, **kwargs): from_format.__doc__ = f"Read data from :class:`{get_cls_name(ff)}` format." return from_format - setattr(System, method, get_func(formatcls)) - setattr(LabeledSystem, method, get_func(formatcls)) - setattr(MultiSystems, method, get_func(formatcls)) + setattr(System, method, get_func_from(formatcls)) + setattr(LabeledSystem, method, get_func_from(formatcls)) + setattr(MultiSystems, method, get_func_from(formatcls)) for method, formatcls in Format.get_to_methods().items(): - def get_func(ff): + def get_func_to(ff): def to_format(self, *args, **kwargs): return self.to_fmt_obj(ff(), *args, **kwargs) to_format.__doc__ = f"Dump data to :class:`{get_cls_name(ff)}` format." return to_format - setattr(System, method, get_func(formatcls)) - setattr(LabeledSystem, method, get_func(formatcls)) - setattr(MultiSystems, method, get_func(formatcls)) + setattr(System, method, get_func_to(formatcls)) + setattr(LabeledSystem, method, get_func_to(formatcls)) + setattr(MultiSystems, method, get_func_to(formatcls)) # at this point, System.DTYPES and LabeledSystem.DTYPES has been initialized System.register_data_type(*get_data_types(labeled=False)) diff --git a/dpdata/utils.py b/dpdata/utils.py index cf4a109ee..a6b76c8cc 100644 --- a/dpdata/utils.py +++ b/dpdata/utils.py @@ -1,9 +1,20 @@ +from typing import Dict, List, Literal, Union, overload import numpy as np from dpdata.periodic_table import Element +@overload +def elements_index_map(elements: List[str], standard: bool, inverse: Literal[True]) -> Dict[int, str]: + ... +@overload +def elements_index_map(elements: List[str], standard: bool, inverse: Literal[False]=...) -> Dict[str, int]: + ... +@overload +def elements_index_map(elements: List[str], standard: bool, inverse: bool=False) -> Union[Dict[str, int], Dict[int, str]]: + ... -def elements_index_map(elements, standard=False, inverse=False): + +def elements_index_map(elements: List[str], standard: bool=False, inverse: bool=False) -> dict: if standard: elements.sort(key=lambda x: Element(x).Z) if inverse: diff --git a/pyproject.toml b/pyproject.toml index 1be79442a..ebb29663d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ dependencies = [ 'h5py', 'wcmatch', 'importlib_metadata>=1.4; python_version < "3.8"', + 'type_extensions>=0.4.0; python_version < "3.8"', ] requires-python = ">=3.7" readme = "README.md" @@ -122,3 +123,8 @@ banned-module-level-imports = [ "monty", "scipy", ] + +[tool.pyright] +include = [ + "dpdata/*.py", +] From fbf964a6d724a47464a2e41bf52003b2eea7f3c0 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 16 May 2024 20:39:12 -0400 Subject: [PATCH 02/12] add pyright Signed-off-by: Jinzhe Zeng --- .github/workflows/pyright.yml | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 .github/workflows/pyright.yml diff --git a/.github/workflows/pyright.yml b/.github/workflows/pyright.yml new file mode 100644 index 000000000..9b4732b94 --- /dev/null +++ b/.github/workflows/pyright.yml @@ -0,0 +1,19 @@ +on: + - push + - pull_request + +name: Type checker +jobs: + pyright: + name: pyright + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@master + - uses: actions/setup-python@v5 + with: + python-version: '3.12' + - run: pip install uv + - run: uv pip install --system -e . + - uses: jakebailey/pyright-action@v2 + with: + version: 1.1.363 From 665c4feda1960180e179717c02b7f9250fafa539 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 16 May 2024 20:40:35 -0400 Subject: [PATCH 03/12] enable TCH Signed-off-by: Jinzhe Zeng --- pyproject.toml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index ebb29663d..4f35bbab7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,6 +84,7 @@ select = [ "UP", # pyupgrade "I", # isort "TID253", # banned-module-level-imports + "TCH", # flake8-type-checking ] ignore = [ "E501", # line too long @@ -124,6 +125,9 @@ banned-module-level-imports = [ "scipy", ] +[tool.ruff.lint.isort] +required-imports = ["from __future__ import annotations"] + [tool.pyright] include = [ "dpdata/*.py", From 2021a01b59df54b5f0380d39c24dafed7b06408e Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 16 May 2024 20:42:36 -0400 Subject: [PATCH 04/12] run pre-commit Signed-off-by: Jinzhe Zeng --- benchmark/test_import.py | 2 + docs/conf.py | 2 + docs/make_format.py | 2 + dpdata/__about__.py | 2 + dpdata/__init__.py | 2 + dpdata/__main__.py | 2 + dpdata/abacus/md.py | 2 + dpdata/abacus/relax.py | 2 + dpdata/abacus/scf.py | 2 + dpdata/amber/mask.py | 1 + dpdata/amber/md.py | 2 + dpdata/amber/sqm.py | 2 + dpdata/ase_calculator.py | 10 +- dpdata/bond_order_system.py | 2 + dpdata/cli.py | 8 +- dpdata/cp2k/cell.py | 1 + dpdata/cp2k/output.py | 2 + dpdata/data_type.py | 13 ++- dpdata/deepmd/comp.py | 2 + dpdata/deepmd/mixed.py | 2 + dpdata/deepmd/raw.py | 2 + dpdata/dftbplus/output.py | 4 +- dpdata/driver.py | 11 +- dpdata/fhi_aims/output.py | 2 + dpdata/format.py | 2 +- dpdata/gaussian/gjf.py | 12 +- dpdata/gaussian/log.py | 2 + dpdata/gromacs/gro.py | 2 + dpdata/lammps/dump.py | 1 + dpdata/lammps/lmp.py | 1 + dpdata/md/msd.py | 2 + dpdata/md/pbc.py | 2 + dpdata/md/rdf.py | 2 + dpdata/md/water.py | 2 + dpdata/openmx/omx.py | 2 + dpdata/orca/output.py | 4 +- dpdata/periodic_table.py | 2 + dpdata/plugin.py | 1 + dpdata/plugins/3dmol.py | 4 +- dpdata/plugins/__init__.py | 2 + dpdata/plugins/abacus.py | 2 + dpdata/plugins/amber.py | 2 + dpdata/plugins/ase.py | 36 +++--- dpdata/plugins/cp2k.py | 2 + dpdata/plugins/dftbplus.py | 2 + dpdata/plugins/fhi_aims.py | 2 + dpdata/plugins/gaussian.py | 2 + dpdata/plugins/gromacs.py | 2 + dpdata/plugins/lammps.py | 2 + dpdata/plugins/list.py | 2 + dpdata/plugins/n2p2.py | 2 + dpdata/plugins/openmx.py | 2 + dpdata/plugins/orca.py | 2 + dpdata/plugins/psi4.py | 2 + dpdata/plugins/pwmat.py | 2 + dpdata/plugins/pymatgen.py | 2 + dpdata/plugins/qe.py | 2 + dpdata/plugins/rdkit.py | 2 + dpdata/plugins/siesta.py | 2 + dpdata/plugins/vasp.py | 2 + dpdata/plugins/xyz.py | 2 + dpdata/psi4/input.py | 7 +- dpdata/psi4/output.py | 4 +- dpdata/pwmat/atomconfig.py | 2 + dpdata/pwmat/movement.py | 2 + dpdata/pymatgen/molecule.py | 2 + dpdata/pymatgen/structure.py | 2 + dpdata/qe/scf.py | 1 + dpdata/qe/traj.py | 2 + dpdata/rdkit/sanitize.py | 2 + dpdata/rdkit/utils.py | 2 + dpdata/siesta/aiMD_output.py | 1 + dpdata/siesta/output.py | 1 + dpdata/stat.py | 4 +- dpdata/system.py | 109 ++++++++++-------- dpdata/unit.py | 2 + dpdata/utils.py | 14 ++- dpdata/vasp/outcar.py | 2 + dpdata/vasp/poscar.py | 1 + dpdata/vasp/xml.py | 1 + dpdata/xyz/quip_gap_xyz.py | 2 + dpdata/xyz/xyz.py | 4 +- plugin_example/dpdata_random/__init__.py | 2 + tests/comp_sys.py | 2 + tests/context.py | 2 + tests/plugin/dpdata_plugin_test/__init__.py | 2 + tests/poscars/poscar_ref_oh.py | 2 + tests/poscars/test_lammps_dump_s_su.py | 2 + tests/pwmat/config_ref_ch4.py | 2 + tests/pwmat/config_ref_oh.py | 2 + tests/test_abacus_md.py | 2 + tests/test_abacus_pw_scf.py | 2 + tests/test_abacus_relax.py | 2 + tests/test_abacus_stru_dump.py | 2 + tests/test_amber_md.py | 2 + tests/test_amber_sqm.py | 2 + tests/test_ase_traj.py | 2 + tests/test_bond_order_system.py | 2 + tests/test_cell_to_low_triangle.py | 2 + tests/test_cli.py | 2 + tests/test_corr.py | 2 + tests/test_cp2k_aimd_output.py | 2 + tests/test_cp2k_output.py | 2 + tests/test_custom_data_type.py | 2 + tests/test_deepmd_comp.py | 2 + tests/test_deepmd_hdf5.py | 2 + tests/test_deepmd_mixed.py | 2 + tests/test_deepmd_raw.py | 2 + tests/test_dftbplus.py | 2 + tests/test_elements_index.py | 2 + tests/test_empty.py | 2 + tests/test_fhi_md_multi_elem_output.py | 2 + tests/test_fhi_md_output.py | 2 + tests/test_fhi_output.py | 2 + tests/test_from_pymatgen.py | 2 + tests/test_gaussian_driver.py | 2 + tests/test_gaussian_gjf.py | 2 + tests/test_gaussian_log.py | 2 + tests/test_gromacs_gro.py | 2 + tests/test_json.py | 2 + tests/test_lammps_dump_idx.py | 1 + tests/test_lammps_dump_shift_origin.py | 2 + tests/test_lammps_dump_skipload.py | 2 + tests/test_lammps_dump_to_system.py | 2 + tests/test_lammps_dump_unfold.py | 2 + tests/test_lammps_lmp_dump.py | 2 + tests/test_lammps_lmp_to_system.py | 2 + tests/test_lammps_read_from_trajs.py | 2 + tests/test_msd.py | 2 + tests/test_multisystems.py | 2 + tests/test_n2p2.py | 2 + tests/test_openmx.py | 2 + tests/test_openmx_check_convergence.py | 2 + tests/test_orca_spout.py | 2 + tests/test_periodic_table.py | 2 + tests/test_perturb.py | 2 + tests/test_pick_atom_idx.py | 2 + tests/test_predict.py | 2 + tests/test_psi4.py | 2 + tests/test_pwmat_config_dump.py | 2 + tests/test_pwmat_config_to_system.py | 2 + tests/test_pwmat_mlmd.py | 2 + tests/test_pwmat_movement.py | 2 + tests/test_pymatgen_molecule.py | 2 + tests/test_qe_cp_traj.py | 2 + tests/test_qe_cp_traj_skipload.py | 2 + tests/test_qe_pw_scf.py | 2 + ...test_qe_pw_scf_crystal_atomic_positions.py | 2 + tests/test_qe_pw_scf_energy_bug.py | 2 + tests/test_quip_gap_xyz.py | 2 + tests/test_remove_atom_names.py | 2 + tests/test_remove_outlier.py | 2 + tests/test_remove_pbc.py | 2 + tests/test_replace.py | 2 + tests/test_replicate.py | 2 + tests/test_shuffle.py | 2 + tests/test_siesta_aiMD_output.py | 2 + tests/test_siesta_output.py | 2 + tests/test_split_dataset.py | 2 + tests/test_sqm_driver.py | 2 + tests/test_stat.py | 2 + tests/test_system_append.py | 2 + tests/test_system_apply_pbc.py | 2 + tests/test_system_set_type.py | 2 + tests/test_to_ase.py | 2 + tests/test_to_list.py | 2 + tests/test_to_pymatgen.py | 2 + tests/test_to_pymatgen_entry.py | 2 + tests/test_type_map.py | 2 + tests/test_vasp_outcar.py | 2 + tests/test_vasp_poscar_dump.py | 2 + tests/test_vasp_poscar_to_system.py | 2 + tests/test_vasp_unconverged_outcar.py | 2 + tests/test_vasp_xml.py | 2 + tests/test_water_ions.py | 2 + tests/test_xyz.py | 2 + 176 files changed, 445 insertions(+), 110 deletions(-) diff --git a/benchmark/test_import.py b/benchmark/test_import.py index 04d461375..846d72b25 100644 --- a/benchmark/test_import.py +++ b/benchmark/test_import.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import subprocess import sys diff --git a/docs/conf.py b/docs/conf.py index 3f897fc9e..e3c0b3d4c 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -11,6 +11,8 @@ # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. # +from __future__ import annotations + import os import subprocess as sp import sys diff --git a/docs/make_format.py b/docs/make_format.py index 8a7878f9d..e9c1f60d3 100644 --- a/docs/make_format.py +++ b/docs/make_format.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import csv import os from collections import defaultdict diff --git a/dpdata/__about__.py b/dpdata/__about__.py index d5cfca647..3ee47d3c2 100644 --- a/dpdata/__about__.py +++ b/dpdata/__about__.py @@ -1 +1,3 @@ +from __future__ import annotations + __version__ = "unknown" diff --git a/dpdata/__init__.py b/dpdata/__init__.py index 847554d38..f2cd233ff 100644 --- a/dpdata/__init__.py +++ b/dpdata/__init__.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from . import lammps, md, vasp from .bond_order_system import BondOrderSystem from .system import LabeledSystem, MultiSystems, System diff --git a/dpdata/__main__.py b/dpdata/__main__.py index aad1556fa..4c60f3f26 100644 --- a/dpdata/__main__.py +++ b/dpdata/__main__.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from dpdata.cli import dpdata_cli if __name__ == "__main__": diff --git a/dpdata/abacus/md.py b/dpdata/abacus/md.py index b96a0fd04..fa1841777 100644 --- a/dpdata/abacus/md.py +++ b/dpdata/abacus/md.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import warnings diff --git a/dpdata/abacus/relax.py b/dpdata/abacus/relax.py index fb3c8da0d..976243b82 100644 --- a/dpdata/abacus/relax.py +++ b/dpdata/abacus/relax.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import numpy as np diff --git a/dpdata/abacus/scf.py b/dpdata/abacus/scf.py index df50b010f..193e4d4b5 100644 --- a/dpdata/abacus/scf.py +++ b/dpdata/abacus/scf.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import re import warnings diff --git a/dpdata/amber/mask.py b/dpdata/amber/mask.py index e3ae1e8da..cd3cb728e 100644 --- a/dpdata/amber/mask.py +++ b/dpdata/amber/mask.py @@ -1,4 +1,5 @@ """Amber mask.""" +from __future__ import annotations try: import parmed diff --git a/dpdata/amber/md.py b/dpdata/amber/md.py index 912401213..f3217fbd9 100644 --- a/dpdata/amber/md.py +++ b/dpdata/amber/md.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import re diff --git a/dpdata/amber/sqm.py b/dpdata/amber/sqm.py index 5dcbf9955..1be3802a2 100644 --- a/dpdata/amber/sqm.py +++ b/dpdata/amber/sqm.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np from dpdata.periodic_table import ELEMENTS diff --git a/dpdata/ase_calculator.py b/dpdata/ase_calculator.py index 3b5868c08..1de760a5a 100644 --- a/dpdata/ase_calculator.py +++ b/dpdata/ase_calculator.py @@ -1,4 +1,6 @@ -from typing import TYPE_CHECKING, List, Optional +from __future__ import annotations + +from typing import TYPE_CHECKING from ase.calculators.calculator import ( # noqa: TID253 Calculator, @@ -35,9 +37,9 @@ def __init__(self, driver: Driver, **kwargs) -> None: def calculate( self, - atoms: Optional["Atoms"] = None, - properties: List[str] = ["energy", "forces"], - system_changes: List[str] = all_changes, + atoms: Atoms | None = None, + properties: list[str] = ["energy", "forces"], + system_changes: list[str] = all_changes, ): """Run calculation with a driver. diff --git a/dpdata/bond_order_system.py b/dpdata/bond_order_system.py index cd8dca059..8d129bcde 100644 --- a/dpdata/bond_order_system.py +++ b/dpdata/bond_order_system.py @@ -1,5 +1,7 @@ # %% # Bond Order System +from __future__ import annotations + from copy import deepcopy import numpy as np diff --git a/dpdata/cli.py b/dpdata/cli.py index 2e39d17d6..386707891 100644 --- a/dpdata/cli.py +++ b/dpdata/cli.py @@ -1,7 +1,7 @@ """Command line interface for dpdata.""" +from __future__ import annotations import argparse -from typing import Optional from . import __version__ from .system import LabeledSystem, MultiSystems, System @@ -59,11 +59,11 @@ def convert( *, from_file: str, from_format: str = "auto", - to_file: Optional[str] = None, - to_format: Optional[str] = None, + to_file: str | None = None, + to_format: str | None = None, no_labeled: bool = False, multi: bool = False, - type_map: Optional[list] = None, + type_map: list | None = None, **kwargs, ): """Convert files from one format to another one. diff --git a/dpdata/cp2k/cell.py b/dpdata/cp2k/cell.py index 7af73353e..a3021b815 100644 --- a/dpdata/cp2k/cell.py +++ b/dpdata/cp2k/cell.py @@ -1,4 +1,5 @@ # %% +from __future__ import annotations import numpy as np diff --git a/dpdata/cp2k/output.py b/dpdata/cp2k/output.py index c84355c46..bd827595e 100644 --- a/dpdata/cp2k/output.py +++ b/dpdata/cp2k/output.py @@ -1,4 +1,6 @@ # %% +from __future__ import annotations + import math import re from collections import OrderedDict diff --git a/dpdata/data_type.py b/dpdata/data_type.py index fdee615e2..752193fef 100644 --- a/dpdata/data_type.py +++ b/dpdata/data_type.py @@ -1,10 +1,11 @@ +from __future__ import annotations + from enum import Enum, unique -from typing import TYPE_CHECKING, Optional, Tuple, Union, Union +from typing import TYPE_CHECKING import numpy as np -from dpdata.bond_order_system import BondOrderSystem -from dpdata.bond_order_system import BondOrderSystem +from dpdata.bond_order_system import BondOrderSystem from dpdata.plugin import Plugin if TYPE_CHECKING: @@ -52,7 +53,7 @@ def __init__( self, name: str, dtype: type, - shape: Optional[Tuple[Union[int, Axis], ...]] = None, + shape: tuple[int | Axis, ...] | None = None, required: bool = True, ) -> None: self.name = name @@ -60,7 +61,7 @@ def __init__( self.shape = shape self.required = required - def real_shape(self, system: "System") -> Tuple[int]: + def real_shape(self, system: System) -> tuple[int]: """Returns expected real shape of a system.""" assert self.shape is not None shape = [] @@ -83,7 +84,7 @@ def real_shape(self, system: "System") -> Tuple[int]: raise RuntimeError("Shape is not an int!") return tuple(shape) - def check(self, system: "System"): + def check(self, system: System): """Check if a system has correct data of this type. Parameters diff --git a/dpdata/deepmd/comp.py b/dpdata/deepmd/comp.py index 7b909b162..ab0044477 100644 --- a/dpdata/deepmd/comp.py +++ b/dpdata/deepmd/comp.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import glob import os import shutil diff --git a/dpdata/deepmd/mixed.py b/dpdata/deepmd/mixed.py index 0d0ad89d9..b25107dbc 100644 --- a/dpdata/deepmd/mixed.py +++ b/dpdata/deepmd/mixed.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import glob import os import shutil diff --git a/dpdata/deepmd/raw.py b/dpdata/deepmd/raw.py index c7a64ec47..e772714a1 100644 --- a/dpdata/deepmd/raw.py +++ b/dpdata/deepmd/raw.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import warnings diff --git a/dpdata/dftbplus/output.py b/dpdata/dftbplus/output.py index ba8f6c840..0f10c3ac9 100644 --- a/dpdata/dftbplus/output.py +++ b/dpdata/dftbplus/output.py @@ -1,9 +1,9 @@ -from typing import Tuple +from __future__ import annotations import numpy as np -def read_dftb_plus(fn_1: str, fn_2: str) -> Tuple[str, np.ndarray, float, np.ndarray]: +def read_dftb_plus(fn_1: str, fn_2: str) -> tuple[str, np.ndarray, float, np.ndarray]: """Read from DFTB+ input and output. Parameters diff --git a/dpdata/driver.py b/dpdata/driver.py index 56ad94384..9a196b2e7 100644 --- a/dpdata/driver.py +++ b/dpdata/driver.py @@ -1,7 +1,8 @@ """Driver plugin system.""" +from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Callable, List, Type, Union +from typing import TYPE_CHECKING, Callable from .plugin import Plugin @@ -43,7 +44,7 @@ def register(key: str) -> Callable: return Driver.__DriverPlugin.register(key) @staticmethod - def get_driver(key: str) -> Type["Driver"]: + def get_driver(key: str) -> type[Driver]: """Get a driver plugin. Parameters @@ -97,7 +98,7 @@ def label(self, data: dict) -> dict: return NotImplemented @property - def ase_calculator(self) -> "ase.calculators.calculator.Calculator": + def ase_calculator(self) -> ase.calculators.calculator.Calculator: """Returns an ase calculator based on this driver.""" from .ase_calculator import DPDataCalculator @@ -130,7 +131,7 @@ class HybridDriver(Driver): This driver is the hybrid of SQM and DP. """ - def __init__(self, drivers: List[Union[dict, Driver]]) -> None: + def __init__(self, drivers: list[dict | Driver]) -> None: self.drivers = [] for driver in drivers: if isinstance(driver, Driver): @@ -200,7 +201,7 @@ def register(key: str) -> Callable: return Minimizer.__MinimizerPlugin.register(key) @staticmethod - def get_minimizer(key: str) -> Type["Minimizer"]: + def get_minimizer(key: str) -> type[Minimizer]: """Get a minimizer plugin. Parameters diff --git a/dpdata/fhi_aims/output.py b/dpdata/fhi_aims/output.py index 9947a231a..762e8bf4d 100755 --- a/dpdata/fhi_aims/output.py +++ b/dpdata/fhi_aims/output.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import re import warnings diff --git a/dpdata/format.py b/dpdata/format.py index 777c9780f..8277df191 100644 --- a/dpdata/format.py +++ b/dpdata/format.py @@ -1,8 +1,8 @@ """Implement the format plugin system.""" +from __future__ import annotations import os from abc import ABC -from typing import Callable from .plugin import Plugin diff --git a/dpdata/gaussian/gjf.py b/dpdata/gaussian/gjf.py index 90aaf2f04..37e2897a7 100644 --- a/dpdata/gaussian/gjf.py +++ b/dpdata/gaussian/gjf.py @@ -2,19 +2,19 @@ # https://github.com/deepmodeling/dpgen/blob/0767dce7cad29367edb2e4a55fd0d8724dbda642/dpgen/generator/lib/gaussian.py#L1-L190 # under LGPL 3.0 license """Generate Gaussian input file.""" +from __future__ import annotations import itertools import re import uuid import warnings -from typing import List, Optional, Tuple, Union import numpy as np from dpdata.periodic_table import Element -def _crd2frag(symbols: List[str], crds: np.ndarray) -> Tuple[int, List[int]]: +def _crd2frag(symbols: list[str], crds: np.ndarray) -> tuple[int, list[int]]: """Detect fragments from coordinates. Parameters @@ -102,12 +102,12 @@ def detect_multiplicity(symbols: np.ndarray) -> int: def make_gaussian_input( sys_data: dict, - keywords: Union[str, List[str]], - multiplicity: Union[str, int] = "auto", + keywords: str | list[str], + multiplicity: str | int = "auto", charge: int = 0, fragment_guesses: bool = False, - basis_set: Optional[str] = None, - keywords_high_multiplicity: Optional[str] = None, + basis_set: str | None = None, + keywords_high_multiplicity: str | None = None, nproc: int = 1, ) -> str: """Make gaussian input file. diff --git a/dpdata/gaussian/log.py b/dpdata/gaussian/log.py index 66881dc1a..204cf464c 100644 --- a/dpdata/gaussian/log.py +++ b/dpdata/gaussian/log.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np from ..periodic_table import ELEMENTS diff --git a/dpdata/gromacs/gro.py b/dpdata/gromacs/gro.py index b643eea86..aca2443b8 100644 --- a/dpdata/gromacs/gro.py +++ b/dpdata/gromacs/gro.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +from __future__ import annotations + import re import numpy as np diff --git a/dpdata/lammps/dump.py b/dpdata/lammps/dump.py index 906fed9ee..f0ade2b03 100644 --- a/dpdata/lammps/dump.py +++ b/dpdata/lammps/dump.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +from __future__ import annotations import os import sys diff --git a/dpdata/lammps/lmp.py b/dpdata/lammps/lmp.py index 317b30ed4..604b18d12 100644 --- a/dpdata/lammps/lmp.py +++ b/dpdata/lammps/lmp.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +from __future__ import annotations import numpy as np diff --git a/dpdata/md/msd.py b/dpdata/md/msd.py index cfb446dde..dfad95507 100644 --- a/dpdata/md/msd.py +++ b/dpdata/md/msd.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np from .pbc import system_pbc_shift diff --git a/dpdata/md/pbc.py b/dpdata/md/pbc.py index 4eee7c654..e57576615 100644 --- a/dpdata/md/pbc.py +++ b/dpdata/md/pbc.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np diff --git a/dpdata/md/rdf.py b/dpdata/md/rdf.py index de8f1c746..b41be525b 100644 --- a/dpdata/md/rdf.py +++ b/dpdata/md/rdf.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np diff --git a/dpdata/md/water.py b/dpdata/md/water.py index 0cb82cc99..cda4ad48b 100644 --- a/dpdata/md/water.py +++ b/dpdata/md/water.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np from .pbc import posi_diff, posi_shift diff --git a/dpdata/openmx/omx.py b/dpdata/openmx/omx.py index bd4b7031e..d3afff00f 100644 --- a/dpdata/openmx/omx.py +++ b/dpdata/openmx/omx.py @@ -1,4 +1,6 @@ #!/usr/bin/python3 +from __future__ import annotations + import numpy as np from ..unit import ( diff --git a/dpdata/orca/output.py b/dpdata/orca/output.py index 13f072f30..183c3c85c 100644 --- a/dpdata/orca/output.py +++ b/dpdata/orca/output.py @@ -1,9 +1,9 @@ -from typing import Tuple +from __future__ import annotations import numpy as np -def read_orca_sp_output(fn: str) -> Tuple[np.ndarray, np.ndarray, float, np.ndarray]: +def read_orca_sp_output(fn: str) -> tuple[np.ndarray, np.ndarray, float, np.ndarray]: """Read from ORCA output. Note that both the energy and the gradient should be printed. diff --git a/dpdata/periodic_table.py b/dpdata/periodic_table.py index 6df1fd410..e6b56cb0b 100644 --- a/dpdata/periodic_table.py +++ b/dpdata/periodic_table.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json from pathlib import Path diff --git a/dpdata/plugin.py b/dpdata/plugin.py index 20e51eb2d..b725f4eb1 100644 --- a/dpdata/plugin.py +++ b/dpdata/plugin.py @@ -1,4 +1,5 @@ """Base of plugin systems.""" +from __future__ import annotations class Plugin: diff --git a/dpdata/plugins/3dmol.py b/dpdata/plugins/3dmol.py index ec994dd9b..56ec25161 100644 --- a/dpdata/plugins/3dmol.py +++ b/dpdata/plugins/3dmol.py @@ -1,4 +1,4 @@ -from typing import Tuple +from __future__ import annotations import numpy as np @@ -17,7 +17,7 @@ def to_system( self, data: dict, f_idx: int = 0, - size: Tuple[int] = (300, 300), + size: tuple[int] = (300, 300), style: dict = {"stick": {}, "sphere": {"radius": 0.4}}, **kwargs, ): diff --git a/dpdata/plugins/__init__.py b/dpdata/plugins/__init__.py index 66364aa25..15634bc0a 100644 --- a/dpdata/plugins/__init__.py +++ b/dpdata/plugins/__init__.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import importlib from pathlib import Path diff --git a/dpdata/plugins/abacus.py b/dpdata/plugins/abacus.py index 754221be0..eb2d7786f 100644 --- a/dpdata/plugins/abacus.py +++ b/dpdata/plugins/abacus.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import dpdata.abacus.md import dpdata.abacus.relax import dpdata.abacus.scf diff --git a/dpdata/plugins/amber.py b/dpdata/plugins/amber.py index cdc92a30b..d991ce482 100644 --- a/dpdata/plugins/amber.py +++ b/dpdata/plugins/amber.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import subprocess as sp import tempfile diff --git a/dpdata/plugins/ase.py b/dpdata/plugins/ase.py index f3347c994..1d8184838 100644 --- a/dpdata/plugins/ase.py +++ b/dpdata/plugins/ase.py @@ -1,4 +1,6 @@ -from typing import TYPE_CHECKING, Optional, Type +from __future__ import annotations + +from typing import TYPE_CHECKING import numpy as np @@ -22,7 +24,7 @@ class ASEStructureFormat(Format): automatic detection fails. """ - def from_system(self, atoms: "ase.Atoms", **kwargs) -> dict: + def from_system(self, atoms: ase.Atoms, **kwargs) -> dict: """Convert ase.Atoms to a System. Parameters @@ -56,7 +58,7 @@ def from_system(self, atoms: "ase.Atoms", **kwargs) -> dict: } return info_dict - def from_labeled_system(self, atoms: "ase.Atoms", **kwargs) -> dict: + def from_labeled_system(self, atoms: ase.Atoms, **kwargs) -> dict: """Convert ase.Atoms to a LabeledSystem. Energies and forces are calculated by the calculator. @@ -103,12 +105,12 @@ def from_labeled_system(self, atoms: "ase.Atoms", **kwargs) -> dict: def from_multi_systems( self, file_name: str, - begin: Optional[int] = None, - end: Optional[int] = None, - step: Optional[int] = None, - ase_fmt: Optional[str] = None, + begin: int | None = None, + end: int | None = None, + step: int | None = None, + ase_fmt: str | None = None, **kwargs, - ) -> "ase.Atoms": + ) -> ase.Atoms: """Convert a ASE supported file to ASE Atoms. It will finally be converted to MultiSystems. @@ -195,9 +197,9 @@ class ASETrajFormat(Format): def from_system( self, file_name: str, - begin: Optional[int] = 0, - end: Optional[int] = None, - step: Optional[int] = 1, + begin: int | None = 0, + end: int | None = None, + step: int | None = 1, **kwargs, ) -> dict: """Read ASE's trajectory file to `System` of multiple frames. @@ -239,9 +241,9 @@ def from_system( def from_labeled_system( self, file_name: str, - begin: Optional[int] = 0, - end: Optional[int] = None, - step: Optional[int] = 1, + begin: int | None = 0, + end: int | None = None, + step: int | None = 1, **kwargs, ) -> dict: """Read ASE's trajectory file to `System` of multiple frames. @@ -309,7 +311,7 @@ class ASEDriver(Driver): ASE calculator """ - def __init__(self, calculator: "ase.calculators.calculator.Calculator") -> None: + def __init__(self, calculator: ase.calculators.calculator.Calculator) -> None: """Setup the driver.""" self.calculator = calculator @@ -361,9 +363,9 @@ class ASEMinimizer(Minimizer): def __init__( self, driver: Driver, - optimizer: Optional[Type["Optimizer"]] = None, + optimizer: type[Optimizer] | None = None, fmax: float = 5e-3, - max_steps: Optional[int] = None, + max_steps: int | None = None, optimizer_kwargs: dict = {}, ) -> None: self.calculator = driver.ase_calculator diff --git a/dpdata/plugins/cp2k.py b/dpdata/plugins/cp2k.py index 162098f70..f5c1b5394 100644 --- a/dpdata/plugins/cp2k.py +++ b/dpdata/plugins/cp2k.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import glob import dpdata.cp2k.output diff --git a/dpdata/plugins/dftbplus.py b/dpdata/plugins/dftbplus.py index 5c8b46828..247fedc9e 100644 --- a/dpdata/plugins/dftbplus.py +++ b/dpdata/plugins/dftbplus.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np from dpdata.dftbplus.output import read_dftb_plus diff --git a/dpdata/plugins/fhi_aims.py b/dpdata/plugins/fhi_aims.py index 45b181fc0..3c198aff6 100644 --- a/dpdata/plugins/fhi_aims.py +++ b/dpdata/plugins/fhi_aims.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import dpdata.fhi_aims.output from dpdata.format import Format diff --git a/dpdata/plugins/gaussian.py b/dpdata/plugins/gaussian.py index a22ce8630..80cfa4076 100644 --- a/dpdata/plugins/gaussian.py +++ b/dpdata/plugins/gaussian.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import subprocess as sp import tempfile diff --git a/dpdata/plugins/gromacs.py b/dpdata/plugins/gromacs.py index 20e508355..12dece718 100644 --- a/dpdata/plugins/gromacs.py +++ b/dpdata/plugins/gromacs.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import dpdata.gromacs.gro from dpdata.format import Format diff --git a/dpdata/plugins/lammps.py b/dpdata/plugins/lammps.py index be89be9d0..65e7f5701 100644 --- a/dpdata/plugins/lammps.py +++ b/dpdata/plugins/lammps.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import dpdata.lammps.dump import dpdata.lammps.lmp from dpdata.format import Format diff --git a/dpdata/plugins/list.py b/dpdata/plugins/list.py index 68a140748..f70368836 100644 --- a/dpdata/plugins/list.py +++ b/dpdata/plugins/list.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from dpdata.format import Format diff --git a/dpdata/plugins/n2p2.py b/dpdata/plugins/n2p2.py index 7162f09fa..b70d6e6fb 100644 --- a/dpdata/plugins/n2p2.py +++ b/dpdata/plugins/n2p2.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np from dpdata.format import Format diff --git a/dpdata/plugins/openmx.py b/dpdata/plugins/openmx.py index 675d1d2c1..4e16566dc 100644 --- a/dpdata/plugins/openmx.py +++ b/dpdata/plugins/openmx.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import dpdata.md.pbc import dpdata.openmx.omx from dpdata.format import Format diff --git a/dpdata/plugins/orca.py b/dpdata/plugins/orca.py index 2585743e1..3d7fa38a7 100644 --- a/dpdata/plugins/orca.py +++ b/dpdata/plugins/orca.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np from dpdata.format import Format diff --git a/dpdata/plugins/psi4.py b/dpdata/plugins/psi4.py index ec7d9df1b..c3b1ee1b8 100644 --- a/dpdata/plugins/psi4.py +++ b/dpdata/plugins/psi4.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np from dpdata.format import Format diff --git a/dpdata/plugins/pwmat.py b/dpdata/plugins/pwmat.py index 11257c4d0..80f219b6c 100644 --- a/dpdata/plugins/pwmat.py +++ b/dpdata/plugins/pwmat.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np import dpdata.pwmat.atomconfig diff --git a/dpdata/plugins/pymatgen.py b/dpdata/plugins/pymatgen.py index e7e527ff7..322298c3c 100644 --- a/dpdata/plugins/pymatgen.py +++ b/dpdata/plugins/pymatgen.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np import dpdata.pymatgen.molecule diff --git a/dpdata/plugins/qe.py b/dpdata/plugins/qe.py index 6a98eedd8..682bb202e 100644 --- a/dpdata/plugins/qe.py +++ b/dpdata/plugins/qe.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import dpdata.md.pbc import dpdata.qe.scf import dpdata.qe.traj diff --git a/dpdata/plugins/rdkit.py b/dpdata/plugins/rdkit.py index c7cef07fc..f01b277d6 100644 --- a/dpdata/plugins/rdkit.py +++ b/dpdata/plugins/rdkit.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import dpdata.rdkit.utils from dpdata.format import Format diff --git a/dpdata/plugins/siesta.py b/dpdata/plugins/siesta.py index 662b5c0e0..906eeb51f 100644 --- a/dpdata/plugins/siesta.py +++ b/dpdata/plugins/siesta.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import dpdata.siesta.aiMD_output import dpdata.siesta.output from dpdata.format import Format diff --git a/dpdata/plugins/vasp.py b/dpdata/plugins/vasp.py index c182bb956..d0681cebf 100644 --- a/dpdata/plugins/vasp.py +++ b/dpdata/plugins/vasp.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np import dpdata.vasp.outcar diff --git a/dpdata/plugins/xyz.py b/dpdata/plugins/xyz.py index fdb5bf3b1..322bf77cb 100644 --- a/dpdata/plugins/xyz.py +++ b/dpdata/plugins/xyz.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np from dpdata.format import Format diff --git a/dpdata/psi4/input.py b/dpdata/psi4/input.py index ad0532817..3959cb753 100644 --- a/dpdata/psi4/input.py +++ b/dpdata/psi4/input.py @@ -1,4 +1,9 @@ -import numpy as np +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import numpy as np # Angston is used in Psi4 by default template = """molecule {{ diff --git a/dpdata/psi4/output.py b/dpdata/psi4/output.py index e93858de8..9ccf90e18 100644 --- a/dpdata/psi4/output.py +++ b/dpdata/psi4/output.py @@ -1,11 +1,11 @@ -from typing import Tuple +from __future__ import annotations import numpy as np from dpdata.unit import LengthConversion -def read_psi4_output(fn: str) -> Tuple[str, np.ndarray, float, np.ndarray]: +def read_psi4_output(fn: str) -> tuple[str, np.ndarray, float, np.ndarray]: """Read from Psi4 output. Note that both the energy and the gradient should be printed. diff --git a/dpdata/pwmat/atomconfig.py b/dpdata/pwmat/atomconfig.py index f128aa5f8..62eff77ca 100644 --- a/dpdata/pwmat/atomconfig.py +++ b/dpdata/pwmat/atomconfig.py @@ -1,4 +1,6 @@ #!/usr/bin/python3 +from __future__ import annotations + import numpy as np from ..periodic_table import ELEMENTS diff --git a/dpdata/pwmat/movement.py b/dpdata/pwmat/movement.py index 748744d6d..ccfd819db 100644 --- a/dpdata/pwmat/movement.py +++ b/dpdata/pwmat/movement.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import warnings import numpy as np diff --git a/dpdata/pymatgen/molecule.py b/dpdata/pymatgen/molecule.py index fc05b07ad..8d397984a 100644 --- a/dpdata/pymatgen/molecule.py +++ b/dpdata/pymatgen/molecule.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from collections import Counter import numpy as np diff --git a/dpdata/pymatgen/structure.py b/dpdata/pymatgen/structure.py index 9f47baee8..36e411c02 100644 --- a/dpdata/pymatgen/structure.py +++ b/dpdata/pymatgen/structure.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np diff --git a/dpdata/qe/scf.py b/dpdata/qe/scf.py index cd9c6f283..37e5fbab6 100755 --- a/dpdata/qe/scf.py +++ b/dpdata/qe/scf.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +from __future__ import annotations import os diff --git a/dpdata/qe/traj.py b/dpdata/qe/traj.py index e27990cbe..1fbf0f71c 100644 --- a/dpdata/qe/traj.py +++ b/dpdata/qe/traj.py @@ -1,4 +1,6 @@ #!/usr/bin/python3 +from __future__ import annotations + import warnings import numpy as np diff --git a/dpdata/rdkit/sanitize.py b/dpdata/rdkit/sanitize.py index 45060abc2..2b0d76634 100644 --- a/dpdata/rdkit/sanitize.py +++ b/dpdata/rdkit/sanitize.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import time from copy import deepcopy diff --git a/dpdata/rdkit/utils.py b/dpdata/rdkit/utils.py index 9c7e50afb..efeef6070 100644 --- a/dpdata/rdkit/utils.py +++ b/dpdata/rdkit/utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np diff --git a/dpdata/siesta/aiMD_output.py b/dpdata/siesta/aiMD_output.py index 4e1890ecc..daa4f6a25 100644 --- a/dpdata/siesta/aiMD_output.py +++ b/dpdata/siesta/aiMD_output.py @@ -1,4 +1,5 @@ # !/usr/bin/python3 +from __future__ import annotations import numpy as np diff --git a/dpdata/siesta/output.py b/dpdata/siesta/output.py index 7418d5433..0c944d5b5 100644 --- a/dpdata/siesta/output.py +++ b/dpdata/siesta/output.py @@ -1,4 +1,5 @@ #!/usr/bin/python3 +from __future__ import annotations import numpy as np diff --git a/dpdata/stat.py b/dpdata/stat.py index 62b10c468..5ec395708 100644 --- a/dpdata/stat.py +++ b/dpdata/stat.py @@ -1,4 +1,6 @@ -from abc import ABCMeta, abstractmethod, abstractmethod, abstractproperty +from __future__ import annotations + +from abc import ABCMeta, abstractmethod from functools import lru_cache import numpy as np diff --git a/dpdata/system.py b/dpdata/system.py index 198f023df..e119f1e4f 100644 --- a/dpdata/system.py +++ b/dpdata/system.py @@ -1,16 +1,25 @@ # %% +from __future__ import annotations + import glob import hashlib import numbers import os +import sys import warnings from copy import deepcopy -import sys -from typing import TYPE_CHECKING, TYPE_CHECKING, Any, Dict, Iterable, List, Literal, Optional, Tuple, Type, Union, overload +from typing import ( + TYPE_CHECKING, + Any, + Iterable, + Literal, + overload, +) + if sys.version_info < (3, 8): - from typing_extensions import TypedDict + pass else: - from typing import TypedDict + pass import numpy as np @@ -78,7 +87,7 @@ class System: data types of this class """ - DTYPES: Tuple[DataType, ...] = ( + DTYPES: tuple[DataType, ...] = ( DataType("atom_numbs", list, (Axis.NTYPES,)), DataType("atom_names", list, (Axis.NTYPES,)), DataType("atom_types", np.ndarray, (Axis.NATOMS,)), @@ -97,10 +106,10 @@ def __init__( # some formats do not use string as input file_name: Any=None, fmt: str="auto", - type_map:Optional[list[str]]=None, + type_map:list[str] | None=None, begin:int=0, step:int=1, - data: Optional[Dict[str, Any]]=None, + data: dict[str, Any] | None=None, convergence_check: bool=True, **kwargs, ): @@ -242,7 +251,7 @@ def from_fmt_obj(self, fmtobj: Format, file_name: Any, **kwargs: Any): self.post_funcs.get_plugin(post_f)(self) return self - def to(self, fmt: str, *args: Any, **kwargs: Any) -> "System": + def to(self, fmt: str, *args: Any, **kwargs: Any) -> System: """Dump systems to the specific format. Parameters @@ -280,13 +289,13 @@ def __str__(self): return ret @overload - def __getitem__(self, key: Union[int, slice, list, np.ndarray]) -> "System": + def __getitem__(self, key: int | slice | list | np.ndarray) -> System: ... @overload - def __getitem__(self, key: Literal["atom_names", "real_atom_names"]) -> List[str]: + def __getitem__(self, key: Literal["atom_names", "real_atom_names"]) -> list[str]: ... @overload - def __getitem__(self, key: Literal["atom_numbs"]) -> List[int]: + def __getitem__(self, key: Literal["atom_numbs"]) -> list[int]: ... @overload def __getitem__(self, key: Literal["nopbc"]) -> bool: @@ -330,7 +339,7 @@ def dump(self, filename: str, indent: int=4): dumpfn(self.as_dict(), filename, indent=indent) - def map_atom_types(self, type_map: Optional[Union[Dict[str, int], List[str]]]=None) -> np.ndarray: + def map_atom_types(self, type_map: dict[str, int] | list[str] | None=None) -> np.ndarray: """Map the atom types of the system. Parameters @@ -396,7 +405,7 @@ def as_dict(self) -> dict: } return d - def get_atom_names(self) -> List[str]: + def get_atom_names(self) -> list[str]: """Returns name of atoms.""" return self.data["atom_names"] @@ -404,7 +413,7 @@ def get_atom_types(self) -> np.ndarray: """Returns type of atoms.""" return self.data["atom_types"] - def get_atom_numbs(self) -> List[int]: + def get_atom_numbs(self) -> list[int]: """Returns number of atoms.""" return self.data["atom_numbs"] @@ -424,7 +433,7 @@ def copy(self): """Returns a copy of the system.""" return self.__class__.from_dict({"data": deepcopy(self.data)}) - def sub_system(self, f_idx: numbers.Integral) -> "System": + def sub_system(self, f_idx: numbers.Integral) -> System: """Construct a subsystem from the system. Parameters @@ -447,7 +456,7 @@ def sub_system(self, f_idx: numbers.Integral) -> "System": continue if tt.shape is not None and Axis.NFRAMES in tt.shape: axis_nframes = tt.shape.index(Axis.NFRAMES) - new_shape: List[Union[slice, np.ndarray]] = [slice(None) for _ in self.data[tt.name].shape] + new_shape: list[slice | np.ndarray] = [slice(None) for _ in self.data[tt.name].shape] new_shape[axis_nframes] = f_idx tmp.data[tt.name] = self.data[tt.name][tuple(new_shape)] else: @@ -455,7 +464,7 @@ def sub_system(self, f_idx: numbers.Integral) -> "System": tmp.data[tt.name] = self.data[tt.name] return tmp - def append(self, system: "System") -> bool: + def append(self, system: System) -> bool: """Append a system to this system. Parameters @@ -511,7 +520,7 @@ def append(self, system: "System") -> bool: self.data["nopbc"] = False return True - def convert_to_mixed_type(self, type_map:Optional[List[str]]=None): + def convert_to_mixed_type(self, type_map:list[str] | None=None): """Convert the data dict to mixed type format structure, in order to append systems with different formula but the same number of atoms. Change the 'atom_names' to one placeholder type 'MIXED_TOKEN' and add 'real_atom_types' to store the real type @@ -537,7 +546,7 @@ def convert_to_mixed_type(self, type_map:Optional[List[str]]=None): self.data["atom_numbs"] = [natoms] self.data["atom_names"] = ["MIXED_TOKEN"] - def sort_atom_names(self, type_map:Optional[List[str]]=None): + def sort_atom_names(self, type_map:list[str] | None=None): """Sort atom_names of the system and reorder atom_numbs and atom_types accoarding to atom_names. If type_map is not given, atom_names will be sorted by alphabetical order. If type_map is given, atom_names will be type_map. @@ -549,7 +558,7 @@ def sort_atom_names(self, type_map:Optional[List[str]]=None): """ self.data = sort_atom_names(self.data, type_map=type_map) - def check_type_map(self, type_map:Optional[List[str]]): + def check_type_map(self, type_map:list[str] | None): """Assign atom_names to type_map if type_map is given and different from atom_names. @@ -561,7 +570,7 @@ def check_type_map(self, type_map:Optional[List[str]]): if type_map is not None and type_map != self.data["atom_names"]: self.sort_atom_names(type_map=type_map) - def apply_type_map(self, type_map: List[str]): + def apply_type_map(self, type_map: list[str]): """Customize the element symbol order and it should maintain order consistency in dpgen or deepmd-kit. It is especially recommended for multiple complexsystems with multiple elements. @@ -591,7 +600,7 @@ def sort_atom_types(self) -> np.ndarray: continue if tt.shape is not None and Axis.NATOMS in tt.shape: axis_natoms = tt.shape.index(Axis.NATOMS) - new_shape: List[Union[slice, np.ndarray]] = [slice(None) for _ in self.data[tt.name].shape] + new_shape: list[slice | np.ndarray] = [slice(None) for _ in self.data[tt.name].shape] new_shape[axis_natoms] = idx self.data[tt.name] = self.data[tt.name][tuple(new_shape)] return idx @@ -659,7 +668,7 @@ def short_name(self) -> str: return short_formula return self.formula_hash - def extend(self, systems: Iterable["System"]): + def extend(self, systems: Iterable[System]): """Extend a system list to this system. Parameters @@ -727,11 +736,11 @@ def rot_frame_lower_triangular(self, f_idx: numbers.Integral=0): self.affine_map(rot, f_idx=f_idx) return np.matmul(qq, rot) - def add_atom_names(self, atom_names: List[str]): + def add_atom_names(self, atom_names: list[str]): """Add atom_names that do not exist.""" self.data = add_atom_names(self.data, atom_names) - def replicate(self, ncopy: Union[List[int], Tuple[int, int, int]]): + def replicate(self, ncopy: list[int] | tuple[int, int, int]): """Replicate the each frame in the system in 3 dimensions. Each frame in the system will become a supercell. @@ -905,7 +914,7 @@ def shuffle(self): self.data = self.sub_system(idx).data return idx - def predict(self, *args: Any, driver: Union[str, Driver] = "dp", **kwargs: Any) -> "LabeledSystem": + def predict(self, *args: Any, driver: str | Driver = "dp", **kwargs: Any) -> LabeledSystem: """Predict energies and forces by a driver. Parameters @@ -934,8 +943,8 @@ def predict(self, *args: Any, driver: Union[str, Driver] = "dp", **kwargs: Any) return LabeledSystem(data=data) def minimize( - self, *args: Any, minimizer: Union[str, Minimizer], **kwargs: Any - ) -> "LabeledSystem": + self, *args: Any, minimizer: str | Minimizer, **kwargs: Any + ) -> LabeledSystem: """Minimize the geometry. Parameters @@ -957,7 +966,7 @@ def minimize( data = minimizer.minimize(self.data.copy()) return LabeledSystem(data=data) - def pick_atom_idx(self, idx: numbers.Integral, nopbc: Optional[bool]=None): + def pick_atom_idx(self, idx: numbers.Integral, nopbc: bool | None=None): """Pick atom index. Parameters @@ -981,7 +990,7 @@ def pick_atom_idx(self, idx: numbers.Integral, nopbc: Optional[bool]=None): continue if tt.shape is not None and Axis.NATOMS in tt.shape: axis_natoms = tt.shape.index(Axis.NATOMS) - new_shape: List[Union[slice, np.ndarray]] = [slice(None) for _ in self.data[tt.name].shape] + new_shape: list[slice | np.ndarray] = [slice(None) for _ in self.data[tt.name].shape] new_shape[axis_natoms] = idx new_sys.data[tt.name] = self.data[tt.name][tuple(new_shape)] # recalculate atom_numbs according to atom_types @@ -993,7 +1002,7 @@ def pick_atom_idx(self, idx: numbers.Integral, nopbc: Optional[bool]=None): new_sys.nopbc = nopbc return new_sys - def remove_atom_names(self, atom_names: Union[str, Iterable[str]]): + def remove_atom_names(self, atom_names: str | Iterable[str]): """Remove atom names and all such atoms. For example, you may not remove EP atoms in TIP4P/Ew water, which is not a real atom. @@ -1019,7 +1028,7 @@ def remove_atom_names(self, atom_names: Union[str, Iterable[str]]): new_sys.data["atom_numbs"] = new_sys.data["atom_numbs"][: len(new_atom_names)] return new_sys - def pick_by_amber_mask(self, param: Union[str, "parmed.Structure"], maskstr: str, pass_coords: bool=False, nopbc: Optional[bool]=None): + def pick_by_amber_mask(self, param: str | parmed.Structure, maskstr: str, pass_coords: bool=False, nopbc: bool | None=None): """Pick atoms by amber mask. Parameters @@ -1154,7 +1163,7 @@ class LabeledSystem(System): The number of skipped frames when loading MD trajectory. """ - DTYPES: Tuple[DataType, ...] = System.DTYPES + ( + DTYPES: tuple[DataType, ...] = System.DTYPES + ( DataType("energies", np.ndarray, (Axis.NFRAMES,)), DataType("forces", np.ndarray, (Axis.NFRAMES, Axis.NATOMS, 3)), DataType("virials", np.ndarray, (Axis.NFRAMES, 3, 3), required=False), @@ -1226,7 +1235,7 @@ def rot_frame_lower_triangular(self, f_idx: numbers.Integral=0): self.affine_map_fv(trans, f_idx=f_idx) return trans - def correction(self, hl_sys: "LabeledSystem") -> "LabeledSystem": + def correction(self, hl_sys: LabeledSystem) -> LabeledSystem: """Get energy and force correction between self and a high-level LabeledSystem. The self's coordinates will be kept, but energy and forces will be replaced by the correction between these two systems. @@ -1255,7 +1264,7 @@ def correction(self, hl_sys: "LabeledSystem") -> "LabeledSystem": ) return corrected_sys - def remove_outlier(self, threshold: float = 8.0) -> "LabeledSystem": + def remove_outlier(self, threshold: float = 8.0) -> LabeledSystem: r"""Remove outlier frames from the system. Remove the frames whose energies satisfy the condition @@ -1306,11 +1315,11 @@ def __init__(self, *systems, type_map=None): type_map : list of str Maps atom type to name """ - self.systems: Dict[str, System] = {} + self.systems: dict[str, System] = {} if type_map is not None: - self.atom_names: List[str] = type_map + self.atom_names: list[str] = type_map else: - self.atom_names: List[str] = [] + self.atom_names: list[str] = [] self.append(*systems) def from_fmt_obj(self, fmtobj: Format, directory, labeled:bool=True, **kwargs: Any): @@ -1356,7 +1365,7 @@ def to_fmt_obj(self, fmtobj: Format, directory, *args: Any, **kwargs: Any): ) return self - def to(self, fmt: str, *args: Any, **kwargs: Any) -> "MultiSystems": + def to(self, fmt: str, *args: Any, **kwargs: Any) -> MultiSystems: """Dump systems to the specific format. Parameters @@ -1406,7 +1415,7 @@ def from_file(cls, file_name, fmt: str, **kwargs: Any): return multi_systems @classmethod - def from_dir(cls, dir_name: str, file_name: str, fmt: str="auto", type_map: Optional[List[str]]=None): + def from_dir(cls, dir_name: str, file_name: str, fmt: str="auto", type_map: list[str] | None=None): multi_systems = cls() target_file_list = sorted( glob.glob(f"./{dir_name}/**/{file_name}", recursive=True) @@ -1417,7 +1426,7 @@ def from_dir(cls, dir_name: str, file_name: str, fmt: str="auto", type_map: Opti ) return multi_systems - def load_systems_from_file(self, file_name=None, fmt: Optional[str]=None, **kwargs): + def load_systems_from_file(self, file_name=None, fmt: str | None=None, **kwargs): assert fmt is not None fmt = fmt.lower() return self.from_fmt_obj(load_format(fmt), file_name, **kwargs) @@ -1426,7 +1435,7 @@ def get_nframes(self) -> int: """Returns number of frames in all systems.""" return sum(len(system) for system in self.systems.values()) - def append(self, *systems: Union[System, "MultiSystems"]): + def append(self, *systems: System | MultiSystems): """Append systems or MultiSystems to systems. Parameters @@ -1476,7 +1485,7 @@ def check_atom_names(self, system: System): system.add_atom_names(new_in_self) system.sort_atom_names(type_map=self.atom_names) - def predict(self, *args: Any, driver: Union[str, Driver]="dp", **kwargs: Any) -> "MultiSystems": + def predict(self, *args: Any, driver: str | Driver="dp", **kwargs: Any) -> MultiSystems: """Predict energies and forces by a driver. Parameters @@ -1501,8 +1510,8 @@ def predict(self, *args: Any, driver: Union[str, Driver]="dp", **kwargs: Any) -> return new_multisystems def minimize( - self, *args: Any, minimizer: Union[str, Minimizer], **kwargs: Any - ) -> "MultiSystems": + self, *args: Any, minimizer: str | Minimizer, **kwargs: Any + ) -> MultiSystems: """Minimize geometry by a minimizer. Parameters @@ -1535,7 +1544,7 @@ def minimize( new_multisystems.append(ss.minimize(*args, minimizer=minimizer, **kwargs)) return new_multisystems - def pick_atom_idx(self, idx: numbers.Integral, nopbc: Optional[bool]=None): + def pick_atom_idx(self, idx: numbers.Integral, nopbc: bool | None=None): """Pick atom index. Parameters @@ -1555,7 +1564,7 @@ def pick_atom_idx(self, idx: numbers.Integral, nopbc: Optional[bool]=None): new_sys.append(ss.pick_atom_idx(idx, nopbc=nopbc)) return new_sys - def correction(self, hl_sys: "MultiSystems") -> "MultiSystems": + def correction(self, hl_sys: MultiSystems) -> MultiSystems: """Get energy and force correction between self (assumed low-level) and a high-level MultiSystems. The self's coordinates will be kept, but energy and forces will be replaced by the correction between these two systems. @@ -1595,8 +1604,8 @@ def correction(self, hl_sys: "MultiSystems") -> "MultiSystems": return corrected_sys def train_test_split( - self, test_size: Union[float, int], seed: Optional[int] = None - ) -> Tuple["MultiSystems", "MultiSystems", Dict[str, np.ndarray]]: + self, test_size: float | int, seed: int | None = None + ) -> tuple[MultiSystems, MultiSystems, dict[str, np.ndarray]]: """Split systems into random train and test subsets. Parameters @@ -1652,7 +1661,7 @@ def train_test_split( return train_systems, test_systems, test_system_idx -def get_cls_name(cls: Type[Any]) -> str: +def get_cls_name(cls: type[Any]) -> str: """Returns the fully qualified name of a class, such as `np.ndarray`. Parameters diff --git a/dpdata/unit.py b/dpdata/unit.py index 5fc8fe1e9..09981b969 100644 --- a/dpdata/unit.py +++ b/dpdata/unit.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from abc import ABC from scipy import constants # noqa: TID253 diff --git a/dpdata/utils.py b/dpdata/utils.py index a6b76c8cc..7acfb9513 100644 --- a/dpdata/utils.py +++ b/dpdata/utils.py @@ -1,20 +1,24 @@ -from typing import Dict, List, Literal, Union, overload +from __future__ import annotations + +from typing import Literal, overload + import numpy as np from dpdata.periodic_table import Element + @overload -def elements_index_map(elements: List[str], standard: bool, inverse: Literal[True]) -> Dict[int, str]: +def elements_index_map(elements: list[str], standard: bool, inverse: Literal[True]) -> dict[int, str]: ... @overload -def elements_index_map(elements: List[str], standard: bool, inverse: Literal[False]=...) -> Dict[str, int]: +def elements_index_map(elements: list[str], standard: bool, inverse: Literal[False]=...) -> dict[str, int]: ... @overload -def elements_index_map(elements: List[str], standard: bool, inverse: bool=False) -> Union[Dict[str, int], Dict[int, str]]: +def elements_index_map(elements: list[str], standard: bool, inverse: bool=False) -> dict[str, int] | dict[int, str]: ... -def elements_index_map(elements: List[str], standard: bool=False, inverse: bool=False) -> dict: +def elements_index_map(elements: list[str], standard: bool=False, inverse: bool=False) -> dict: if standard: elements.sort(key=lambda x: Element(x).Z) if inverse: diff --git a/dpdata/vasp/outcar.py b/dpdata/vasp/outcar.py index 0eddac91a..0fa4cb68e 100644 --- a/dpdata/vasp/outcar.py +++ b/dpdata/vasp/outcar.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import re import warnings diff --git a/dpdata/vasp/poscar.py b/dpdata/vasp/poscar.py index fde0f8fbe..102e79041 100644 --- a/dpdata/vasp/poscar.py +++ b/dpdata/vasp/poscar.py @@ -1,4 +1,5 @@ #!/usr/bin/python3 +from __future__ import annotations import numpy as np diff --git a/dpdata/vasp/xml.py b/dpdata/vasp/xml.py index a534fd0cf..352b107ed 100755 --- a/dpdata/vasp/xml.py +++ b/dpdata/vasp/xml.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +from __future__ import annotations import xml.etree.ElementTree as ET diff --git a/dpdata/xyz/quip_gap_xyz.py b/dpdata/xyz/quip_gap_xyz.py index 068bec1fb..b23b27e07 100644 --- a/dpdata/xyz/quip_gap_xyz.py +++ b/dpdata/xyz/quip_gap_xyz.py @@ -1,5 +1,7 @@ #!/usr/bin/env python3 # %% +from __future__ import annotations + import re from collections import OrderedDict diff --git a/dpdata/xyz/xyz.py b/dpdata/xyz/xyz.py index 745a97b1b..0c36ac32b 100644 --- a/dpdata/xyz/xyz.py +++ b/dpdata/xyz/xyz.py @@ -1,4 +1,4 @@ -from typing import Tuple +from __future__ import annotations import numpy as np @@ -31,7 +31,7 @@ def coord_to_xyz(coord: np.ndarray, types: list) -> str: return "\n".join(buff) -def xyz_to_coord(xyz: str) -> Tuple[np.ndarray, list]: +def xyz_to_coord(xyz: str) -> tuple[np.ndarray, list]: """Convert xyz format to coordinates and types. Parameters diff --git a/plugin_example/dpdata_random/__init__.py b/plugin_example/dpdata_random/__init__.py index 22820e0fa..cc14faca3 100644 --- a/plugin_example/dpdata_random/__init__.py +++ b/plugin_example/dpdata_random/__init__.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np from dpdata.format import Format diff --git a/tests/comp_sys.py b/tests/comp_sys.py index f4663780b..99879af61 100644 --- a/tests/comp_sys.py +++ b/tests/comp_sys.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np diff --git a/tests/context.py b/tests/context.py index 77a7557d3..3214e28ea 100644 --- a/tests/context.py +++ b/tests/context.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import sys diff --git a/tests/plugin/dpdata_plugin_test/__init__.py b/tests/plugin/dpdata_plugin_test/__init__.py index b3821cb34..ef26e7c1d 100644 --- a/tests/plugin/dpdata_plugin_test/__init__.py +++ b/tests/plugin/dpdata_plugin_test/__init__.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np from dpdata.data_type import Axis, DataType, register_data_type diff --git a/tests/poscars/poscar_ref_oh.py b/tests/poscars/poscar_ref_oh.py index f120183ed..2d29aeeb6 100644 --- a/tests/poscars/poscar_ref_oh.py +++ b/tests/poscars/poscar_ref_oh.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np diff --git a/tests/poscars/test_lammps_dump_s_su.py b/tests/poscars/test_lammps_dump_s_su.py index 28673dfc7..967c767aa 100644 --- a/tests/poscars/test_lammps_dump_s_su.py +++ b/tests/poscars/test_lammps_dump_s_su.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import unittest diff --git a/tests/pwmat/config_ref_ch4.py b/tests/pwmat/config_ref_ch4.py index 71aef7fe1..72499398e 100644 --- a/tests/pwmat/config_ref_ch4.py +++ b/tests/pwmat/config_ref_ch4.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np diff --git a/tests/pwmat/config_ref_oh.py b/tests/pwmat/config_ref_oh.py index 6f3e05619..ad546019a 100644 --- a/tests/pwmat/config_ref_oh.py +++ b/tests/pwmat/config_ref_oh.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np diff --git a/tests/test_abacus_md.py b/tests/test_abacus_md.py index 782ed5214..ddcb7734a 100644 --- a/tests/test_abacus_md.py +++ b/tests/test_abacus_md.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import unittest diff --git a/tests/test_abacus_pw_scf.py b/tests/test_abacus_pw_scf.py index eb712fbeb..8d13dddcf 100644 --- a/tests/test_abacus_pw_scf.py +++ b/tests/test_abacus_pw_scf.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import shutil import unittest diff --git a/tests/test_abacus_relax.py b/tests/test_abacus_relax.py index 65d73e53f..b752a4262 100644 --- a/tests/test_abacus_relax.py +++ b/tests/test_abacus_relax.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import shutil import unittest diff --git a/tests/test_abacus_stru_dump.py b/tests/test_abacus_stru_dump.py index 46cb5de6a..356aa57f4 100644 --- a/tests/test_abacus_stru_dump.py +++ b/tests/test_abacus_stru_dump.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import unittest diff --git a/tests/test_amber_md.py b/tests/test_amber_md.py index 3995371ee..b0a060585 100644 --- a/tests/test_amber_md.py +++ b/tests/test_amber_md.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import shutil import unittest diff --git a/tests/test_amber_sqm.py b/tests/test_amber_sqm.py index 7f14ff84c..b7f091100 100644 --- a/tests/test_amber_sqm.py +++ b/tests/test_amber_sqm.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import shutil import unittest diff --git a/tests/test_ase_traj.py b/tests/test_ase_traj.py index b6eab27e1..8e4a6e12f 100644 --- a/tests/test_ase_traj.py +++ b/tests/test_ase_traj.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest from comp_sys import CompLabeledSys, CompSys, IsPBC diff --git a/tests/test_bond_order_system.py b/tests/test_bond_order_system.py index 41a167fbc..104e18f1f 100644 --- a/tests/test_bond_order_system.py +++ b/tests/test_bond_order_system.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import glob import os import unittest diff --git a/tests/test_cell_to_low_triangle.py b/tests/test_cell_to_low_triangle.py index c080c8e5f..34d0a90ae 100644 --- a/tests/test_cell_to_low_triangle.py +++ b/tests/test_cell_to_low_triangle.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_cli.py b/tests/test_cli.py index 200a1c1ef..9d70db5ff 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import subprocess as sp import sys import unittest diff --git a/tests/test_corr.py b/tests/test_corr.py index 463c99af9..a7c6f7c4a 100644 --- a/tests/test_corr.py +++ b/tests/test_corr.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest from comp_sys import CompLabeledSys, IsPBC diff --git a/tests/test_cp2k_aimd_output.py b/tests/test_cp2k_aimd_output.py index bce242500..46f292b11 100644 --- a/tests/test_cp2k_aimd_output.py +++ b/tests/test_cp2k_aimd_output.py @@ -1,4 +1,6 @@ # %% +from __future__ import annotations + import unittest from comp_sys import CompLabeledSys diff --git a/tests/test_cp2k_output.py b/tests/test_cp2k_output.py index 0e4b153dc..da58e87ce 100644 --- a/tests/test_cp2k_output.py +++ b/tests/test_cp2k_output.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest from comp_sys import CompLabeledSys diff --git a/tests/test_custom_data_type.py b/tests/test_custom_data_type.py index 7e3278ea3..e94ba5e0f 100644 --- a/tests/test_custom_data_type.py +++ b/tests/test_custom_data_type.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest import h5py # noqa: TID253 diff --git a/tests/test_deepmd_comp.py b/tests/test_deepmd_comp.py index 46f8e7414..284287866 100644 --- a/tests/test_deepmd_comp.py +++ b/tests/test_deepmd_comp.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import shutil import unittest diff --git a/tests/test_deepmd_hdf5.py b/tests/test_deepmd_hdf5.py index 20d16c370..b4a22f3c1 100644 --- a/tests/test_deepmd_hdf5.py +++ b/tests/test_deepmd_hdf5.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import unittest diff --git a/tests/test_deepmd_mixed.py b/tests/test_deepmd_mixed.py index 7e522e065..02044932e 100644 --- a/tests/test_deepmd_mixed.py +++ b/tests/test_deepmd_mixed.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import shutil import unittest diff --git a/tests/test_deepmd_raw.py b/tests/test_deepmd_raw.py index 1b0567260..af875fdea 100644 --- a/tests/test_deepmd_raw.py +++ b/tests/test_deepmd_raw.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import shutil import unittest diff --git a/tests/test_dftbplus.py b/tests/test_dftbplus.py index 2a2913a52..29cdaa92e 100644 --- a/tests/test_dftbplus.py +++ b/tests/test_dftbplus.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_elements_index.py b/tests/test_elements_index.py index 45408b4d8..186d7b806 100644 --- a/tests/test_elements_index.py +++ b/tests/test_elements_index.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest from dpdata.system import elements_index_map diff --git a/tests/test_empty.py b/tests/test_empty.py index 8787f9543..12913bab9 100644 --- a/tests/test_empty.py +++ b/tests/test_empty.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_fhi_md_multi_elem_output.py b/tests/test_fhi_md_multi_elem_output.py index a20c45bdd..b11a52f54 100644 --- a/tests/test_fhi_md_multi_elem_output.py +++ b/tests/test_fhi_md_multi_elem_output.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_fhi_md_output.py b/tests/test_fhi_md_output.py index d205e3911..391cc319a 100644 --- a/tests/test_fhi_md_output.py +++ b/tests/test_fhi_md_output.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_fhi_output.py b/tests/test_fhi_output.py index 067e5f699..bd3582f31 100644 --- a/tests/test_fhi_output.py +++ b/tests/test_fhi_output.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_from_pymatgen.py b/tests/test_from_pymatgen.py index d3ddbe3e9..7689a9d5e 100644 --- a/tests/test_from_pymatgen.py +++ b/tests/test_from_pymatgen.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import unittest diff --git a/tests/test_gaussian_driver.py b/tests/test_gaussian_driver.py index 07150bc7b..ff1638488 100644 --- a/tests/test_gaussian_driver.py +++ b/tests/test_gaussian_driver.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import importlib import os import shutil diff --git a/tests/test_gaussian_gjf.py b/tests/test_gaussian_gjf.py index 2e5f4ea8f..b3819946e 100644 --- a/tests/test_gaussian_gjf.py +++ b/tests/test_gaussian_gjf.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import unittest diff --git a/tests/test_gaussian_log.py b/tests/test_gaussian_log.py index 6622e6841..784fd5945 100644 --- a/tests/test_gaussian_log.py +++ b/tests/test_gaussian_log.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_gromacs_gro.py b/tests/test_gromacs_gro.py index 2971755f1..674c65100 100644 --- a/tests/test_gromacs_gro.py +++ b/tests/test_gromacs_gro.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import unittest diff --git a/tests/test_json.py b/tests/test_json.py index 545e5db8c..0b6f1b9dd 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest from comp_sys import CompLabeledSys, IsPBC diff --git a/tests/test_lammps_dump_idx.py b/tests/test_lammps_dump_idx.py index 272cc222e..39379158b 100644 --- a/tests/test_lammps_dump_idx.py +++ b/tests/test_lammps_dump_idx.py @@ -1,4 +1,5 @@ # The index should map to that in the dump file +from __future__ import annotations import os import unittest diff --git a/tests/test_lammps_dump_shift_origin.py b/tests/test_lammps_dump_shift_origin.py index 4ecd6f873..a74442347 100644 --- a/tests/test_lammps_dump_shift_origin.py +++ b/tests/test_lammps_dump_shift_origin.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest from comp_sys import CompSys, IsPBC diff --git a/tests/test_lammps_dump_skipload.py b/tests/test_lammps_dump_skipload.py index 224ec6d1f..299e1db48 100644 --- a/tests/test_lammps_dump_skipload.py +++ b/tests/test_lammps_dump_skipload.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import unittest diff --git a/tests/test_lammps_dump_to_system.py b/tests/test_lammps_dump_to_system.py index af9748a51..4d634037c 100644 --- a/tests/test_lammps_dump_to_system.py +++ b/tests/test_lammps_dump_to_system.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import unittest diff --git a/tests/test_lammps_dump_unfold.py b/tests/test_lammps_dump_unfold.py index 1e78d9756..587602c8a 100644 --- a/tests/test_lammps_dump_unfold.py +++ b/tests/test_lammps_dump_unfold.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import unittest diff --git a/tests/test_lammps_lmp_dump.py b/tests/test_lammps_lmp_dump.py index 8e9cfb328..25525f76b 100644 --- a/tests/test_lammps_lmp_dump.py +++ b/tests/test_lammps_lmp_dump.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import unittest diff --git a/tests/test_lammps_lmp_to_system.py b/tests/test_lammps_lmp_to_system.py index 19e133121..444b1dd43 100644 --- a/tests/test_lammps_lmp_to_system.py +++ b/tests/test_lammps_lmp_to_system.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import unittest diff --git a/tests/test_lammps_read_from_trajs.py b/tests/test_lammps_read_from_trajs.py index f1e5afdd1..578ae471e 100644 --- a/tests/test_lammps_read_from_trajs.py +++ b/tests/test_lammps_read_from_trajs.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import unittest diff --git a/tests/test_msd.py b/tests/test_msd.py index 52b1ce935..7148b0b5a 100644 --- a/tests/test_msd.py +++ b/tests/test_msd.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_multisystems.py b/tests/test_multisystems.py index 2bda13a9b..88d4593a1 100644 --- a/tests/test_multisystems.py +++ b/tests/test_multisystems.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import tempfile import unittest diff --git a/tests/test_n2p2.py b/tests/test_n2p2.py index 855a27524..32ac64473 100644 --- a/tests/test_n2p2.py +++ b/tests/test_n2p2.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import unittest diff --git a/tests/test_openmx.py b/tests/test_openmx.py index 0705ed0a6..2122e8f47 100644 --- a/tests/test_openmx.py +++ b/tests/test_openmx.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_openmx_check_convergence.py b/tests/test_openmx_check_convergence.py index 362c89c58..b19ad6e8d 100644 --- a/tests/test_openmx_check_convergence.py +++ b/tests/test_openmx_check_convergence.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_orca_spout.py b/tests/test_orca_spout.py index ecb1a5ca8..d034fbb08 100644 --- a/tests/test_orca_spout.py +++ b/tests/test_orca_spout.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_periodic_table.py b/tests/test_periodic_table.py index 6b856e913..3cf36b99b 100644 --- a/tests/test_periodic_table.py +++ b/tests/test_periodic_table.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest from context import dpdata diff --git a/tests/test_perturb.py b/tests/test_perturb.py index b89a8c7f2..eea711167 100644 --- a/tests/test_perturb.py +++ b/tests/test_perturb.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest from unittest.mock import patch diff --git a/tests/test_pick_atom_idx.py b/tests/test_pick_atom_idx.py index 0dc069911..ef3368f35 100644 --- a/tests/test_pick_atom_idx.py +++ b/tests/test_pick_atom_idx.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest from comp_sys import CompSys, IsNoPBC diff --git a/tests/test_predict.py b/tests/test_predict.py index f08125ab2..6ab00be36 100644 --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_psi4.py b/tests/test_psi4.py index b9c2124e4..93bfc4088 100644 --- a/tests/test_psi4.py +++ b/tests/test_psi4.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import tempfile import textwrap import unittest diff --git a/tests/test_pwmat_config_dump.py b/tests/test_pwmat_config_dump.py index 9389c7a97..e4d5a5a8e 100644 --- a/tests/test_pwmat_config_dump.py +++ b/tests/test_pwmat_config_dump.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import unittest diff --git a/tests/test_pwmat_config_to_system.py b/tests/test_pwmat_config_to_system.py index 0956f9569..59fd73399 100644 --- a/tests/test_pwmat_config_to_system.py +++ b/tests/test_pwmat_config_to_system.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import unittest diff --git a/tests/test_pwmat_mlmd.py b/tests/test_pwmat_mlmd.py index 4a920c150..8dcdb1efc 100644 --- a/tests/test_pwmat_mlmd.py +++ b/tests/test_pwmat_mlmd.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import unittest diff --git a/tests/test_pwmat_movement.py b/tests/test_pwmat_movement.py index 68a9e681a..14e976b24 100644 --- a/tests/test_pwmat_movement.py +++ b/tests/test_pwmat_movement.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_pymatgen_molecule.py b/tests/test_pymatgen_molecule.py index 231bd97ff..e6a1b5ee5 100644 --- a/tests/test_pymatgen_molecule.py +++ b/tests/test_pymatgen_molecule.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import unittest diff --git a/tests/test_qe_cp_traj.py b/tests/test_qe_cp_traj.py index 6a9631064..9e0629867 100644 --- a/tests/test_qe_cp_traj.py +++ b/tests/test_qe_cp_traj.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_qe_cp_traj_skipload.py b/tests/test_qe_cp_traj_skipload.py index 2964e716b..43cbe88d9 100644 --- a/tests/test_qe_cp_traj_skipload.py +++ b/tests/test_qe_cp_traj_skipload.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import unittest diff --git a/tests/test_qe_pw_scf.py b/tests/test_qe_pw_scf.py index 57a739fb3..8703e7c24 100644 --- a/tests/test_qe_pw_scf.py +++ b/tests/test_qe_pw_scf.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_qe_pw_scf_crystal_atomic_positions.py b/tests/test_qe_pw_scf_crystal_atomic_positions.py index 01c4df21f..383ea6cd7 100644 --- a/tests/test_qe_pw_scf_crystal_atomic_positions.py +++ b/tests/test_qe_pw_scf_crystal_atomic_positions.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_qe_pw_scf_energy_bug.py b/tests/test_qe_pw_scf_energy_bug.py index 8360a7a92..b66ce924b 100644 --- a/tests/test_qe_pw_scf_energy_bug.py +++ b/tests/test_qe_pw_scf_energy_bug.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest from context import dpdata diff --git a/tests/test_quip_gap_xyz.py b/tests/test_quip_gap_xyz.py index b383bd2f4..a265544ce 100644 --- a/tests/test_quip_gap_xyz.py +++ b/tests/test_quip_gap_xyz.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest from comp_sys import CompLabeledSys, IsPBC diff --git a/tests/test_remove_atom_names.py b/tests/test_remove_atom_names.py index d2d4abc7e..9fbd8faf0 100644 --- a/tests/test_remove_atom_names.py +++ b/tests/test_remove_atom_names.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest from comp_sys import CompLabeledSys, IsNoPBC diff --git a/tests/test_remove_outlier.py b/tests/test_remove_outlier.py index b2cb52fcf..c08de0bf4 100644 --- a/tests/test_remove_outlier.py +++ b/tests/test_remove_outlier.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_remove_pbc.py b/tests/test_remove_pbc.py index d5befd771..d70a2f028 100644 --- a/tests/test_remove_pbc.py +++ b/tests/test_remove_pbc.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_replace.py b/tests/test_replace.py index b16c388b5..b91941374 100644 --- a/tests/test_replace.py +++ b/tests/test_replace.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest from unittest.mock import patch diff --git a/tests/test_replicate.py b/tests/test_replicate.py index 99104c3ca..3add2dc02 100644 --- a/tests/test_replicate.py +++ b/tests/test_replicate.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest from comp_sys import CompSys, IsPBC diff --git a/tests/test_shuffle.py b/tests/test_shuffle.py index 9c4622143..3ac33c2f5 100644 --- a/tests/test_shuffle.py +++ b/tests/test_shuffle.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest from comp_sys import CompLabeledSys, IsPBC diff --git a/tests/test_siesta_aiMD_output.py b/tests/test_siesta_aiMD_output.py index a1ba31b6d..4dcb04453 100644 --- a/tests/test_siesta_aiMD_output.py +++ b/tests/test_siesta_aiMD_output.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_siesta_output.py b/tests/test_siesta_output.py index 9ff0167a0..c649f7d0e 100644 --- a/tests/test_siesta_output.py +++ b/tests/test_siesta_output.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_split_dataset.py b/tests/test_split_dataset.py index a5419b7b1..ac0960cfe 100644 --- a/tests/test_split_dataset.py +++ b/tests/test_split_dataset.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_sqm_driver.py b/tests/test_sqm_driver.py index 3dbc6df4a..d7c0da73b 100644 --- a/tests/test_sqm_driver.py +++ b/tests/test_sqm_driver.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import shutil import unittest diff --git a/tests/test_stat.py b/tests/test_stat.py index 9ae8a175b..863cea6c6 100644 --- a/tests/test_stat.py +++ b/tests/test_stat.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest from context import dpdata diff --git a/tests/test_system_append.py b/tests/test_system_append.py index a2c30b238..7c325113f 100644 --- a/tests/test_system_append.py +++ b/tests/test_system_append.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_system_apply_pbc.py b/tests/test_system_apply_pbc.py index 9cf44ae08..2114cf6a8 100644 --- a/tests/test_system_apply_pbc.py +++ b/tests/test_system_apply_pbc.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_system_set_type.py b/tests/test_system_set_type.py index 4bb14b621..d8362ec7b 100644 --- a/tests/test_system_set_type.py +++ b/tests/test_system_set_type.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_to_ase.py b/tests/test_to_ase.py index 60dc931d9..09b830baa 100644 --- a/tests/test_to_ase.py +++ b/tests/test_to_ase.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import unittest diff --git a/tests/test_to_list.py b/tests/test_to_list.py index d559ffce2..998f12650 100644 --- a/tests/test_to_list.py +++ b/tests/test_to_list.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest from comp_sys import CompLabeledSys, IsPBC diff --git a/tests/test_to_pymatgen.py b/tests/test_to_pymatgen.py index b55443d4d..72d1b27ad 100644 --- a/tests/test_to_pymatgen.py +++ b/tests/test_to_pymatgen.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import unittest diff --git a/tests/test_to_pymatgen_entry.py b/tests/test_to_pymatgen_entry.py index 7111dcdc4..dfdeb4680 100644 --- a/tests/test_to_pymatgen_entry.py +++ b/tests/test_to_pymatgen_entry.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import unittest diff --git a/tests/test_type_map.py b/tests/test_type_map.py index 2cc508654..92d25adac 100644 --- a/tests/test_type_map.py +++ b/tests/test_type_map.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest from itertools import permutations diff --git a/tests/test_vasp_outcar.py b/tests/test_vasp_outcar.py index fb2ec1c94..832b0a91b 100644 --- a/tests/test_vasp_outcar.py +++ b/tests/test_vasp_outcar.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_vasp_poscar_dump.py b/tests/test_vasp_poscar_dump.py index a81cbe94b..62f215986 100644 --- a/tests/test_vasp_poscar_dump.py +++ b/tests/test_vasp_poscar_dump.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import unittest diff --git a/tests/test_vasp_poscar_to_system.py b/tests/test_vasp_poscar_to_system.py index dcb83bfdf..7457d33d2 100644 --- a/tests/test_vasp_poscar_to_system.py +++ b/tests/test_vasp_poscar_to_system.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import unittest diff --git a/tests/test_vasp_unconverged_outcar.py b/tests/test_vasp_unconverged_outcar.py index 7e1b35353..1f3b3d2d9 100644 --- a/tests/test_vasp_unconverged_outcar.py +++ b/tests/test_vasp_unconverged_outcar.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import unittest diff --git a/tests/test_vasp_xml.py b/tests/test_vasp_xml.py index cc0bbb41a..0b9177545 100644 --- a/tests/test_vasp_xml.py +++ b/tests/test_vasp_xml.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_water_ions.py b/tests/test_water_ions.py index 788030f38..40c1c143c 100644 --- a/tests/test_water_ions.py +++ b/tests/test_water_ions.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import unittest diff --git a/tests/test_xyz.py b/tests/test_xyz.py index a84ad28bc..d9bcf70ea 100644 --- a/tests/test_xyz.py +++ b/tests/test_xyz.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import tempfile import unittest From 69d4fa0514bb155c54723c21426e640e7e976c50 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 16 May 2024 20:45:01 -0400 Subject: [PATCH 05/12] fix typo Signed-off-by: Jinzhe Zeng --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 4f35bbab7..6efe08188 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ dependencies = [ 'h5py', 'wcmatch', 'importlib_metadata>=1.4; python_version < "3.8"', - 'type_extensions>=0.4.0; python_version < "3.8"', + 'typing_extensions; python_version < "3.8"', ] requires-python = ">=3.7" readme = "README.md" From 04e584baba1340d1cda6ed70efda3c5d7af52e0c Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 16 May 2024 20:45:46 -0400 Subject: [PATCH 06/12] fix import issue Signed-off-by: Jinzhe Zeng --- dpdata/data_type.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/dpdata/data_type.py b/dpdata/data_type.py index 752193fef..bbc7401d6 100644 --- a/dpdata/data_type.py +++ b/dpdata/data_type.py @@ -5,7 +5,6 @@ import numpy as np -from dpdata.bond_order_system import BondOrderSystem from dpdata.plugin import Plugin if TYPE_CHECKING: @@ -74,8 +73,7 @@ def real_shape(self, system: System) -> tuple[int]: shape.append(system.get_natoms()) elif ii is Axis.NBONDS: # BondOrderSystem - assert isinstance(system, BondOrderSystem) - shape.append(system.get_nbonds()) + shape.append(system.get_nbonds()) # type: ignore elif ii == -1: shape.append(AnyInt(-1)) elif isinstance(ii, int): From 1d1b95f5efc14112c431674b304e98d3d51cb90e Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 16 May 2024 20:46:29 -0400 Subject: [PATCH 07/12] run pre-commit Signed-off-by: Jinzhe Zeng --- .github/workflows/pyright.yml | 2 +- docs/nb/try_dpdata.ipynb | 2 + dpdata/amber/mask.py | 1 + dpdata/bond_order_system.py | 2 +- dpdata/cli.py | 1 + dpdata/driver.py | 1 + dpdata/format.py | 1 + dpdata/gaussian/gjf.py | 1 + dpdata/plugin.py | 1 + dpdata/system.py | 116 ++++++++++++++++++++++------------ dpdata/utils.py | 19 ++++-- 11 files changed, 96 insertions(+), 51 deletions(-) diff --git a/.github/workflows/pyright.yml b/.github/workflows/pyright.yml index 9b4732b94..45195bd6b 100644 --- a/.github/workflows/pyright.yml +++ b/.github/workflows/pyright.yml @@ -1,7 +1,7 @@ on: - push - pull_request - + name: Type checker jobs: pyright: diff --git a/docs/nb/try_dpdata.ipynb b/docs/nb/try_dpdata.ipynb index 7dc225b4b..1a0a73280 100644 --- a/docs/nb/try_dpdata.ipynb +++ b/docs/nb/try_dpdata.ipynb @@ -13,6 +13,8 @@ "metadata": {}, "outputs": [], "source": [ + "from __future__ import annotations\n", + "\n", "import dpdata" ] }, diff --git a/dpdata/amber/mask.py b/dpdata/amber/mask.py index cd3cb728e..155e2a7be 100644 --- a/dpdata/amber/mask.py +++ b/dpdata/amber/mask.py @@ -1,4 +1,5 @@ """Amber mask.""" + from __future__ import annotations try: diff --git a/dpdata/bond_order_system.py b/dpdata/bond_order_system.py index 8d129bcde..7a23acca5 100644 --- a/dpdata/bond_order_system.py +++ b/dpdata/bond_order_system.py @@ -98,7 +98,7 @@ def from_fmt_obj(self, fmtobj, file_name, **kwargs): mol = fmtobj.from_bond_order_system(file_name, **kwargs) self.from_rdkit_mol(mol) if hasattr(fmtobj.from_bond_order_system, "post_func"): - for post_f in fmtobj.from_bond_order_system.post_func: # type: ignore + for post_f in fmtobj.from_bond_order_system.post_func: # type: ignore self.post_funcs.get_plugin(post_f)(self) return self diff --git a/dpdata/cli.py b/dpdata/cli.py index 386707891..aadff1a8d 100644 --- a/dpdata/cli.py +++ b/dpdata/cli.py @@ -1,4 +1,5 @@ """Command line interface for dpdata.""" + from __future__ import annotations import argparse diff --git a/dpdata/driver.py b/dpdata/driver.py index 9a196b2e7..b5ff53403 100644 --- a/dpdata/driver.py +++ b/dpdata/driver.py @@ -1,4 +1,5 @@ """Driver plugin system.""" + from __future__ import annotations from abc import ABC, abstractmethod diff --git a/dpdata/format.py b/dpdata/format.py index 8277df191..ade83c21c 100644 --- a/dpdata/format.py +++ b/dpdata/format.py @@ -1,4 +1,5 @@ """Implement the format plugin system.""" + from __future__ import annotations import os diff --git a/dpdata/gaussian/gjf.py b/dpdata/gaussian/gjf.py index 37e2897a7..b83dad1c2 100644 --- a/dpdata/gaussian/gjf.py +++ b/dpdata/gaussian/gjf.py @@ -2,6 +2,7 @@ # https://github.com/deepmodeling/dpgen/blob/0767dce7cad29367edb2e4a55fd0d8724dbda642/dpgen/generator/lib/gaussian.py#L1-L190 # under LGPL 3.0 license """Generate Gaussian input file.""" + from __future__ import annotations import itertools diff --git a/dpdata/plugin.py b/dpdata/plugin.py index b725f4eb1..9e18e2122 100644 --- a/dpdata/plugin.py +++ b/dpdata/plugin.py @@ -1,4 +1,5 @@ """Base of plugin systems.""" + from __future__ import annotations diff --git a/dpdata/system.py b/dpdata/system.py index e119f1e4f..94a9005b5 100644 --- a/dpdata/system.py +++ b/dpdata/system.py @@ -104,13 +104,13 @@ class System: def __init__( self, # some formats do not use string as input - file_name: Any=None, - fmt: str="auto", - type_map:list[str] | None=None, - begin:int=0, - step:int=1, - data: dict[str, Any] | None=None, - convergence_check: bool=True, + file_name: Any = None, + fmt: str = "auto", + type_map: list[str] | None = None, + begin: int = 0, + step: int = 1, + data: dict[str, Any] | None = None, + convergence_check: bool = True, **kwargs, ): """Constructor. @@ -231,7 +231,7 @@ def check_data(self): post_funcs = Plugin() - def from_fmt(self, file_name: Any, fmt: str="auto", **kwargs: Any): + def from_fmt(self, file_name: Any, fmt: str = "auto", **kwargs: Any): fmt = fmt.lower() if fmt == "auto": fmt = os.path.basename(file_name).split(".")[-1].lower() @@ -247,7 +247,7 @@ def from_fmt_obj(self, fmtobj: Format, file_name: Any, **kwargs: Any): self.data = {**self.data, **data} self.check_data() if hasattr(fmtobj.from_system, "post_func"): - for post_f in fmtobj.from_system.post_func: # type: ignore + for post_f in fmtobj.from_system.post_func: # type: ignore self.post_funcs.get_plugin(post_f)(self) return self @@ -289,20 +289,19 @@ def __str__(self): return ret @overload - def __getitem__(self, key: int | slice | list | np.ndarray) -> System: - ... + def __getitem__(self, key: int | slice | list | np.ndarray) -> System: ... @overload - def __getitem__(self, key: Literal["atom_names", "real_atom_names"]) -> list[str]: - ... + def __getitem__( + self, key: Literal["atom_names", "real_atom_names"] + ) -> list[str]: ... @overload - def __getitem__(self, key: Literal["atom_numbs"]) -> list[int]: - ... + def __getitem__(self, key: Literal["atom_numbs"]) -> list[int]: ... @overload - def __getitem__(self, key: Literal["nopbc"]) -> bool: - ... + def __getitem__(self, key: Literal["nopbc"]) -> bool: ... @overload - def __getitem__(self, key: Literal["orig", "coords", "energies", "forces", "virials"]) -> np.ndarray: - ... + def __getitem__( + self, key: Literal["orig", "coords", "energies", "forces", "virials"] + ) -> np.ndarray: ... @overload def __getitem__(self, key: str) -> Any: # other cases, for example customized data @@ -333,13 +332,15 @@ def __add__(self, others): raise RuntimeError("Unspported data structure") return self.__class__.from_dict({"data": self_copy.data}) - def dump(self, filename: str, indent: int=4): + def dump(self, filename: str, indent: int = 4): """Dump .json or .yaml file.""" from monty.serialization import dumpfn dumpfn(self.as_dict(), filename, indent=indent) - def map_atom_types(self, type_map: dict[str, int] | list[str] | None=None) -> np.ndarray: + def map_atom_types( + self, type_map: dict[str, int] | list[str] | None = None + ) -> np.ndarray: """Map the atom types of the system. Parameters @@ -456,7 +457,9 @@ def sub_system(self, f_idx: numbers.Integral) -> System: continue if tt.shape is not None and Axis.NFRAMES in tt.shape: axis_nframes = tt.shape.index(Axis.NFRAMES) - new_shape: list[slice | np.ndarray] = [slice(None) for _ in self.data[tt.name].shape] + new_shape: list[slice | np.ndarray] = [ + slice(None) for _ in self.data[tt.name].shape + ] new_shape[axis_nframes] = f_idx tmp.data[tt.name] = self.data[tt.name][tuple(new_shape)] else: @@ -520,7 +523,7 @@ def append(self, system: System) -> bool: self.data["nopbc"] = False return True - def convert_to_mixed_type(self, type_map:list[str] | None=None): + def convert_to_mixed_type(self, type_map: list[str] | None = None): """Convert the data dict to mixed type format structure, in order to append systems with different formula but the same number of atoms. Change the 'atom_names' to one placeholder type 'MIXED_TOKEN' and add 'real_atom_types' to store the real type @@ -546,7 +549,7 @@ def convert_to_mixed_type(self, type_map:list[str] | None=None): self.data["atom_numbs"] = [natoms] self.data["atom_names"] = ["MIXED_TOKEN"] - def sort_atom_names(self, type_map:list[str] | None=None): + def sort_atom_names(self, type_map: list[str] | None = None): """Sort atom_names of the system and reorder atom_numbs and atom_types accoarding to atom_names. If type_map is not given, atom_names will be sorted by alphabetical order. If type_map is given, atom_names will be type_map. @@ -558,7 +561,7 @@ def sort_atom_names(self, type_map:list[str] | None=None): """ self.data = sort_atom_names(self.data, type_map=type_map) - def check_type_map(self, type_map:list[str] | None): + def check_type_map(self, type_map: list[str] | None): """Assign atom_names to type_map if type_map is given and different from atom_names. @@ -600,7 +603,9 @@ def sort_atom_types(self) -> np.ndarray: continue if tt.shape is not None and Axis.NATOMS in tt.shape: axis_natoms = tt.shape.index(Axis.NATOMS) - new_shape: list[slice | np.ndarray] = [slice(None) for _ in self.data[tt.name].shape] + new_shape: list[slice | np.ndarray] = [ + slice(None) for _ in self.data[tt.name].shape + ] new_shape[axis_natoms] = idx self.data[tt.name] = self.data[tt.name][tuple(new_shape)] return idx @@ -686,7 +691,7 @@ def apply_pbc(self): self.data["coords"] = np.matmul(ncoord, self.data["cells"]) @post_funcs.register("remove_pbc") - def remove_pbc(self, protect_layer: int=9): + def remove_pbc(self, protect_layer: int = 9): """This method does NOT delete the definition of the cells, it (1) revises the cell to a cubic cell and ensures that the cell boundary to any atom in the system is no less than `protect_layer` @@ -701,7 +706,7 @@ def remove_pbc(self, protect_layer: int=9): assert protect_layer >= 0, "the protect_layer should be no less than 0" remove_pbc(self.data, protect_layer) - def affine_map(self, trans, f_idx: numbers.Integral=0): + def affine_map(self, trans, f_idx: numbers.Integral = 0): assert np.linalg.det(trans) != 0 self.data["cells"][f_idx] = np.matmul(self.data["cells"][f_idx], trans) self.data["coords"][f_idx] = np.matmul(self.data["coords"][f_idx], trans) @@ -719,7 +724,7 @@ def rot_lower_triangular(self): for ii in range(self.get_nframes()): self.rot_frame_lower_triangular(ii) - def rot_frame_lower_triangular(self, f_idx: numbers.Integral=0): + def rot_frame_lower_triangular(self, f_idx: numbers.Integral = 0): qq, rr = np.linalg.qr(self.data["cells"][f_idx].T) if np.linalg.det(qq) < 0: qq = -qq @@ -837,7 +842,11 @@ def replace(self, initial_atom_type: str, end_atom_type: str, replace_num: int): self.sort_atom_types() def perturb( - self, pert_num: int, cell_pert_fraction: float, atom_pert_distance: float, atom_pert_style: str="normal" + self, + pert_num: int, + cell_pert_fraction: float, + atom_pert_distance: float, + atom_pert_style: str = "normal", ): """Perturb each frame in the system randomly. The cell will be deformed randomly, and atoms will be displaced by a random distance in random direction. @@ -914,7 +923,9 @@ def shuffle(self): self.data = self.sub_system(idx).data return idx - def predict(self, *args: Any, driver: str | Driver = "dp", **kwargs: Any) -> LabeledSystem: + def predict( + self, *args: Any, driver: str | Driver = "dp", **kwargs: Any + ) -> LabeledSystem: """Predict energies and forces by a driver. Parameters @@ -966,7 +977,7 @@ def minimize( data = minimizer.minimize(self.data.copy()) return LabeledSystem(data=data) - def pick_atom_idx(self, idx: numbers.Integral, nopbc: bool | None=None): + def pick_atom_idx(self, idx: numbers.Integral, nopbc: bool | None = None): """Pick atom index. Parameters @@ -990,7 +1001,9 @@ def pick_atom_idx(self, idx: numbers.Integral, nopbc: bool | None=None): continue if tt.shape is not None and Axis.NATOMS in tt.shape: axis_natoms = tt.shape.index(Axis.NATOMS) - new_shape: list[slice | np.ndarray] = [slice(None) for _ in self.data[tt.name].shape] + new_shape: list[slice | np.ndarray] = [ + slice(None) for _ in self.data[tt.name].shape + ] new_shape[axis_natoms] = idx new_sys.data[tt.name] = self.data[tt.name][tuple(new_shape)] # recalculate atom_numbs according to atom_types @@ -1028,7 +1041,13 @@ def remove_atom_names(self, atom_names: str | Iterable[str]): new_sys.data["atom_numbs"] = new_sys.data["atom_numbs"][: len(new_atom_names)] return new_sys - def pick_by_amber_mask(self, param: str | parmed.Structure, maskstr: str, pass_coords: bool=False, nopbc: bool | None=None): + def pick_by_amber_mask( + self, + param: str | parmed.Structure, + maskstr: str, + pass_coords: bool = False, + nopbc: bool | None = None, + ): """Pick atoms by amber mask. Parameters @@ -1093,7 +1112,10 @@ def get_cell_perturb_matrix(cell_pert_fraction: float): return cell_pert_matrix -def get_atom_perturb_vector(atom_pert_distance: float, atom_pert_style: Literal["normal", "uniform", "const"]="normal"): +def get_atom_perturb_vector( + atom_pert_distance: float, + atom_pert_style: Literal["normal", "uniform", "const"] = "normal", +): random_vector = None if atom_pert_distance < 0: raise RuntimeError("atom_pert_distance can not be negative") @@ -1230,7 +1252,7 @@ def affine_map_fv(self, trans, f_idx: numbers.Integral): trans.T, np.matmul(self.data["virials"][f_idx], trans) ) - def rot_frame_lower_triangular(self, f_idx: numbers.Integral=0): + def rot_frame_lower_triangular(self, f_idx: numbers.Integral = 0): trans = System.rot_frame_lower_triangular(self, f_idx=f_idx) self.affine_map_fv(trans, f_idx=f_idx) return trans @@ -1322,7 +1344,9 @@ def __init__(self, *systems, type_map=None): self.atom_names: list[str] = [] self.append(*systems) - def from_fmt_obj(self, fmtobj: Format, directory, labeled:bool=True, **kwargs: Any): + def from_fmt_obj( + self, fmtobj: Format, directory, labeled: bool = True, **kwargs: Any + ): if not isinstance(fmtobj, dpdata.plugins.deepmd.DeePMDMixedFormat): for dd in fmtobj.from_multi_systems(directory, **kwargs): if labeled: @@ -1415,7 +1439,13 @@ def from_file(cls, file_name, fmt: str, **kwargs: Any): return multi_systems @classmethod - def from_dir(cls, dir_name: str, file_name: str, fmt: str="auto", type_map: list[str] | None=None): + def from_dir( + cls, + dir_name: str, + file_name: str, + fmt: str = "auto", + type_map: list[str] | None = None, + ): multi_systems = cls() target_file_list = sorted( glob.glob(f"./{dir_name}/**/{file_name}", recursive=True) @@ -1426,7 +1456,7 @@ def from_dir(cls, dir_name: str, file_name: str, fmt: str="auto", type_map: list ) return multi_systems - def load_systems_from_file(self, file_name=None, fmt: str | None=None, **kwargs): + def load_systems_from_file(self, file_name=None, fmt: str | None = None, **kwargs): assert fmt is not None fmt = fmt.lower() return self.from_fmt_obj(load_format(fmt), file_name, **kwargs) @@ -1485,7 +1515,9 @@ def check_atom_names(self, system: System): system.add_atom_names(new_in_self) system.sort_atom_names(type_map=self.atom_names) - def predict(self, *args: Any, driver: str | Driver="dp", **kwargs: Any) -> MultiSystems: + def predict( + self, *args: Any, driver: str | Driver = "dp", **kwargs: Any + ) -> MultiSystems: """Predict energies and forces by a driver. Parameters @@ -1544,7 +1576,7 @@ def minimize( new_multisystems.append(ss.minimize(*args, minimizer=minimizer, **kwargs)) return new_multisystems - def pick_atom_idx(self, idx: numbers.Integral, nopbc: bool | None=None): + def pick_atom_idx(self, idx: numbers.Integral, nopbc: bool | None = None): """Pick atom index. Parameters @@ -1599,7 +1631,7 @@ def correction(self, hl_sys: MultiSystems) -> MultiSystems: for nn in self.systems.keys(): ll_ss = self[nn] hl_ss = hl_sys[nn] - assert isinstance(ll_ss, LabeledSystem) + assert isinstance(ll_ss, LabeledSystem) corrected_sys.append(ll_ss.correction(hl_ss)) return corrected_sys diff --git a/dpdata/utils.py b/dpdata/utils.py index 7acfb9513..030256dd6 100644 --- a/dpdata/utils.py +++ b/dpdata/utils.py @@ -8,17 +8,22 @@ @overload -def elements_index_map(elements: list[str], standard: bool, inverse: Literal[True]) -> dict[int, str]: - ... +def elements_index_map( + elements: list[str], standard: bool, inverse: Literal[True] +) -> dict[int, str]: ... @overload -def elements_index_map(elements: list[str], standard: bool, inverse: Literal[False]=...) -> dict[str, int]: - ... +def elements_index_map( + elements: list[str], standard: bool, inverse: Literal[False] = ... +) -> dict[str, int]: ... @overload -def elements_index_map(elements: list[str], standard: bool, inverse: bool=False) -> dict[str, int] | dict[int, str]: - ... +def elements_index_map( + elements: list[str], standard: bool, inverse: bool = False +) -> dict[str, int] | dict[int, str]: ... -def elements_index_map(elements: list[str], standard: bool=False, inverse: bool=False) -> dict: +def elements_index_map( + elements: list[str], standard: bool = False, inverse: bool = False +) -> dict: if standard: elements.sort(key=lambda x: Element(x).Z) if inverse: From 6e7462ffac196b09ae625a2ec2234f6ab71ecd1c Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 16 May 2024 20:49:11 -0400 Subject: [PATCH 08/12] Literal Signed-off-by: Jinzhe Zeng --- docs/make_format.py | 7 ++++++- dpdata/system.py | 7 +++---- dpdata/utils.py | 7 ++++++- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/docs/make_format.py b/docs/make_format.py index e9c1f60d3..2b3c03c67 100644 --- a/docs/make_format.py +++ b/docs/make_format.py @@ -2,9 +2,14 @@ import csv import os +import sys from collections import defaultdict from inspect import Parameter, Signature, cleandoc, signature -from typing import Literal + +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal from numpydoc.docscrape import Parameter as numpydoc_Parameter from numpydoc.docscrape_sphinx import SphinxDocString diff --git a/dpdata/system.py b/dpdata/system.py index 94a9005b5..de3a3d0cf 100644 --- a/dpdata/system.py +++ b/dpdata/system.py @@ -12,14 +12,13 @@ TYPE_CHECKING, Any, Iterable, - Literal, overload, ) -if sys.version_info < (3, 8): - pass +if sys.version_info >= (3, 8): + from typing import Literal else: - pass + from typing_extensions import Literal import numpy as np diff --git a/dpdata/utils.py b/dpdata/utils.py index 030256dd6..e008120ea 100644 --- a/dpdata/utils.py +++ b/dpdata/utils.py @@ -1,7 +1,12 @@ from __future__ import annotations -from typing import Literal, overload +import sys +from typing import overload +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal import numpy as np from dpdata.periodic_table import Element From eed57512647acf4de54fe291830a1075e589b122 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 16 May 2024 20:50:11 -0400 Subject: [PATCH 09/12] install all dependencies Signed-off-by: Jinzhe Zeng --- .github/workflows/pyright.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pyright.yml b/.github/workflows/pyright.yml index 45195bd6b..73dc81f15 100644 --- a/.github/workflows/pyright.yml +++ b/.github/workflows/pyright.yml @@ -13,7 +13,7 @@ jobs: with: python-version: '3.12' - run: pip install uv - - run: uv pip install --system -e . + - run: uv pip install --system -e .[amber,ase,pymatgen] rdkit openbabel-wheel - uses: jakebailey/pyright-action@v2 with: version: 1.1.363 From 77d798c89873595351f79ac614ad2022968f154f Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 16 May 2024 21:06:21 -0400 Subject: [PATCH 10/12] fix rest errors Signed-off-by: Jinzhe Zeng --- dpdata/system.py | 35 +++++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/dpdata/system.py b/dpdata/system.py index de3a3d0cf..2614bc23b 100644 --- a/dpdata/system.py +++ b/dpdata/system.py @@ -433,7 +433,7 @@ def copy(self): """Returns a copy of the system.""" return self.__class__.from_dict({"data": deepcopy(self.data)}) - def sub_system(self, f_idx: numbers.Integral) -> System: + def sub_system(self, f_idx: int | slice | list | np.ndarray): """Construct a subsystem from the system. Parameters @@ -450,13 +450,14 @@ def sub_system(self, f_idx: numbers.Integral) -> System: # convert int to array_like if isinstance(f_idx, numbers.Integral): f_idx = np.array([f_idx]) + assert not isinstance(f_idx, int) for tt in self.DTYPES: if tt.name not in self.data: # skip optional data continue if tt.shape is not None and Axis.NFRAMES in tt.shape: axis_nframes = tt.shape.index(Axis.NFRAMES) - new_shape: list[slice | np.ndarray] = [ + new_shape: list[slice | np.ndarray | list] = [ slice(None) for _ in self.data[tt.name].shape ] new_shape[axis_nframes] = f_idx @@ -705,7 +706,7 @@ def remove_pbc(self, protect_layer: int = 9): assert protect_layer >= 0, "the protect_layer should be no less than 0" remove_pbc(self.data, protect_layer) - def affine_map(self, trans, f_idx: numbers.Integral = 0): + def affine_map(self, trans, f_idx: int | numbers.Integral = 0): assert np.linalg.det(trans) != 0 self.data["cells"][f_idx] = np.matmul(self.data["cells"][f_idx], trans) self.data["coords"][f_idx] = np.matmul(self.data["coords"][f_idx], trans) @@ -723,7 +724,7 @@ def rot_lower_triangular(self): for ii in range(self.get_nframes()): self.rot_frame_lower_triangular(ii) - def rot_frame_lower_triangular(self, f_idx: numbers.Integral = 0): + def rot_frame_lower_triangular(self, f_idx: int | numbers.Integral = 0): qq, rr = np.linalg.qr(self.data["cells"][f_idx].T) if np.linalg.det(qq) < 0: qq = -qq @@ -776,7 +777,7 @@ def replicate(self, ncopy: list[int] | tuple[int, int, int]): np.array(np.copy(data["atom_numbs"])) * np.prod(ncopy) ) tmp.data["atom_types"] = np.sort( - np.tile(np.copy(data["atom_types"]), np.prod(ncopy)), kind="stable" + np.tile(np.copy(data["atom_types"]), np.prod(ncopy).item()), kind="stable" ) tmp.data["cells"] = np.copy(data["cells"]) for ii in range(3): @@ -976,7 +977,11 @@ def minimize( data = minimizer.minimize(self.data.copy()) return LabeledSystem(data=data) - def pick_atom_idx(self, idx: numbers.Integral, nopbc: bool | None = None): + def pick_atom_idx( + self, + idx: int | numbers.Integral | list[int] | slice | np.ndarray, + nopbc: bool | None = None, + ): """Pick atom index. Parameters @@ -994,13 +999,14 @@ def pick_atom_idx(self, idx: numbers.Integral, nopbc: bool | None = None): new_sys = self.copy() if isinstance(idx, numbers.Integral): idx = np.array([idx]) + assert not isinstance(idx, int) for tt in self.DTYPES: if tt.name not in self.data: # skip optional data continue if tt.shape is not None and Axis.NATOMS in tt.shape: axis_natoms = tt.shape.index(Axis.NATOMS) - new_shape: list[slice | np.ndarray] = [ + new_shape: list[slice | np.ndarray | list[int]] = [ slice(None) for _ in self.data[tt.name].shape ] new_shape[axis_natoms] = idx @@ -1014,7 +1020,7 @@ def pick_atom_idx(self, idx: numbers.Integral, nopbc: bool | None = None): new_sys.nopbc = nopbc return new_sys - def remove_atom_names(self, atom_names: str | Iterable[str]): + def remove_atom_names(self, atom_names: str | list[str]): """Remove atom names and all such atoms. For example, you may not remove EP atoms in TIP4P/Ew water, which is not a real atom. @@ -1113,7 +1119,7 @@ def get_cell_perturb_matrix(cell_pert_fraction: float): def get_atom_perturb_vector( atom_pert_distance: float, - atom_pert_style: Literal["normal", "uniform", "const"] = "normal", + atom_pert_style: str = "normal", ): random_vector = None if atom_pert_distance < 0: @@ -1243,7 +1249,7 @@ def has_virial(self) -> bool: # return ('virials' in self.data) and (len(self.data['virials']) > 0) return "virials" in self.data - def affine_map_fv(self, trans, f_idx: numbers.Integral): + def affine_map_fv(self, trans, f_idx: int | numbers.Integral): assert np.linalg.det(trans) != 0 self.data["forces"][f_idx] = np.matmul(self.data["forces"][f_idx], trans) if self.has_virial(): @@ -1251,7 +1257,7 @@ def affine_map_fv(self, trans, f_idx: numbers.Integral): trans.T, np.matmul(self.data["virials"][f_idx], trans) ) - def rot_frame_lower_triangular(self, f_idx: numbers.Integral = 0): + def rot_frame_lower_triangular(self, f_idx: int | numbers.Integral = 0): trans = System.rot_frame_lower_triangular(self, f_idx=f_idx) self.affine_map_fv(trans, f_idx=f_idx) return trans @@ -1575,7 +1581,11 @@ def minimize( new_multisystems.append(ss.minimize(*args, minimizer=minimizer, **kwargs)) return new_multisystems - def pick_atom_idx(self, idx: numbers.Integral, nopbc: bool | None = None): + def pick_atom_idx( + self, + idx: int | numbers.Integral | list[int] | slice | np.ndarray, + nopbc: bool | None = None, + ): """Pick atom index. Parameters @@ -1631,6 +1641,7 @@ def correction(self, hl_sys: MultiSystems) -> MultiSystems: ll_ss = self[nn] hl_ss = hl_sys[nn] assert isinstance(ll_ss, LabeledSystem) + assert isinstance(hl_ss, LabeledSystem) corrected_sys.append(ll_ss.correction(hl_ss)) return corrected_sys From 8b0700a9f7f43c68d2447fac7becb430fa96f5c9 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 16 May 2024 21:06:45 -0400 Subject: [PATCH 11/12] add py.typed file Signed-off-by: Jinzhe Zeng --- dpdata/py.typed | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 dpdata/py.typed diff --git a/dpdata/py.typed b/dpdata/py.typed new file mode 100644 index 000000000..e69de29bb From 24b93d503288b637b89172d7dca0f88de2814b55 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 16 May 2024 23:08:10 -0400 Subject: [PATCH 12/12] remove **kwargs: dict, which is incorrect --- dpdata/plugins/amber.py | 2 +- dpdata/plugins/gaussian.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dpdata/plugins/amber.py b/dpdata/plugins/amber.py index d991ce482..42fce5528 100644 --- a/dpdata/plugins/amber.py +++ b/dpdata/plugins/amber.py @@ -126,7 +126,7 @@ class SQMDriver(Driver): -15.41111246 """ - def __init__(self, sqm_exec: str = "sqm", **kwargs: dict) -> None: + def __init__(self, sqm_exec: str = "sqm", **kwargs) -> None: self.sqm_exec = sqm_exec self.kwargs = kwargs diff --git a/dpdata/plugins/gaussian.py b/dpdata/plugins/gaussian.py index 80cfa4076..b55447b91 100644 --- a/dpdata/plugins/gaussian.py +++ b/dpdata/plugins/gaussian.py @@ -83,7 +83,7 @@ class GaussianDriver(Driver): -1102.714590995794 """ - def __init__(self, gaussian_exec: str = "g16", **kwargs: dict) -> None: + def __init__(self, gaussian_exec: str = "g16", **kwargs) -> None: self.gaussian_exec = gaussian_exec self.kwargs = kwargs