diff --git a/README.md b/README.md index bbcdda88..febc4739 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -Endstate correction from MM to QML potential +Endstate correction from MM to NNP ============================== [//]: # (Badges) [![GitHub Actions Build Status](https://github.com/wiederm/endstate_correction/workflows/CI/badge.svg)](https://github.com/wiederm/endstate_correction/actions?query=workflow%3ACI) diff --git a/data/ZINC00077329/switching_charmmff/ZINC00077329_neq_ws_from_mm_to_qml_200_5001.pickle b/data/ZINC00077329/switching_charmmff/ZINC00077329_neq_ws_from_mm_to_nnp_200_5001.pickle similarity index 100% rename from data/ZINC00077329/switching_charmmff/ZINC00077329_neq_ws_from_mm_to_qml_200_5001.pickle rename to data/ZINC00077329/switching_charmmff/ZINC00077329_neq_ws_from_mm_to_nnp_200_5001.pickle diff --git a/docs/conf.py b/docs/conf.py index deb1faf0..93b21194 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -186,7 +186,7 @@ "endstate_correction Documentation", author, "endstate_correction", - "Endstate reweighting from MM to QML potential", + "Endstate reweighting from MM to NNP", "Miscellaneous", ), ] diff --git a/docs/getting_started.rst b/docs/getting_started.rst index 750457e6..c9929e73 100644 --- a/docs/getting_started.rst +++ b/docs/getting_started.rst @@ -12,9 +12,9 @@ This package can be installed using: How to use this package ----------------- We have prepared two scripts that should help to use this package, both are located in :code:`endstate_correction/scripts`. -We will start by describing the use of the :code:`sampling.py` script and then discuss the :code:`perform_correction.py` script. +We will start by describing the use of the :code:`generate_endstate_samples.py` script and then discuss the :code:`perform_correction.py` script. -A typical NEQ workflow +A typical Non-equilibrium (NEQ) switching workflow ----------------- Generate the equilibrium distribution :math:`\pi(x, \lambda=0)` ~~~~~~~~~~~~~~~~~~~~~~ @@ -57,9 +57,10 @@ Note that we explicitly define the atoms that should be perturbed from the refer If you want to perform bidirectional FEP or NEQ you need to sample at :math:`\pi(x, \lambda=0)` *and* :math:`\pi(x, \lambda=1)`. This can be controlled by setting the number using the variable :code:`nr_lambda_states`. -The default value set in the :code:`sampling.py` script is :code:`nr_lambda_states=2`, generating samples from both endstate equilibrium distributions. +The default value set in the :code:`generate_endstate_samples.py` script is :code:`nr_lambda_states=2`, generating samples from both endstate equilibrium distributions. .. code:: python + nr_lambda_states = 2 # samples equilibrium distribution at both endstates lambs = np.linspace(0, 1, nr_lambda_states) @@ -81,13 +82,18 @@ with: reference_samples=mm_samples, nr_of_switches=100, neq_switching_length=1_000, + save_endstates=False, + save_trajs=False ) This protocol is then passed to the actual function performing the protocol: :code:`perform_endstate_correction(neq_protocol)`. The particular code above defines unidirectional NEQ switching using 100 switches and a switching length of 1 ps. The direciton of the switching simulation is controlled by the sampels that are provided: -if `reference_samples` are provided, switching is performed from the reference to the target level of theory, if `target_samples` are provided, switching is performed from the target level to the reference level. +if :code:`reference_samples` are provided, switching is performed from the reference to the target level of theory, if `target_samples` are provided, switching is performed from the target level to the reference level. If both samples are provided, bidirectional NEQ switching is performed (for an example see below). +There is the possibility to save the endstate of each switch if :code:`save_endstates` is set to :code:`True`. +Additionally, the switching trajectory of each switch can be saved if :code:`save_trajs=True`. +Both options only make sense for a NEQ protocol. Perform bidirectional NEQ from :math:`\pi(x, \lambda=0)` and :math:`\pi(x, \lambda=1)` ~~~~~~~~~~~~~~~~~~~~~~ @@ -99,7 +105,7 @@ The endstate correction can be performed using the script :code:`perform_correct method="NEQ", sim=sim, reference_samples=mm_samples, - target_samples=qml_samples, + target_samples=nnp_samples, nr_of_switches=100, neq_switching_length=1_000, ) @@ -112,6 +118,7 @@ Perform unidirectional FEP from :math:`\pi(x, \lambda=0)` The protocol has to be adopted slightly: .. code:: python + fep_protocol = Protocol( method="FEP", sim=sim, @@ -130,27 +137,52 @@ To analyse the results generated by :code:`r = perform_endstate_correction()` pa Available protocols ----------------- +unidirectional NEQ protocol (reference to target) + .. code:: python neq_protocol = Protocol( method="NEQ", sim=sim, reference_samples=mm_samples, - target_samples=qml_samples, nr_of_switches=100, neq_switching_length=1_000, ) +bidirectional NEQ protocol (save switching trajectory and endstate of each switch) + .. code:: python neq_protocol = Protocol( - method="FEP", + method="NEQ", sim=sim, reference_samples=mm_samples, - target_samples=qml_samples, + target_samples=nnp_samples, nr_of_switches=100, neq_switching_length=1_000, save_endstates=True, save_trajs=True, ) +unidirectional FEP protocol (reference to target) + +.. code:: python + + fep_protocol = Protocol( + method="FEP", + sim=sim, + reference_samples=mm_samples, + nr_of_switches=1_000, + ) + +bidirectional FEP protocol + +.. code:: python + + fep_protocol = Protocol( + method="FEP", + sim=sim, + reference_samples=mm_samples, + target_samples=nnp_samples, + nr_of_switches=1_000, + ) diff --git a/docs/theory.rst b/docs/theory.rst index 405439be..af9bd61a 100644 --- a/docs/theory.rst +++ b/docs/theory.rst @@ -2,12 +2,12 @@ Theory =============== -Equilibrium free energy endstate corrections +Multistage equilibrium free energy calculations ----------------- -In equilibrium free energy calculations samples are drawn from the Boltzmann distrubtion +In multistage equilibrium free energy calculations samples are drawn from the Boltzmann distrubtion at specific interpolation states between thermodynamic states (in our specific case: different energetic -descriptions of the molecular system, i.e. the source level of theory and the target level of theroy) and, +descriptions of the molecular system, i.e. the reference level of theory and the target level of theroy) and, given sufficient overlap of its pdfs, a free energy can be estimated. This protocol is expensive (it needs iid samples at each lambda state connecting the Boltzmann distribution at the endstates) but also reliable and accureate (with low variance). @@ -16,7 +16,7 @@ but also reliable and accureate (with low variance). -Non-equilibrium work protocol +Non-equilibrium (NEQ) work protocol ----------------- Non-equilibrium work protocols, and the fluctuation theorems connecting non-equilibrium driven @@ -26,11 +26,11 @@ A specific NEQ protocol typically consists of a series of perturbation kernel : propagation kernel :math:`\kappa_t(x,y)`, which are used in an alternating pattern to drive the system out of equilibrium. Each perturbation kernel :math:`\alpha` drives an alchemical coupling parameter :math:`\lambda`, and each -propagation kernel :math:`\kappa`$` propagates the coordinates of the system at fixed :math:`\lambda`$` according +propagation kernel :math:`\kappa` propagates the coordinates of the system at fixed :math:`\lambda` according to a defined MD process. The free energy difference can then be recovered using either the Jarzynski equation (if initial conformations -to seed the NEQ protocol are only drawn from :math:`\pi(x, \lambda=0)` and the NEQ protocol perturbations only -from :math:`\lambda=0` to :math:`\lambda=1`) or the Crooks' fluctuation theorem (if samples to seed the NEQ protocol +to seed the NEQ protocol are only drawn from :math:`\pi(x, \lambda=0)` (OR :math:`\pi(x, \lambda=1)`)) and the NEQ protocol perturbations only +from :math:`\lambda=0` to :math:`\lambda=1` (OR :math:`\lambda=1` to :math:`\lambda=0`) or the Crooks' fluctuation theorem (if samples to seed the NEQ protocol are drawn from :math:`\pi(x, \lambda=0)` and :math:`\pi(x, \lambda=1)` and the perturbation kernels are set for a bidirectinoal protocol). @@ -43,4 +43,4 @@ consists only of a single perturbation kernel :math:`\alpha_t(x,y)`, without a p to another 'endstate'. In the limiting cases of infinitely fast switching the Jarzynski equality reduces to the well-known FEP equation: :math:`e^{-\beta \Delta F} = \langle e^{−β[E(x,\lambda=1)− E(x,\lambda=0)]} \rangle_{\lambda=0}`. -:math:`\langle \rangle_{\lambda=0}` indicate that samples are drawn from the equilibrium distribution :math:`\pi(x, \lambda=0)`. +:math:`\langle \rangle_{\lambda=0}` indicates that samples are drawn from the equilibrium distribution :math:`\pi(x, \lambda=0)`. diff --git a/endstate_correction/__init__.py b/endstate_correction/__init__.py index e8ef8e08..54475327 100644 --- a/endstate_correction/__init__.py +++ b/endstate_correction/__init__.py @@ -1,4 +1,4 @@ -"""Endstate reweighting from MM to QML potential""" +"""Endstate reweighting from MM to NNP""" # Add imports here from .endstate_correction import * diff --git a/endstate_correction/analysis.py b/endstate_correction/analysis.py index bcc76715..51df9908 100644 --- a/endstate_correction/analysis.py +++ b/endstate_correction/analysis.py @@ -20,10 +20,10 @@ def plot_overlap_for_equilibrium_free_energy( N_k: np.array, u_kn: np.ndarray, name: str ): """ - Calculate the overlap for each state with each other state. THe overlap is normalized to be 1 for each row. + Calculate the overlap for each state with each other state. The overlap is normalized to be 1 for each row. Args: - N_k (np.array): numnber of samples for each state k + N_k (np.array): number of samples for each state k u_kn (np.ndarray): each of the potential energy functions `u` describing a state `k` are applied to each sample `n` from each of the states `k` name (str): name of the system in the plot """ @@ -43,7 +43,7 @@ def plot_overlap_for_equilibrium_free_energy( annot_kws={"size": "small"}, ) plt.title(f"Free energy estimate for {name}", fontsize=15) - plt.savefig(f"{name}_equilibrium_free_energy.png") + plt.savefig(f"{name}_overlap_equilibrium_free_energy.png") plt.show() plt.close() @@ -54,9 +54,8 @@ def plot_results_for_equilibrium_free_energy( """ Calculate the accumulated free energy along the mutation progress. - Args: - N_k (np.array): numnber of samples for each state k + N_k (np.array): number of samples for each state k u_kn (np.ndarray): each of the potential energy functions `u` describing a state `k` are applied to each sample `n` from each of the states `k` name (str): name of the system in the plot """ @@ -82,12 +81,12 @@ def plot_results_for_equilibrium_free_energy( plt.title(f"Free energy estimate for {name}", fontsize=15) plt.ylabel("Free energy estimate in kT", fontsize=15) plt.xlabel("lambda state (0 to 1)", fontsize=15) - plt.savefig(f"{name}_equilibrium_free_energy.png") + plt.savefig(f"{name}_results_equilibrium_free_energy.png") plt.show() plt.close() -def return_endstate_correction(results: AllResults, method:str, direction: str) -> Tuple[float, float]: - """eturn the endstate correction for a given method and direction. +def return_endstate_correction(results: AllResults, method:str = "NEQ", direction: str = "forw") -> Tuple[float, float]: + """Return the endstate correction for a given method and direction. Args: results (AllResults): instance of the AllRestults class @@ -97,6 +96,7 @@ def return_endstate_correction(results: AllResults, method:str, direction: str) Raises: ValueError: if method or direction is not supported + Returns: Tuple[float, float]: endstate correction delta_f, endstate correction error """ @@ -196,10 +196,16 @@ def summarize_endstate_correction_results(results: AllResults): print(f"Equilibrium free energy: {ddG_equ}+/-{dddG_equ}") #if results.smc_results: - def plot_endstate_correction_results( name: str, results: AllResults, filename: str = "plot.png" ): + """Plot endstate correction results. + + Args: + name (str): name of the system in the plot + results (Results): instance of the AllResults class + filename (str, optional): Defaults to "plot.png". + """ assert type(results) == AllResults ########################################################### @@ -285,13 +291,14 @@ def plot_endstate_correction_results( label=r"$\Delta$E(QML$\rightarrow$MM)", color=c4, ) + axs[ax_index].legend() ########################################################### # ------------------- Plot results ------------------------ if multiple_results > 1: ax_index += 1 - axs[ax_index].set_title(rf"{name} - offset $\Delta$G(MM$\rightarrow$QML)") + axs[ax_index].set_title(rf"{name} - offset $\Delta$G(MM$\rightarrow$NNP)") ddG_list, dddG_list, names = [], [], [] if results.equ_results: @@ -422,6 +429,7 @@ def plot_endstate_correction_results( label=r"stddev $\Delta$E(QML$\rightarrow$MM)", color=c4, ) + # plot 1 kT limit axs[ax_index].axhline(y=1.0, color=c7, linestyle=":") axs[ax_index].axhline(y=2.0, color=c5, linestyle=":") @@ -432,320 +440,4 @@ def plot_endstate_correction_results( plt.tight_layout() plt.savefig(filename) - plt.show() - - -# plotting torsion profiles -########################################################################################################################################################################### - - -# generate molecule picture with atom indices -def save_mol_pic(zinc_id: str, ff: str): - from rdkit import Chem - from rdkit.Chem import AllChem, Draw - from rdkit.Chem.Draw import IPythonConsole - - IPythonConsole.drawOptions.addAtomIndices = True - from rdkit.Chem.Draw import rdMolDraw2D - - # get name - name, _ = zinc_systems[zinc_id] - # generate openff Molecule - mol = generate_molecule(name=name, forcefield=ff) - # convert openff object to rdkit mol object - mol_rd = mol.to_rdkit() - - # remove explicit H atoms - if zinc_id == 4: - # NOTE: FIXME: this is a temporary workaround to fix the wrong indexing in rdkit - # when using the RemoveHs() function - mol_draw = Chem.RWMol(mol_rd) - # remove all explicit H atoms, except the ones on the ring and on N atoms (for correct indexing) - for run in range(1, 7): - n_atoms = mol_draw.GetNumAtoms() - mol_draw.RemoveAtom(n_atoms - 7) - else: - # remove explicit H atoms - mol_draw = Chem.RemoveHs(mol_rd) - - # get 2D representation - AllChem.Compute2DCoords(mol_draw) - # formatting - d = rdMolDraw2D.MolDraw2DCairo(1500, 1000) - d.drawOptions().fixedFontSize = 90 - d.drawOptions().fixedBondLength = 110 - d.drawOptions().annotationFontScale = 0.7 - d.drawOptions().addAtomIndices = True - - d.DrawMolecule(mol_draw) - d.FinishDrawing() - if not os.path.isdir(f"mol_pics_{ff}"): - os.makedirs(f"mol_pics_{ff}") - d.WriteDrawingText(f"mol_pics_{ff}/{name}_{ff}.png") - - -# get indices of dihedral bonds -def get_indices(rot_bond: int, rot_bond_list: list, bonds: list): - print(f"---------- Investigating bond nr {rot_bond} ----------") - - # get indices of both atoms forming an rotatable bond - atom_1_idx = (rot_bond_list[rot_bond]).atom1_index - atom_2_idx = (rot_bond_list[rot_bond]).atom2_index - - # create lists to collect neighbors of atom_1 and atom_2 - neighbors1 = [] - neighbors2 = [] - - # find neighbors of atoms forming the rotatable bond and add to index list (if heavy atom torsion) - for bond in bonds: - # get neighbors of atom_1 (of rotatable bond) - # check, if atom_1 (of rotatable bond) is the first atom in the current bond - if bond.atom1_index == atom_1_idx: - # make sure, that neighboring atom is not an hydrogen, nor atom_2 - if ( - not bond.atom2.element.name == "hydrogen" - and not bond.atom2_index == atom_2_idx - ): - neighbors1.append(bond.atom2_index) - - # check, if atom_1 (of rotatable bond) is the second atom in the current bond - elif bond.atom2_index == atom_1_idx: - # make sure, that neighboring atom is not an hydrogen, nor atom_2 - if ( - not bond.atom1.element.name == "hydrogen" - and not bond.atom1_index == atom_2_idx - ): - neighbors1.append(bond.atom1_index) - - # get neighbors of atom_2 (of rotatable bond) - # check, if atom_2 (of rotatable bond) is the first atom in the current bond - if bond.atom1_index == atom_2_idx: - # make sure, that neighboring atom is not an hydrogen, nor atom_1 - if ( - not bond.atom2.element.name == "hydrogen" - and not bond.atom2_index == atom_1_idx - ): - neighbors2.append(bond.atom2_index) - - # check, if atom_2 (of rotatable bond) is the second atom in the current bond - elif bond.atom2_index == atom_2_idx: - # make sure, that neighboring atom is not an hydrogen, nor atom_1 - if ( - not bond.atom1.element.name == "hydrogen" - and not bond.atom1_index == atom_1_idx - ): - neighbors2.append(bond.atom1_index) - - # check, if both atoms forming the rotatable bond have neighbors - if len(neighbors1) > 0 and len(neighbors2) > 0: - # list for final atom indices defining torsion - indices = [[neighbors1[0], atom_1_idx, atom_2_idx, neighbors2[0]]] - return indices - - else: - print(f"No heavy atom torsions found for bond {rot_bond}") - indices = [] - return indices - - -# plot torsion profiles -def visualize_torsion_profile(mol, trajectories: dataclass): - # get all bonds - bonds = mol.bonds - # get all rotatable bonds - rot_bond_list = mol.find_rotatable_bonds() - print(len(rot_bond_list), "rotatable bonds found.") - - ################################################## GET HEAVY ATOM TORSIONS ########################################################################################## - - # list for collecting bond nr, which form a dihedral angle - torsions = [] - # list for collecting all atom indices, which form a dihedral angle - all_indices = [] - # lists for traj data - torsions_mm = [] - torsions_qml = [] - # lists for traj data after switching - torsions_mm_switching = [] - torsions_qml_switching = [] - # boolean which enables plotting, if data can be retrieved - plotting = False - - for rot_bond in range(len(rot_bond_list)): - # get atom indices of current rotatable bond forming a torsion - indices = get_indices( - rot_bond=rot_bond, rot_bond_list=rot_bond_list, bonds=bonds - ) - print(indices) - - # compute dihedrals only if heavy atom torsion was found for rotatable bond - if len(indices) > 0: - print(f"Dihedrals are computed for bond nr {rot_bond}") - # add bond nr to list - torsions.append(rot_bond) - # add corresponding atom indices to list - all_indices.extend(indices) - - # check if traj data can be retrieved - traj_mm = trajectories.equilibrium_mm_trajectory - traj_qml = trajectories.equilibrium_qml_trajectory - - # if also 'post-switching' data has to be plotted, check if it can be retrieved - - # if both, mm and qml samples are found, compute dihedrals - if traj_mm and traj_qml: - torsions_mm.append( - md.compute_dihedrals(traj_mm, indices, periodic=True, opt=True) - ) # * 180.0 / np.pi - torsions_qml.append( - md.compute_dihedrals(traj_qml, indices, periodic=True, opt=True) - ) # * 180.0 / np.pi - plotting = True - - # additionally, compute dihedrals from 'post-switching' data - if switching and traj_mm_switching and traj_qml_switching: - torsions_mm_switching.append( - md.compute_dihedrals( - traj_mm_switching, indices, periodic=True, opt=True - ) - ) # * 180.0 / np.pi - torsions_qml_switching.append( - md.compute_dihedrals( - traj_qml_switching, indices, periodic=True, opt=True - ) - ) # * 180.0 / np.pi - elif switching and not traj_mm_switching and not traj_qml_switching: - plotting = False - - else: - print(f"Trajectory data cannot be found for {name}") - else: - print(f"No dihedrals will be computed for bond nr {rot_bond}") - - ################################################## PLOT TORSION PROFILES ########################################################################################## - - if plotting: - import matplotlib.gridspec as gridspec - - plt.style.use("fivethirtyeight") - sns.set_theme() - sns.set_palette("bright") - - # generate molecule picture - save_mol_pic(zinc_id=zinc_id, ff=ff) - - # create corresponding nr of subplots - fig = plt.figure(tight_layout=True, figsize=(8, len(torsions) * 2 + 6), dpi=400) - gs = gridspec.GridSpec( - len(torsions) + 1, - 2, - ) - - fig.suptitle(f"Torsion profile of {name} ({ff})", fontsize=15, weight="bold") - - # flip the image, so that it is displayed correctly - image = mpimg.imread(f"mol_pics_{ff}/{name}_{ff}.png") - - # plot the molecule image on the first axis - ax = fig.add_subplot(gs[0, :]) - - ax.imshow(image) - ax.axis("off") - - # iterate over all torsions and plot results - for counter in range(1, len(torsions) + 1): - # counter for atom indices - idx_counter = counter - 1 - # plot only sampling data - if not switching: - data_histplot = { - "mm samples": torsions_mm[idx_counter].squeeze(), - "qml samples": torsions_qml[idx_counter].squeeze(), - } - - # compare to data after switching - else: - data_histplot = { - "mm samples": torsions_mm[idx_counter].squeeze(), - "qml samples": torsions_qml[idx_counter].squeeze(), - rf"qml$\rightarrow$mm endstate ({switching_length}ps switch)": torsions_mm_switching[ - idx_counter - ].squeeze(), - rf"mm$\rightarrow$qml endstate ({switching_length}ps switch)": torsions_qml_switching[ - idx_counter - ].squeeze(), - } - - # if needed, compute wasserstein distance - """ # compute wasserstein distance - w_distance = wasserstein_distance(u_values = list(chain.from_iterable(torsions_mm[idx_counter])), v_values = list(chain.from_iterable(torsions_qml[idx_counter]))) - w_distance_qml_switch_mm = wasserstein_distance(u_values = list(chain.from_iterable(torsions_qml[idx_counter])), v_values = list(chain.from_iterable(torsions_mm_switching[idx_counter]))) - w_distance_mm_switch_qml = wasserstein_distance(u_values = list(chain.from_iterable(torsions_mm[idx_counter])), v_values = list(chain.from_iterable(torsions_qml_switching[idx_counter]))) """ - - ax_violin = fig.add_subplot(gs[counter, 0]) - sns.violinplot( - ax=ax_violin, - data=[ - torsions_mm[idx_counter].squeeze(), - torsions_qml[idx_counter].squeeze(), - torsions_mm_switching[idx_counter].squeeze(), - torsions_qml_switching[idx_counter].squeeze(), - ], - orient="h", - inner="point", - split=True, - scale="width", - saturation=0.5, - ) - ax_kde = fig.add_subplot(gs[counter, 1]) - sns.kdeplot( - ax=ax_kde, - data=data_histplot, - common_norm=False, - shade=True, - linewidth=2, - # kde=True, - # alpha=0.5, - # stat="density", - # common_norm=False, - ) - - # adjust axis labelling - unit = np.arange(-np.pi, np.pi + np.pi / 4, step=(1 / 4 * np.pi)) - for ax in [ax_violin, ax_kde]: - # add atom indices as subplot title - ax.set_title(f"Torsion {all_indices[idx_counter]}", fontsize=13) - ax.set(xlim=(-np.pi, np.pi)) - ax.set_xticks( - unit, - ["-π", "-3π/4", "-π/2", "-π/4", "0", "π/4", "π/2", "3π/4", "π"], - ) - ax.yaxis.set_major_formatter(FormatStrFormatter("%.3f")) - ax.set_yticks([]) # remove tick values on y axis - - # if wasserstein distance is computed, it can be added as an annotation box next to the plot - """ text_div = f'Wasserstein distance\n\nmm (sampling) & qml (sampling): {w_distance:.3f}\nmm (sampling) & qml ({switching_length}ps switch): {w_distance_mm_switch_qml:.3f}\nqml (sampling) & mm ({switching_length}ps switch): {w_distance_qml_switch_mm:.3f}' - offsetbox = TextArea(text_div, - textprops=dict(ha='left', size = 13)) - xy = (0,0) - if switching_length == 5: - x_box = 1.56 - elif switching_length == 10 or switching_length == 20: - x_box = 1.575 - ab = AnnotationBbox(offsetbox, xy, - xybox=(x_box, 10), - xycoords='axes points', - boxcoords=("axes fraction", "axes points"), - box_alignment=(1, 0.08)) - #arrowprops=dict(arrowstyle="->")) - axs[counter][0].add_artist(ab) """ - - # axs[-1][0].set_xlabel("Dihedral angle") - plt.tight_layout() - if not os.path.isdir(f"torsion_profiles_{ff}"): - os.makedirs(f"torsion_profiles_{ff}") - plt.savefig(f"torsion_profiles_{ff}/{name}_{ff}_{switching_length}ps.png") - plt.show() - - else: - print(f"No torsion profile can be generated for {name}") + plt.show() \ No newline at end of file diff --git a/endstate_correction/equ.py b/endstate_correction/equ.py index 2e585f0b..a386b68e 100644 --- a/endstate_correction/equ.py +++ b/endstate_correction/equ.py @@ -13,14 +13,17 @@ def _collect_equ_samples( trajs: list, every_nth_frame: int = 10 ) -> Tuple[list, np.array]: - """ - Given a list of k trajectories with n samples a dictionary with the number of samples per trajektory - and a list with all samples [n_1, n_2, ...] is generated + """Generate a dictionary with the number of samples per trajektory and + a list with all samples [n_1, n_2, ...] given a list of k trajectories with n samples. + + Args: + trajs (list): list of trajectories + every_nth_frame (int, optional): prune samples by taking only every nth sample. Defaults to 10. Returns: - Tuple(coordinates, N_k) + Tuple[list, np.array]: coordinates, N_k """ - + coordinates = [] N_k = np.zeros(len(trajs)) @@ -48,6 +51,7 @@ def calculate_u_kn( trajs (list): list of trajectories sim (Simulation): simulation object every_nth_frame (int, optional): prune the samples further by taking only every nth sample. Defaults to 2. + Returns: np.ndarray: u_kn matrix """ diff --git a/endstate_correction/neq.py b/endstate_correction/neq.py index 52c4fa7c..1ece8f1b 100644 --- a/endstate_correction/neq.py +++ b/endstate_correction/neq.py @@ -124,6 +124,14 @@ def perform_switching( def _collect_work_values(file: str) -> list: + """Return a list of work values + + Args: + file (str): pickle file containing work values + + Returns: + list: list of work values in kJ/mol + """ ws = pickle.load(open(file, "rb")).value_in_unit(unit.kilojoule_per_mole) number_of_samples = len(ws) print(f"Number of samples used: {number_of_samples}") diff --git a/endstate_correction/system.py b/endstate_correction/system.py index 3a146e4c..3da09f47 100644 --- a/endstate_correction/system.py +++ b/endstate_correction/system.py @@ -9,9 +9,15 @@ def gen_box(psf: CharmmPsfFile, crd: CharmmCrdFile) -> CharmmPsfFile: - """ - Function to create psf file containing information about the box used (only for waterbox or commplex simulations). Usful - when information about box size is not available (e.g. when using TF) + """Function to create psf file containing information about the box used (only for waterbox or commplex simulations). + Usful when information about box size is not available (e.g. when using TF) + + Args: + psf (CharmmPsfFile): topology instance + crd (CharmmCrdFile): coordinates + + Returns: + CharmmPsfFile: topology instance containing information about the box """ coords = crd.positions diff --git a/endstate_correction/tests/test_analysis.py b/endstate_correction/tests/test_analysis.py index de7a6929..922326ad 100644 --- a/endstate_correction/tests/test_analysis.py +++ b/endstate_correction/tests/test_analysis.py @@ -24,7 +24,7 @@ def test_plotting_equilibrium_free_energy(): from .test_equ import load_equ_samples - """test if we are able to plot overlap and """ + """Test if we are able to plot overlap and equilibrium free energy results""" ######################################################## ######################################################## @@ -66,12 +66,12 @@ def test_plot_results_for_FEP_protocol(): system_name = "ZINC00079729" # start with FEP - sim, mm_samples, qml_samples = setup_ZINC00077329_system() + sim, mm_samples, nnp_samples = setup_ZINC00077329_system() fep_protocol = FEPProtocol( sim=sim, reference_samples=mm_samples, - target_samples=qml_samples, + target_samples=nnp_samples, nr_of_switches=100, ) @@ -85,7 +85,7 @@ def test_plot_results_for_FEP_protocol(): def test_plot_results_for_NEQ_protocol(): - """Perform FEP uni- and bidirectional protocol""" + """Perform NEQ uni- and bidirectional protocol""" import pickle from endstate_correction.analysis import ( @@ -117,7 +117,7 @@ def test_plot_results_for_NEQ_protocol(): def test_plot_results_for_all_protocol(): - """Perform FEP uni- and bidirectional protocol""" + """Perform FEP and NEQ uni- and bidirectional protocol""" import pickle from endstate_correction.analysis import plot_endstate_correction_results @@ -127,19 +127,12 @@ def test_plot_results_for_all_protocol(): system_name = "ZINC00079729" # start with NEQ - sim, mm_samples, qml_samples = setup_ZINC00077329_system() + sim, mm_samples, nnp_samples = setup_ZINC00077329_system() #################################################### # ---------------- All corrections ----------------- #################################################### - # fep_protocol = FEPProtocol( - # sim=sim, - # reference_samples=mm_samples, - # target_samples=qml_samples, - # nr_of_switches=100, - # ) - # load data r = pickle.load( open( diff --git a/endstate_correction/tests/test_endstate_correction.py b/endstate_correction/tests/test_endstate_correction.py index 8d8cb5fc..23113c21 100644 --- a/endstate_correction/tests/test_endstate_correction.py +++ b/endstate_correction/tests/test_endstate_correction.py @@ -22,12 +22,12 @@ def test_endstate_correction_imported(): assert "endstate_correction" in sys.modules -def save_pickle_results(sim, mm_samples, qml_samples, system_name): +def save_pickle_results(sim, mm_samples, nnp_samples, system_name): # generate data for plotting tests protocol = NEQProtocol( sim=sim, target_samples=mm_samples, - reference_samples=qml_samples, + reference_samples=nnp_samples, nr_of_switches=100, switching_length=100, ) @@ -60,7 +60,7 @@ def test_FEP_protocol(): """Perform FEP uni- and bidirectional protocol""" # load samples - sim, mm_samples, qml_samples = setup_ZINC00077329_system() + sim, mm_samples, nnp_samples = setup_ZINC00077329_system() #################################################### # ----------------------- FEP ---------------------- @@ -81,7 +81,7 @@ def test_FEP_protocol(): fep_protocol = FEPProtocol( sim=sim, reference_samples=mm_samples, - target_samples=qml_samples, + target_samples=nnp_samples, nr_of_switches=50, ) @@ -102,12 +102,12 @@ def test_NEQ_protocol(): from endstate_correction.protocol import perform_endstate_correction, NEQProtocol # load samples - sim, mm_samples, qml_samples = setup_ZINC00077329_system() + sim, mm_samples, nnp_samples = setup_ZINC00077329_system() protocol = NEQProtocol( sim=sim, reference_samples=mm_samples, - target_samples=qml_samples, + target_samples=nnp_samples, nr_of_switches=10, switching_length=50, ) @@ -166,11 +166,12 @@ def test_SMC_protocol(): reason="Skipping tests that take too long in github actions", ) def test_ALL_protocol(): + """Perform uni- and bidirectional FEP and NEQ & SMC protocol""" from endstate_correction.protocol import perform_endstate_correction, AllProtocol, FEPProtocol, NEQProtocol, SMCProtocol, AllResults # load samples - sim, mm_samples, qml_samples = setup_ZINC00077329_system() + sim, mm_samples, nnp_samples = setup_ZINC00077329_system() #################################################### # ---------------- All corrections ----------------- @@ -184,7 +185,7 @@ def test_ALL_protocol(): neq_protocol = NEQProtocol( sim=sim, reference_samples=mm_samples, - target_samples=qml_samples, + target_samples=nnp_samples, nr_of_switches=10, switching_length=50, ) @@ -228,11 +229,12 @@ def test_ALL_protocol(): def test_each_protocol(): - """Perform FEP uni- and bidirectional protocol""" + + """Test FEP and NEQ uni- and bidirectional protocols separately""" from endstate_correction.protocol import perform_endstate_correction, FEPProtocol, NEQProtocol # load samples - sim, mm_samples, qml_samples = setup_ZINC00077329_system() + sim, mm_samples, nnp_samples = setup_ZINC00077329_system() #################################################### # ----------------------- FEP ---------------------- @@ -254,7 +256,7 @@ def test_each_protocol(): fep_protocol = FEPProtocol( sim=sim, target_samples=mm_samples, - reference_samples=qml_samples, + reference_samples=nnp_samples, nr_of_switches=10, ) @@ -269,12 +271,13 @@ def test_each_protocol(): # ----------------------- NEQ ---------------------- #################################################### + neq_protocol = NEQProtocol( - sim=sim, - reference_samples=mm_samples, - nr_of_switches=10, - switching_length=50, - ) + sim=sim, + reference_samples=mm_samples, + nr_of_switches=10, + switching_length=50, + ) r = perform_endstate_correction(neq_protocol) assert r.equ_results == None @@ -310,7 +313,7 @@ def test_each_protocol(): protocol = NEQProtocol( sim=sim, reference_samples=mm_samples, - target_samples=qml_samples, + target_samples=nnp_samples, nr_of_switches=10, switching_length=50, save_endstates=True, diff --git a/endstate_correction/tests/test_neq.py b/endstate_correction/tests/test_neq.py index 653de2ab..f2484aa8 100644 --- a/endstate_correction/tests/test_neq.py +++ b/endstate_correction/tests/test_neq.py @@ -15,18 +15,19 @@ def test_collect_work_values(): from endstate_correction.neq import _collect_work_values nr_of_switches = 200 - path = f"data/ZINC00077329/switching_charmmff/ZINC00077329_neq_ws_from_mm_to_qml_{nr_of_switches}_5001.pickle" + path = f"data/ZINC00077329/switching_charmmff/ZINC00077329_neq_ws_from_mm_to_nnp_{nr_of_switches}_5001.pickle" ws = _collect_work_values(path) assert len(ws) == nr_of_switches def test_switching(): + """test return values of perform_switching function""" # load simulation and samples for ZINC00077329 - sim, samples_mm, samples_qml = setup_ZINC00077329_system() + sim, samples_mm, samples_nnp = setup_ZINC00077329_system() # perform instantaneous switching with predetermined coordinate set - # here, we evaluate dU_forw = dU(x)_qml - dU(x)_mm and make sure that it is the same as - # dU_rev = dU(x)_mm - dU(x)_qml + # here, we evaluate dU_forw = dU(x)_nnp - dU(x)_mm and make sure that it is the same as + # dU_rev = dU(x)_mm - dU(x)_nnp lambs = np.linspace(0, 1, 2) print(lambs) dE_list, _, _ = perform_switching( diff --git a/endstate_correction/tests/test_sampling.py b/endstate_correction/tests/test_sampling.py index ed95821e..7f69a0ed 100644 --- a/endstate_correction/tests/test_sampling.py +++ b/endstate_correction/tests/test_sampling.py @@ -33,7 +33,7 @@ def test_sampling(): f"{hipen_testsystem}/par_all36_cgenff.prm", f"{hipen_testsystem}/{system_name}/{system_name}.str", ) - # define region that should be treated with the qml + # define region that should be treated with the nnp sim = setup_vacuum_simulation(psf, params) sim.context.setPositions(crd.positions) sim.context.setVelocitiesToTemperature(300) @@ -58,7 +58,7 @@ def test_sampling(): f"{jctc_testsystem}/toppar/toppar_water_ions.str", ) psf = read_box(psf, f"{jctc_testsystem}/{system_name}/charmm-gui/input.config.dat") - # define region that should be treated with the qml + # define region that should be treated with the nnp chains = list(psf.topology.chains()) ml_atoms = [atom.index for atom in chains[0].atoms()] # set up system diff --git a/endstate_correction/tests/test_smc.py b/endstate_correction/tests/test_smc.py index 186ec826..b4af144b 100644 --- a/endstate_correction/tests/test_smc.py +++ b/endstate_correction/tests/test_smc.py @@ -67,7 +67,7 @@ def evolve_configuration(x): def test_SMC(_am_I_on_GH): - sim, samples_mm, samples_mm_qml = setup_ZINC00077329_system() + sim, samples_mm, samples_mm_nnp = setup_ZINC00077329_system() smc_sampler = SMC(sim=sim, samples=samples_mm) # perform SMC switching print("Performing SMC switching") diff --git a/endstate_correction/tests/test_system.py b/endstate_correction/tests/test_system.py index 223e2b35..c249357a 100644 --- a/endstate_correction/tests/test_system.py +++ b/endstate_correction/tests/test_system.py @@ -25,6 +25,14 @@ def load_endstate_system_and_samples( system_name: str, ) -> Tuple[Simulation, list, list]: + """Test if samples can be loaded and system can be created + + Args: + system_name (str): name of the system + + Returns: + Tuple[Simulation, list, list]: instance of Simulation class, MM samples, NNP samples + """ # initialize simulation and load pre-generated samples from openmm.app import CharmmCrdFile, CharmmParameterSet, CharmmPsfFile @@ -43,7 +51,7 @@ def load_endstate_system_and_samples( f"{hipen_testsystem}/par_all36_cgenff.prm", f"{hipen_testsystem}/{system_name}/{system_name}.str", ) - # define region that should be treated with the qml + # define region that should be treated with the nnp sim = setup_vacuum_simulation(psf=psf, params=params) sim.context.setPositions(crd.positions) n_samples = 5_000 @@ -55,13 +63,13 @@ def load_endstate_system_and_samples( top=pdb_file, ) - qml_samples = [] - qml_samples = md.load_dcd( + nnp_samples = [] + nnp_samples = md.load_dcd( f"data/{system_name}/sampling_charmmff/run01/{system_name}_samples_{n_samples}_steps_{n_steps_per_sample}_lamb_1.0000.dcd", top=pdb_file, ) - return sim, mm_samples, qml_samples + return sim, mm_samples, nnp_samples def setup_ZINC00077329_system(): @@ -69,15 +77,24 @@ def setup_ZINC00077329_system(): print(f"{system_name=}") # load simulation and samples for ZINC00077329 - sim, samples_mm, samples_mm_qml = load_endstate_system_and_samples( + sim, samples_mm, samples_mm_nnp = load_endstate_system_and_samples( system_name=system_name, ) - return sim, samples_mm, samples_mm_qml + return sim, samples_mm, samples_mm_nnp def setup_vacuum_simulation( psf: CharmmPsfFile, params: CharmmParameterSet ) -> Simulation: + """Test setup simulation in vacuum. + + Args: + psf (CharmmPsfFile): topology instance + params (CharmmParameterSet): parameter + + Returns: + Simulation: instance of Simulation class + """ chains = list(psf.topology.chains()) ml_atoms = [atom.index for atom in chains[0].atoms()] print(f"{ml_atoms=}") @@ -96,6 +113,17 @@ def setup_waterbox_simulation( r_off: float = 1.2, r_on: float = 0.0, ) -> Simulation: + """Test setup simulation in waterbox. + + Args: + psf (CharmmPsfFile): topology instance + params (CharmmParameterSet): parameter + r_off (float, optional): _description_. Defaults to 1.2. + r_on (float, optional): _description_. Defaults to 0.0. + + Returns: + Simulation: instance of Simulation class + """ chains = list(psf.topology.chains()) ml_atoms = [atom.index for atom in chains[0].atoms()] print(f"{ml_atoms=}") @@ -116,7 +144,7 @@ def setup_waterbox_simulation( def test_initializing_ZINC00077329_system(): - sim, samples_mm, samples_mm_qml = setup_ZINC00077329_system() + sim, samples_mm, samples_mm_nnp = setup_ZINC00077329_system() assert len(samples_mm) == 5000 @@ -137,7 +165,7 @@ def test_generate_simulation_instances_with_charmmff(): f"{hipen_testsystem}/par_all36_cgenff.prm", f"{hipen_testsystem}/{system_name}/{system_name}.str", ) - # define region that should be treated with the qml + # define region that should be treated with the nnp sim = setup_vacuum_simulation(psf, params) # set up system sim.context.setPositions(crd.positions) @@ -155,11 +183,11 @@ def test_generate_simulation_instances_with_charmmff(): ############################ ############################ - # at lambda=1.0 (qml endpoint) + # at lambda=1.0 (nnp endpoint) sim.context.setParameter("lambda_interpolate", 1.0) - e_sim_qml_endstate = get_energy(sim).value_in_unit(unit.kilojoule_per_mole) - print(e_sim_qml_endstate) - assert np.isclose(e_sim_qml_endstate, -5252411.066221259) + e_sim_nnp_endstate = get_energy(sim).value_in_unit(unit.kilojoule_per_mole) + print(e_sim_nnp_endstate) + assert np.isclose(e_sim_nnp_endstate, -5252411.066221259) ######################################################## ######################################################## @@ -175,7 +203,7 @@ def test_generate_simulation_instances_with_charmmff(): f"{jctc_testsystem}/toppar/par_all36_cgenff.prm", f"{jctc_testsystem}/toppar/toppar_water_ions.str", ) - # define region that should be treated with the qml + # define region that should be treated with the nnp sim = setup_vacuum_simulation(psf, params) sim.context.setPositions(pdb.positions) @@ -192,11 +220,11 @@ def test_generate_simulation_instances_with_charmmff(): ############################ ############################ - # at lambda=1.0 (qml endpoint) + # at lambda=1.0 (nnp endpoint) sim.context.setParameter("lambda_interpolate", 1.0) - e_sim_qml_endstate = get_energy(sim).value_in_unit(unit.kilojoule_per_mole) - print(e_sim_qml_endstate) - assert np.isclose(e_sim_qml_endstate, -1025774.735780582) + e_sim_nnp_endstate = get_energy(sim).value_in_unit(unit.kilojoule_per_mole) + print(e_sim_nnp_endstate) + assert np.isclose(e_sim_nnp_endstate, -1025774.735780582) ######################################################## ######################################################## @@ -231,11 +259,11 @@ def test_generate_simulation_instances_with_charmmff(): ############################ ############################ - # at lambda=1.0 (qml endpoint) + # at lambda=1.0 (nnp endpoint) sim.context.setParameter("lambda_interpolate", 1.0) - e_sim_qml_endstate = get_energy(sim).value_in_unit(unit.kilojoule_per_mole) - print(e_sim_qml_endstate) - assert np.isclose(e_sim_qml_endstate, -1062775.8348574494) + e_sim_nnp_endstate = get_energy(sim).value_in_unit(unit.kilojoule_per_mole) + print(e_sim_nnp_endstate) + assert np.isclose(e_sim_nnp_endstate, -1062775.8348574494) def test_simulating(): diff --git a/endstate_correction/utils.py b/endstate_correction/utils.py index d178f103..87382479 100644 --- a/endstate_correction/utils.py +++ b/endstate_correction/utils.py @@ -11,6 +11,15 @@ def convert_pickle_to_dcd_file( dcd_output_path: str, pdb_output_path: str, ): + """Convert pickle file trajectory to dcd file. + + Args: + pickle_file_path (str): path where pickle file is stored + path_to_psf (str): path where psf file is stored + path_to_crd (str): path where crd file is stored + dcd_output_path (str): path to save dcd file + pdb_output_path (str): path to save pdb file + """ # helper function that converts pickle trajectory file to dcd file f = pickle.load(open(pickle_file_path, "rb")) diff --git a/endstate_correction/vis.py b/endstate_correction/vis.py index 3826bd69..f5351c4c 100644 --- a/endstate_correction/vis.py +++ b/endstate_correction/vis.py @@ -1,45 +1,33 @@ -import pickle import nglview as ng -from endstate_correction.system import generate_molecule from endstate_correction.constant import zinc_systems import mdtraj as md +from openff.toolkit.topology import Molecule def visualize_mol( - zinc_id: str, - forcefield: str, - endstate: str, - run_id: str = "", - w_dir: str = "/data/shared/projects/endstate_correction", - switching: bool = False, - switching_length: int = 5001, + smiles: str, + traj_dir: str, ): + """Inspect conformations generated by sampling or NEQ switching. - # get name - name, _ = zinc_systems[zinc_id] - # generate molecule object - m = generate_molecule(name=name, forcefield=forcefield) - # write mol as pdb - m.to_file("m.pdb", file_format="pdb") - # get pickle file - if not switching: - # get correct file label - if endstate == "mm": - lamb_nr = "0.0000" - elif endstate == "qml": - lamb_nr = "1.0000" - pickle_file = f"{w_dir}/{name}/sampling_{forcefield}/run{run_id}/{name}_samples_5000_steps_1000_lamb_{lamb_nr}.pickle" - else: - pickle_file = f"{w_dir}/{name}/switching_{forcefield}/{name}_samples_5000_steps_1000_lamb_{endstate}_endstate_nr_samples_500_switching_length_{switching_length}.pickle" - # load traj - f = pickle.load(open(pickle_file, "rb")) - # load topology from pdb file + Args: + smiles (str): smiles string + traj_dir (str): path where dcd trajecotry file is stored + + Returns: + _type_: molecule visualization + """ + + molecule = Molecule.from_smiles(smiles, hydrogens_are_explicit=False) + molecule.to_file("mol.pdb", file_format="pdb") + # load trajectory and topology + f = md.load_dcd(traj_dir, top = "mol.pdb") # NOTE: pdb file is needed for mdtraj, which reads the topology # this is not very elegant # FIXME: try to load topology directly - top = md.load("m.pdb").topology + top = md.load("mol.pdb").topology # generate trajectory instance - traj = md.Trajectory(f, topology=top) + traj = md.Trajectory(f.xyz, topology=top) # align traj traj.superpose(traj) view = ng.show_mdtraj(traj) - return view + return view \ No newline at end of file diff --git a/scripts/generate_endstate_samples.py b/scripts/generate_endstate_samples.py index 07dd0659..00f22aa1 100644 --- a/scripts/generate_endstate_samples.py +++ b/scripts/generate_endstate_samples.py @@ -43,7 +43,7 @@ topology = molecule.to_topology() system = forcefield.create_openmm_system(topology) -# define region that should be treated with the qml +# define region that should be treated with the nnp ml_atoms = [atom.molecule_particle_index for atom in topology.atoms] integrator = LangevinIntegrator(temperature, collision_rate, stepsize) platform = Platform.getPlatformByName("CUDA") diff --git a/scripts/perform_correction.py b/scripts/perform_correction.py index 6d7c9789..d7597902 100644 --- a/scripts/perform_correction.py +++ b/scripts/perform_correction.py @@ -5,7 +5,7 @@ n_samples = 1_000 n_steps_per_sample = 1_000 run_id = 1 -traj_base = f"{system_name}/equilibrium_samples/run{run_id:0>2d}" # define directory containing MM and QML sampling data +traj_base = f"{system_name}/equilibrium_samples/run{run_id:0>2d}" # define directory containing MM and NNP sampling data output_base = f"{system_name}/switching" # --------------------------------------------- # @@ -35,9 +35,9 @@ topology = molecule.to_topology() system = forcefield.create_openmm_system(topology) -# define region that should be treated with the qml +# define region that should be treated with the nnp ml_atoms = [atom.molecule_particle_index for atom in topology.atoms] -print(ml_atoms) +print(f"{ml_atoms=}") integrator = LangevinIntegrator(temperature, collision_rate, stepsize) platform = Platform.getPlatformByName("CUDA") topology = topology.to_openmm() @@ -59,34 +59,36 @@ ] # discart first 20% of the trajectory print(f"Initializing switch from {len(mm_samples)} MM samples") # --------------------------------------------- # -# load QML samples -qml_samples = [] +# load NNP samples +nnp_samples = [] base = f"{traj_base}/{system_name}_samples_{n_samples}_steps_{n_steps_per_sample}_lamb_1.0000_{env}" -qml_samples = mdtraj.load_dcd( +nnp_samples = mdtraj.load_dcd( f"{base}.dcd", top=f"{traj_base}/{system_name}.pdb", )[ int((1_000 / 100) * 20) : ] # discart first 20% of the trajectory -print(f"Initializing switch from {len(qml_samples)} QML samples") +print(f"Initializing switch from {len(nnp_samples)} NNP samples") # --------------------------------------------- # # ---------------- FEP protocol --------------- # --------------------------------------------- # +# bidirectional fep_protocol = Protocol( method="FEP", sim=sim, reference_samples=mm_samples, - target_samples=qml_samples, + target_samples=nnp_samples, nr_of_switches=1_000, ) # --------------------------------------------- # # ----------------- NEQ protocol -------------- # --------------------------------------------- # +# unidirectional (switching from reference to target) neq_protocol = Protocol( method="NEQ", sim=sim, reference_samples=mm_samples, - #target_samples=qml_samples, + #target_samples=nnp_samples, nr_of_switches=100, neq_switching_length=1_000, save_endstates=True, diff --git a/scripts/perform_correction_hipen.py b/scripts/perform_correction_hipen.py index 304bf8c1..6d57fa4c 100644 --- a/scripts/perform_correction_hipen.py +++ b/scripts/perform_correction_hipen.py @@ -49,7 +49,7 @@ with open("temp.pdb", "w") as outfile: PDBFile.writeFile(psf.topology, crd.positions, outfile) -# define region that should be treated with the qml +# define region that should be treated with the nnp chains = list(psf.topology.chains()) ml_atoms = [atom.index for atom in chains[0].atoms()] print(f"{ml_atoms=}") @@ -66,7 +66,7 @@ n_samples = 5_000 n_steps_per_sample = 1_000 -# define directory containing MM and QML sampling data +# define directory containing MM and NNP sampling data traj_base = f"/data/shared/projects/endstate_rew/{system_name}/sampling_charmmff/" # load MM samples @@ -82,8 +82,8 @@ mm_samples.extend(traj[1000:].xyz * unit.nanometer) # NOTE: this is in nanometer! print(f"Initializing switch from {len(mm_samples)} MM samples") -# load QML samples -qml_samples = [] +# load NNP samples +nnp_samples = [] for i in range(1, 4): base = f"{traj_base}/run0{i}/{system_name}_samples_{n_samples}_steps_{n_steps_per_sample}_lamb_1.0000" # if needed, convert pickle file to dcd @@ -92,8 +92,8 @@ f"{base}.dcd", top=psf_file, ) - qml_samples.extend(traj[1000:].xyz * unit.nanometer) # NOTE: this is in nanometer! -print(f"Initializing switch from {len(mm_samples)} QML samples") + nnp_samples.extend(traj[1000:].xyz * unit.nanometer) # NOTE: this is in nanometer! +print(f"Initializing switch from {len(mm_samples)} NNP samples") ######################################################## ######################################################## @@ -110,7 +110,7 @@ method="FEP", direction="bidirectional", sim=sim, - trajectories=[mm_samples, qml_samples], + trajectories=[mm_samples, nnp_samples], nr_of_switches=10, # 2_000, ) @@ -121,7 +121,7 @@ method="NEQ", direction="bidirectional", sim=sim, - trajectories=[mm_samples, qml_samples], + trajectories=[mm_samples, nnp_samples], nr_of_switches=3, # 500, neq_switching_length=5, # _000, save_endstates=True, diff --git a/setup.py b/setup.py index aaa72229..fcf71a35 100644 --- a/setup.py +++ b/setup.py @@ -1,12 +1,12 @@ """ endstate_correction -Endstate reweighting from MM to QML potential +Endstate correction from MM to NNP """ import sys from setuptools import setup, find_packages import versioneer -short_description = "Endstate reweighting from MM to QML potential".split("\n")[0] +short_description = "Endstate reweighting from MM to NNP".split("\n")[0] # from https://github.com/pytest-dev/pytest-runner#conditional-requirement needs_pytest = {"pytest", "test", "ptr"}.intersection(sys.argv)