Skip to content

Commit

Permalink
Replace new_tensor (#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinfriede authored May 16, 2023
1 parent 96a67a9 commit 05ce444
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 17 deletions.
6 changes: 4 additions & 2 deletions src/tad_dftd3/damping/atm.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def dispersion_atm(
Tensor
Atom-resolved ATM dispersion energy.
"""
dd = {"device": positions.device, "dtype": positions.dtype}

s9 = s9.type(positions.dtype).to(positions.device)
rs9 = rs9.type(positions.dtype).to(positions.device)
alp = alp.type(positions.dtype).to(positions.device)
Expand All @@ -85,7 +87,7 @@ def dispersion_atm(
torch.cdist(
positions, positions, p=2, compute_mode="use_mm_for_euclid_dist"
),
positions.new_tensor(torch.finfo(positions.dtype).eps),
torch.tensor(torch.finfo(positions.dtype).eps, **dd),
),
2.0,
)
Expand All @@ -107,7 +109,7 @@ def dispersion_atm(
* (r2jk <= cutoff2)
* (r2jk <= cutoff2),
0.375 * s / r5 + 1.0 / r3,
positions.new_tensor(0.0),
torch.tensor(0.0, **dd),
)

energy = ang * fdamp * c9
Expand Down
6 changes: 4 additions & 2 deletions src/tad_dftd3/damping/rational.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def rational_damping(
Tensor
Values of the damping function.
"""
a1 = param.get("a1", distances.new_tensor(defaults.A1))
a2 = param.get("a2", distances.new_tensor(defaults.A1))
dd = {"device": distances.device, "dtype": distances.dtype}

a1 = param.get("a1", torch.tensor(defaults.A1, **dd))
a2 = param.get("a2", torch.tensor(defaults.A2, **dd))
return 1.0 / (distances.pow(order) + (a1 * torch.sqrt(qq) + a2).pow(order))
22 changes: 14 additions & 8 deletions src/tad_dftd3/disp.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,10 @@ def dispersion(
Damping function evaluate distance dependent contributions.
Additional arguments are passed through to the function.
"""
dd = {"device": positions.device, "dtype": positions.dtype}

if cutoff is None:
cutoff = positions.new_tensor(50.0)
cutoff = torch.tensor(50.0, **dd)
if r4r2 is None:
r4r2 = (
data.sqrt_z_r4_over_r2[numbers].type(positions.dtype).to(positions.device)
Expand Down Expand Up @@ -155,11 +157,13 @@ def dispersion2(
Damping function evaluate distance dependent contributions.
Additional arguments are passed through to the function.
"""
dd = {"device": positions.device, "dtype": positions.dtype}

mask = real_pairs(numbers, diagonal=False)
distances = torch.where(
mask,
torch.cdist(positions, positions, p=2, compute_mode="use_mm_for_euclid_dist"),
positions.new_tensor(torch.finfo(positions.dtype).eps),
torch.tensor(torch.finfo(positions.dtype).eps, **dd),
)

qq = 3 * r4r2.unsqueeze(-1) * r4r2.unsqueeze(-2)
Expand All @@ -168,19 +172,19 @@ def dispersion2(
t6 = torch.where(
mask * (distances <= cutoff),
damping_function(6, distances, qq, param, **kwargs),
positions.new_tensor(0.0),
torch.tensor(0.0, **dd),
)
t8 = torch.where(
mask * (distances <= cutoff),
damping_function(8, distances, qq, param, **kwargs),
positions.new_tensor(0.0),
torch.tensor(0.0, **dd),
)

e6 = -0.5 * torch.sum(c6 * t6, dim=-1)
e8 = -0.5 * torch.sum(c8 * t8, dim=-1)

s6 = param.get("s6", positions.new_tensor(defaults.S6))
s8 = param.get("s8", positions.new_tensor(defaults.S8))
s6 = param.get("s6", torch.tensor(defaults.S6, **dd))
s8 = param.get("s8", torch.tensor(defaults.S8, **dd))
return s6 * e6 + s8 * e8


Expand Down Expand Up @@ -220,8 +224,10 @@ def dispersion3(
Tensor
Atom-resolved three-body dispersion energy.
"""
alp = param.get("alp", positions.new_tensor(14.0))
s9 = param.get("s9", positions.new_tensor(14.0))
dd = {"device": positions.device, "dtype": positions.dtype}

alp = param.get("alp", torch.tensor(14.0, **dd))
s9 = param.get("s9", torch.tensor(1.0, **dd))
rs9 = rs9.type(positions.dtype).to(positions.device)

return dispersion_atm(numbers, positions, c6, rvdw, cutoff, s9, rs9, alp)
2 changes: 1 addition & 1 deletion src/tad_dftd3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def weight_references(
weights = torch.where(
mask,
weighting_function(reference.cn[numbers] - cn.unsqueeze(-1), **kwargs),
cn.new_tensor(0.0),
torch.tensor(0.0, device=cn.device, dtype=cn.dtype),
)
norms = torch.add(torch.sum(weights, dim=-1), epsilon)

Expand Down
9 changes: 5 additions & 4 deletions src/tad_dftd3/ncoord.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,10 @@ def coordination_number(
-------
Tensor: The coordination number of each atom in the system.
"""
dd = {"device": positions.device, "dtype": positions.dtype}

if cutoff is None:
cutoff = positions.new_tensor(25.0)
cutoff = torch.tensor(25.0, **dd)
if rcov is None:
rcov = data.covalent_rad_d3[numbers].type(positions.dtype).to(positions.device)
if numbers.shape != rcov.shape:
Expand All @@ -127,18 +129,17 @@ def coordination_number(
if numbers.shape != positions.shape[:-1]:
raise ValueError("Shape of positions is not consistent with atomic numbers")

eps = positions.new_tensor(torch.finfo(positions.dtype).eps)
mask = real_pairs(numbers, diagonal=False)
distances = torch.where(
mask,
torch.cdist(positions, positions, p=2, compute_mode="use_mm_for_euclid_dist"),
eps,
torch.tensor(torch.finfo(positions.dtype).eps, **dd),
)

rc = rcov.unsqueeze(-2) + rcov.unsqueeze(-1)
cf = torch.where(
mask * (distances <= cutoff),
counting_function(distances, rc.type(distances.dtype), **kwargs),
positions.new_tensor(0.0),
torch.tensor(0.0, **dd),
)
return torch.sum(cf, dim=-1)

0 comments on commit 05ce444

Please sign in to comment.