Skip to content

Commit

Permalink
Merge pull request #11 from mila-iqia/dataset_creation_v1
Browse files Browse the repository at this point in the history
Dataset creation v1
  • Loading branch information
sblackburn86 authored Apr 2, 2024
2 parents a14da1b + 052ef1e commit 46dd060
Show file tree
Hide file tree
Showing 7 changed files with 335 additions and 0 deletions.
1 change: 1 addition & 0 deletions crystal_diffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
TOP_DIR = ROOT_DIR.parent
ANALYSIS_RESULTS_DIR = TOP_DIR.joinpath("analysis_results/")
ANALYSIS_RESULTS_DIR.mkdir(exist_ok=True)
DATA_DIR = TOP_DIR.joinpath("data/")
126 changes: 126 additions & 0 deletions crystal_diffusion/analysis/dataset_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
"""Diffusion Dataset analysis.
This script computes and plots different features of a dataset used to train a diffusion model.
"""
import os
from typing import Dict

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from crystal_diffusion import ANALYSIS_RESULTS_DIR, DATA_DIR
from crystal_diffusion.analysis import PLEASANT_FIG_SIZE, PLOT_STYLE_PATH
from crystal_diffusion.data.parse_lammps_outputs import parse_lammps_output

DATASET_NAME = 'si_diffusion_v1'


def read_lammps_run(run_path: str) -> pd.DataFrame:
"""Read and organize the LAMMPS output files in a dataframe.
Args:
run_path: path to LAMMPS output directory. Should contain a dump file and a thermo log file.
Returns:
output as
"""
dump_file = [d for d in os.listdir(run_path) if 'dump' in d]
thermo_file = [d for d in os.listdir(run_path) if 'thermo' in d]

df = parse_lammps_output(os.path.join(run_path, dump_file[0]), os.path.join(run_path, thermo_file[0]), None)

return df


def compute_metrics_for_a_run(df: pd.DataFrame) -> Dict[str, pd.Series]:
"""Get the energy, forces average, RMS displacement and std dev for a single MD run.
Args:
df: LAMMPS output organized in a DataFrame.
Returns:
metrics evaluated at each MD step organized in a dict
"""
metrics = {}
metrics['energy'] = df['energy']
force_norm_mean = df.apply(lambda row: np.mean([np.sqrt(fx**2 + fy**2 + fz**2) for fx, fy, fz in
zip(row['fx'], row['fy'], row['fz'])]), axis=1)
metrics['force_norm_average'] = force_norm_mean

x0s = df['x'][0]
y0s = df['y'][0]
z0s = df['z'][0]

square_displacement = df.apply(lambda row: [(x - x0) ** 2 + (y - y0) ** 2 + (z - z0) ** 2 for x, y, z, x0, y0, z0 in
zip(row['x'], row['y'], row['z'], x0s, y0s, z0s)], axis=1)

metrics['root_mean_square_displacement'] = square_displacement.apply(lambda row: np.sqrt(np.mean(row)))

metrics['std_displacement'] = np.std(square_displacement.apply(np.sqrt))

return metrics


def plot_metrics_runs(dataset_name: str, mode: str = 'train'):
"""Compute and plot metrics for a dataset made up of several MD runs.
Args:
dataset_name: name of the dataset - should match the name of the folder in DATA_DIR
mode (optional): analyze train or valid data. Defaults to train.
"""
assert mode in ["train", "valid"], f"Mode should be train or valid. Got {mode}"
dataset_path = os.path.join(DATA_DIR, dataset_name)

list_runs = [d for d in os.listdir(dataset_path) if os.path.isdir(os.path.join(dataset_path, d))
and d.startswith(f"{mode}_run") and not d.endswith('backup')]

metrics = {}
for run in list_runs:
df = read_lammps_run(os.path.join(dataset_path, run))
metrics_run = compute_metrics_for_a_run(df)
metrics[run] = metrics_run

plt.style.use(PLOT_STYLE_PATH)

fig, axs = plt.subplots(4, 1, figsize=(PLEASANT_FIG_SIZE[0], 4 * PLEASANT_FIG_SIZE[1]))
fig.suptitle("MD runs properties")

