Skip to content

Commit

Permalink
Fix NaN's in ATM term (#27)
Browse files Browse the repository at this point in the history
- rename typing to _typing (to avoid shadowed imports)
- fix triples mask (now masks out everything with at least two identical
indices)
- add more masks to avoid NaN's for single precision (caused by
exponentiation of already large numbers)
  • Loading branch information
marvinfriede authored Jul 25, 2023
1 parent 03270b5 commit f9c585e
Show file tree
Hide file tree
Showing 24 changed files with 490 additions and 55 deletions.
9 changes: 8 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,14 @@ repos:
rev: v2.4.0
hooks:
- id: setup-cfg-fmt
args: [--include-version-classifiers, --max-py-version, "3.11"]
args:
[
--include-version-classifiers,
--min-py-version,
"3.8",
--max-py-version,
"3.11",
]

- repo: https://github.com/asottile/pyupgrade
rev: v3.9.0
Expand Down
2 changes: 1 addition & 1 deletion src/tad_dftd3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
import torch

from . import damping, data, disp, model, ncoord, reference
from .typing import (
from ._typing import (
DD,
CountingFunction,
DampingFunction,
Expand Down
1 change: 1 addition & 0 deletions src/tad_dftd3/typing.py → src/tad_dftd3/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
List,
NoReturn,
Optional,
Protocol,
Tuple,
TypedDict,
Union,
Expand Down
26 changes: 22 additions & 4 deletions src/tad_dftd3/damping/atm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import torch

from .. import defaults
from ..typing import DD, Tensor
from .._typing import DD, Tensor
from ..util import cdist, real_pairs, real_triples


Expand Down Expand Up @@ -84,9 +84,11 @@ def dispersion_atm(
srvdw = rs9 * rvdw

mask_pairs = real_pairs(numbers, diagonal=False)
mask_triples = real_triples(numbers, diagonal=False)
mask_triples = real_triples(numbers, self=False)

eps = torch.tensor(torch.finfo(positions.dtype).eps, **dd)
zero = torch.tensor(0.0, **dd)
one = torch.tensor(1.0, **dd)

# C9_ABC = s9 * sqrt(|C6_AB * C6_AC * C6_BC|)
c9 = s9 * torch.sqrt(
Expand Down Expand Up @@ -120,9 +122,25 @@ def dispersion_atm(
r3 = torch.where(mask_triples, r1 * r2, eps)
r5 = torch.where(mask_triples, r2 * r3, eps)

fdamp = 1.0 / (1.0 + 6.0 * (r0 / r1) ** ((alp + 2.0) / 3.0))
# dividing by tiny numbers leads to huge numbers, which result in NaN's
# upon exponentiation in the subsequent step
mask = real_triples(numbers, self=False)
base = r0 / torch.where(mask_triples, r1, one)

# to fix the previous mask, we mask again (not strictly necessary because
# `ang` is also masked and we later multiply with `ang`)
fdamp = torch.where(
mask_triples,
1.0 / (1.0 + 6.0 * base ** ((alp + 2.0) / 3.0)),
zero,
)

s = torch.where(
mask,
(r2ij + r2jk - r2ik) * (r2ij - r2jk + r2ik) * (-r2ij + r2jk + r2ik),
zero,
)

s = (r2ij + r2jk - r2ik) * (r2ij - r2jk + r2ik) * (-r2ij + r2jk + r2ik)
ang = torch.where(
mask_triples * (r2ij <= cutoff2) * (r2jk <= cutoff2) * (r2jk <= cutoff2),
0.375 * s / r5 + 1.0 / r3,
Expand Down
2 changes: 1 addition & 1 deletion src/tad_dftd3/damping/rational.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import torch

from .. import defaults
from ..typing import DD, Dict, Tensor
from .._typing import DD, Dict, Tensor


def rational_damping(
Expand Down
4 changes: 2 additions & 2 deletions src/tad_dftd3/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import torch

from . import constants
from .typing import Tensor
from ._typing import Tensor

# fmt: off
covalent_rad_2009 = constants.ANGSTROM_TO_BOHR * torch.tensor([
Expand Down Expand Up @@ -4569,7 +4569,7 @@
) # fmt: off


def _load_vdw_rad_d3(dtype: torch.dtype = torch.float) -> Tensor:
def _load_vdw_rad_d3(dtype: torch.dtype = torch.double) -> Tensor:
# pylint: disable=import-outside-toplevel
import math

Expand Down
2 changes: 1 addition & 1 deletion src/tad_dftd3/disp.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@
import torch

from . import data, defaults
from ._typing import DD, Any, DampingFunction, Dict, Optional, Tensor
from .damping import dispersion_atm, rational_damping
from .typing import DD, Any, DampingFunction, Dict, Optional, Tensor
from .util import cdist, real_pairs


Expand Down
20 changes: 8 additions & 12 deletions src/tad_dftd3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,12 @@
"""
import torch

from ._typing import Any, Tensor, WeightingFunction
from .reference import Reference
from .typing import Any, Tensor, WeightingFunction
from .util import real_atoms


def atomic_c6(
numbers: Tensor,
weights: Tensor,
reference: Reference,
) -> Tensor:
def atomic_c6(numbers: Tensor, weights: Tensor, reference: Reference) -> Tensor:
"""
Calculate atomic dispersion coefficients.
Expand Down Expand Up @@ -132,11 +128,11 @@ def weight_references(
# exactly one. This might, however, not be the case and ultimately cause
# larger deviations in the final values.
#
# If the values become even smaller, we may have to evaluate this portion
# in double precision to retain the correct results. This must be done in
# the D4 variant because the weighting functions contains higher powers,
# which lead to values down to 1e-300.
dcn = reference.cn[numbers] - cn.unsqueeze(-1)
# This must be done in the D4 variant because the weighting functions
# contains higher powers, which lead to values down to 1e-300.
# Since there are also cases in D3, we have to evaluate this portion
# in double precision to retain the correct results and avoid nan's.
dcn = (reference.cn[numbers] - cn.unsqueeze(-1)).type(torch.double)
weights = torch.where(
mask,
weighting_function(dcn, **kwargs),
Expand All @@ -159,4 +155,4 @@ def weight_references(
torch.sum(weights, dim=-1),
torch.tensor(torch.finfo(dcn.dtype).eps, device=cn.device, dtype=dcn.dtype),
)
return weights / norms.unsqueeze(-1)
return (weights / norms.unsqueeze(-1)).type(cn.dtype)
2 changes: 1 addition & 1 deletion src/tad_dftd3/ncoord.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
import torch

from . import data
from .typing import DD, Any, CountingFunction, Optional, Tensor
from ._typing import DD, Any, CountingFunction, Optional, Tensor
from .util import cdist, real_pairs


Expand Down
13 changes: 6 additions & 7 deletions src/tad_dftd3/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@

import torch

from .typing import Any, NoReturn, Optional, Tensor
from ._typing import Any, NoReturn, Optional, Tensor


def _load_cn(
dtype: torch.dtype = torch.float, device: Optional[torch.device] = None
dtype: torch.dtype = torch.double, device: Optional[torch.device] = None
) -> Tensor:
return torch.tensor(
[
Expand Down Expand Up @@ -133,7 +133,7 @@ def _load_cn(


def _load_c6(
dtype: torch.dtype = torch.float, device: Optional[torch.device] = None
dtype: torch.dtype = torch.double, device: Optional[torch.device] = None
) -> Tensor:
"""
Load reference C6 coefficients from file and fill them into a tensor
Expand All @@ -150,7 +150,7 @@ def _load_c6(
n_element = (math.isqrt(8 * ref.shape[0] + 1) - 1) // 2 + 1
n_reference = ref.shape[-1]
c6 = torch.zeros(
(n_element, n_element, n_reference, n_reference), dtype=dtype, device=device
(n_element, n_element, n_reference, n_reference), dtype=ref.dtype, device=device
)

for i in range(1, n_element):
Expand Down Expand Up @@ -188,14 +188,13 @@ def __init__(
):
if cn is None:
cn = _load_cn(
dtype if dtype is not None else torch.float,
dtype=dtype if dtype is not None else torch.double,
device=device,
)
self.cn = cn
if c6 is None:
c6 = _load_c6(
dtype if dtype is not None else torch.float,
device=device,
dtype=dtype if dtype is not None else torch.double, device=device
)
self.c6 = c6

Expand Down
2 changes: 1 addition & 1 deletion src/tad_dftd3/util/distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"""
import torch

from ..typing import Optional, Tensor
from .._typing import Optional, Tensor

__all__ = ["cdist"]

Expand Down
2 changes: 1 addition & 1 deletion src/tad_dftd3/util/grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import torch

from ..__version__ import __torch_version__
from ..typing import Any, Callable, Tensor, Tuple
from .._typing import Any, Callable, Tensor, Tuple

__all__ = ["jac", "hessian"]

Expand Down
65 changes: 63 additions & 2 deletions src/tad_dftd3/util/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,28 +21,89 @@
"""
import torch

from ..typing import List, Optional, Size, Tensor, TensorOrTensors, Union
from .._typing import List, Optional, Size, Tensor, TensorOrTensors, Union

__all__ = ["real_atoms", "real_pairs", "real_triples", "pack", "to_number"]


def real_atoms(numbers: Tensor) -> Tensor:
"""
Create a mask for atoms, discerning padding and actual atoms.
Padding value is zero.
Parameters
----------
numbers : Tensor
Atomic numbers for all atoms.
Returns
-------
Tensor
Mask for atoms that discerns padding and real atoms.
"""
return numbers != 0


def real_pairs(numbers: Tensor, diagonal: bool = False) -> Tensor:
"""
Create a mask for pairs of atoms from atomic numbers, discerning padding
and actual atoms. Padding value is zero.
Parameters
----------
numbers : Tensor
Atomic numbers for all atoms.
diagonal : bool, optional
Flag for also writing `False` to the diagonal, i.e., to all pairs
with the same indices. Defaults to `False`, i.e., writing False
to the diagonal.
Returns
-------
Tensor
Mask for atom pairs that discerns padding and real atoms.
"""
real = real_atoms(numbers)
mask = real.unsqueeze(-2) * real.unsqueeze(-1)
if diagonal is False:
mask *= ~torch.diag_embed(torch.ones_like(real))
return mask


def real_triples(numbers: Tensor, diagonal: bool = False) -> Tensor:
def real_triples(
numbers: torch.Tensor, diagonal: bool = False, self: bool = True
) -> Tensor:
"""
Create a mask for triples from atomic numbers. Padding value is zero.
Parameters
----------
numbers : torch.Tensor
Atomic numbers for all atoms.
diagonal : bool, optional
Flag for also writing `False` to the space diagonal, i.e., to all
triples with the same indices. Defaults to `False`, i.e., writing False
to the diagonal.
self : bool, optional
Flag for also writing `False` to all triples where at least two indices
are identical. Defaults to `True`, i.e., not writing `False`.
Returns
-------
Tensor
Mask for triples.
"""
real = real_pairs(numbers, diagonal=True)
mask = real.unsqueeze(-3) * real.unsqueeze(-2) * real.unsqueeze(-1)

if diagonal is False:
mask *= ~torch.diag_embed(torch.ones_like(real))

if self is False:
mask *= ~torch.diag_embed(torch.ones_like(real), offset=0, dim1=-3, dim2=-2)
mask *= ~torch.diag_embed(torch.ones_like(real), offset=0, dim1=-3, dim2=-1)
mask *= ~torch.diag_embed(torch.ones_like(real), offset=0, dim1=-2, dim2=-1)

return mask


Expand Down
25 changes: 23 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@

torch.set_printoptions(precision=10)

FAST_MODE: bool = True
"""Flag for fast gradient tests."""


def pytest_addoption(parser: pytest.Parser) -> None:
"""Set up additional command line options."""
Expand All @@ -42,6 +45,18 @@ def pytest_addoption(parser: pytest.Parser) -> None:
help="Enable JIT during tests (default = False).",
)

parser.addoption(
"--fast",
action="store_true",
help="Use `fast_mode` for gradient checks (default = True).",
)

parser.addoption(
"--slow",
action="store_true",
help="Do *not* use `fast_mode` for gradient checks (default = False).",
)

parser.addoption(
"--tpo-linewidth",
action="store",
Expand Down Expand Up @@ -83,9 +98,15 @@ def pytest_configure(config: pytest.Config) -> None:
torch.autograd.anomaly_mode.set_detect_anomaly(True)

if config.getoption("--jit"):
torch.jit._state.enable() # type: ignore
torch.jit._state.enable() # type: ignore # pylint: disable=protected-access
else:
torch.jit._state.disable() # type: ignore
torch.jit._state.disable() # type: ignore # pylint: disable=protected-access

global FAST_MODE
if config.getoption("--fast"):
FAST_MODE = True
if config.getoption("--slow"):
FAST_MODE = False

if config.getoption("--tpo-linewidth"):
torch.set_printoptions(linewidth=config.getoption("--tpo-linewidth"))
Expand Down
2 changes: 1 addition & 1 deletion tests/molecules.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""
import torch

from tad_dftd3.typing import Dict, Molecule
from tad_dftd3._typing import Dict, Molecule
from tad_dftd3.util import to_number

mols: Dict[str, Molecule] = {
Expand Down
2 changes: 1 addition & 1 deletion tests/samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""
import torch

from tad_dftd3.typing import Dict, Molecule, Tensor, TypedDict
from tad_dftd3._typing import Dict, Molecule, Tensor, TypedDict

from .molecules import mols
from .utils import merge_nested_dicts
Expand Down
Loading

0 comments on commit f9c585e

Please sign in to comment.