Skip to content

Commit

Permalink
Add equilibration support to RepexRunner.
Browse files Browse the repository at this point in the history
  • Loading branch information
lohedges committed Feb 10, 2025
1 parent c2a230e commit dc4af0f
Showing 1 changed file with 138 additions and 1 deletion.
139 changes: 138 additions & 1 deletion src/somd2/runner/_repex.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,33 @@ def get(self, index):
"""
return self._dynamics[index]

def set(self, index, dynamics):
"""
Set the dynamics object for a given index.
Parameters
----------
index: int
The index of the replica.
dynamics: sire.legacy.Convert.SOMMContext
The dynamics object.
"""
self._dynamics[index] = dynamics

def delete(self, index):
"""
Delete the dynamics object for a given index.
Parameters
----------
index: int
The index of the replica.
"""
self._dynamics[index] = None

def save_openmm_state(self, index):
"""
Save the state of the dynamics object.
Expand Down Expand Up @@ -346,14 +373,38 @@ def __init__(self, system, config):
# Store the name of the replica exchange swap acceptance matrix.
self._repex_matrix = self._config.output_directory / "repex_matrix.txt"

# Flag that we haven't equilibrated.
self._is_equilibration = False

# Create the dynamics cache.
if not self._is_restart:
dynamics_kwargs = self._dynamics_kwargs.copy()

if self._config.equilibration_time.value() > 0.0:
self._is_equilibration = True

# Overload the dynamics kwargs with the equilibration options.
dynamics_kwargs.update(
{
"constraint": (
"none"
if not self._config.equilibration_constraints
else self._config.constraint
),
"perturbable_constraint": (
"none"
if not self._config.equilibration_constraints
else self._config.perturbable_constraint
),
}
)

self._dynamics_cache = DynamicsCache(
self._system,
self._lambda_values,
self._rest2_scale_factors,
self._num_gpus,
self._dynamics_kwargs,
dynamics_kwargs,
)
else:
# Check to see if the simulation is already complete.
Expand Down Expand Up @@ -498,6 +549,24 @@ def run(self):
_logger.error("Minimisation cancelled. Exiting.")
exit(1)

# Equilibrate the system.
if self._is_equilibration:
for i in range(num_batches):
with ThreadPoolExecutor(max_workers=num_workers) as executor:
try:
for success, index, e in executor.map(
self._equilibrate,
replica_list[i * num_workers : (i + 1) * num_workers],
):
if not success:
_logger.error(
f"Equilibration failed for {_lam_sym} = {self._lambda_values[index]:.5f}: {e}"
)
raise e
except KeyboardInterrupt:
_logger.error("Equilibration cancelled. Exiting.")
exit(1)

# Current block number.
block = self._start_block

Expand Down Expand Up @@ -762,6 +831,74 @@ def _minimise(self, index):

return True, index, None

def _equilibrate(self, index):
"""
Equilibrate the system.
Parameters
----------
index: int
The index of the replica.
Returns
-------
success: bool
Whether the equilibration was successful.
index: int
The index of the replica.
exception: Exception
The exception if the equilibration failed.
"""
_logger.info(f"Equilibrating at {_lam_sym} = {self._lambda_values[index]:.5f}")

try:
# Get the dynamics object.
dynamics = self._dynamics_cache.get(index)

# Equilibrate.
dynamics.run(
self._config.equilibration_time,
energy_frequency=0,
frame_frequency=0,
)

# Commit the system.
system = dynamics.commit()

# Delete the dynamics object.
self._dynamics_cache.delete(index)

# Work out the device index.
device = index % self._num_gpus

_logger.info(
f"Creating production dynamics object for {_lam_sym} = "
f"{self._lambda_values[index]:.5f}"
)

# Copy the dynamics keyword arguments.
dynamics_kwargs = self._dynamics_kwargs.copy()

# Overload the device and lambda value.
dynamics_kwargs["device"] = device
dynamics_kwargs["lambda_value"] = self._lambda_values[index]
dynamics_kwargs["rest2_scale"] = self._rest2_scale_factors[index]

# Create the dynamics object.
dynamics = system.dynamics(**dynamics_kwargs)

# Set the new dynamics object.
self._dynamics_cache.set(index, dynamics)

except Exception as e:
return False, index, e

return True, index, None

def _assemble_results(self, results):
"""
Assemble the results into a matrix.
Expand Down

0 comments on commit dc4af0f

Please sign in to comment.