Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: improve type annotations #659

Merged
merged 12 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions .github/workflows/pyright.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
on:
- push
- pull_request

name: Type checker
jobs:
pyright:
name: pyright
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@master
- uses: actions/setup-python@v5
with:
python-version: '3.12'
- run: pip install uv
- run: uv pip install --system -e .[amber,ase,pymatgen] rdkit openbabel-wheel
- uses: jakebailey/pyright-action@v2
with:
version: 1.1.363
2 changes: 2 additions & 0 deletions benchmark/test_import.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import subprocess
import sys

Expand Down
2 changes: 2 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
from __future__ import annotations

import os
import subprocess as sp
import sys
Expand Down
9 changes: 8 additions & 1 deletion docs/make_format.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
from __future__ import annotations

import csv
import os
import sys
from collections import defaultdict
from inspect import Parameter, Signature, cleandoc, signature
from typing import Literal

if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal

from numpydoc.docscrape import Parameter as numpydoc_Parameter
from numpydoc.docscrape_sphinx import SphinxDocString
Expand Down
2 changes: 2 additions & 0 deletions docs/nb/try_dpdata.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
"metadata": {},
"outputs": [],
"source": [
"from __future__ import annotations\n",
"\n",
"import dpdata"
]
},
Expand Down
2 changes: 2 additions & 0 deletions dpdata/__about__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from __future__ import annotations

Check warning on line 1 in dpdata/__about__.py

View check run for this annotation

Codecov / codecov/patch

dpdata/__about__.py#L1

Added line #L1 was not covered by tests

__version__ = "unknown"
2 changes: 2 additions & 0 deletions dpdata/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from . import lammps, md, vasp
from .bond_order_system import BondOrderSystem
from .system import LabeledSystem, MultiSystems, System
Expand Down
2 changes: 2 additions & 0 deletions dpdata/__main__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

Check warning on line 1 in dpdata/__main__.py

View check run for this annotation

Codecov / codecov/patch

dpdata/__main__.py#L1

Added line #L1 was not covered by tests

from dpdata.cli import dpdata_cli

if __name__ == "__main__":
Expand Down
2 changes: 2 additions & 0 deletions dpdata/abacus/md.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import os
import warnings

Expand Down
2 changes: 2 additions & 0 deletions dpdata/abacus/relax.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import os

import numpy as np
Expand Down
2 changes: 2 additions & 0 deletions dpdata/abacus/scf.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import os
import re
import warnings
Expand Down
2 changes: 2 additions & 0 deletions dpdata/amber/mask.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Amber mask."""

from __future__ import annotations

try:
import parmed
except ImportError:
Expand Down
2 changes: 2 additions & 0 deletions dpdata/amber/md.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import os
import re

Expand Down
2 changes: 2 additions & 0 deletions dpdata/amber/sqm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import numpy as np

from dpdata.periodic_table import ELEMENTS
Expand Down
21 changes: 13 additions & 8 deletions dpdata/ase_calculator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import TYPE_CHECKING, List, Optional
from __future__ import annotations

from typing import TYPE_CHECKING

from ase.calculators.calculator import ( # noqa: TID253
Calculator,
Expand All @@ -23,7 +25,10 @@ class DPDataCalculator(Calculator):
dpdata driver
"""

name = "dpdata"
@property
def name(self) -> str:
return "dpdata"

implemented_properties = ["energy", "free_energy", "forces", "virial", "stress"]

def __init__(self, driver: Driver, **kwargs) -> None:
Expand All @@ -32,9 +37,9 @@ def __init__(self, driver: Driver, **kwargs) -> None:

def calculate(
self,
atoms: Optional["Atoms"] = None,
properties: List[str] = ["energy", "forces"],
system_changes: List[str] = all_changes,
atoms: Atoms | None = None,
properties: list[str] = ["energy", "forces"],
system_changes: list[str] = all_changes,
):
"""Run calculation with a driver.

Expand All @@ -48,10 +53,10 @@ def calculate(
system_changes : List[str], optional
unused, only for function signature compatibility, by default all_changes
"""
if atoms is not None:
self.atoms = atoms.copy()
assert atoms is not None
atoms = atoms.copy()