# energy
axs[0].set_title("Energy")
axs[0].set_ylabel("Energy (kcal / mol)")
# forces
axs[1].set_title("Force Norm Averaged over Atoms")
axs[1].set_ylabel(r"Force Norm (g/mol * Angstrom / fs^2)")
# mean squared displacement
axs[2].set_title("RMS Displacement")
axs[2].set_ylabel("RMSD (Angstrom)")
# std squared displacement
axs[3].set_title("Std-Dev Displacement")
axs[3].set_ylabel("Std Displacement (Angstrom)")

legend = []
for k, m in metrics.items():
axs[0].plot(m['energy'], '-', lw=2)
axs[1].plot(m['force_norm_average'], ':', lw=2)
axs[2].plot(m['root_mean_square_displacement'], lw=2)
axs[3].plot(m['std_displacement'], lw=2)
legend.append(k)

for ax in axs:
ax.legend(legend)
ax.set_xlabel("MD step")

fig.tight_layout()

fig.savefig(ANALYSIS_RESULTS_DIR.joinpath(f"{dataset_name}_{mode}_analysis.png"), dpi=300)


def main():
"""Analyze training and validation set of a dataset."""
plot_metrics_runs(DATASET_NAME, mode='train')
plot_metrics_runs(DATASET_NAME, mode='valid')


if __name__ == '__main__':
main()
50 changes: 50 additions & 0 deletions crystal_diffusion/data/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""Utility functions for data processing."""
import os
from typing import Any, AnyStr, Dict, List, Tuple

import yaml


def crop_lammps_yaml(lammps_dump: str, lammps_thermo: str, crop_step: int, inplace: bool = False) \
-> Tuple[List[Dict[AnyStr, Any]], Dict[AnyStr, Any]]:
"""Remove the first steps of a LAMMPS run to remove structures near the starting point.
Args:
lammps_dump: path to LAMMPS output file as a yaml
lammps_thermo: path to LAMMPS thermodynamic output file as a yaml
crop_step: number of steps to remove
inplace (optional): if True, overwrite the two LAMMPS file with a cropped version. If False, do not write.
Defaults to False.
Returns:
cropped LAMMPS output file
cropped LAMMPS thermodynamic output file
"""
if not os.path.exists(lammps_dump):
raise ValueError(f'{lammps_dump} does not exist. Please provide a valid LAMMPS dump file as yaml.')

if not os.path.exists(lammps_thermo):
raise ValueError(f'{lammps_thermo} does not exist. Please provide a valid LAMMPS thermo log file as yaml.')

# get the atom information (positions and forces) from the LAMMPS 'dump' file
with open(lammps_dump, 'r') as f:
dump_yaml = yaml.safe_load_all(f)
dump_yaml = [d for d in dump_yaml] # generator to list
# every MD iteration is saved as a separate document in the yaml file
# prepare a dataframe to get all the data
if crop_step >= len(dump_yaml):
raise ValueError(f"Trying to remove {crop_step} steps in a run of {len(dump_yaml)} steps.")
dump_yaml = dump_yaml[crop_step:]

# get the total energy from the LAMMPS thermodynamic output
with open(lammps_thermo, 'r') as f:
thermo_yaml = yaml.safe_load(f)
thermo_yaml['data'] = thermo_yaml['data'][crop_step:]

if inplace:
with open("test_yaml.yaml", "w") as f:
yaml.dump_all(dump_yaml, f, explicit_start=True)
with open("test_thermo.yaml", "w") as f:
yaml.dump(thermo_yaml, f)

return dump_yaml, thermo_yaml
34 changes: 34 additions & 0 deletions data/crop_lammps_outputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""Read and crop LAMMPS outputs."""
import argparse
import os

import yaml

from crystal_diffusion.data.utils import crop_lammps_yaml


def main():
"""Read LAMMPS outputs from arguments and crops."""
parser = argparse.ArgumentParser()
parser.add_argument('--lammps_yaml', help='path to LAMMPS yaml file', required=True)
parser.add_argument('--lammps_thermo', help='path to LAMMPS thermo output', required=True)
parser.add_argument('--crop', type=int, help='number of steps to remove at the start of the run', required=True)
parser.add_argument('--output_dir', help='path to folder where outputs will be saved', required=True)
args = parser.parse_args()

lammps_yaml = args.lammps_yaml
lammps_thermo_yaml = args.lammps_thermo

