Skip to content

Commit

Permalink
Add esp calculator in ADMPPmeForce and CoulNoCutoffForce for noPBC sy…
Browse files Browse the repository at this point in the history
…stems
  • Loading branch information
WangXinyan940 committed Oct 27, 2023
1 parent 16d1179 commit 7145d58
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 46 deletions.
109 changes: 105 additions & 4 deletions dmff/admp/pme.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class ADMPPmeForce:
The so called "environment paramters" means parameters that do not need to be differentiable
'''

def __init__(self, box, axis_type, axis_indices, rc, ethresh, lmax, lpol=False, lpme=True, steps_pol=None, has_aux=False):
def __init__(self, box, map_atomtype, axis_type, axis_indices, rc, ethresh, lmax, lpol=False, lpme=True, steps_pol=None, has_aux=False):
'''
Initialize the ADMPPmeForce calculator.
Expand Down Expand Up @@ -70,6 +70,7 @@ def __init__(self, box, axis_type, axis_indices, rc, ethresh, lmax, lpol=False,
Output:
'''
self.map_atomtype = map_atomtype
self.axis_type = axis_type
self.axis_indices = axis_indices
self.rc = rc
Expand Down Expand Up @@ -130,18 +131,118 @@ def get_energy(
U_ind, lconverg, n_cycle = self.optimize_Uind(
positions, box, pairs, Q_local, pol, tholes,
mScales, pScales, dScales,
U_init=U_init, steps_pol=self.steps_pol)
U_init=U_init * 10.0, steps_pol=self.steps_pol) # nm to angstrom
# here we rely on Feynman-Hellman theorem, drop the term dV/dU*dU/dr !
# self.U_ind = jax.lax.stop_gradient(U_ind)
energy = energy_fn(positions, box, pairs, Q_local, U_ind, pol, tholes, mScales, pScales, dScales)
if aux is not None:
aux["U_ind"] = U_ind
aux["U_ind"] = U_ind * 0.1 # Angstrom to nm
aux["lconverg"] = lconverg
aux["n_cycle"] = n_cycle
return energy, aux
else:
return energy
return get_energy

def generate_esp(self):

@jit_condition()
def CD_mat(p1, p2):
mCD = jnp.zeros((1, 3))
one_R = 1 / jnp.linalg.norm(p1 - p2 + 1e-16)
pre_factor = - one_R * one_R * one_R
delta = p2 - p1
mCD = mCD.at[0,0].add(pre_factor * delta[0])
mCD = mCD.at[0,1].add(pre_factor * delta[1])
mCD = mCD.at[0,2].add(pre_factor * delta[2])
return mCD

@jit_condition()
def CQ_mat(p1, p2):
one_R = 1 / jnp.linalg.norm(p1 - p2 + 1e-16)
delta = p2 - p1
CQ = jnp.zeros((3, 3))
d_x, d_y, d_z = delta[0], delta[1], delta[2]
o_3 = one_R * one_R * one_R
o_5 = o_3 * one_R * one_R

v = 3 * d_x**2 * o_5 - o_3
CQ = CQ.at[0, 0].add(v)
v = 3 * d_x * d_y * o_5
CQ = CQ.at[0, 1].add(v)
CQ = CQ.at[1, 0].add(v)
v = 3 * d_x * d_z * o_5
CQ = CQ.at[0, 2].add(v)
CQ = CQ.at[2, 0].add(v)
v = 3 * d_y**2 * o_5 - o_3
CQ = CQ.at[1, 1].add(v)
v = 3 * d_y * d_z * o_5
CQ = CQ.at[1, 2].add(v)
CQ = CQ.at[2, 1].add(v)
v = 3 * d_z**2 * o_5 - o_3
CQ = CQ.at[2, 2].add(v)

return CQ.reshape((1, 9))

@jit_condition()
def esp_kernel(position, grid, Q, U): # Unit: angstrom
Qtot = Q.at[1:4].add(U)
r = grid - position
dist = jnp.linalg.norm(r+1e-16)
one_dist = 1. / dist
esp = DIELECTRIC * 0.1 * Qtot[0] * one_dist
if self.lpol or self.lmax >= 1:
# C-U&D
mCD = CD_mat(grid, position)
U = Q[1:4].reshape((3, 1)) * 0.1
esp += DIELECTRIC * 0.1 * jnp.matmul(mCD, U)
if self.lmax >= 2:
# C-Q
mCQ = CQ_mat(grid, position)
Qmat = jnp.zeros((3, 3))
Qmat.at[0,0].set(Qtot[4] / 300.0)
Qmat.at[1,1].set(Qtot[5] / 300.0)
Qmat.at[2,2].set(Qtot[6] / 300.0)
Qmat.at[0,1].set(Qtot[7] / 300.0)
Qmat.at[1,0].set(Qtot[7] / 300.0)
Qmat.at[0,2].set(Qtot[8] / 300.0)
Qmat.at[2,0].set(Qtot[8] / 300.0)
Qmat.at[1,2].set(Qtot[9] / 300.0)
Qmat.at[2,1].set(Qtot[9] / 300.0)
esp += DIELECTRIC * 0.1 * jnp.matmul(mCQ, Qmat.reshape((-1,1)))

return esp.ravel()

esp_point_kernel = jax.vmap(esp_kernel, in_axes=(0, None, 0, 0), out_axes=0)

@jit_condition()
def esp_point(positions, grid, Q, U):
esp = esp_point_kernel(positions, grid, Q, U)
return jnp.sum(esp)

esp_grid = jax.vmap(esp_point, in_axes=(None, 0, None, None), out_axes=0)

if self.lpol:
@jit_condition()
def get_esp(positions, grids, Q_local, U_ind):
box = jnp.eye(3) * 1000.0
local_frames = self.construct_local_frames(positions, box)
Q_global = rot_local2global(Q_local[self.map_atomtype], local_frames, self.lmax)
esp = esp_grid(positions, grids, Q_global, U_ind)
return esp.reshape((grids.shape[0],))
else:
@jit_condition()
def get_esp(positions, grids, Q_local):
U_ind = jnp.zeros(positions.shape)
box = jnp.eye(3) * 1000.0
local_frames = self.construct_local_frames(positions, box)
Q_global = rot_local2global(Q_local[self.map_atomtype], local_frames, self.lmax)
esp = esp_grid(positions, grids, Q_global, U_ind)
return esp.reshape((grids.shape[0],))

return get_esp




def update_env(self, attr, val):
Expand Down Expand Up @@ -825,7 +926,7 @@ def pme_real(positions, box, pairs,

# deals with geometries
dr = r1 - r2
dr = v_pbc_shift(dr, box, box_inv)
dr = v_pbc_shift(dr, box, box_inv) + 1e-32
norm_dr = jnp.linalg.norm(dr, axis=-1)
Ri = build_quasi_internal(r1, r2, dr, norm_dr)
qiQI = rot_global2local(Q_extendi, Ri, lmax)
Expand Down
76 changes: 54 additions & 22 deletions dmff/classical/inter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Iterable, Tuple, Optional

import jax
import jax.numpy as jnp
import numpy as np

Expand Down Expand Up @@ -132,8 +132,9 @@ def get_energy(box, epsilon, sigma, epsfix, sigfix):
class CoulNoCutoffForce:
# E=\frac{{q}_{1}{q}_{2}}{4\pi\epsilon_0\epsilon_1 r}

def __init__(self, epsilon_1=1.0, topology_matrix=None) -> None:
def __init__(self, init_charges, epsilon_1=1.0, topology_matrix=None) -> None:

self.init_charges = init_charges
self.eps_1 = epsilon_1
self.top_mat = topology_matrix

Expand All @@ -146,7 +147,7 @@ def get_coul_energy(dr_vec, chrgprod, box):

return E

def get_energy(positions, box, pairs, charges, mscales):
def get_energy_kernel(positions, box, pairs, charges, mscales):
pairs = pairs.at[:, :2].set(regularize_pairs(pairs[:, :2]))
mask = pair_buffer_scales(pairs[:, :2])
cov_pair = pairs[:, 2]
Expand All @@ -162,27 +163,56 @@ def get_energy(positions, box, pairs, charges, mscales):

return jnp.sum(E_inter * mask)

def get_energy_bcc(positions, box, pairs, pre_charges, bcc, mscales):
charges = pre_charges + jnp.dot(self.top_mat, bcc).flatten()
return get_energy(positions, box, pairs, charges, mscales)
if self.top_mat is None:
def get_energy(positions, box, pairs, mscales):
return get_energy_kernel(positions, box, pairs, self.init_charges, mscales)
else:
def get_energy(positions, box, pairs, bcc, mscales):
charges = self.init_charges + jnp.dot(self.top_mat, bcc).flatten()
return get_energy_kernel(positions, box, pairs, charges, mscales)
return get_energy

def generate_esp(self):

def esp_kernel(position, grid, charge):
dist = jnp.linalg.norm(position - grid + 1e-16)
oneR = 1. / dist
return ONE_4PI_EPS0 * charge * oneR

esp_grid_kernel = jax.vmap(esp_kernel, in_axes=(0, None, 0))

def esp_grid(positions, grid, charges):
return jnp.sum(esp_grid_kernel(positions, grid, charges))

esp_all = jax.vmap(esp_grid, in_axes=(None, 0, None))

if self.top_mat is None:
return get_energy
def get_esp(positions, grids):
charges = self.init_charges
return esp_all(positions, grids, charges).ravel()
else:
return get_energy_bcc
def get_esp(positions, grids, bcc):
charges = self.init_charges + jnp.dot(self.top_mat, bcc).flatten()
return esp_all(positions, grids, charges).ravel()

return get_esp




class CoulReactionFieldForce:
# E=\frac{{q}_{1}{q}_{2}}{4\pi\epsilon_0\epsilon_1}\left(\frac{1}{r}+{k}_{\mathit{rf}}{r}^{2}-{c}_{\mathit{rf}}\right)
def __init__(
self,
r_cut,
init_charges,
epsilon_1=1.0,
epsilon_solv=78.5,
isPBC=True,
topology_matrix=None
) -> None:

self.init_charges = init_charges
self.r_cut = r_cut
self.krf = (1.0 / r_cut ** 3) * (epsilon_solv - 1) / (2.0 * epsilon_solv + 1)
self.crf = (1.0 / r_cut) * 3.0 * epsilon_solv / (2.0 * epsilon_solv + 1)
Expand All @@ -208,7 +238,7 @@ def get_rf_energy(dr_vec, chrgprod, box):

return E

def get_energy(positions, box, pairs, charges, mscales):
def get_energy_kernel(positions, box, pairs, charges, mscales):
pairs = pairs.at[:, :2].set(regularize_pairs(pairs[:, :2]))
mask = pair_buffer_scales(pairs[:, :2])

Expand All @@ -224,28 +254,30 @@ def get_energy(positions, box, pairs, charges, mscales):
E_inter = get_rf_energy(dr_vec, chrgprod_scale, box)

return jnp.sum(E_inter * mask)

def get_energy_bcc(positions, box, pairs, pre_charges, bcc, mscales):
charges = pre_charges + jnp.dot(self.top_mat, bcc).flatten()
return get_energy(positions, box, pairs, charges, mscales)

if self.top_mat is None:
return get_energy
def get_energy(positions, box, pairs, mscales):
return get_energy_kernel(positions, box, pairs, self.init_charges, mscales)
else:
return get_energy_bcc
def get_energy(positions, box, pairs, bcc, mscales):
charges = self.init_charges + jnp.dot(self.top_mat, bcc).flatten()
return get_energy_kernel(positions, box, pairs, charges, mscales)
return get_energy


class CoulombPMEForce:

def __init__(
self,
r_cut: float,
init_charges,
kappa: float,
K: Tuple[int, int, int],
pme_order: int = 6,
topology_matrix: Optional[jnp.array] = None,
):
self.r_cut = r_cut
self.init_charges = init_charges
self.lmax = 0
self.kappa = kappa
self.K1, self.K2, self.K3 = K[0], K[1], K[2]
Expand All @@ -255,7 +287,7 @@ def __init__(

def generate_get_energy(self):

def get_energy(positions, box, pairs, charges, mscales):
def get_energy_kernel(positions, box, pairs, charges, mscales):

pme_recip_fn = generate_pme_recip(
Ck_fn=Ck_1,
Expand Down Expand Up @@ -292,11 +324,11 @@ def get_energy(positions, box, pairs, charges, mscales):
False,
)

def get_energy_bcc(positions, box, pairs, pre_charges, bcc, mscales):
charges = pre_charges + jnp.dot(self.top_mat, bcc).flatten()
return get_energy(positions, box, pairs, charges, mscales)

if self.top_mat is None:
return get_energy
def get_energy(positions, box, pairs, mscales):
return get_energy_kernel(positions, box, pairs, self.init_charges, mscales)
else:
return get_energy_bcc
def get_energy(positions, box, pairs, bcc, mscales):
charges = self.init_charges + jnp.dot(self.top_mat, bcc).flatten()
return get_energy_kernel(positions, box, pairs, charges, mscales)
return get_energy
20 changes: 11 additions & 9 deletions dmff/generators/admp.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,6 +881,7 @@ def __init__(self, ffinfo: dict, paramset: ParamSet):
self.ethresh = 5e-4
self.step_pol = None
self.ref_dip = ""
self.pme_force = None

self.lmax = int(self.ffinfo["Forces"][self.name]["meta"]["lmax"])

Expand Down Expand Up @@ -1084,16 +1085,16 @@ def overwrite(self, paramset):
if node["name"] in ["Atom", "Multipole"]:
node["c0"] = Q_global[n_multipole, 0]
if self.lmax >= 1:
node["dX"] = Q_global[n_multipole, 1]
node["dY"] = Q_global[n_multipole, 2]
node["dZ"] = Q_global[n_multipole, 3]
node["dX"] = Q_global[n_multipole, 1] * 0.1
node["dY"] = Q_global[n_multipole, 2] * 0.1
node["dZ"] = Q_global[n_multipole, 3] * 0.1
if self.lmax >= 2:
node["qXX"] = Q_global[n_multipole, 4]
node["qYY"] = Q_global[n_multipole, 5]
node["qZZ"] = Q_global[n_multipole, 6]
node["qXY"] = Q_global[n_multipole, 7]
node["qXZ"] = Q_global[n_multipole, 8]
node["qYZ"] = Q_global[n_multipole, 9]
node["qXX"] = Q_global[n_multipole, 4] / 300.0
node["qYY"] = Q_global[n_multipole, 5] / 300.0
node["qZZ"] = Q_global[n_multipole, 6] / 300.0
node["qXY"] = Q_global[n_multipole, 7] / 300.0
node["qXZ"] = Q_global[n_multipole, 8] / 300.0
node["qYZ"] = Q_global[n_multipole, 9] / 300.0
if q_local_masks[n_multipole] < 0.999:
node["mask"] = "true"
n_multipole += 1
Expand Down Expand Up @@ -1338,6 +1339,7 @@ def createPotential(
self.step_pol = kwargs["step_pol"]
pme_force = ADMPPmeForce(
box,
map_atomtype,
axis_types,
axis_indices,
rc,
Expand Down
Loading

0 comments on commit 7145d58

Please sign in to comment.