system = dpdata.System(self.atoms, fmt="ase/structure")
system = dpdata.System(atoms, fmt="ase/structure")
data = system.predict(driver=self.driver).data

self.results["energy"] = data["energies"][0]
Expand Down
9 changes: 6 additions & 3 deletions dpdata/bond_order_system.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# %%
# Bond Order System
from __future__ import annotations

from copy import deepcopy

import numpy as np
Expand Down Expand Up @@ -96,13 +98,14 @@
mol = fmtobj.from_bond_order_system(file_name, **kwargs)
self.from_rdkit_mol(mol)
if hasattr(fmtobj.from_bond_order_system, "post_func"):
for post_f in fmtobj.from_bond_order_system.post_func:
for post_f in fmtobj.from_bond_order_system.post_func: # type: ignore

Check warning on line 101 in dpdata/bond_order_system.py

View check run for this annotation

Codecov / codecov/patch

dpdata/bond_order_system.py#L101

Added line #L101 was not covered by tests
self.post_funcs.get_plugin(post_f)(self)
return self

def to_fmt_obj(self, fmtobj, *args, **kwargs):
from rdkit.Chem import Conformer

assert self.rdkit_mol is not None
coderabbitai[bot] marked this conversation as resolved.
Show resolved Hide resolved
self.rdkit_mol.RemoveAllConformers()
for ii in range(self.get_nframes()):
conf = Conformer()
Expand Down Expand Up @@ -145,9 +148,9 @@
"""Return the formal charges on each atom."""
return self.data["formal_charges"]

def copy(self):
def copy(self): # type: ignore
new_mol = deepcopy(self.rdkit_mol)
self.__class__(data=deepcopy(self.data), rdkit_mol=new_mol)
return self.__class__(data=deepcopy(self.data), rdkit_mol=new_mol)

Check warning on line 153 in dpdata/bond_order_system.py

View check run for this annotation

Codecov / codecov/patch

dpdata/bond_order_system.py#L153

Added line #L153 was not covered by tests

