-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #11 from mila-iqia/dataset_creation_v1
Dataset creation v1
- Loading branch information
Showing
7 changed files
with
335 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
"""Diffusion Dataset analysis. | ||
This script computes and plots different features of a dataset used to train a diffusion model. | ||
""" | ||
import os | ||
from typing import Dict | ||
|
||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import pandas as pd | ||
|
||
from crystal_diffusion import ANALYSIS_RESULTS_DIR, DATA_DIR | ||
from crystal_diffusion.analysis import PLEASANT_FIG_SIZE, PLOT_STYLE_PATH | ||
from crystal_diffusion.data.parse_lammps_outputs import parse_lammps_output | ||
|
||
DATASET_NAME = 'si_diffusion_v1' | ||
|
||
|
||
def read_lammps_run(run_path: str) -> pd.DataFrame: | ||
"""Read and organize the LAMMPS output files in a dataframe. | ||
Args: | ||
run_path: path to LAMMPS output directory. Should contain a dump file and a thermo log file. | ||
Returns: | ||
output as | ||
""" | ||
dump_file = [d for d in os.listdir(run_path) if 'dump' in d] | ||
thermo_file = [d for d in os.listdir(run_path) if 'thermo' in d] | ||
|
||
df = parse_lammps_output(os.path.join(run_path, dump_file[0]), os.path.join(run_path, thermo_file[0]), None) | ||
|
||
return df | ||
|
||
|
||
def compute_metrics_for_a_run(df: pd.DataFrame) -> Dict[str, pd.Series]: | ||
"""Get the energy, forces average, RMS displacement and std dev for a single MD run. | ||
Args: | ||
df: LAMMPS output organized in a DataFrame. | ||
Returns: | ||
metrics evaluated at each MD step organized in a dict | ||
""" | ||
metrics = {} | ||
metrics['energy'] = df['energy'] | ||
force_norm_mean = df.apply(lambda row: np.mean([np.sqrt(fx**2 + fy**2 + fz**2) for fx, fy, fz in | ||
zip(row['fx'], row['fy'], row['fz'])]), axis=1) | ||
metrics['force_norm_average'] = force_norm_mean | ||
|
||
x0s = df['x'][0] | ||
y0s = df['y'][0] | ||
z0s = df['z'][0] | ||
|
||
square_displacement = df.apply(lambda row: [(x - x0) ** 2 + (y - y0) ** 2 + (z - z0) ** 2 for x, y, z, x0, y0, z0 in | ||
zip(row['x'], row['y'], row['z'], x0s, y0s, z0s)], axis=1) | ||
|
||
metrics['root_mean_square_displacement'] = square_displacement.apply(lambda row: np.sqrt(np.mean(row))) | ||
|
||
metrics['std_displacement'] = np.std(square_displacement.apply(np.sqrt)) | ||
|
||
return metrics | ||
|
||
|
||
def plot_metrics_runs(dataset_name: str, mode: str = 'train'): | ||
"""Compute and plot metrics for a dataset made up of several MD runs. | ||
Args: | ||
dataset_name: name of the dataset - should match the name of the folder in DATA_DIR | ||
mode (optional): analyze train or valid data. Defaults to train. | ||
""" | ||
assert mode in ["train", "valid"], f"Mode should be train or valid. Got {mode}" | ||
dataset_path = os.path.join(DATA_DIR, dataset_name) | ||
|
||
list_runs = [d for d in os.listdir(dataset_path) if os.path.isdir(os.path.join(dataset_path, d)) | ||
and d.startswith(f"{mode}_run") and not d.endswith('backup')] | ||
|
||
metrics = {} | ||
for run in list_runs: | ||
df = read_lammps_run(os.path.join(dataset_path, run)) | ||
metrics_run = compute_metrics_for_a_run(df) | ||
metrics[run] = metrics_run | ||
|
||
plt.style.use(PLOT_STYLE_PATH) | ||
|
||
fig, axs = plt.subplots(4, 1, figsize=(PLEASANT_FIG_SIZE[0], 4 * PLEASANT_FIG_SIZE[1])) | ||
fig.suptitle("MD runs properties") | ||
|
||
# energy | ||
axs[0].set_title("Energy") | ||
axs[0].set_ylabel("Energy (kcal / mol)") | ||
# forces | ||
axs[1].set_title("Force Norm Averaged over Atoms") | ||
axs[1].set_ylabel(r"Force Norm (g/mol * Angstrom / fs^2)") | ||
# mean squared displacement | ||
axs[2].set_title("RMS Displacement") | ||
axs[2].set_ylabel("RMSD (Angstrom)") | ||
# std squared displacement | ||
axs[3].set_title("Std-Dev Displacement") | ||
axs[3].set_ylabel("Std Displacement (Angstrom)") | ||
|
||
legend = [] | ||
for k, m in metrics.items(): | ||
axs[0].plot(m['energy'], '-', lw=2) | ||
axs[1].plot(m['force_norm_average'], ':', lw=2) | ||
axs[2].plot(m['root_mean_square_displacement'], lw=2) | ||
axs[3].plot(m['std_displacement'], lw=2) | ||
legend.append(k) | ||
|
||
for ax in axs: | ||
ax.legend(legend) | ||
ax.set_xlabel("MD step") | ||
|
||
fig.tight_layout() | ||
|
||
fig.savefig(ANALYSIS_RESULTS_DIR.joinpath(f"{dataset_name}_{mode}_analysis.png"), dpi=300) | ||
|
||
|
||
def main(): | ||
"""Analyze training and validation set of a dataset.""" | ||
plot_metrics_runs(DATASET_NAME, mode='train') | ||
plot_metrics_runs(DATASET_NAME, mode='valid') | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
"""Utility functions for data processing.""" | ||
import os | ||
from typing import Any, AnyStr, Dict, List, Tuple | ||
|
||
import yaml | ||
|
||
|
||
def crop_lammps_yaml(lammps_dump: str, lammps_thermo: str, crop_step: int, inplace: bool = False) \ | ||
-> Tuple[List[Dict[AnyStr, Any]], Dict[AnyStr, Any]]: | ||
"""Remove the first steps of a LAMMPS run to remove structures near the starting point. | ||
Args: | ||
lammps_dump: path to LAMMPS output file as a yaml | ||
lammps_thermo: path to LAMMPS thermodynamic output file as a yaml | ||
crop_step: number of steps to remove | ||
inplace (optional): if True, overwrite the two LAMMPS file with a cropped version. If False, do not write. | ||
Defaults to False. | ||
Returns: | ||
cropped LAMMPS output file | ||
cropped LAMMPS thermodynamic output file | ||
""" | ||
if not os.path.exists(lammps_dump): | ||
raise ValueError(f'{lammps_dump} does not exist. Please provide a valid LAMMPS dump file as yaml.') | ||
|
||
if not os.path.exists(lammps_thermo): | ||
raise ValueError(f'{lammps_thermo} does not exist. Please provide a valid LAMMPS thermo log file as yaml.') | ||
|
||
# get the atom information (positions and forces) from the LAMMPS 'dump' file | ||
with open(lammps_dump, 'r') as f: | ||
dump_yaml = yaml.safe_load_all(f) | ||
dump_yaml = [d for d in dump_yaml] # generator to list | ||
# every MD iteration is saved as a separate document in the yaml file | ||
# prepare a dataframe to get all the data | ||
if crop_step >= len(dump_yaml): | ||
raise ValueError(f"Trying to remove {crop_step} steps in a run of {len(dump_yaml)} steps.") | ||
dump_yaml = dump_yaml[crop_step:] | ||
|
||
# get the total energy from the LAMMPS thermodynamic output | ||
with open(lammps_thermo, 'r') as f: | ||
thermo_yaml = yaml.safe_load(f) | ||
thermo_yaml['data'] = thermo_yaml['data'][crop_step:] | ||
|
||
if inplace: | ||
with open("test_yaml.yaml", "w") as f: | ||
yaml.dump_all(dump_yaml, f, explicit_start=True) | ||
with open("test_thermo.yaml", "w") as f: | ||
yaml.dump(thermo_yaml, f) | ||
|
||
return dump_yaml, thermo_yaml |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
"""Read and crop LAMMPS outputs.""" | ||
import argparse | ||
import os | ||
|
||
import yaml | ||
|
||
from crystal_diffusion.data.utils import crop_lammps_yaml | ||
|
||
|
||
def main(): | ||
"""Read LAMMPS outputs from arguments and crops.""" | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--lammps_yaml', help='path to LAMMPS yaml file', required=True) | ||
parser.add_argument('--lammps_thermo', help='path to LAMMPS thermo output', required=True) | ||
parser.add_argument('--crop', type=int, help='number of steps to remove at the start of the run', required=True) | ||
parser.add_argument('--output_dir', help='path to folder where outputs will be saved', required=True) | ||
args = parser.parse_args() | ||
|
||
lammps_yaml = args.lammps_yaml | ||
lammps_thermo_yaml = args.lammps_thermo | ||
|
||
lammps_yaml, lammps_thermo_yaml = crop_lammps_yaml(lammps_yaml, lammps_thermo_yaml, args.crop, inplace=False) | ||
|
||
if not os.path.exists(args.output_dir): | ||
os.makedirs(args.output_dir) | ||
|
||
with open(os.path.join(args.output_dir, 'lammps_dump.yaml'), 'w') as f: | ||
yaml.dump_all(lammps_yaml, f, explicit_start=True) | ||
with open(os.path.join(args.output_dir, 'lammps_thermo.yaml'), 'w') as f: | ||
yaml.dump(lammps_thermo_yaml, f) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
#!/bin/bash | ||
|
||
TEMPERATURE=300 | ||
BOX_SIZE=4 | ||
STEP=1000 | ||
CROP=100 | ||
NTRAIN_RUN=10 | ||
NVALID_RUN=5 | ||
|
||
NRUN=$(($NTRAIN_RUN + $NVALID_RUN)) | ||
|
||
for SEED in $(seq 1 $NRUN); | ||
do | ||
if [ "$SEED" -le $NTRAIN_RUN ]; then | ||
MODE="train" | ||
else | ||
MODE="valid" | ||
fi | ||
echo $MODE $SEED | ||
mkdir -p "${MODE}_run_${SEED}" | ||
cd "${MODE}_run_${SEED}" | ||
lmp < ../in.si.lammps -v STEP $(($STEP + $CROP)) -v T $TEMPERATURE -v S $BOX_SIZE -v SEED $SEED | ||
|
||
# extract the thermodynamic outputs in a yaml file | ||
egrep '^(keywords:|data:$|---$|\.\.\.$| - \[)' log.lammps > thermo_log.yaml | ||
|
||
mkdir -p "uncropped_outputs" | ||
mv "dump.si-${TEMPERATURE}-${BOX_SIZE}.yaml" uncropped_outputs/ | ||
mv thermo_log.yaml uncropped_outputs/ | ||
|
||
python ../../crop_lammps_outputs.py \ | ||
--lammps_yaml "uncropped_outputs/dump.si-${TEMPERATURE}-${BOX_SIZE}.yaml" \ | ||
--lammps_thermo "uncropped_outputs/thermo_log.yaml" \ | ||
--crop $CROP \ | ||
--output_dir ./ | ||
|
||
cd .. | ||
done |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
log log.lammps | ||
|
||
units metal | ||
atom_style atomic | ||
atom_modify map array | ||
|
||
lattice diamond 5.43 | ||
region simbox block 0 ${S} 0 ${S} 0 ${S} | ||
create_box 1 simbox | ||
create_atoms 1 region simbox | ||
|
||
mass 1 28.0855 | ||
|
||
group Si type 1 | ||
|
||
pair_style sw | ||
pair_coeff * * ../../Si.sw Si | ||
|
||
velocity all create ${T} ${SEED} | ||
|
||
dump 1 all yaml 1 dump.si-${T}-${S}.yaml id type x y z fx fy fz | ||
|
||
thermo_style yaml | ||
thermo 1 | ||
#==========================Output files======================== | ||
|
||
fix 1 all nvt temp ${T} ${T} 0.01 | ||
run ${STEP} | ||
unfix 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import pytest | ||
import yaml | ||
|
||
from crystal_diffusion.data.utils import crop_lammps_yaml | ||
|
||
# Sample data for dump and thermo YAML files | ||
DUMP_YAML_CONTENT = """ | ||
- step: 0 | ||
data: {...} | ||
--- | ||
- step: 1 | ||
data: {...} | ||
--- | ||
- step: 2 | ||
data: {...} | ||
""" | ||
THERMO_YAML_CONTENT = """ | ||
data: | ||
- {...} | ||
- {...} | ||
- {...} | ||
""" | ||
|
||
|
||
@pytest.fixture | ||
def dump_file(tmpdir): | ||
file = tmpdir.join("lammps_dump.yaml") | ||
file.write(DUMP_YAML_CONTENT) | ||
return str(file) | ||
|
||
|
||
@pytest.fixture | ||
def thermo_file(tmpdir): | ||
file = tmpdir.join("lammps_thermo.yaml") | ||
file.write(THERMO_YAML_CONTENT) | ||
return str(file) | ||
|
||
|
||
def test_crop_lammps_yaml(dump_file, thermo_file): | ||
crop_step = 1 | ||
# Call the function with the path to the temporary files | ||
cropped_dump, cropped_thermo = crop_lammps_yaml(dump_file, thermo_file, crop_step) | ||
|
||
# Load the content to assert correctness | ||
with open(dump_file) as f: | ||
dump_yaml_content = list(yaml.safe_load_all(f)) | ||
|
||
with open(thermo_file) as f: | ||
thermo_yaml_content = yaml.safe_load(f) | ||
|
||
# Verify the function output | ||
assert len(cropped_dump) == len(dump_yaml_content) - crop_step | ||
assert len(cropped_thermo['data']) == len(thermo_yaml_content['data']) - crop_step | ||
|
||
# Testing exception for too large crop_step | ||
with pytest.raises(ValueError): | ||
crop_lammps_yaml(dump_file, thermo_file, 4) |