Skip to content

Commit

Permalink
Fix examples (#130)
Browse files Browse the repository at this point in the history
* Add sGNN generator
fixed a few problems in ADMPPmeGenerator

* remove debugging codes

* Fix the two examples:
* fluctuated_leading_term_waterff
* peg_slater_isa
Make map_atomtype and map_poltype available in generator

* Fix the jupyter notebook

* Update fluctuated_leading_term_waterff and water_fullpol example

* Fix the ADMP examples, also improve the way map_atomtypes and
map_poltypes are accessed.

* Break a long line in code

* Fix the ADMP aux test
  • Loading branch information
KuangYu authored Nov 3, 2023
1 parent 2171e4b commit 9302690
Show file tree
Hide file tree
Showing 24 changed files with 3,284 additions and 3,168 deletions.
4 changes: 3 additions & 1 deletion dmff/admp/pme.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,7 @@ def optimize_Uind(
flag = False
else: # converged
flag = True
n_cycles = i
else:

def update_U(i, U):
Expand All @@ -418,7 +419,8 @@ def update_U(i, U):

U = jax.lax.fori_loop(0, steps_pol, update_U, U)
flag = True
return U, flag, steps_pol
n_cycles = steps_pol
return U, flag, n_cycles


def setup_ewald_parameters(
Expand Down
2 changes: 1 addition & 1 deletion dmff/generators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .classical import *
from .admp import *
from .ml import *
from .qeq import *
from .qeq import *
35 changes: 27 additions & 8 deletions dmff/generators/admp.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def createPotential(
for i in range(n_atoms):
atype = atoms[i].meta[self.key_type]
map_atomtype[i] = self._find_atype_key_index(atype)
topdata._meta[self.name+"_map_atomtype"] = map_atomtype
# here box is only used to setup ewald parameters, no need to be differentiable
if lpme:
box = topdata.getPeriodicBoxVectors() * 10
Expand Down Expand Up @@ -339,6 +340,8 @@ def createPotential(
atype = atoms[i].meta[self.key_type]
map_atomtype[i] = self._find_atype_key_index(atype)

topdata._meta[self.name+"_map_atomtype"] = map_atomtype

# here box is only used to setup ewald parameters, no need to be differentiable
if lpme:
box = topdata.getPeriodicBoxVectors() * 10
Expand Down Expand Up @@ -491,6 +494,7 @@ def createPotential(
atype = atoms[i].meta[self.key_type]
map_atomtype[i] = np.where(self.atom_keys == atype)[0][0]

topdata._meta[self.name+"_map_atomtype"] = map_atomtype
pot_fn_sr = generate_pairwise_interaction(TT_damping_qq_kernel, static_args={})

has_aux = False
Expand Down Expand Up @@ -630,6 +634,8 @@ def createPotential(
atype = atoms[i].meta[self.key_type]
map_atomtype[i] = np.where(self.atom_keys == atype)[0][0]

topdata._meta[self.name+"_map_atomtype"] = map_atomtype

# WORKING
pot_fn_sr = generate_pairwise_interaction(
slater_disp_damping_kernel, static_args={}
Expand Down Expand Up @@ -765,6 +771,8 @@ def createPotential(
atype = atoms[i].meta[self.key_type]
map_atomtype[i] = np.where(self.atom_keys == atype)[0][0]

topdata._meta[self.name+"_map_atomtype"] = map_atomtype

pot_fn_sr = generate_pairwise_interaction(slater_sr_kernel, static_args={})

has_aux = False
Expand Down Expand Up @@ -818,7 +826,10 @@ def createPotential(
atype = atoms[i].meta[self.key_type]
map_atomtype[i] = np.where(self.atom_keys == atype)[0][0]

pot_fn_sr = generate_pairwise_interaction(slater_sr_kernel, static_args={})
topdata._meta[self.name+"_map_atomtype"] = map_atomtype

pot_fn_sr = generate_pairwise_interaction(slater_sr_kernel,
static_args={})

has_aux = False
if "has_aux" in kwargs and kwargs["has_aux"]:
Expand Down Expand Up @@ -1193,6 +1204,7 @@ def createPotential(
if self.lpol:
map_poltype[i] = self._find_polarize_key_index(atype)


# here box is only used to setup ewald parameters, no need to be differentiable
box = topdata.getPeriodicBoxVectors()
if box is not None:
Expand Down Expand Up @@ -1362,8 +1374,8 @@ def createPotential(
has_aux
)
self.pme_force = pme_force
topdata._meta["admp_map_atomtype"] = map_atomtype
topdata._meta["admp_map_poltype"] = map_poltype
topdata._meta[self.name+"_map_atomtype"] = map_atomtype
topdata._meta[self.name+"_map_poltype"] = map_poltype

if "has_aux" in kwargs and kwargs["has_aux"]:
has_aux = True
Expand All @@ -1379,11 +1391,18 @@ def potential_fn(positions, box, pairs, params, aux=None):
tholes = params["ADMPPmeForce"]["thole"][map_poltype]

if has_aux:
energy, aux = pme_force.get_energy(
positions, box, pairs, Q_local, pol, tholes,
self.mScales, self.pScales, self.dScales,
U_init = aux["U_ind"], aux = aux
)
if aux is not None:
energy, aux = pme_force.get_energy(
positions, box, pairs, Q_local, pol, tholes,
self.mScales, self.pScales, self.dScales,
U_init = aux["U_ind"], aux = aux
)
else:
energy, aux = pme_force.get_energy(
positions, box, pairs, Q_local, pol, tholes,
self.mScales, self.pScales, self.dScales,
U_init=jnp.zeros((n_atoms,3)), aux={}
)
return energy, aux
else:
energy = pme_force.get_energy(
Expand Down
Loading

0 comments on commit 9302690

Please sign in to comment.