lammps_yaml, lammps_thermo_yaml = crop_lammps_yaml(lammps_yaml, lammps_thermo_yaml, args.crop, inplace=False)

if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)

with open(os.path.join(args.output_dir, 'lammps_dump.yaml'), 'w') as f:
yaml.dump_all(lammps_yaml, f, explicit_start=True)
with open(os.path.join(args.output_dir, 'lammps_thermo.yaml'), 'w') as f:
yaml.dump(lammps_thermo_yaml, f)


if __name__ == '__main__':
main()
38 changes: 38 additions & 0 deletions data/si_diffusion_v1/create_data.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#!/bin/bash

TEMPERATURE=300
BOX_SIZE=4
STEP=1000
CROP=100
NTRAIN_RUN=10
NVALID_RUN=5

NRUN=$(($NTRAIN_RUN + $NVALID_RUN))

for SEED in $(seq 1 $NRUN);
do
if [ "$SEED" -le $NTRAIN_RUN ]; then
MODE="train"
else
MODE="valid"
fi
echo $MODE $SEED
mkdir -p "${MODE}_run_${SEED}"
cd "${MODE}_run_${SEED}"
lmp < ../in.si.lammps -v STEP $(($STEP + $CROP)) -v T $TEMPERATURE -v S $BOX_SIZE -v SEED $SEED

# extract the thermodynamic outputs in a yaml file
egrep '^(keywords:|data:$|---$|\.\.\.$| - \[)' log.lammps > thermo_log.yaml

mkdir -p "uncropped_outputs"
mv "dump.si-${TEMPERATURE}-${BOX_SIZE}.yaml" uncropped_outputs/
mv thermo_log.yaml uncropped_outputs/

python ../../crop_lammps_outputs.py \
--lammps_yaml "uncropped_outputs/dump.si-${TEMPERATURE}-${BOX_SIZE}.yaml" \
--lammps_thermo "uncropped_outputs/thermo_log.yaml" \
--crop $CROP \
--output_dir ./

cd ..
done
29 changes: 29 additions & 0 deletions data/si_diffusion_v1/in.si.lammps
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
log log.lammps

units metal
atom_style atomic
atom_modify map array

lattice diamond 5.43
region simbox block 0 ${S} 0 ${S} 0 ${S}
create_box 1 simbox
create_atoms 1 region simbox

mass 1 28.0855

group Si type 1

pair_style sw
pair_coeff * * ../../Si.sw Si

velocity all create ${T} ${SEED}

dump 1 all yaml 1 dump.si-${T}-${S}.yaml id type x y z fx fy fz

thermo_style yaml
thermo 1
#==========================Output files========================

fix 1 all nvt temp ${T} ${T} 0.01
run ${STEP}
unfix 1
57 changes: 57 additions & 0 deletions tests/data/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import pytest
import yaml

from crystal_diffusion.data.utils import crop_lammps_yaml

# Sample data for dump and thermo YAML files
DUMP_YAML_CONTENT = """
- step: 0
data: {...}
---
- step: 1
data: {...}
---
- step: 2
data: {...}
"""
THERMO_YAML_CONTENT = """
data:
- {...}
- {...}
- {...}
"""


@pytest.fixture
def dump_file(tmpdir):
file = tmpdir.join("lammps_dump.yaml")
file.write(DUMP_YAML_CONTENT)
return str(file)


@pytest.fixture
def thermo_file(tmpdir):
file = tmpdir.join("lammps_thermo.yaml")
file.write(THERMO_YAML_CONTENT)
return str(file)


def test_crop_lammps_yaml(dump_file, thermo_file):
crop_step = 1
# Call the function with the path to the temporary files
cropped_dump, cropped_thermo = crop_lammps_yaml(dump_file, thermo_file, crop_step)

# Load the content to assert correctness
with open(dump_file) as f:
dump_yaml_content = list(yaml.safe_load_all(f))

with open(thermo_file) as f:
thermo_yaml_content = yaml.safe_load(f)

# Verify the function output
assert len(cropped_dump) == len(dump_yaml_content) - crop_step
assert len(cropped_thermo['data']) == len(thermo_yaml_content['data']) - crop_step

# Testing exception for too large crop_step
with pytest.raises(ValueError):
crop_lammps_yaml(dump_file, thermo_file, 4)

0 comments on commit 46dd060

Please sign in to comment.