Skip to content

Commit

Permalink
Fix differentiation w.r.t. positions
Browse files Browse the repository at this point in the history
  • Loading branch information
awvwgk committed Apr 26, 2022
1 parent b6de889 commit 57a5a26
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 19 deletions.
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
[metadata]
name = tad-dftd3
version = 0.0.1
desciption = Torch autodiff DFT-D3 implementation
long_desciption = file: README.rst
description = Torch autodiff DFT-D3 implementation
long_description = file: README.rst
long_description_content_type = text/x-rst
license = Apache-2.0
license_files = LICENSE
Expand Down
29 changes: 21 additions & 8 deletions src/tad_dftd3/disp.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def dispersion(
rvdw: Optional[Tensor] = None,
r4r2: Optional[Tensor] = None,
damping_function: DampingFunction = rational_damping,
cutoff: Optional[Tensor] = None,
s6: float = 1.0,
s8: float = 1.0,
**kwargs
Expand Down Expand Up @@ -120,6 +121,8 @@ def dispersion(
s8 : float
Scaling factor for the C8 interaction.
"""
if cutoff is None:
cutoff = torch.tensor(50.0, dtype=positions.dtype)
if r4r2 is None:
r4r2 = data.sqrt_z_r4_over_r2[numbers].type(positions.dtype)
if rvdw is None:
Expand All @@ -133,19 +136,29 @@ def dispersion(
"Shape of expectation values is not consistent with atomic numbers"
)

eps = torch.tensor(torch.finfo(positions.dtype).eps, dtype=positions.dtype)
real = numbers != 0
mask = ~(real.unsqueeze(-2) * real.unsqueeze(-1))
distances = torch.cdist(positions, positions, p=2)
distances[mask] = 0
mask.diagonal(dim1=-2, dim2=-1).fill_(True)
mask = real.unsqueeze(-2) * real.unsqueeze(-1)
mask.diagonal(dim1=-2, dim2=-1).fill_(False)
distances = torch.where(
mask,
torch.cdist(positions, positions, p=2),
eps,
)

qq = 3 * r4r2.unsqueeze(-1) * r4r2.unsqueeze(-2)
c8 = c6 * qq

t6 = damping_function(6, distances, rvdw, qq, **kwargs)
t8 = damping_function(8, distances, rvdw, qq, **kwargs)
t6[mask] = 0
t8[mask] = 0
t6 = torch.where(
mask * (distances <= cutoff),
damping_function(6, distances, rvdw, qq, **kwargs),
torch.tensor(0.0, dtype=distances.dtype),
)
t8 = torch.where(
mask * (distances <= cutoff),
damping_function(8, distances, rvdw, qq, **kwargs),
torch.tensor(0.0, dtype=distances.dtype),
)

e6 = -0.5 * torch.sum(c6 * t6, dim=-1)
e8 = -0.5 * torch.sum(c8 * t8, dim=-1)
Expand Down
9 changes: 6 additions & 3 deletions src/tad_dftd3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,13 @@ def weight_references(
Weights of all reference systems
"""

mask = reference.cn[numbers] < 0
mask = reference.cn[numbers] >= 0

weights = weighting_function(reference.cn[numbers] - cn.unsqueeze(-1), **kwargs)
weights[mask] = 0
weights = torch.where(
mask,
weighting_function(reference.cn[numbers] - cn.unsqueeze(-1), **kwargs),
torch.tensor(0.0, dtype=cn.dtype),
)
norms = torch.add(torch.sum(weights, dim=-1), epsilon)

return weights / norms.unsqueeze(-1)
24 changes: 18 additions & 6 deletions src/tad_dftd3/ncoord.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def coordination_number(
positions: Tensor,
rcov: Optional[Tensor] = None,
counting_function: CountingFunction = exp_count,
cutoff: Optional[Tensor] = None,
**kwargs,
) -> Tensor:
"""
Expand All @@ -108,25 +109,36 @@ def coordination_number(
Calculates counting value in range 0 to 1 from a batch of
distances and covalent radii, additional parameters can
be passed through via key-value arguments.
cutoff : float
Real-space cutoff for the evaluation of counting function
Returns
-------
Tensor: The coordination number of each atom in the system.
"""
if cutoff is None:
cutoff = torch.tensor(25.0, dtype=positions.dtype)
if rcov is None:
rcov = data.covalent_rad_d3[numbers].type(positions.dtype)
if numbers.shape != rcov.shape:
raise ValueError("Shape of covalent radii is not consistent with atomic numbers")
if numbers.shape != positions.shape[:-1]:
raise ValueError("Shape of positions is not consistent with atomic numbers")

eps = torch.tensor(torch.finfo(positions.dtype).eps, dtype=positions.dtype)
real = numbers != 0
mask = ~(real.unsqueeze(-2) * real.unsqueeze(-1))
distances = torch.cdist(positions, positions, p=2)
distances[mask] = 0.0
mask.diagonal(dim1=-2, dim2=-1).fill_(True)
mask = real.unsqueeze(-2) * real.unsqueeze(-1)
mask.diagonal(dim1=-2, dim2=-1).fill_(False)
distances = torch.where(
mask,
torch.cdist(positions, positions, p=2),
eps,
)

rc = rcov.unsqueeze(-2) + rcov.unsqueeze(-1)
cf = counting_function(distances, rc.type(distances.dtype), **kwargs)
cf[mask] = 0
cf = torch.where(
mask * (distances <= cutoff),
counting_function(distances, rc.type(distances.dtype), **kwargs),
torch.tensor(0.0, dtype=distances.dtype),
)
return torch.sum(cf, dim=-1)
20 changes: 20 additions & 0 deletions tests/test_dftd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,23 @@ def func(*inputs):
return dftd3(numbers, positions, input_param)

assert torch.autograd.gradcheck(func, param)


@pytest.mark.grad
def test_positions_grad():
dtype = torch.float64
sample = samples.structures["C4H5NCS"]
numbers = sample["numbers"]
positions = sample["positions"].type(dtype)
positions.requires_grad_(True)
param = {
"s6": torch.tensor(1.00000000, dtype=dtype),
"s8": torch.tensor(0.78981345, dtype=dtype),
"a1": torch.tensor(0.49484001, dtype=dtype),
"a2": torch.tensor(5.73083694, dtype=dtype),
}

def func(positions):
return dftd3(numbers, positions, param)

assert torch.autograd.gradcheck(func, positions)

0 comments on commit 57a5a26

Please sign in to comment.