Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update qeq module #187

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
266 changes: 215 additions & 51 deletions dmff/admp/qeq.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import jax.numpy as jnp
from ..common.constants import DIELECTRIC
from jax import grad, vmap
from jax import grad, value_and_grad, vmap, jacfwd, jacrev
from ..classical.inter import CoulNoCutoffForce, CoulombPMEForce
from typing import Tuple, List
from ..settings import PRECISION
Expand Down Expand Up @@ -69,15 +69,19 @@ def padding_consts(const_list, max_idx):

@jit_condition()
def E_constQ(q, lagmt, const_list, const_vals):
constraint = (group_sum(q, const_list) - const_vals) * lagmt
return jnp.sum(constraint)
# constraint = (group_sum(q, const_list) - const_vals) * lagmt
# return jnp.sum(constraint)
return 0.0


@jit_condition()
def E_constP(q, lagmt, const_list, const_vals):
constraint = group_sum(q, const_list) * const_vals
return jnp.sum(constraint)

@jit_condition()
def E_noconst(q, lagmt, const_list, const_vals):
return 0.0

@vmap
@jit_condition()
Expand All @@ -87,13 +91,14 @@ def mask_to_zero(v, mask):
)


@jit_condition()
def E_sr(pos, box, pairs, q, eta, ds, buffer_scales):
@jit_condition(static_argnums=[6])
def E_sr(pos, box, pairs, q, eta, buffer_scales, pbc_flag):
return 0.0


@jit_condition()
def E_sr2(pos, box, pairs, q, eta, ds, buffer_scales):
@jit_condition(static_argnums=[6])
def E_sr2(pos, box, pairs, q, eta, buffer_scales, pbc_flag):
ds = ds_pairs(pos, box, pairs, pbc_flag)
etasqrt = jnp.sqrt(2 * (eta[pairs[:, 0]] ** 2 + eta[pairs[:, 1]] ** 2))
pre_pair = -eta_piecewise(etasqrt, ds) * DIELECTRIC
pre_self = etainv_piecewise(eta) / (jnp.sqrt(2 * jnp.pi)) * DIELECTRIC
Expand All @@ -104,8 +109,9 @@ def E_sr2(pos, box, pairs, q, eta, ds, buffer_scales):
return e_sr


@jit_condition()
def E_sr3(pos, box, pairs, q, eta, ds, buffer_scales):
@jit_condition(static_argnums=[6])
def E_sr3(pos, box, pairs, q, eta, buffer_scales, pbc_flag):
ds = ds_pairs(pos, box, pairs, pbc_flag)
etasqrt = jnp.sqrt(
eta[pairs[:, 0]] ** 2 + eta[pairs[:, 1]] ** 2 + 1e-64
) # add eta to avoid division by zero
Expand All @@ -126,13 +132,13 @@ def E_site(chi, J, q):

@jit_condition()
def E_site2(chi, J, q):
ene = (chi * q + 0.5 * J * q**2) * 96.4869
ene = (chi * q + 0.5 * J * q**2) * 96.4869 #ev to kj/mol
return jnp.sum(ene)


@jit_condition()
def E_site3(chi, J, q):
ene = chi * q * 4.184 + J * q**2 * DIELECTRIC * 2 * jnp.pi
ene = chi * q + J* q**2 # kj/mol
return jnp.sum(ene)


Expand All @@ -149,7 +155,7 @@ def E_corr(pos, box, pairs, q, kappa, neutral_flag=True):
- Q_tot * (jnp.sum(q * pos[:, 2] ** 2))
- jnp.power(Q_tot, 2) * jnp.power(Lz, 2) / 12
)
if neutral_flag:
if not neutral_flag:
# kappa = pme_potential.pme_force.kappa
pre_corr_non = -jnp.pi / (2 * V * kappa**2) * DIELECTRIC
e_corr_non = pre_corr_non * Q_tot**2
Expand Down Expand Up @@ -211,20 +217,47 @@ def __init__(
slab_flag: bool = False,
constQ: bool = True,
pbc_flag: bool = True,
part_const:bool = True,
has_aux=False,
):
self.has_aux = has_aux
self.part_const = part_const
const_vals = np.array(const_vals)
if neutral_flag:
const_vals = const_vals - np.sum(const_vals) / len(const_vals)
#if neutral_flag:
# const_vals = const_vals - np.sum(const_vals) / len(const_vals)
self.const_vals = jnp.array(const_vals)
assert len(const_list) == len(
const_vals
), "const_list and const_vals must have the same length"
n_atoms = len(init_q)
self.const_list = padding_consts(const_list, n_atoms)

const_mat = np.zeros((len(const_list), n_atoms))
for ncl, cl in enumerate(const_list):
const_mat[ncl][cl] = 1
self.const_mat = jnp.array(const_mat)

