Skip to content

Commit

Permalink
Refactor typing (#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinfriede authored Dec 22, 2023
1 parent 53805ab commit 8a2962d
Show file tree
Hide file tree
Showing 21 changed files with 3,657 additions and 3,486 deletions.
14 changes: 12 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
#
# You should have received a copy of the GNU Lesser General Public License
# along with tad-multicharge. If not, see <https://www.gnu.org/licenses/>.
ci:
skip: [mypy]

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
Expand Down Expand Up @@ -48,14 +51,21 @@ repos:
args: [--py37-plus, --keep-runtime-typing]

- repo: https://github.com/pycqa/isort
rev: 5.13.0
rev: 5.13.2
hooks:
- id: isort
name: isort (python)
args: ["--profile", "black", "--filter-files"]

- repo: https://github.com/psf/black
rev: 23.11.0
rev: 23.12.0
hooks:
- id: black
stages: [commit]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.7.1
hooks:
- id: mypy
additional_dependencies: [types-all]
exclude: 'test/conftest.py'
1 change: 1 addition & 0 deletions docs/modules/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ The following modules are contained with `tad_multicharge`.
defaults
eeq
model
typing/index
2 changes: 2 additions & 0 deletions docs/modules/typing/builtin.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
.. automodule:: tad_multicharge.typing.builtin
:members:
8 changes: 8 additions & 0 deletions docs/modules/typing/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
.. _typing:

.. automodule:: tad_multicharge.typing

.. toctree::

builtin
pytorch
2 changes: 2 additions & 0 deletions docs/modules/typing/pytorch.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
.. automodule:: tad_multicharge.typing.pytorch
:members:
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ disallow_untyped_defs = true
warn_redundant_casts = true
warn_unreachable = true
warn_unused_ignores = true
exclude = '''
(?x)
^test?s/conftest.py$
'''


[tool.coverage.run]
Expand Down
10 changes: 5 additions & 5 deletions src/tad_multicharge/eeq.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@
from tad_mctc import storch
from tad_mctc.batch import real_atoms, real_pairs
from tad_mctc.ncoord import cn_eeq, erf_count
from tad_mctc.typing import DD, Any, CountingFunction, Tensor

from . import defaults
from .model import ChargeModel
from .param import eeq2019
from .typing import DD, Any, CountingFunction, Tensor, get_default_dtype

__all__ = ["EEQModel", "solve", "get_charges"]

Expand Down Expand Up @@ -93,10 +93,10 @@ def param2019(
EEQModel
Instance of the EEQ charge model class.
"""

dd: dict = {"device": device}
if dtype is not None:
dd["dtype"] = dtype
dd: DD = {
"device": device,
"dtype": dtype if dtype is not None else get_default_dtype(),
}

return cls(
eeq2019.chi.to(**dd),
Expand Down
3 changes: 2 additions & 1 deletion src/tad_multicharge/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from __future__ import annotations

import torch
from tad_mctc.typing import Tensor, TensorLike

from .typing import Tensor, TensorLike

__all__ = ["ChargeModel"]

Expand Down
25 changes: 25 additions & 0 deletions src/tad_multicharge/typing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# This file is part of tad-multicharge.
#
# SPDX-Identifier: LGPL-3.0
# Copyright (C) 2023 Marvin Friede
#
# tad-multicharge is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# tad-multicharge is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with tad-multicharge. If not, see <https://www.gnu.org/licenses/>.
"""
Type annotations
================
All type annotations for this project.
"""
from .builtin import *
from .pytorch import *
27 changes: 27 additions & 0 deletions src/tad_multicharge/typing/builtin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# This file is part of tad-multicharge.
#
# SPDX-Identifier: LGPL-3.0
# Copyright (C) 2023 Marvin Friede
#
# tad-multicharge is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# tad-multicharge is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with tad-multicharge. If not, see <https://www.gnu.org/licenses/>.
"""
Type annotations: Built-ins
===========================
Built-in type annotations are imported from the *tad-mctc* library, which
handles some version checking.
"""
from tad_mctc.typing import Any, Callable, TypedDict

__all__ = ["Any", "Callable", "TypedDict"]
42 changes: 42 additions & 0 deletions src/tad_multicharge/typing/pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# This file is part of tad-multicharge.
#
# SPDX-Identifier: LGPL-3.0
# Copyright (C) 2023 Marvin Friede
#
# tad-multicharge is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# tad-multicharge is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with tad-multicharge. If not, see <https://www.gnu.org/licenses/>.
"""
Type annotations: PyTorch
=========================
PyTorch-related type annotations for this project.
"""
from tad_mctc.typing import (
DD,
CountingFunction,
Molecule,
Tensor,
TensorLike,
get_default_device,
get_default_dtype,
)

__all__ = [
"DD",
"CountingFunction",
"Molecule",
"Tensor",
"TensorLike",
"get_default_device",
"get_default_dtype",
]
2 changes: 1 addition & 1 deletion test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def pytest_configure(config: pytest.Config) -> None:
if torch.__version__ < (2, 0, 0): # type: ignore
torch.set_default_tensor_type("torch.cuda.FloatTensor") # type: ignore
else:
torch.set_default_device(DEVICE) # type: ignore
torch.set_default_device(DEVICE) # type: ignore[attr-defined]
else:
torch.use_deterministic_algorithms(True)
DEVICE = None
Expand Down
Loading

0 comments on commit 8a2962d

Please sign in to comment.