From 7145d5892c2d65dc0cefef38f86bbc538708f27d Mon Sep 17 00:00:00 2001 From: Wang Xinyan Date: Fri, 27 Oct 2023 18:23:23 +0800 Subject: [PATCH] Add esp calculator in ADMPPmeForce and CoulNoCutoffForce for noPBC systems --- dmff/admp/pme.py | 109 +++++++++++++++++++++++++++++++++-- dmff/classical/inter.py | 76 +++++++++++++++++------- dmff/generators/admp.py | 20 ++++--- dmff/generators/classical.py | 22 +++---- 4 files changed, 181 insertions(+), 46 deletions(-) diff --git a/dmff/admp/pme.py b/dmff/admp/pme.py index d7c8a04b1..95a7a228d 100755 --- a/dmff/admp/pme.py +++ b/dmff/admp/pme.py @@ -42,7 +42,7 @@ class ADMPPmeForce: The so called "environment paramters" means parameters that do not need to be differentiable ''' - def __init__(self, box, axis_type, axis_indices, rc, ethresh, lmax, lpol=False, lpme=True, steps_pol=None, has_aux=False): + def __init__(self, box, map_atomtype, axis_type, axis_indices, rc, ethresh, lmax, lpol=False, lpme=True, steps_pol=None, has_aux=False): ''' Initialize the ADMPPmeForce calculator. @@ -70,6 +70,7 @@ def __init__(self, box, axis_type, axis_indices, rc, ethresh, lmax, lpol=False, Output: ''' + self.map_atomtype = map_atomtype self.axis_type = axis_type self.axis_indices = axis_indices self.rc = rc @@ -130,18 +131,118 @@ def get_energy( U_ind, lconverg, n_cycle = self.optimize_Uind( positions, box, pairs, Q_local, pol, tholes, mScales, pScales, dScales, - U_init=U_init, steps_pol=self.steps_pol) + U_init=U_init * 10.0, steps_pol=self.steps_pol) # nm to angstrom # here we rely on Feynman-Hellman theorem, drop the term dV/dU*dU/dr ! # self.U_ind = jax.lax.stop_gradient(U_ind) energy = energy_fn(positions, box, pairs, Q_local, U_ind, pol, tholes, mScales, pScales, dScales) if aux is not None: - aux["U_ind"] = U_ind + aux["U_ind"] = U_ind * 0.1 # Angstrom to nm aux["lconverg"] = lconverg aux["n_cycle"] = n_cycle return energy, aux else: return energy return get_energy + + def generate_esp(self): + + @jit_condition() + def CD_mat(p1, p2): + mCD = jnp.zeros((1, 3)) + one_R = 1 / jnp.linalg.norm(p1 - p2 + 1e-16) + pre_factor = - one_R * one_R * one_R + delta = p2 - p1 + mCD = mCD.at[0,0].add(pre_factor * delta[0]) + mCD = mCD.at[0,1].add(pre_factor * delta[1]) + mCD = mCD.at[0,2].add(pre_factor * delta[2]) + return mCD + + @jit_condition() + def CQ_mat(p1, p2): + one_R = 1 / jnp.linalg.norm(p1 - p2 + 1e-16) + delta = p2 - p1 + CQ = jnp.zeros((3, 3)) + d_x, d_y, d_z = delta[0], delta[1], delta[2] + o_3 = one_R * one_R * one_R + o_5 = o_3 * one_R * one_R + + v = 3 * d_x**2 * o_5 - o_3 + CQ = CQ.at[0, 0].add(v) + v = 3 * d_x * d_y * o_5 + CQ = CQ.at[0, 1].add(v) + CQ = CQ.at[1, 0].add(v) + v = 3 * d_x * d_z * o_5 + CQ = CQ.at[0, 2].add(v) + CQ = CQ.at[2, 0].add(v) + v = 3 * d_y**2 * o_5 - o_3 + CQ = CQ.at[1, 1].add(v) + v = 3 * d_y * d_z * o_5 + CQ = CQ.at[1, 2].add(v) + CQ = CQ.at[2, 1].add(v) + v = 3 * d_z**2 * o_5 - o_3 + CQ = CQ.at[2, 2].add(v) + + return CQ.reshape((1, 9)) + + @jit_condition() + def esp_kernel(position, grid, Q, U): # Unit: angstrom + Qtot = Q.at[1:4].add(U) + r = grid - position + dist = jnp.linalg.norm(r+1e-16) + one_dist = 1. / dist + esp = DIELECTRIC * 0.1 * Qtot[0] * one_dist + if self.lpol or self.lmax >= 1: + # C-U&D + mCD = CD_mat(grid, position) + U = Q[1:4].reshape((3, 1)) * 0.1 + esp += DIELECTRIC * 0.1 * jnp.matmul(mCD, U) + if self.lmax >= 2: + # C-Q + mCQ = CQ_mat(grid, position) + Qmat = jnp.zeros((3, 3)) + Qmat.at[0,0].set(Qtot[4] / 300.0) + Qmat.at[1,1].set(Qtot[5] / 300.0) + Qmat.at[2,2].set(Qtot[6] / 300.0) + Qmat.at[0,1].set(Qtot[7] / 300.0) + Qmat.at[1,0].set(Qtot[7] / 300.0) + Qmat.at[0,2].set(Qtot[8] / 300.0) + Qmat.at[2,0].set(Qtot[8] / 300.0) + Qmat.at[1,2].set(Qtot[9] / 300.0) + Qmat.at[2,1].set(Qtot[9] / 300.0) + esp += DIELECTRIC * 0.1 * jnp.matmul(mCQ, Qmat.reshape((-1,1))) + + return esp.ravel() + + esp_point_kernel = jax.vmap(esp_kernel, in_axes=(0, None, 0, 0), out_axes=0) + + @jit_condition() + def esp_point(positions, grid, Q, U): + esp = esp_point_kernel(positions, grid, Q, U) + return jnp.sum(esp) + + esp_grid = jax.vmap(esp_point, in_axes=(None, 0, None, None), out_axes=0) + + if self.lpol: + @jit_condition() + def get_esp(positions, grids, Q_local, U_ind): + box = jnp.eye(3) * 1000.0 + local_frames = self.construct_local_frames(positions, box) + Q_global = rot_local2global(Q_local[self.map_atomtype], local_frames, self.lmax) + esp = esp_grid(positions, grids, Q_global, U_ind) + return esp.reshape((grids.shape[0],)) + else: + @jit_condition() + def get_esp(positions, grids, Q_local): + U_ind = jnp.zeros(positions.shape) + box = jnp.eye(3) * 1000.0 + local_frames = self.construct_local_frames(positions, box) + Q_global = rot_local2global(Q_local[self.map_atomtype], local_frames, self.lmax) + esp = esp_grid(positions, grids, Q_global, U_ind) + return esp.reshape((grids.shape[0],)) + + return get_esp + + def update_env(self, attr, val): @@ -825,7 +926,7 @@ def pme_real(positions, box, pairs, # deals with geometries dr = r1 - r2 - dr = v_pbc_shift(dr, box, box_inv) + dr = v_pbc_shift(dr, box, box_inv) + 1e-32 norm_dr = jnp.linalg.norm(dr, axis=-1) Ri = build_quasi_internal(r1, r2, dr, norm_dr) qiQI = rot_global2local(Q_extendi, Ri, lmax) diff --git a/dmff/classical/inter.py b/dmff/classical/inter.py index 8dd9d8b7b..3a35dbe03 100644 --- a/dmff/classical/inter.py +++ b/dmff/classical/inter.py @@ -1,5 +1,5 @@ from typing import Iterable, Tuple, Optional - +import jax import jax.numpy as jnp import numpy as np @@ -132,8 +132,9 @@ def get_energy(box, epsilon, sigma, epsfix, sigfix): class CoulNoCutoffForce: # E=\frac{{q}_{1}{q}_{2}}{4\pi\epsilon_0\epsilon_1 r} - def __init__(self, epsilon_1=1.0, topology_matrix=None) -> None: + def __init__(self, init_charges, epsilon_1=1.0, topology_matrix=None) -> None: + self.init_charges = init_charges self.eps_1 = epsilon_1 self.top_mat = topology_matrix @@ -146,7 +147,7 @@ def get_coul_energy(dr_vec, chrgprod, box): return E - def get_energy(positions, box, pairs, charges, mscales): + def get_energy_kernel(positions, box, pairs, charges, mscales): pairs = pairs.at[:, :2].set(regularize_pairs(pairs[:, :2])) mask = pair_buffer_scales(pairs[:, :2]) cov_pair = pairs[:, 2] @@ -162,14 +163,41 @@ def get_energy(positions, box, pairs, charges, mscales): return jnp.sum(E_inter * mask) - def get_energy_bcc(positions, box, pairs, pre_charges, bcc, mscales): - charges = pre_charges + jnp.dot(self.top_mat, bcc).flatten() - return get_energy(positions, box, pairs, charges, mscales) + if self.top_mat is None: + def get_energy(positions, box, pairs, mscales): + return get_energy_kernel(positions, box, pairs, self.init_charges, mscales) + else: + def get_energy(positions, box, pairs, bcc, mscales): + charges = self.init_charges + jnp.dot(self.top_mat, bcc).flatten() + return get_energy_kernel(positions, box, pairs, charges, mscales) + return get_energy + + def generate_esp(self): + + def esp_kernel(position, grid, charge): + dist = jnp.linalg.norm(position - grid + 1e-16) + oneR = 1. / dist + return ONE_4PI_EPS0 * charge * oneR + + esp_grid_kernel = jax.vmap(esp_kernel, in_axes=(0, None, 0)) + + def esp_grid(positions, grid, charges): + return jnp.sum(esp_grid_kernel(positions, grid, charges)) + + esp_all = jax.vmap(esp_grid, in_axes=(None, 0, None)) if self.top_mat is None: - return get_energy + def get_esp(positions, grids): + charges = self.init_charges + return esp_all(positions, grids, charges).ravel() else: - return get_energy_bcc + def get_esp(positions, grids, bcc): + charges = self.init_charges + jnp.dot(self.top_mat, bcc).flatten() + return esp_all(positions, grids, charges).ravel() + + return get_esp + + class CoulReactionFieldForce: @@ -177,12 +205,14 @@ class CoulReactionFieldForce: def __init__( self, r_cut, + init_charges, epsilon_1=1.0, epsilon_solv=78.5, isPBC=True, topology_matrix=None ) -> None: + self.init_charges = init_charges self.r_cut = r_cut self.krf = (1.0 / r_cut ** 3) * (epsilon_solv - 1) / (2.0 * epsilon_solv + 1) self.crf = (1.0 / r_cut) * 3.0 * epsilon_solv / (2.0 * epsilon_solv + 1) @@ -208,7 +238,7 @@ def get_rf_energy(dr_vec, chrgprod, box): return E - def get_energy(positions, box, pairs, charges, mscales): + def get_energy_kernel(positions, box, pairs, charges, mscales): pairs = pairs.at[:, :2].set(regularize_pairs(pairs[:, :2])) mask = pair_buffer_scales(pairs[:, :2]) @@ -224,15 +254,15 @@ def get_energy(positions, box, pairs, charges, mscales): E_inter = get_rf_energy(dr_vec, chrgprod_scale, box) return jnp.sum(E_inter * mask) - - def get_energy_bcc(positions, box, pairs, pre_charges, bcc, mscales): - charges = pre_charges + jnp.dot(self.top_mat, bcc).flatten() - return get_energy(positions, box, pairs, charges, mscales) if self.top_mat is None: - return get_energy + def get_energy(positions, box, pairs, mscales): + return get_energy_kernel(positions, box, pairs, self.init_charges, mscales) else: - return get_energy_bcc + def get_energy(positions, box, pairs, bcc, mscales): + charges = self.init_charges + jnp.dot(self.top_mat, bcc).flatten() + return get_energy_kernel(positions, box, pairs, charges, mscales) + return get_energy class CoulombPMEForce: @@ -240,12 +270,14 @@ class CoulombPMEForce: def __init__( self, r_cut: float, + init_charges, kappa: float, K: Tuple[int, int, int], pme_order: int = 6, topology_matrix: Optional[jnp.array] = None, ): self.r_cut = r_cut + self.init_charges = init_charges self.lmax = 0 self.kappa = kappa self.K1, self.K2, self.K3 = K[0], K[1], K[2] @@ -255,7 +287,7 @@ def __init__( def generate_get_energy(self): - def get_energy(positions, box, pairs, charges, mscales): + def get_energy_kernel(positions, box, pairs, charges, mscales): pme_recip_fn = generate_pme_recip( Ck_fn=Ck_1, @@ -292,11 +324,11 @@ def get_energy(positions, box, pairs, charges, mscales): False, ) - def get_energy_bcc(positions, box, pairs, pre_charges, bcc, mscales): - charges = pre_charges + jnp.dot(self.top_mat, bcc).flatten() - return get_energy(positions, box, pairs, charges, mscales) - if self.top_mat is None: - return get_energy + def get_energy(positions, box, pairs, mscales): + return get_energy_kernel(positions, box, pairs, self.init_charges, mscales) else: - return get_energy_bcc + def get_energy(positions, box, pairs, bcc, mscales): + charges = self.init_charges + jnp.dot(self.top_mat, bcc).flatten() + return get_energy_kernel(positions, box, pairs, charges, mscales) + return get_energy diff --git a/dmff/generators/admp.py b/dmff/generators/admp.py index 85ec54981..c40c2efee 100644 --- a/dmff/generators/admp.py +++ b/dmff/generators/admp.py @@ -881,6 +881,7 @@ def __init__(self, ffinfo: dict, paramset: ParamSet): self.ethresh = 5e-4 self.step_pol = None self.ref_dip = "" + self.pme_force = None self.lmax = int(self.ffinfo["Forces"][self.name]["meta"]["lmax"]) @@ -1084,16 +1085,16 @@ def overwrite(self, paramset): if node["name"] in ["Atom", "Multipole"]: node["c0"] = Q_global[n_multipole, 0] if self.lmax >= 1: - node["dX"] = Q_global[n_multipole, 1] - node["dY"] = Q_global[n_multipole, 2] - node["dZ"] = Q_global[n_multipole, 3] + node["dX"] = Q_global[n_multipole, 1] * 0.1 + node["dY"] = Q_global[n_multipole, 2] * 0.1 + node["dZ"] = Q_global[n_multipole, 3] * 0.1 if self.lmax >= 2: - node["qXX"] = Q_global[n_multipole, 4] - node["qYY"] = Q_global[n_multipole, 5] - node["qZZ"] = Q_global[n_multipole, 6] - node["qXY"] = Q_global[n_multipole, 7] - node["qXZ"] = Q_global[n_multipole, 8] - node["qYZ"] = Q_global[n_multipole, 9] + node["qXX"] = Q_global[n_multipole, 4] / 300.0 + node["qYY"] = Q_global[n_multipole, 5] / 300.0 + node["qZZ"] = Q_global[n_multipole, 6] / 300.0 + node["qXY"] = Q_global[n_multipole, 7] / 300.0 + node["qXZ"] = Q_global[n_multipole, 8] / 300.0 + node["qYZ"] = Q_global[n_multipole, 9] / 300.0 if q_local_masks[n_multipole] < 0.999: node["mask"] = "true" n_multipole += 1 @@ -1338,6 +1339,7 @@ def createPotential( self.step_pol = kwargs["step_pol"] pme_force = ADMPPmeForce( box, + map_atomtype, axis_types, axis_indices, rc, diff --git a/dmff/generators/classical.py b/dmff/generators/classical.py index e34d221ce..b330c87e2 100644 --- a/dmff/generators/classical.py +++ b/dmff/generators/classical.py @@ -913,13 +913,14 @@ def createPotential(self, topdata: DMFFTopology, nonbondedMethod, # do not use PME if nonbondedMethod in [app.CutoffPeriodic, app.CutoffNonPeriodic]: # use Reaction Field - coulforce = CoulReactionFieldForce(r_cut, isPBC=ifPBC) + coulforce = CoulReactionFieldForce(r_cut, charges, isPBC=ifPBC) if nonbondedMethod is app.NoCutoff: # use NoCutoff - coulforce = CoulNoCutoffForce() + coulforce = CoulNoCutoffForce(init_charges=charges) else: - coulforce = CoulombPMEForce(r_cut, kappa, (K1, K2, K3)) - + coulforce = CoulombPMEForce(r_cut, charges, kappa, (K1, K2, K3)) + + self.pme_force = coulforce coulenergy = coulforce.generate_get_energy() # LJ @@ -998,8 +999,7 @@ def potential_fn(positions, box, pairs, params, aux=None): # it is jit-compatiable isinstance_jnp(positions, box, params) - coulE = coulenergy(positions, box, pairs, charges, - mscales_coul) + coulE = coulenergy(positions, box, pairs, mscales_coul) ljE = ljenergy(positions, box, pairs, params[self.name]["epsilon"], params[self.name]["sigma"], eps_nbfix, sig_nbfix, mscales_lj) @@ -1142,21 +1142,21 @@ def createPotential(self, topdata: DMFFTopology, nonbondedMethod, # use Reaction Field coulforce = CoulReactionFieldForce( r_cut, + charges, isPBC=ifPBC, topology_matrix=top_mat if self._use_bcc else None) if nonbondedMethod is app.NoCutoff: # use NoCutoff coulforce = CoulNoCutoffForce( - topology_matrix=top_mat if self._use_bcc else None) + charges, topology_matrix=top_mat if self._use_bcc else None) else: coulforce = CoulombPMEForce( r_cut, + charges, kappa, (K1, K2, K3), topology_matrix=top_mat if self._use_bcc else None) coulenergy = coulforce.generate_get_energy() - self.coulforce = coulforce #for qeq calculation - self.coulenergy = coulenergy #for qeq calculation has_aux = False if "has_aux" in kwargs and kwargs["has_aux"]: @@ -1170,10 +1170,10 @@ def potential_fn(positions, box, pairs, params, aux=None): isinstance_jnp(positions, box, params) if self._use_bcc: - coulE = coulenergy(positions, box, pairs, charges, + coulE = coulenergy(positions, box, pairs, params["CoulombForce"]["bcc"], mscales_coul) else: - coulE = coulenergy(positions, box, pairs, charges, + coulE = coulenergy(positions, box, pairs, mscales_coul) if has_aux: