Skip to content

Commit

Permalink
Merge pull request #11 from comecattin/10-center-molecule
Browse files Browse the repository at this point in the history
Center molecule
  • Loading branch information
comecattin authored Mar 12, 2024
2 parents 2eca77f + f1809cb commit d406999
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 2 deletions.
3 changes: 2 additions & 1 deletion example/data/input_file/input.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ sigma: 1.
display_energy: True
display_animation: True
printing_step: 10
writing_step: 100
writing_step: 100
center: True
33 changes: 32 additions & 1 deletion src/jax_md/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def compute_forces_and_potential_energy(
return f, potential_energy


@jax.jit
def step(
position: jnp.ndarray,
velocity: jnp.ndarray,
Expand Down Expand Up @@ -168,6 +169,7 @@ def dynamics(
n_steps: int = 1000,
writing_step: int = 100,
printing_step: int = 100,
center: bool = True,
) -> Tuple[jnp.ndarray, list, list, list]:
"""Run the dynamics of the system.
Expand All @@ -193,6 +195,8 @@ def dynamics(
Write the system position every writing_step's time step.
printing_step : int, optional
Print the system energy every printing_step's time step.
center : bool, optional
Center the position of the particules in the box, by default True
Returns
-------
Expand Down Expand Up @@ -246,7 +250,12 @@ def dynamics(

if step_i % writing_step == 0:
print('Saving positions')
position_list.append(position)
if center:
position_list.append(
position_center_box(position,box_size=box_size)
)
else:
position_list.append(position)

return (
jnp.array(position_list),
Expand All @@ -273,6 +282,28 @@ def compute_kinetic_energy(velocity: jnp.ndarray) -> float:
return jnp.sum(velocity ** 2)


def position_center_box(
position: jnp.ndarray,
box_size: float
) -> jnp.ndarray:
"""Center the position of the particules in the box.
Parameters
----------
position : jnp.ndarray
Position of the particules.
The shape of the array is (n, 3) where n is the number of particules.
box_size : float
Size of the simulation box.
Returns
-------
jnp.ndarray
Centered position of the particules.
"""
return position - jnp.floor(position / box_size) * box_size


def main():
"""Run the main function."""
parser = Parser()
Expand Down
5 changes: 5 additions & 0 deletions src/jax_md/parser_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,10 @@ def get_dynamics_kwargs(self):
-------
dict: Dictionary of the keyword arguments for the dynamics.
"""
try:
center = self.arguments['center']
except KeyError:
center = False
return {
'box_size': self.arguments['box_size'],
'position': self.pos,
Expand All @@ -145,6 +149,7 @@ def get_dynamics_kwargs(self):
'sigma': self.arguments['sigma'],
'printing_step': self.arguments['printing_step'],
'writing_step': self.arguments['writing_step'],
'center': center,
}


Expand Down

0 comments on commit d406999

Please sign in to comment.