diff --git a/backend/openmm_dmff_plugin/openmmapi/include/DMFFForce.h b/backend/openmm_dmff_plugin/openmmapi/include/DMFFForce.h index 4cd428fc8..dec6c22d6 100644 --- a/backend/openmm_dmff_plugin/openmmapi/include/DMFFForce.h +++ b/backend/openmm_dmff_plugin/openmmapi/include/DMFFForce.h @@ -77,7 +77,17 @@ class OPENMM_EXPORT_DMFF DMFFForce : public OpenMM::Force { * @param hasAux : true if model was saved with auxilary input. */ void setHasAux(const bool hasAux); - + /** + * @brief Set the Cutoff for neighbor list fetching. + * + * @param cutoff + */ + void setCutoff(const double cutoff); + /** + * @brief get the DMFF graph file. + * + * @return const std::string& + */ const std::string& getDMFFGraphFile() const; /** * @brief Get the Coord Unit Coefficient. diff --git a/backend/openmm_dmff_plugin/openmmapi/src/DMFFForce.cpp b/backend/openmm_dmff_plugin/openmmapi/src/DMFFForce.cpp index 5a48abf00..9574d7050 100644 --- a/backend/openmm_dmff_plugin/openmmapi/src/DMFFForce.cpp +++ b/backend/openmm_dmff_plugin/openmmapi/src/DMFFForce.cpp @@ -58,6 +58,10 @@ void DMFFForce::setHasAux(const bool hasAux){ this->has_aux = hasAux; } +void DMFFForce::setCutoff(const double cutoff){ + this->cutoff = cutoff; +} + double DMFFForce::getCoordUnitCoefficient() const {return coordCoeff;} double DMFFForce::getForceUnitCoefficient() const {return forceCoeff;} double DMFFForce::getEnergyUnitCoefficient() const {return energyCoeff;} diff --git a/backend/openmm_dmff_plugin/python/OpenMMDMFFPlugin.i b/backend/openmm_dmff_plugin/python/OpenMMDMFFPlugin.i index 6bd7bccb5..ebe300ea6 100644 --- a/backend/openmm_dmff_plugin/python/OpenMMDMFFPlugin.i +++ b/backend/openmm_dmff_plugin/python/OpenMMDMFFPlugin.i @@ -48,7 +48,7 @@ public: void setUnitTransformCoefficients(const double coordCoefficient, const double forceCoefficient, const double energyCoefficient); void setHasAux(const bool hasAux); - + void setCutoff(const double cutoff); /* * Add methods for casting a Force to a DMFFForce. */ diff --git a/backend/openmm_dmff_plugin/python/OpenMMDMFFPlugin/tools.py b/backend/openmm_dmff_plugin/python/OpenMMDMFFPlugin/tools.py index 7726b3210..db846b305 100644 --- a/backend/openmm_dmff_plugin/python/OpenMMDMFFPlugin/tools.py +++ b/backend/openmm_dmff_plugin/python/OpenMMDMFFPlugin/tools.py @@ -78,6 +78,16 @@ def setHasAux(self, has_aux = False): has_aux (bool, optional): Defaults to False. """ self.dmff_force.setHasAux(has_aux) + return + + def setCutoff(self, cutoff = 1.2): + """Set the cutoff for the DMFF model. + + Args: + cutoff (float, optional): Defaults to 1.2. + """ + self.dmff_force.setCutoff(cutoff) + return def createSystem(self, topology): """Create the OpenMM System object for the DMFF model. diff --git a/dmff/eann/eann.py b/dmff/eann/eann.py index bd7db7487..dc1edd5ca 100644 --- a/dmff/eann/eann.py +++ b/dmff/eann/eann.py @@ -247,7 +247,7 @@ def get_features(self, radial, dr, pairs, buffer_scales, orb_coeff): f_cut = cutoff_cosine(dr_norm, self.rc) neigh_list = jnp.concatenate((pairs,pairs[:,[1,0]]),axis=0) buffer_scales_ = jnp.concatenate((buffer_scales,buffer_scales),axis=0) - totneighbour = len(neigh_list) + totneighbour = neigh_list.shape[0] prefacs = f_cut.reshape(1, -1) angular = prefacs for ipsin in range(1,self.nipsin+1): @@ -310,7 +310,8 @@ def get_energy(positions, box, pairs, params): self.rs = params['density.rs'] self.inta = params['density.inta'] - radial_i, radial_j = get_gto(jnp.arange(len(dr_norm)), dr_norm, pairs, self.rc, self.rs, self.inta, self.elem_indices) + length_dr_norm = dr_norm.shape[0] + radial_i, radial_j = get_gto(jnp.arange(length_dr_norm), dr_norm, pairs, self.rc, self.rs, self.inta, self.elem_indices) radial = jnp.concatenate((radial_i,radial_j), axis=0) orb_coeff = params['density.params'][self.elem_indices,:] # (48,16) diff --git a/dmff/generators/ml.py b/dmff/generators/ml.py index ae747bb88..4bf5595ad 100644 --- a/dmff/generators/ml.py +++ b/dmff/generators/ml.py @@ -116,12 +116,19 @@ def createPotential(self, topdata: DMFFTopology, nonbondedMethod, nonbondedCutof n_elem, elem_indices = get_elem_indices(self.ommtopology) self.model = EANNForce(n_elem, elem_indices, n_gto=self.ngto, nipsin=self.nipsin, rc=self.rc) n_layers = self.model.n_layers - def potential_fn(positions, box, pairs, params): + + has_aux = False + if "has_aux" in kwargs and kwargs["has_aux"]: + has_aux = True + + def potential_fn(positions, box, pairs, params, aux=None): # convert unit to angstrom positions = positions * 10 box = box * 10 - - return self.model.get_energy(positions, box, pairs, params[self.name]) + if has_aux: + return self.model.get_energy(positions, box, pairs, params[self.name]), aux + else: + return self.model.get_energy(positions, box, pairs, params[self.name]) self._jaxPotential = potential_fn return potential_fn