def __add__(self, other):
raise NotImplementedError(
Expand Down
9 changes: 5 additions & 4 deletions dpdata/cli.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Command line interface for dpdata."""

from __future__ import annotations

Check warning on line 3 in dpdata/cli.py

View check run for this annotation

Codecov / codecov/patch

dpdata/cli.py#L3

Added line #L3 was not covered by tests

import argparse
from typing import Optional

from . import __version__
from .system import LabeledSystem, MultiSystems, System
Expand Down Expand Up @@ -59,11 +60,11 @@
*,
from_file: str,
from_format: str = "auto",
to_file: Optional[str] = None,
to_format: Optional[str] = None,
to_file: str | None = None,
to_format: str | None = None,
no_labeled: bool = False,
multi: bool = False,
type_map: Optional[list] = None,
type_map: list | None = None,
**kwargs,
):
"""Convert files from one format to another one.
Expand Down
1 change: 1 addition & 0 deletions dpdata/cp2k/cell.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# %%
from __future__ import annotations

import numpy as np

Expand Down
2 changes: 2 additions & 0 deletions dpdata/cp2k/output.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# %%
from __future__ import annotations

import math
import re
from collections import OrderedDict
Expand Down
13 changes: 8 additions & 5 deletions dpdata/data_type.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from enum import Enum, unique
from typing import TYPE_CHECKING, Tuple
from typing import TYPE_CHECKING

import numpy as np

Expand Down Expand Up @@ -50,16 +52,17 @@ def __init__(
self,
name: str,
dtype: type,
shape: Tuple[int, Axis] = None,
shape: tuple[int | Axis, ...] | None = None,
required: bool = True,
) -> None:
self.name = name
self.dtype = dtype
self.shape = shape
self.required = required

def real_shape(self, system: "System") -> Tuple[int]:
def real_shape(self, system: System) -> tuple[int]:
"""Returns expected real shape of a system."""
assert self.shape is not None
shape = []
for ii in self.shape:
if ii is Axis.NFRAMES:
Expand All @@ -70,7 +73,7 @@ def real_shape(self, system: "System") -> Tuple[int]:
shape.append(system.get_natoms())
elif ii is Axis.NBONDS:
# BondOrderSystem
shape.append(system.get_nbonds())
shape.append(system.get_nbonds()) # type: ignore
elif ii == -1:
shape.append(AnyInt(-1))
elif isinstance(ii, int):
Expand All @@ -79,7 +82,7 @@ def real_shape(self, system: "System") -> Tuple[int]:
raise RuntimeError("Shape is not an int!")
return tuple(shape)

def check(self, system: "System"):
def check(self, system: System):
"""Check if a system has correct data of this type.

Parameters
Expand Down
2 changes: 2 additions & 0 deletions dpdata/deepmd/comp.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import glob
import os
import shutil
Expand Down
2 changes: 2 additions & 0 deletions dpdata/deepmd/mixed.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import glob
import os
import shutil
Expand Down
2 changes: 2 additions & 0 deletions dpdata/deepmd/raw.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import os
import warnings

Expand Down
4 changes: 2 additions & 2 deletions dpdata/dftbplus/output.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Tuple
from __future__ import annotations

import numpy as np


def read_dftb_plus(fn_1: str, fn_2: str) -> Tuple[str, np.ndarray, float, np.ndarray]:
def read_dftb_plus(fn_1: str, fn_2: str) -> tuple[str, np.ndarray, float, np.ndarray]:
"""Read from DFTB+ input and output.

Parameters
Expand Down
15 changes: 9 additions & 6 deletions dpdata/driver.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""Driver plugin system."""

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Callable, List, Union
from typing import TYPE_CHECKING, Callable

from .plugin import Plugin

if TYPE_CHECKING:
import ase
import ase.calculators.calculator

Check warning on line 11 in dpdata/driver.py

View check run for this annotation

Codecov / codecov/patch

dpdata/driver.py#L11

Added line #L11 was not covered by tests


class Driver(ABC):
Expand Down Expand Up @@ -43,7 +45,7 @@
return Driver.__DriverPlugin.register(key)

@staticmethod
def get_driver(key: str) -> "Driver":
def get_driver(key: str) -> type[Driver]:
"""Get a driver plugin.

Parameters
Expand Down Expand Up @@ -97,7 +99,7 @@
return NotImplemented

@property
def ase_calculator(self) -> "ase.calculators.calculator.Calculator":
def ase_calculator(self) -> ase.calculators.calculator.Calculator:
"""Returns an ase calculator based on this driver."""
from .ase_calculator import DPDataCalculator

Expand Down Expand Up @@ -130,7 +132,7 @@
This driver is the hybrid of SQM and DP.
"""

def __init__(self, drivers: List[Union[dict, Driver]]) -> None:
def __init__(self, drivers: list[dict | Driver]) -> None:
self.drivers = []
for driver in drivers:
if isinstance(driver, Driver):
Expand All @@ -157,6 +159,7 @@
dict
labeled data with energies and forces
"""
labeled_data = {}
for ii, driver in enumerate(self.drivers):
lb_data = driver.label(data.copy())
if ii == 0:
Expand Down Expand Up @@ -199,7 +202,7 @@
return Minimizer.__MinimizerPlugin.register(key)

@staticmethod
def get_minimizer(key: str) -> "Minimizer":
def get_minimizer(key: str) -> type[Minimizer]:
"""Get a minimizer plugin.

Parameters
Expand Down
2 changes: 2 additions & 0 deletions dpdata/fhi_aims/output.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import re
import warnings

Expand Down
4 changes: 3 additions & 1 deletion dpdata/format.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Implement the format plugin system."""

from __future__ import annotations

import os
from abc import ABC

Expand Down Expand Up @@ -163,7 +165,7 @@
if not isinstance(func_name, (list, tuple, set)):
object.post_func = (func_name,)
else:
object.post_func = func_name
object.post_func = tuple(func_name)

Check warning on line 168 in dpdata/format.py

View check run for this annotation

Codecov / codecov/patch

dpdata/format.py#L168

Added line #L168 was not covered by tests
return object

return decorator
Expand Down
Loading