if len(const_list) != 0:
self.const_list = padding_consts(const_list, n_atoms)
#if fix part charges
self.all_const_list = self.const_list[jnp.where(self.const_list < n_atoms)]
else:
self.const_list = np.array(const_list)
self.all_const_list = self.const_list

all_fix_list = jnp.setdiff1d(jnp.array(range(n_atoms)),self.all_const_list)
fix_mat = np.zeros((len(all_fix_list),n_atoms))
for i, j in enumerate(all_fix_list):
fix_mat[i][j] = 1
self.all_fix_list = jnp.array(all_fix_list)
self.fix_mat = jnp.array(fix_mat)

self.init_q = jnp.array(init_q)
self.init_lagmt = jnp.ones((len(const_list),))

self.init_energy = True #init charge by hession inversion method
self.icount = 0
self.hessinv_stride = 1
self.qupdate_stride = 1

self.damp_mod = damp_mod
self.neutral_flag = neutral_flag
Expand All @@ -234,6 +267,8 @@ def __init__(

if constQ:
e_constraint = E_constQ
elif not part_const:
e_constraint = E_noconst
else:
e_constraint = E_constP
self.e_constraint = e_constraint
Expand Down Expand Up @@ -318,11 +353,15 @@ def coul_energy(positions, box, pairs, q, mscales):

self.coul_energy = coul_energy


def generate_get_energy(self):
@jit_condition()
def E_full(q, lagmt, chi, J, pos, box, pairs, eta, ds, buffer_scales, mscales):
e1 = self.e_constraint(q, lagmt, self.const_list, self.const_vals)
e2 = self.e_sr(pos * 10, box * 10, pairs, q, eta, ds * 10, buffer_scales)
def E_full(q, lagmt, chi, J, pos, box, pairs, eta, buffer_scales, mscales):
if self.part_const:
e1 = self.e_constraint(q, lagmt, self.const_list, self.const_vals)
else:
e1 = 0
e2 = self.e_sr(pos * 10, box * 10, pairs, q, eta, buffer_scales, self.pbc_flag)
e3 = self.e_site(chi, J, q)
e4 = self.coul_energy(pos, box, pairs, q, mscales)
if self.slab_flag:
Expand All @@ -336,70 +375,195 @@ def E_full(q, lagmt, chi, J, pos, box, pairs, eta, ds, buffer_scales, mscales):
grad_E_full = grad(E_full, argnums=(0, 1))

@jit_condition()
def E_grads(
b_value, chi, J, positions, box, pairs, eta, ds, buffer_scales, mscales
):
n_const = len(self.const_vals)
q = b_value[:-n_const]
lagmt = b_value[-n_const:]

g1, g2 = grad_E_full(
q, lagmt, chi, J, positions, box, pairs, eta, ds, buffer_scales, mscales
)
g = jnp.concatenate((g1, g2))
return g
def E_hession(q, lagmt, chi, J, pos, box, pairs, eta, buffer_scales, mscales):
h = jacfwd(jacrev(E_full, argnums=(0)))(q, lagmt, chi, J, pos, box, pairs, eta, buffer_scales, mscales)
return h

def get_energy(positions, box, pairs, mscales, eta, chi, J, aux=None):
@jit_condition()
def get_init_energy(positions, box, pairs, mscales, eta, chi, J, aux=None):
pos = positions
ds = ds_pairs(pos, box, pairs, self.pbc_flag)
buffer_scales = pair_buffer_scales(pairs)

n_const = len(self.init_lagmt)
b_vector = jnp.concatenate((-chi, self.const_vals)) #For E_site3

if self.has_aux:
b_value = jnp.concatenate((aux["q"], aux["lagmt"]))
q = aux["q"][:len(pos)]
lagmt = aux["q"][len(pos):]
else:
b_value = jnp.concatenate([self.init_q, self.init_lagmt])
# if JAXOPT_OLD:
if True:
rf = jaxopt.ScipyRootFinding(
optimality_fun=E_grads, method="hybr", jit=False, tol=1e-10
)
q = self.init_q
lagmt = self.init_lagmt
B = E_hession(q, lagmt, chi, J, pos, box, pairs, eta, buffer_scales, mscales)

if self.part_const:
C = jnp.eye(len(q))
A = C.at[self.all_const_list].set(B[self.all_const_list])
else:
rf = jaxopt.Broyden(fun=E_grads, tol=1e-10)
b_0, _ = rf.run(
b_value,
A = self.fix_mat

b_vector = b_vector.at[self.all_fix_list].set(q[self.all_fix_list])


m0 = jnp.concatenate((A,self.const_mat),axis=0)
n0 = jnp.concatenate((jnp.transpose(self.const_mat),jnp.zeros((n_const,n_const))),axis=0)

M = jnp.concatenate((m0,n0),axis=1)

q_0 = jnp.linalg.solve(M,b_vector)
q = q_0[:len(pos)]
lagmt = q_0[len(pos):]
energy = E_full(
q,
lagmt,
chi,
J,
positions,
box,
pairs,
eta,
ds,
buffer_scales,
mscales,
)
b_0 = jax.lax.stop_gradient(b_0)
q_0 = b_0[:-n_const]
lagmt_0 = b_0[-n_const:]
self.init_energy = False
# self.icount = self.icount + 1
if self.has_aux:
aux["q"] = q_0
aux["A"] = A
aux["m0"] = m0
aux["n0"] = n0
aux["b_vector"] = b_vector
aux["init_energy"] = self.init_energy
aux["icount"] = self.icount
return energy, aux
else:
return energy


# @jit_condition()
def get_proj_grad(func, constraint_matrix, has_aux=False):
def value_and_proj_grad(*arg, **kwargs):
value, grad = value_and_grad(func, has_aux=has_aux)(*arg, **kwargs)
a = jnp.matmul(constraint_matrix, grad.reshape(-1, 1))
b = jnp.sum(constraint_matrix * constraint_matrix, axis=1, keepdims=True)
delta_grad = jnp.matmul((a / b).T, constraint_matrix)
proj_grad = grad - delta_grad.reshape(-1)
return value, proj_grad
return value_and_proj_grad

@jit_condition()
def get_step_energy(positions, box, pairs, mscales, eta, chi, J, aux=None):
if self.init_energy:
if self.has_aux:
energy,aux = get_init_energy(positions, box, pairs, mscales, eta, chi, J, aux)
return energy, aux
else:
energy = get_init_energy(positions, box, pairs, mscales, eta, chi, J, aux)
return energy
if not self.icount % self.hessinv_stride :
if self.has_aux:
energy,aux = get_init_energy(positions, box, pairs, mscales, eta, chi, J, aux)
return energy, aux
else:
energy = get_init_energy(positions, box, pairs, mscales, eta, chi, J, aux)
return energy

func = get_proj_grad(E_full,self.const_mat)
solver = jaxopt.LBFGS(
fun=func,
value_and_grad=True,
tol=1e-2,
)
pos = positions
buffer_scales = pair_buffer_scales(pairs)
if self.has_aux:
q = aux["q"][:len(pos)]
lagmt = aux["q"][len(pos):]
else:
q = self.init_q
lagmt = self.init_lagmt

res = solver.run(
q,
lagmt,
chi,
J,
positions,
box,
pairs,
eta,
buffer_scales,
mscales,
)
q_opt = res.params
energy = E_full(
q_0,
lagmt_0,
q_opt,
lagmt,
chi,
J,
positions,
box,
pairs,
eta,
ds,
buffer_scales,
mscales,
)
if self.has_aux:
aux["q"] = q_0
aux["lagmt"] = lagmt_0
aux["q"] = aux['q'].at[:len(pos)].set(q_opt)
return energy, aux
else:
return energy
# @jit_condition()
def get_energy(positions, box, pairs, mscales, eta, chi, J, aux=None):
if self.has_aux :
if "const_vals" in aux.keys():
self.const_vals = aux["const_vals"]
if "hessinv_stride" in aux.keys():
self.hessinv_stride = aux["hessinv_stride"]
if "qupdate_stride" in aux.keys():
self.qupdate_stride = aux["qupdate_stride"]
if not self.icount % self.qupdate_stride :
if self.has_aux:
# aux["q"] = aux['q'].at[:len(pos)].set(q)
energy, aux = get_step_energy(positions, box, pairs, mscales, eta, chi, J, aux)
self.icount = self.icount + 1
aux["icount"] = self.icount
return energy, aux
else:
self.icount = self.icount + 1
energy = get_step_energy(positions, box, pairs, mscales, eta, chi, J )
return energy

else:
self.icount = self.icount + 1
# print(self.icount)
pos = positions
buffer_scales = pair_buffer_scales(pairs)
if self.has_aux:
q = aux["q"][:len(pos)]
lagmt = aux["q"][len(pos):]
else:
q = self.init_q
lagmt = self.init_lagmt
energy = E_full(
q,
lagmt,
chi,
J,
positions,
box,
pairs,
eta,
buffer_scales,
mscales,
)

if self.has_aux:
aux = aux
# aux["q"] = aux['q'].at[:len(pos)].set(q)
aux["icount"] = self.icount
return energy, aux
else:
return energy

return get_energy

Loading
Loading