diff --git a/.gitignore b/.gitignore index 2cc0b025..7455b4a2 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,12 @@ examples/data/ examples/*/output/ examples/*/lightning_logs/ +**/train_run*/ +**/valid_run*/ +**/processed/ +**/cache/ +**/output/ + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/data/SiGe_diffusion_1x1x1/config.yaml b/data/SiGe_diffusion_1x1x1/config.yaml new file mode 100644 index 00000000..0c428131 --- /dev/null +++ b/data/SiGe_diffusion_1x1x1/config.yaml @@ -0,0 +1,6 @@ +# Configuration for the dataloader +batch_size: 1024 +num_workers: 0 +max_atom: 8 +spatial_dimension: 3 +elements: [Si, Ge] \ No newline at end of file diff --git a/data/SiGe_diffusion_1x1x1/create_data.sh b/data/SiGe_diffusion_1x1x1/create_data.sh new file mode 100755 index 00000000..87177fc6 --- /dev/null +++ b/data/SiGe_diffusion_1x1x1/create_data.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +source ../data_generation_functions.sh + +TEMPERATURE=300 +BOX_SIZE=1 +STEP=10000 +CROP=10000 +NTRAIN_RUN=10 +NVALID_RUN=5 + +SW_PATH="../stillinger_weber_coefficients/SiGe.sw" +IN_PATH="in.SiGe.lammps" +CONFIG_PATH="config.yaml" + +create_data_function $TEMPERATURE $BOX_SIZE $STEP $CROP $NTRAIN_RUN $NVALID_RUN $SW_PATH $IN_PATH $CONFIG_PATH diff --git a/data/SiGe_diffusion_1x1x1/in.SiGe.lammps b/data/SiGe_diffusion_1x1x1/in.SiGe.lammps new file mode 100644 index 00000000..30afdf6b --- /dev/null +++ b/data/SiGe_diffusion_1x1x1/in.SiGe.lammps @@ -0,0 +1,34 @@ +log log.lammps + +units metal +atom_style atomic +atom_modify map array + +lattice diamond 5.5421217827 +region box block 0 ${S} 0 ${S} 0 ${S} + +create_box 2 box +create_atoms 1 box basis 1 1 basis 2 1 basis 3 1 basis 4 1 basis 5 2 basis 6 2 basis 7 2 basis 8 2 + + +mass 1 28.0855 +mass 2 72.64 + +group Si type 1 +group Ge type 2 + +pair_style sw +pair_coeff * * ${SW_PATH} Si Ge + +velocity all create ${T} ${SEED} + +dump dump_id all yaml 1 dump.${T}-${S}.yaml id element x y z fx fy fz +dump_modify dump_id element Si Ge + +thermo_style yaml +thermo 1 +#==========================Output files======================== + +fix 1 all nvt temp ${T} ${T} 0.01 +run ${STEP} +unfix 1 diff --git a/data/SiGe_diffusion_2x2x2/config.yaml b/data/SiGe_diffusion_2x2x2/config.yaml new file mode 100644 index 00000000..a02c5af3 --- /dev/null +++ b/data/SiGe_diffusion_2x2x2/config.yaml @@ -0,0 +1,6 @@ +# Configuration for the dataloader +batch_size: 1024 +num_workers: 0 +max_atom: 64 +spatial_dimension: 3 +elements: [Si, Ge] diff --git a/data/SiGe_diffusion_2x2x2/create_data.sh b/data/SiGe_diffusion_2x2x2/create_data.sh new file mode 100755 index 00000000..a7b7b38a --- /dev/null +++ b/data/SiGe_diffusion_2x2x2/create_data.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +source ../data_generation_functions.sh + +TEMPERATURE=300 +BOX_SIZE=2 +STEP=10000 +CROP=10000 +NTRAIN_RUN=10 +NVALID_RUN=5 + +SW_PATH="../stillinger_weber_coefficients/SiGe.sw" +IN_PATH="in.SiGe.lammps" +CONFIG_PATH="config.yaml" + +create_data_function $TEMPERATURE $BOX_SIZE $STEP $CROP $NTRAIN_RUN $NVALID_RUN $SW_PATH $IN_PATH $CONFIG_PATH diff --git a/data/SiGe_diffusion_2x2x2/in.SiGe.lammps b/data/SiGe_diffusion_2x2x2/in.SiGe.lammps new file mode 100644 index 00000000..30afdf6b --- /dev/null +++ b/data/SiGe_diffusion_2x2x2/in.SiGe.lammps @@ -0,0 +1,34 @@ +log log.lammps + +units metal +atom_style atomic +atom_modify map array + +lattice diamond 5.5421217827 +region box block 0 ${S} 0 ${S} 0 ${S} + +create_box 2 box +create_atoms 1 box basis 1 1 basis 2 1 basis 3 1 basis 4 1 basis 5 2 basis 6 2 basis 7 2 basis 8 2 + + +mass 1 28.0855 +mass 2 72.64 + +group Si type 1 +group Ge type 2 + +pair_style sw +pair_coeff * * ${SW_PATH} Si Ge + +velocity all create ${T} ${SEED} + +dump dump_id all yaml 1 dump.${T}-${S}.yaml id element x y z fx fy fz +dump_modify dump_id element Si Ge + +thermo_style yaml +thermo 1 +#==========================Output files======================== + +fix 1 all nvt temp ${T} ${T} 0.01 +run ${STEP} +unfix 1 diff --git a/data/SiGe_diffusion_3x3x3/config.yaml b/data/SiGe_diffusion_3x3x3/config.yaml new file mode 100644 index 00000000..299eaa3a --- /dev/null +++ b/data/SiGe_diffusion_3x3x3/config.yaml @@ -0,0 +1,6 @@ +# Configuration for the dataloader +batch_size: 1024 +num_workers: 0 +max_atom: 216 +spatial_dimension: 3 +elements: [Si, Ge] diff --git a/data/SiGe_diffusion_3x3x3/create_data.sh b/data/SiGe_diffusion_3x3x3/create_data.sh new file mode 100755 index 00000000..d8aff091 --- /dev/null +++ b/data/SiGe_diffusion_3x3x3/create_data.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +source ../data_generation_functions.sh + +TEMPERATURE=300 +BOX_SIZE=3 +STEP=10000 +CROP=10000 +NTRAIN_RUN=10 +NVALID_RUN=5 + +SW_PATH="../stillinger_weber_coefficients/SiGe.sw" +IN_PATH="in.SiGe.lammps" +CONFIG_PATH="config.yaml" + +create_data_function $TEMPERATURE $BOX_SIZE $STEP $CROP $NTRAIN_RUN $NVALID_RUN $SW_PATH $IN_PATH $CONFIG_PATH diff --git a/data/SiGe_diffusion_3x3x3/in.SiGe.lammps b/data/SiGe_diffusion_3x3x3/in.SiGe.lammps new file mode 100644 index 00000000..30afdf6b --- /dev/null +++ b/data/SiGe_diffusion_3x3x3/in.SiGe.lammps @@ -0,0 +1,34 @@ +log log.lammps + +units metal +atom_style atomic +atom_modify map array + +lattice diamond 5.5421217827 +region box block 0 ${S} 0 ${S} 0 ${S} + +create_box 2 box +create_atoms 1 box basis 1 1 basis 2 1 basis 3 1 basis 4 1 basis 5 2 basis 6 2 basis 7 2 basis 8 2 + + +mass 1 28.0855 +mass 2 72.64 + +group Si type 1 +group Ge type 2 + +pair_style sw +pair_coeff * * ${SW_PATH} Si Ge + +velocity all create ${T} ${SEED} + +dump dump_id all yaml 1 dump.${T}-${S}.yaml id element x y z fx fy fz +dump_modify dump_id element Si Ge + +thermo_style yaml +thermo 1 +#==========================Output files======================== + +fix 1 all nvt temp ${T} ${T} 0.01 +run ${STEP} +unfix 1 diff --git a/data/Si_diffusion_1x1x1/config.yaml b/data/Si_diffusion_1x1x1/config.yaml new file mode 100644 index 00000000..e2282c99 --- /dev/null +++ b/data/Si_diffusion_1x1x1/config.yaml @@ -0,0 +1,6 @@ +# Configuration for the dataloader +batch_size: 1024 +num_workers: 0 +max_atom: 8 +spatial_dimension: 3 +elements: [Si] diff --git a/data/Si_diffusion_1x1x1/create_data.sh b/data/Si_diffusion_1x1x1/create_data.sh new file mode 100755 index 00000000..b34f3ba2 --- /dev/null +++ b/data/Si_diffusion_1x1x1/create_data.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +source ../data_generation_functions.sh + +TEMPERATURE=300 +BOX_SIZE=1 +STEP=10000 +CROP=10000 +NTRAIN_RUN=10 +NVALID_RUN=5 + +SW_PATH="../stillinger_weber_coefficients/Si.sw" +IN_PATH="in.Si.lammps" +CONFIG_PATH="config.yaml" + +create_data_function $TEMPERATURE $BOX_SIZE $STEP $CROP $NTRAIN_RUN $NVALID_RUN $SW_PATH $IN_PATH $CONFIG_PATH diff --git a/data/si_diffusion_2x2x2/in.si.lammps b/data/Si_diffusion_1x1x1/in.Si.lammps old mode 100755 new mode 100644 similarity index 77% rename from data/si_diffusion_2x2x2/in.si.lammps rename to data/Si_diffusion_1x1x1/in.Si.lammps index 17f20e42..3ad49932 --- a/data/si_diffusion_2x2x2/in.si.lammps +++ b/data/Si_diffusion_1x1x1/in.Si.lammps @@ -14,11 +14,13 @@ mass 1 28.0855 group Si type 1 pair_style sw -pair_coeff * * ../../si.sw Si +pair_coeff * * ${SW_PATH} Si + velocity all create ${T} ${SEED} -dump 1 all yaml 1 dump.si-${T}-${S}.yaml id type x y z fx fy fz +dump dump_id all yaml 1 dump.${T}-${S}.yaml id element x y z fx fy fz +dump_modify dump_id element Si thermo_style yaml thermo 1 diff --git a/data/Si_diffusion_2x2x2/config.yaml b/data/Si_diffusion_2x2x2/config.yaml new file mode 100644 index 00000000..a8256af2 --- /dev/null +++ b/data/Si_diffusion_2x2x2/config.yaml @@ -0,0 +1,6 @@ +# Configuration for the dataloader +batch_size: 1024 +num_workers: 0 +max_atom: 64 +spatial_dimension: 3 +elements: [Si] diff --git a/data/Si_diffusion_2x2x2/create_data.sh b/data/Si_diffusion_2x2x2/create_data.sh new file mode 100755 index 00000000..072e8822 --- /dev/null +++ b/data/Si_diffusion_2x2x2/create_data.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +source ../data_generation_functions.sh + +TEMPERATURE=300 +BOX_SIZE=2 +STEP=10000 +CROP=10000 +NTRAIN_RUN=10 +NVALID_RUN=5 + +SW_PATH="../stillinger_weber_coefficients/Si.sw" +IN_PATH="in.Si.lammps" +CONFIG_PATH="config.yaml" + +create_data_function $TEMPERATURE $BOX_SIZE $STEP $CROP $NTRAIN_RUN $NVALID_RUN $SW_PATH $IN_PATH $CONFIG_PATH diff --git a/data/si_diffusion_1x1x1_large/in.si.lammps b/data/Si_diffusion_2x2x2/in.Si.lammps old mode 100755 new mode 100644 similarity index 77% rename from data/si_diffusion_1x1x1_large/in.si.lammps rename to data/Si_diffusion_2x2x2/in.Si.lammps index 17f20e42..3ad49932 --- a/data/si_diffusion_1x1x1_large/in.si.lammps +++ b/data/Si_diffusion_2x2x2/in.Si.lammps @@ -14,11 +14,13 @@ mass 1 28.0855 group Si type 1 pair_style sw -pair_coeff * * ../../si.sw Si +pair_coeff * * ${SW_PATH} Si + velocity all create ${T} ${SEED} -dump 1 all yaml 1 dump.si-${T}-${S}.yaml id type x y z fx fy fz +dump dump_id all yaml 1 dump.${T}-${S}.yaml id element x y z fx fy fz +dump_modify dump_id element Si thermo_style yaml thermo 1 diff --git a/data/Si_diffusion_3x3x3/config.yaml b/data/Si_diffusion_3x3x3/config.yaml new file mode 100644 index 00000000..fe0287db --- /dev/null +++ b/data/Si_diffusion_3x3x3/config.yaml @@ -0,0 +1,6 @@ +# Configuration for the dataloader +batch_size: 1024 +num_workers: 0 +max_atom: 216 +spatial_dimension: 3 +elements: [Si] diff --git a/data/Si_diffusion_3x3x3/create_data.sh b/data/Si_diffusion_3x3x3/create_data.sh new file mode 100755 index 00000000..6d4e581f --- /dev/null +++ b/data/Si_diffusion_3x3x3/create_data.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +source ../data_generation_functions.sh + +TEMPERATURE=300 +BOX_SIZE=3 +STEP=10000 +CROP=10000 +NTRAIN_RUN=10 +NVALID_RUN=5 + +SW_PATH="../stillinger_weber_coefficients/Si.sw" +IN_PATH="in.Si.lammps" +CONFIG_PATH="config.yaml" + +create_data_function $TEMPERATURE $BOX_SIZE $STEP $CROP $NTRAIN_RUN $NVALID_RUN $SW_PATH $IN_PATH $CONFIG_PATH diff --git a/data/si_diffusion_1x1x1/in.si.lammps b/data/Si_diffusion_3x3x3/in.Si.lammps old mode 100755 new mode 100644 similarity index 77% rename from data/si_diffusion_1x1x1/in.si.lammps rename to data/Si_diffusion_3x3x3/in.Si.lammps index 17f20e42..3ad49932 --- a/data/si_diffusion_1x1x1/in.si.lammps +++ b/data/Si_diffusion_3x3x3/in.Si.lammps @@ -14,11 +14,13 @@ mass 1 28.0855 group Si type 1 pair_style sw -pair_coeff * * ../../si.sw Si +pair_coeff * * ${SW_PATH} Si + velocity all create ${T} ${SEED} -dump 1 all yaml 1 dump.si-${T}-${S}.yaml id type x y z fx fy fz +dump dump_id all yaml 1 dump.${T}-${S}.yaml id element x y z fx fy fz +dump_modify dump_id element Si thermo_style yaml thermo 1 diff --git a/data/data_generation_functions.sh b/data/data_generation_functions.sh new file mode 100644 index 00000000..b2b66f14 --- /dev/null +++ b/data/data_generation_functions.sh @@ -0,0 +1,56 @@ +#!/bin/bash + +function create_data_function() { + # this function drives the creation training and validation data with LAMMPS. + # It assumes : + # - the function is sourced in a bash script (the "calling script") within the folder where the data is to be created. + # - the calling script is invoked in a shell with the correct python environment. + # - the LAMMPS input file follows a template and has all the passed variables defined. + # - the paths are defined with respect to the folder where the generation script is called. + + TEMPERATURE="$1" + BOX_SIZE="$2" + STEP="$3" + CROP="$4" + NTRAIN_RUN="$5" + NVALID_RUN="$6" + SW_PATH="$7" + IN_PATH="$8" + CONFIG_PATH="$9" + + NRUN=$(($NTRAIN_RUN + $NVALID_RUN)) + + # Generate the data + for SEED in $(seq 1 $NRUN); do + if [ "$SEED" -le $NTRAIN_RUN ]; then + MODE="train" + else + MODE="valid" + fi + echo "Creating LAMMPS data for ${MODE}_run_${SEED}..." + mkdir -p "${MODE}_run_${SEED}" + cd "${MODE}_run_${SEED}" + + # Calling LAMMPS with various arguments to keep it quiet. Also, the current location is "${MODE}_run_${SEED}", which is one + # folder away from the location of the calling script. + lmp -echo none -screen none < ../$IN_PATH -v STEP $(($STEP + $CROP)) -v T $TEMPERATURE -v S $BOX_SIZE -v SEED $SEED -v SW_PATH ../$SW_PATH + + # extract the thermodynamic outputs in a yaml file + egrep '^(keywords:|data:$|---$|\.\.\.$| - \[)' log.lammps > thermo_log.yaml + + mkdir -p "uncropped_outputs" + mv "dump.${TEMPERATURE}-${BOX_SIZE}.yaml" uncropped_outputs/ + mv thermo_log.yaml uncropped_outputs/ + + python ../../crop_lammps_outputs.py \ + --lammps_yaml "uncropped_outputs/dump.${TEMPERATURE}-${BOX_SIZE}.yaml" \ + --lammps_thermo "uncropped_outputs/thermo_log.yaml" \ + --crop $CROP \ + --output_dir ./ + + cd .. + done + + # process the data + python ../process_lammps_data.py --data "./" --processed_datadir "./processed/" --config ${CONFIG_PATH} +} diff --git a/data/lammps_input_example.lammps b/data/lammps_input_example.lammps deleted file mode 100755 index c2f77445..00000000 --- a/data/lammps_input_example.lammps +++ /dev/null @@ -1,31 +0,0 @@ -log log.si-${T}-${S}.lammps - -units metal -atom_style atomic -atom_modify map array - -lattice diamond 5.43 -region simbox block 0 ${S} 0 ${S} 0 ${S} -create_box 1 simbox -create_atoms 1 region simbox - -#read_dump ${DUMP} ${STEP} x y z vx vy vz fx fy fz box yes replace no purge yes add yes - -mass 1 28.0855 - -group Si type 1 - -pair_style sw -pair_coeff * * si.sw Si - -velocity all create ${T} 62177 - -dump 1 all yaml 1 dump.si-${T}-${S}.yaml id type x y z fx fy fz - -thermo_style yaml -thermo 1 -#==========================Output files======================== - -fix 1 all nvt temp ${T} ${T} 0.01 -run ${STEP} -unfix 1 diff --git a/data/parse_lammps.sh b/data/parse_lammps.sh deleted file mode 100755 index bcbb2079..00000000 --- a/data/parse_lammps.sh +++ /dev/null @@ -1,11 +0,0 @@ -#!/bin/bash - -EXP_DIR="lammps_scripts/Si/si-custom/" -DUMP_FILENAME="dump.si-300-1.yaml" -THERMO_FILENAME="thermo_log.yaml" -OUTPUT_NAME="demo.parquet" - -python crystal_diffusion/data/parse_lammps_outputs.py \ - --dump_file ${EXP_DIR}/${DUMP_FILENAME} \ - --thermo_file ${EXP_DIR}/${THERMO_FILENAME} \ - --output_name ${EXP_DIR}/${OUTPUT_NAME} diff --git a/data/process_lammps_data.py b/data/process_lammps_data.py index a55f0beb..12713b3e 100644 --- a/data/process_lammps_data.py +++ b/data/process_lammps_data.py @@ -1,10 +1,16 @@ """Create the processed data.""" import argparse +import logging import tempfile -from crystal_diffusion.data.diffusion.data_loader import ( +from diffusion_for_multi_scale_molecular_dynamics.data.diffusion.data_loader import ( LammpsForDiffusionDataModule, LammpsLoaderParameters) -from crystal_diffusion.utils.logging_utils import setup_analysis_logger +from diffusion_for_multi_scale_molecular_dynamics.utils.logging_utils import \ + setup_analysis_logger +from diffusion_for_multi_scale_molecular_dynamics.utils.main_utils import \ + _get_hyperparameters + +logger = logging.getLogger(__name__) def main(): @@ -12,12 +18,20 @@ def main(): parser = argparse.ArgumentParser() parser.add_argument('--data', help='path to a LAMMPS data set', required=True) parser.add_argument('--processed_datadir', help='path to the processed data directory', required=True) - parser.add_argument('--max_atom', help='maximum number of atoms', required=True) + parser.add_argument("--config", help="config file with dataloader hyper-parameters, such as " + "batch_size, elements, ... - in yaml format") args = parser.parse_args() lammps_run_dir = args.data processed_dataset_dir = args.processed_datadir - data_params = LammpsLoaderParameters(batch_size=128, num_workers=0, max_atom=int(args.max_atom)) + hyper_params = _get_hyperparameters(config_file_path=args.config) + + logger.info("Starting process_lammps_data.py script with arguments") + logger.info(f" --data : {args.data}") + logger.info(f" --processed_datadir : {args.processed_datadir}") + logger.info(f" --config: {args.config}") + + data_params = LammpsLoaderParameters(**hyper_params) with tempfile.TemporaryDirectory() as tmp_work_dir: data_module = LammpsForDiffusionDataModule(lammps_run_dir=lammps_run_dir, diff --git a/data/run_lammps_example.sh b/data/run_lammps_example.sh deleted file mode 100644 index 60ea1792..00000000 --- a/data/run_lammps_example.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/bash - -TEMPERATURE=300 -BOX_SIZE=1 - -lmp < lammps_input_example.lammps -v STEP 10 -v T $TEMPERATURE -v S $BOX_SIZE - -# extract the thermodynamic outputs in a yaml file -egrep '^(keywords:|data:$|---$|\.\.\.$| - \[)' log.lammps > log.yaml diff --git a/data/si_diffusion_1x1x1/create_data.sh b/data/si_diffusion_1x1x1/create_data.sh deleted file mode 100755 index 881bce76..00000000 --- a/data/si_diffusion_1x1x1/create_data.sh +++ /dev/null @@ -1,42 +0,0 @@ -#!/bin/bash - -TEMPERATURE=300 -BOX_SIZE=1 -MAX_ATOM=8 -STEP=10000 -CROP=10000 -NTRAIN_RUN=10 -NVALID_RUN=5 - -NRUN=$(($NTRAIN_RUN + $NVALID_RUN)) - -# Generate the data -for SEED in $(seq 1 $NRUN); do - if [ "$SEED" -le $NTRAIN_RUN ]; then - MODE="train" - else - MODE="valid" - fi - echo "Creating LAMMPS data for ${MODE}_run_${SEED}..." - mkdir -p "${MODE}_run_${SEED}" - cd "${MODE}_run_${SEED}" - lmp -echo none -screen none < ../in.si.lammps -v STEP $(($STEP + $CROP)) -v T $TEMPERATURE -v S $BOX_SIZE -v SEED $SEED - - # extract the thermodynamic outputs in a yaml file - egrep '^(keywords:|data:$|---$|\.\.\.$| - \[)' log.lammps > thermo_log.yaml - - mkdir -p "uncropped_outputs" - mv "dump.si-${TEMPERATURE}-${BOX_SIZE}.yaml" uncropped_outputs/ - mv thermo_log.yaml uncropped_outputs/ - - python ../../crop_lammps_outputs.py \ - --lammps_yaml "uncropped_outputs/dump.si-${TEMPERATURE}-${BOX_SIZE}.yaml" \ - --lammps_thermo "uncropped_outputs/thermo_log.yaml" \ - --crop $CROP \ - --output_dir ./ - - cd .. -done - -# process the data -python ../process_lammps_data.py --data "./" --processed_datadir "./processed/" --max_atom ${MAX_ATOM} diff --git a/data/si_diffusion_1x1x1_large/create_data.sh b/data/si_diffusion_1x1x1_large/create_data.sh deleted file mode 100755 index f6f4f105..00000000 --- a/data/si_diffusion_1x1x1_large/create_data.sh +++ /dev/null @@ -1,43 +0,0 @@ -#!/bin/bash - -TEMPERATURE=300 -BOX_SIZE=1 -MAX_ATOM=8 -STEP=10000 -CROP=10000 -NTRAIN_RUN=10 -NVALID_RUN=5 -NTRAIN_RUN_EXTRA=40 - -NRUN=$(($NTRAIN_RUN + $NVALID_RUN + $NTRAIN_RUN_EXTRA)) - -# Generate the data -for SEED in $(seq 1 $NRUN); do - if [ "$SEED" -le $NTRAIN_RUN ]; then - MODE="train" - elif [ "$SEED" -le $(($NTRAIN_RUN + $NVALID_RUN)) ]; then - MODE="valid" - else - MODE="train" - fi - echo "Creating LAMMPS data for ${MODE}_run_${SEED}..." - mkdir -p "${MODE}_run_${SEED}" - cd "${MODE}_run_${SEED}" - lmp -echo none -screen none < ../in.si.lammps -v STEP $(($STEP + $CROP)) -v T $TEMPERATURE -v S $BOX_SIZE -v SEED $SEED - - # extract the thermodynamic outputs in a yaml file - egrep '^(keywords:|data:$|---$|\.\.\.$| - \[)' log.lammps > thermo_log.yaml - - mkdir -p "uncropped_outputs" - mv "dump.si-${TEMPERATURE}-${BOX_SIZE}.yaml" uncropped_outputs/ - mv thermo_log.yaml uncropped_outputs/ - - FILE_LENGTH=$((20 * $STEP)) - tail -n $FILE_LENGTH "uncropped_outputs/dump.si-${TEMPERATURE}-${BOX_SIZE}.yaml" > "lammps_dump.yaml" - { sed -n '2,3p' uncropped_outputs/thermo_log.yaml; tail -n $(($STEP + 1)) uncropped_outputs/thermo_log.yaml | - head -n $STEP; } > lammps_thermo.yaml - cd .. -done - -# process the data -python ../process_lammps_data.py --data "./" --processed_datadir "./processed/" --max_atom ${MAX_ATOM} diff --git a/data/si_diffusion_1x1x1_single_example/create_data.sh b/data/si_diffusion_1x1x1_single_example/create_data.sh deleted file mode 100755 index e4ea82f7..00000000 --- a/data/si_diffusion_1x1x1_single_example/create_data.sh +++ /dev/null @@ -1,46 +0,0 @@ -#!/bin/bash -#================================================================================ -# This script creates a 'fake' dataset composed of a single example repeated -# multiple times. -#================================================================================ - -TEMPERATURE=0 -BOX_SIZE=1 -MAX_ATOM=8 -STEP=4048 -CROP=1 # Crop 1 to make sure there is exactly 4048 examples in the final dataset. -NTRAIN_RUN=1 -NVALID_RUN=1 - -NRUN=$(($NTRAIN_RUN + $NVALID_RUN)) - -# Generate the data -for SEED in $(seq 1 $NRUN); do - if [ "$SEED" -le $NTRAIN_RUN ]; then - MODE="train" - else - MODE="valid" - fi - echo "Creating LAMMPS data for ${MODE}_run_${SEED}..." - mkdir -p "${MODE}_run_${SEED}" - cd "${MODE}_run_${SEED}" - lmp -echo none -screen none < ../in.si.lammps -v STEP $STEP -v S $BOX_SIZE -v T $TEMPERATURE - - # extract the thermodynamic outputs in a yaml file - egrep '^(keywords:|data:$|---$|\.\.\.$| - \[)' log.lammps > thermo_log.yaml - - mkdir -p "uncropped_outputs" - mv "dump.si-${TEMPERATURE}-${BOX_SIZE}.yaml" uncropped_outputs/ - mv thermo_log.yaml uncropped_outputs/ - - python ../../crop_lammps_outputs.py \ - --lammps_yaml "uncropped_outputs/dump.si-${TEMPERATURE}-${BOX_SIZE}.yaml" \ - --lammps_thermo "uncropped_outputs/thermo_log.yaml" \ - --crop $CROP \ - --output_dir ./ - - cd .. -done - -# process the data -python ../process_lammps_data.py --data "./" --processed_datadir "./processed/" --max_atom ${MAX_ATOM} diff --git a/data/si_diffusion_1x1x1_single_example/in.si.lammps b/data/si_diffusion_1x1x1_single_example/in.si.lammps deleted file mode 100755 index 4941cb17..00000000 --- a/data/si_diffusion_1x1x1_single_example/in.si.lammps +++ /dev/null @@ -1,29 +0,0 @@ -# This configuration file creates the SAME EQUILIBRIUM POSITIONS multiple times. This is for debugging. -log log.lammps - -units metal -atom_style atomic -atom_modify map array - -lattice diamond 5.43 -region simbox block 0 ${S} 0 ${S} 0 ${S} -create_box 1 simbox -create_atoms 1 region simbox - -mass 1 28.0855 - -group Si type 1 - -pair_style sw -pair_coeff * * ../../si.sw Si - -dump 1 all yaml 1 dump.si-${T}-${S}.yaml id type x y z fx fy fz - - - -thermo_style yaml -thermo 1 - -#==========================Output files======================== - -run ${STEP} diff --git a/data/si_diffusion_2x2x2/create_data.sh b/data/si_diffusion_2x2x2/create_data.sh deleted file mode 100755 index b859aab2..00000000 --- a/data/si_diffusion_2x2x2/create_data.sh +++ /dev/null @@ -1,42 +0,0 @@ -#!/bin/bash - -TEMPERATURE=300 -BOX_SIZE=2 -MAX_ATOM=64 -STEP=10000 -CROP=10000 -NTRAIN_RUN=10 -NVALID_RUN=5 - -NRUN=$(($NTRAIN_RUN + $NVALID_RUN)) - -# Generate the data -for SEED in $(seq 1 $NRUN); do - if [ "$SEED" -le $NTRAIN_RUN ]; then - MODE="train" - else - MODE="valid" - fi - echo "Creating LAMMPS data for ${MODE}_run_${SEED}..." - mkdir -p "${MODE}_run_${SEED}" - cd "${MODE}_run_${SEED}" - lmp -echo none -screen none < ../in.si.lammps -v STEP $(($STEP + $CROP)) -v T $TEMPERATURE -v S $BOX_SIZE -v SEED $SEED - - # extract the thermodynamic outputs in a yaml file - egrep '^(keywords:|data:$|---$|\.\.\.$| - \[)' log.lammps > thermo_log.yaml - - mkdir -p "uncropped_outputs" - mv "dump.si-${TEMPERATURE}-${BOX_SIZE}.yaml" uncropped_outputs/ - mv thermo_log.yaml uncropped_outputs/ - - python ../../crop_lammps_outputs.py \ - --lammps_yaml "uncropped_outputs/dump.si-${TEMPERATURE}-${BOX_SIZE}.yaml" \ - --lammps_thermo "uncropped_outputs/thermo_log.yaml" \ - --crop $CROP \ - --output_dir ./ - - cd .. -done - -# process the data -python ../process_lammps_data.py --data "./" --processed_datadir "./processed/" --max_atom ${MAX_ATOM} diff --git a/data/si_diffusion_2x2x2_single_example/create_data.sh b/data/si_diffusion_2x2x2_single_example/create_data.sh deleted file mode 100755 index ab391fa8..00000000 --- a/data/si_diffusion_2x2x2_single_example/create_data.sh +++ /dev/null @@ -1,46 +0,0 @@ -#!/bin/bash -#================================================================================ -# This script creates a 'fake' dataset composed of a single example repeated -# multiple times. -#================================================================================ - -TEMPERATURE=0 -BOX_SIZE=2 -MAX_ATOM=64 -STEP=4048 -CROP=1 # Crop 1 to make sure there is exactly 4048 examples in the final dataset. -NTRAIN_RUN=1 -NVALID_RUN=1 - -NRUN=$(($NTRAIN_RUN + $NVALID_RUN)) - -# Generate the data -for SEED in $(seq 1 $NRUN); do - if [ "$SEED" -le $NTRAIN_RUN ]; then - MODE="train" - else - MODE="valid" - fi - echo "Creating LAMMPS data for ${MODE}_run_${SEED}..." - mkdir -p "${MODE}_run_${SEED}" - cd "${MODE}_run_${SEED}" - lmp -echo none -screen none < ../in.si.lammps -v STEP $STEP -v S $BOX_SIZE -v T $TEMPERATURE - - # extract the thermodynamic outputs in a yaml file - egrep '^(keywords:|data:$|---$|\.\.\.$| - \[)' log.lammps > thermo_log.yaml - - mkdir -p "uncropped_outputs" - mv "dump.si-${TEMPERATURE}-${BOX_SIZE}.yaml" uncropped_outputs/ - mv thermo_log.yaml uncropped_outputs/ - - python ../../crop_lammps_outputs.py \ - --lammps_yaml "uncropped_outputs/dump.si-${TEMPERATURE}-${BOX_SIZE}.yaml" \ - --lammps_thermo "uncropped_outputs/thermo_log.yaml" \ - --crop $CROP \ - --output_dir ./ - - cd .. -done - -# process the data -python ../process_lammps_data.py --data "./" --processed_datadir "./processed/" --max_atom ${MAX_ATOM} diff --git a/data/si_diffusion_2x2x2_single_example/in.si.lammps b/data/si_diffusion_2x2x2_single_example/in.si.lammps deleted file mode 100755 index 4941cb17..00000000 --- a/data/si_diffusion_2x2x2_single_example/in.si.lammps +++ /dev/null @@ -1,29 +0,0 @@ -# This configuration file creates the SAME EQUILIBRIUM POSITIONS multiple times. This is for debugging. -log log.lammps - -units metal -atom_style atomic -atom_modify map array - -lattice diamond 5.43 -region simbox block 0 ${S} 0 ${S} 0 ${S} -create_box 1 simbox -create_atoms 1 region simbox - -mass 1 28.0855 - -group Si type 1 - -pair_style sw -pair_coeff * * ../../si.sw Si - -dump 1 all yaml 1 dump.si-${T}-${S}.yaml id type x y z fx fy fz - - - -thermo_style yaml -thermo 1 - -#==========================Output files======================== - -run ${STEP} diff --git a/data/si_diffusion_3x3x3/create_data.sh b/data/si_diffusion_3x3x3/create_data.sh deleted file mode 100755 index 56277b71..00000000 --- a/data/si_diffusion_3x3x3/create_data.sh +++ /dev/null @@ -1,42 +0,0 @@ -#!/bin/bash - -TEMPERATURE=300 -BOX_SIZE=3 -MAX_ATOM=216 -STEP=10000 -CROP=10000 -NTRAIN_RUN=10 -NVALID_RUN=5 - -NRUN=$(($NTRAIN_RUN + $NVALID_RUN)) - -# Generate the data -for SEED in $(seq 1 $NRUN); do - if [ "$SEED" -le $NTRAIN_RUN ]; then - MODE="train" - else - MODE="valid" - fi - echo "Creating LAMMPS data for ${MODE}_run_${SEED}..." - mkdir -p "${MODE}_run_${SEED}" - cd "${MODE}_run_${SEED}" - lmp -echo none -screen none < ../in.si.lammps -v STEP $(($STEP + $CROP)) -v T $TEMPERATURE -v S $BOX_SIZE -v SEED $SEED - - # extract the thermodynamic outputs in a yaml file - egrep '^(keywords:|data:$|---$|\.\.\.$| - \[)' log.lammps > thermo_log.yaml - - mkdir -p "uncropped_outputs" - mv "dump.si-${TEMPERATURE}-${BOX_SIZE}.yaml" uncropped_outputs/ - mv thermo_log.yaml uncropped_outputs/ - - python ../../crop_lammps_outputs.py \ - --lammps_yaml "uncropped_outputs/dump.si-${TEMPERATURE}-${BOX_SIZE}.yaml" \ - --lammps_thermo "uncropped_outputs/thermo_log.yaml" \ - --crop $CROP \ - --output_dir ./ - - cd .. -done - -# process the data -python ../process_lammps_data.py --data "./" --processed_datadir "./processed/" --max_atom ${MAX_ATOM} diff --git a/data/si_diffusion_3x3x3/in.si.lammps b/data/si_diffusion_3x3x3/in.si.lammps deleted file mode 100755 index 17f20e42..00000000 --- a/data/si_diffusion_3x3x3/in.si.lammps +++ /dev/null @@ -1,29 +0,0 @@ -log log.lammps - -units metal -atom_style atomic -atom_modify map array - -lattice diamond 5.43 -region simbox block 0 ${S} 0 ${S} 0 ${S} -create_box 1 simbox -create_atoms 1 region simbox - -mass 1 28.0855 - -group Si type 1 - -pair_style sw -pair_coeff * * ../../si.sw Si - -velocity all create ${T} ${SEED} - -dump 1 all yaml 1 dump.si-${T}-${S}.yaml id type x y z fx fy fz - -thermo_style yaml -thermo 1 -#==========================Output files======================== - -fix 1 all nvt temp ${T} ${T} 0.01 -run ${STEP} -unfix 1 diff --git a/data/si.sw b/data/stillinger_weber_coefficients/Si.sw old mode 100755 new mode 100644 similarity index 100% rename from data/si.sw rename to data/stillinger_weber_coefficients/Si.sw diff --git a/data/stillinger_weber_coefficients/SiGe.sw b/data/stillinger_weber_coefficients/SiGe.sw new file mode 100644 index 00000000..0a0176e0 --- /dev/null +++ b/data/stillinger_weber_coefficients/SiGe.sw @@ -0,0 +1,14 @@ + +# v2: Epitaxial growth of Si1−xGex on Si(100)2 × 1: A molecular-dynamics study +# epsilon, sigma, a, lambda, gamma, costheta0 A, B, p, q, tol +Si Si Si 3.472 2.095 1.80 21.0 1.20 -0.333333333333 7.050 0.6022 4.0 0.0 0.0 +Ge Ge Ge 3.085 2.181 1.80 31.0 1.20 -0.333333333333 7.050 0.6022 4.0 0.0 0.0 + +Si Ge Ge 3.273 2.138 1.80 25.5 1.20 -0.333333333333 7.050 0.6022 4.0 0.0 0.0 +Ge Si Si 3.273 2.138 1.80 25.5 1.20 -0.333333333333 7.050 0.6022 4.0 0.0 0.0 + +Si Ge Si 3.371 2.138 1.80 23.1 1.20 -0.333333333333 7.050 0.6022 4.0 0.0 0.0 +Si Si Ge 3.371 2.138 1.80 23.1 1.20 -0.333333333333 7.050 0.6022 4.0 0.0 0.0 + +Ge Si Ge 3.178 2.138 1.80 28.1 1.20 -0.333333333333 7.050 0.6022 4.0 0.0 0.0 +Ge Ge Si 3.178 2.138 1.80 28.1 1.20 -0.333333333333 7.050 0.6022 4.0 0.0 0.0 diff --git a/examples/config_files/diffusion/config_diffusion_egnn.yaml b/examples/config_files/diffusion/config_diffusion_egnn.yaml index 4d04b0b4..b53931f5 100644 --- a/examples/config_files/diffusion/config_diffusion_egnn.yaml +++ b/examples/config_files/diffusion/config_diffusion_egnn.yaml @@ -10,17 +10,22 @@ accumulate_grad_batches: 1 # make this number of forward passes before doing a # results will not be reproducible) seed: 1234 +elements: [Si] + # data data: batch_size: 128 num_workers: 8 - max_atom: 64 + max_atom: 8 # architecture spatial_dimension: 3 model: + loss: + coordinates_algorithm: mse score_network: architecture: egnn + num_atom_types: 1 n_layers: 4 coordinate_hidden_dimensions_size: 128 coordinate_n_hidden_dimensions: 4 @@ -35,7 +40,7 @@ model: tanh: False edges: fully_connected noise: - total_time_steps: 1000 + total_time_steps: 100 sigma_min: 0.0001 sigma_max: 0.2 corrector_step_epsilon: 2.0e-7 @@ -65,23 +70,24 @@ model_checkpoint: # Sampling from the generative model diffusion_sampling: noise: - total_time_steps: 1000 + total_time_steps: 100 sigma_min: 0.0001 sigma_max: 0.2 corrector_step_epsilon: 2.0e-7 sampling: algorithm: predictor_corrector + num_atom_types: 1 sample_batchsize: 128 spatial_dimension: 3 number_of_corrector_steps: 1 - number_of_atoms: 64 + number_of_atoms: 8 number_of_samples: 32 record_samples: False - cell_dimensions: [10.86, 10.86, 10.86] + cell_dimensions: [5.43, 5.43, 5.43] metrics: compute_energies: True compute_structure_factor: True - structure_factor_max_distance: 10.0 + structure_factor_max_distance: 5.0 sampling_visualization: record_every_n_epochs: 1 @@ -90,6 +96,10 @@ sampling_visualization: record_energies: True record_structure: True +oracle: + name: lammps + sw_coeff_filename: Si.sw + logging: - comet diff --git a/examples/config_files/diffusion/config_diffusion_mace.yaml b/examples/config_files/diffusion/config_diffusion_mace.yaml index 92e3f784..c1199cd0 100644 --- a/examples/config_files/diffusion/config_diffusion_mace.yaml +++ b/examples/config_files/diffusion/config_diffusion_mace.yaml @@ -10,6 +10,8 @@ accumulate_grad_batches: 1 # make this number of forward passes before doing a # results will not be reproducible) seed: 1234 +elements: [Si] + # data data: batch_size: 512 @@ -20,9 +22,10 @@ data: spatial_dimension: 3 model: loss: - algorithm: mse + coordinates_algorithm: mse score_network: architecture: diffusion_mace + num_atom_types: 1 number_of_atoms: 8 r_max: 5.0 num_bessel: 8 @@ -79,7 +82,7 @@ diffusion_sampling: sigma_min: 0.001 # default value sigma_max: 0.5 # default value sampling: - algorithm: ode + algorithm: predictor_corrector spatial_dimension: 3 number_of_atoms: 8 number_of_samples: 16 @@ -87,6 +90,10 @@ diffusion_sampling: record_samples: True cell_dimensions: [5.43, 5.43, 5.43] +oracle: + name: lammps + sw_coeff_filename: Si.sw + logging: # - csv - tensorboard diff --git a/examples/config_files/diffusion/config_diffusion_mace_orion.yaml b/examples/config_files/diffusion/config_diffusion_mace_orion.yaml index a1ec43c0..75fd612d 100644 --- a/examples/config_files/diffusion/config_diffusion_mace_orion.yaml +++ b/examples/config_files/diffusion/config_diffusion_mace_orion.yaml @@ -10,6 +10,8 @@ accumulate_grad_batches: 1 # make this number of forward passes before doing a # results will not be reproducible) seed: 1234 +elements: [Si] + # data data: batch_size: 512 @@ -20,9 +22,10 @@ data: spatial_dimension: 3 model: loss: - algorithm: mse + coordinates_algorithm: mse score_network: architecture: diffusion_mace + num_atom_types: 1 number_of_atoms: 8 r_max: 5.0 num_bessel: 'orion~choices([128, 256, 512])' @@ -79,6 +82,7 @@ diffusion_sampling: sigma_min: 0.001 # default value sigma_max: 0.5 # default value sampling: + num_atom_types: 1 spatial_dimension: 3 number_of_corrector_steps: 1 number_of_atoms: 8 diff --git a/examples/config_files/diffusion/config_diffusion_mlp.yaml b/examples/config_files/diffusion/config_diffusion_mlp.yaml index 3fc18e24..fee1a6cd 100644 --- a/examples/config_files/diffusion/config_diffusion_mlp.yaml +++ b/examples/config_files/diffusion/config_diffusion_mlp.yaml @@ -10,6 +10,8 @@ accumulate_grad_batches: 1 # make this number of forward passes before doing a # results will not be reproducible) seed: 1234 +elements: ["Si"] + # data data: batch_size: 1024 @@ -20,12 +22,14 @@ data: spatial_dimension: 3 model: loss: - algorithm: mse + coordinates_algorithm: mse score_network: architecture: mlp + num_atom_types: 1 number_of_atoms: 8 n_hidden_dimensions: 2 - embedding_dimensions_size: 16 + noise_embedding_dimensions_size: 16 + atom_type_embedding_dimensions_size: 16 hidden_dimensions_size: 64 conditional_prob: 0.0 conditional_gamma: 2 @@ -43,6 +47,7 @@ diffusion_sampling: sigma_max: 0.1 sampling: algorithm: predictor_corrector + num_atom_types: 1 spatial_dimension: 3 number_of_atoms: 8 number_of_samples: 16 @@ -88,6 +93,11 @@ loss_monitoring: number_of_bins: 50 sample_every_n_epochs: 25 +oracle: + name: lammps + sw_coeff_filename: Si.sw + + logging: # - comet - tensorboard diff --git a/examples/config_files/diffusion/config_diffusion_mlp_orion.yaml b/examples/config_files/diffusion/config_diffusion_mlp_orion.yaml index c1b7d82e..29613838 100644 --- a/examples/config_files/diffusion/config_diffusion_mlp_orion.yaml +++ b/examples/config_files/diffusion/config_diffusion_mlp_orion.yaml @@ -10,6 +10,8 @@ accumulate_grad_batches: 1 # make this number of forward passes before doing a # results will not be reproducible) seed: 1234 +elements: [Si] + # data data: batch_size: 1024 @@ -20,9 +22,10 @@ data: spatial_dimension: 3 model: loss: - algorithm: mse + coordinates_algorithm: mse score_network: architecture: mlp + num_atom_types: 1 number_of_atoms: 8 n_hidden_dimensions: 'orion~choices([1, 2, 3, 4])' hidden_dimensions_size: 'orion~choices([16, 32, 64])' @@ -67,7 +70,9 @@ diffusion_sampling: sigma_min: 0.001 # default value sigma_max: 0.5 # default value sampling: + algorithm: predictor_corrector spatial_dimension: 3 + num_atom_types: 1 number_of_corrector_steps: 1 number_of_atoms: 8 number_of_samples: 16 diff --git a/examples/config_files/diffusion/config_mace_equivariant_head.yaml b/examples/config_files/diffusion/config_mace_equivariant_head.yaml index 9d2bb7a1..10f71129 100644 --- a/examples/config_files/diffusion/config_mace_equivariant_head.yaml +++ b/examples/config_files/diffusion/config_mace_equivariant_head.yaml @@ -9,6 +9,8 @@ accumulate_grad_batches: 1 # make this number of forward passes before doing a # results will not be reproducible) seed: 1234 +elements: [Si] + # data data: batch_size: 1024 @@ -19,10 +21,11 @@ data: spatial_dimension: 3 model: loss: - algorithm: mse + coordinates_algorithm: mse score_network: architecture: mace - number_of_atoms: 8 + num_atom_types: 1 + number_of_atoms: 8 r_max: 5.0 num_bessel: 8 num_polynomial_cutoff: 5 @@ -76,6 +79,7 @@ diffusion_sampling: sigma_max: 0.5 # default value sampling: algorithm: predictor_corrector + num_atom_types: 1 spatial_dimension: 3 number_of_corrector_steps: 1 number_of_atoms: 8 diff --git a/examples/config_files/diffusion/config_mace_mlp_head.yaml b/examples/config_files/diffusion/config_mace_mlp_head.yaml index c235edf9..7add8acb 100644 --- a/examples/config_files/diffusion/config_mace_mlp_head.yaml +++ b/examples/config_files/diffusion/config_mace_mlp_head.yaml @@ -9,6 +9,8 @@ accumulate_grad_batches: 1 # make this number of forward passes before doing a # results will not be reproducible) seed: 1234 +elements: [Si] + # data data: batch_size: 512 @@ -19,9 +21,10 @@ data: spatial_dimension: 3 model: loss: - algorithm: mse + coordinates_algorithm: mse score_network: architecture: mace + num_atom_types: 1 use_pretrained: None pretrained_weights_path: ./ number_of_atoms: 8 @@ -77,6 +80,7 @@ diffusion_sampling: sigma_max: 0.5 # default value sampling: algorithm: predictor_corrector + num_atom_types: 1 spatial_dimension: 3 number_of_corrector_steps: 1 number_of_atoms: 8 diff --git a/examples/drawing_samples/draw_samples.py b/examples/drawing_samples/draw_samples.py deleted file mode 100644 index ab0bf130..00000000 --- a/examples/drawing_samples/draw_samples.py +++ /dev/null @@ -1,95 +0,0 @@ -"""Draw Samples. - -This script draws samples from a checkpoint. - -THIS SCRIPT IS AN EXAMPLE. IT SHOULD BE MODIFIED DEPENDING ON USER PREFERENCES. -""" -import logging -from pathlib import Path - -import numpy as np -import torch -from crystal_diffusion.generators.instantiate_generator import \ - instantiate_generator -from crystal_diffusion.generators.predictor_corrector_position_generator import \ - PredictorCorrectorSamplingParameters -from crystal_diffusion.models.position_diffusion_lightning_model import \ - PositionDiffusionLightningModel -from crystal_diffusion.oracle.energies import compute_oracle_energies -from crystal_diffusion.utils.logging_utils import setup_analysis_logger -from src.crystal_diffusion.samplers.variance_sampler import NoiseParameters -from src.crystal_diffusion.samples.sampling import create_batch_of_samples - -logger = logging.getLogger(__name__) -setup_analysis_logger() - -checkpoint_path = ("/network/scratch/r/rousseab/experiments/sept21_egnn_2x2x2/run4/" - "output/best_model/best_model-epoch=024-step=019550.ckpt") -samples_dir = Path( - "/network/scratch/r/rousseab/experiments/sept21_egnn_2x2x2/run4_samples/samples" -) -samples_dir.mkdir(exist_ok=True) - -device = torch.device("cuda") - - -spatial_dimension = 3 -number_of_atoms = 64 -atom_types = np.ones(number_of_atoms, dtype=int) - -acell = 10.86 -box = np.diag([acell, acell, acell]) - -number_of_samples = 128 -total_time_steps = 1000 -number_of_corrector_steps = 1 - -noise_parameters = NoiseParameters( - total_time_steps=total_time_steps, - corrector_step_epsilon=2e-7, - sigma_min=0.0001, - sigma_max=0.2, -) - -sampling_parameters = PredictorCorrectorSamplingParameters( - number_of_corrector_steps=number_of_corrector_steps, - spatial_dimension=spatial_dimension, - number_of_atoms=number_of_atoms, - number_of_samples=number_of_samples, - cell_dimensions=[acell, acell, acell], - record_samples=True, -) - - -if __name__ == "__main__": - logger.info("Loading checkpoint...") - pl_model = PositionDiffusionLightningModel.load_from_checkpoint(checkpoint_path) - pl_model.eval() - - sigma_normalized_score_network = pl_model.sigma_normalized_score_network - - logger.info("Instantiate generator...") - position_generator = instantiate_generator( - sampling_parameters=sampling_parameters, - noise_parameters=noise_parameters, - sigma_normalized_score_network=sigma_normalized_score_network, - ) - - logger.info("Drawing samples...") - with torch.no_grad(): - samples_batch = create_batch_of_samples( - generator=position_generator, - sampling_parameters=sampling_parameters, - device=device, - ) - - sample_output_path = str(samples_dir / "diffusion_samples.pt") - position_generator.sample_trajectory_recorder.write_to_pickle(sample_output_path) - logger.info("Done Generating Samples") - - logger.info("Compute energy from Oracle") - sample_energies = compute_oracle_energies(samples_batch) - - energy_output_path = str(samples_dir / "diffusion_energies.pt") - with open(energy_output_path, "wb") as fd: - torch.save(sample_energies, fd) diff --git a/examples/local/diffusion/run_diffusion.sh b/examples/local/diffusion/run_diffusion.sh index ceec8b5f..4f96731a 100755 --- a/examples/local/diffusion/run_diffusion.sh +++ b/examples/local/diffusion/run_diffusion.sh @@ -1,16 +1,16 @@ #!/bin/bash -# This example assumes that the dataset 'si_diffusion_small' is present locally in the DATA folder. -# It is also assumed that the user has a Comet account for logging experiments. +# This example assumes that the dataset 'Si_diffusion_1x1x1' is present locally in the DATA folder. CONFIG=../../config_files/diffusion/config_diffusion_mlp.yaml -DATA_DIR=../../../data/si_diffusion_1x1x1 +DATA_DIR=../../../data/Si_diffusion_1x1x1 PROCESSED_DATA=${DATA_DIR}/processed DATA_WORK_DIR=${DATA_DIR}/cache/ OUTPUT=output/run1 -python ../../../crystal_diffusion/train_diffusion.py \ +python ../../../src/diffusion_for_multi_scale_molecular_dynamics/train_diffusion.py \ + --accelerator "cpu" \ --config $CONFIG \ --data $DATA_DIR \ --processed_datadir $PROCESSED_DATA \ diff --git a/experiments/analysis/analytic_score/analytical_score_sampling_and_plotting.py b/experiments/analysis/analytic_score/analytical_score_sampling_and_plotting.py index c28bd3b1..8f0db1a1 100644 --- a/experiments/analysis/analytic_score/analytical_score_sampling_and_plotting.py +++ b/experiments/analysis/analytic_score/analytical_score_sampling_and_plotting.py @@ -23,7 +23,9 @@ PredictorCorrectorSamplingParameters from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.analytical_score_network import ( AnalyticalScoreNetwork, AnalyticalScoreNetworkParameters) -from diffusion_for_multi_scale_molecular_dynamics.samplers.variance_sampler import \ +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.exploding_variance import \ + ExplodingVariance +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ NoiseParameters from diffusion_for_multi_scale_molecular_dynamics.utils.logging_utils import \ setup_analysis_logger @@ -68,6 +70,8 @@ total_time_steps=total_time_steps, sigma_min=0.001, sigma_max=0.5 ) + exploding_variance = ExplodingVariance(noise_parameters) + score_network_parameters = AnalyticalScoreNetworkParameters( number_of_atoms=number_of_atoms, spatial_dimension=spatial_dimension, @@ -127,7 +131,7 @@ # Plot the ODE parameters logger.info("Plotting ODE parameters") times = torch.linspace(0, 1, 1001) - sigmas = position_generator._get_exploding_variance_sigma(times) + sigmas = exploding_variance.get_sigma(times) ode_prefactor = position_generator._get_ode_prefactor(sigmas) fig0 = plt.figure(figsize=PLEASANT_FIG_SIZE) diff --git a/experiments/analysis/analytic_score/exploring_langevin_generator/generate_sample_energies.py b/experiments/analysis/analytic_score/exploring_langevin_generator/generate_sample_energies.py index 2a016f5f..21176e2f 100644 --- a/experiments/analysis/analytic_score/exploring_langevin_generator/generate_sample_energies.py +++ b/experiments/analysis/analytic_score/exploring_langevin_generator/generate_sample_energies.py @@ -15,10 +15,10 @@ PredictorCorrectorSamplingParameters from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.analytical_score_network import ( AnalyticalScoreNetworkParameters, TargetScoreBasedAnalyticalScoreNetwork) +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ + NoiseParameters from diffusion_for_multi_scale_molecular_dynamics.oracle.lammps import \ get_energy_and_forces_from_lammps -from diffusion_for_multi_scale_molecular_dynamics.samplers.variance_sampler import \ - NoiseParameters from diffusion_for_multi_scale_molecular_dynamics.utils.basis_transformations import ( get_positions_from_coordinates, map_relative_coordinates_to_unit_cell) from experiments.analysis.analytic_score.exploring_langevin_generator import \ diff --git a/experiments/analysis/analytic_score/perfect_score_loss_analysis.py b/experiments/analysis/analytic_score/perfect_score_loss_analysis.py index 085e7c3b..2c07f331 100644 --- a/experiments/analysis/analytic_score/perfect_score_loss_analysis.py +++ b/experiments/analysis/analytic_score/perfect_score_loss_analysis.py @@ -1,3 +1,7 @@ +"""Perfect Score Loss Analysis. + +TODO: this file has not been verified after a major refactor. The code below might be broken. +""" import logging import tempfile @@ -15,16 +19,16 @@ PLOT_STYLE_PATH from diffusion_for_multi_scale_molecular_dynamics.callbacks.loss_monitoring_callback import \ LossMonitoringCallback -from diffusion_for_multi_scale_molecular_dynamics.callbacks.sampling_visualization_callback import \ - PredictorCorrectorDiffusionSamplingCallback from diffusion_for_multi_scale_molecular_dynamics.generators.predictor_corrector_position_generator import \ PredictorCorrectorSamplingParameters -from diffusion_for_multi_scale_molecular_dynamics.models.loss import ( - MSELossParameters, create_loss_calculator) +from diffusion_for_multi_scale_molecular_dynamics.loss import \ + create_loss_calculator +from diffusion_for_multi_scale_molecular_dynamics.loss.loss_parameters import \ + MSELossParameters +from diffusion_for_multi_scale_molecular_dynamics.models.axl_diffusion_lightning_model import ( + AXLDiffusionLightningModel, AXLDiffusionParameters) from diffusion_for_multi_scale_molecular_dynamics.models.optimizer import \ OptimizerParameters -from diffusion_for_multi_scale_molecular_dynamics.models.position_diffusion_lightning_model import ( - PositionDiffusionLightningModel, PositionDiffusionParameters) from diffusion_for_multi_scale_molecular_dynamics.models.scheduler import \ CosineAnnealingLRSchedulerParameters from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.analytical_score_network import ( @@ -32,12 +36,14 @@ TargetScoreBasedAnalyticalScoreNetwork) from diffusion_for_multi_scale_molecular_dynamics.namespace import ( CARTESIAN_FORCES, RELATIVE_COORDINATES) +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ + NoiseParameters +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_scheduler import \ + NoiseScheduler +from diffusion_for_multi_scale_molecular_dynamics.noisers.relative_coordinates_noiser import \ + RelativeCoordinatesNoiser from diffusion_for_multi_scale_molecular_dynamics.oracle.lammps import \ get_energy_and_forces_from_lammps -from diffusion_for_multi_scale_molecular_dynamics.samplers.noisy_relative_coordinates_sampler import \ - NoisyRelativeCoordinatesSampler -from diffusion_for_multi_scale_molecular_dynamics.samplers.variance_sampler import ( - ExplodingVarianceSampler, NoiseParameters) from diffusion_for_multi_scale_molecular_dynamics.utils.basis_transformations import \ map_relative_coordinates_to_unit_cell from experiments.analysis.analytic_score.utils import (get_exact_samples, @@ -46,14 +52,14 @@ logger = logging.getLogger(__name__) -class AnalyticalScorePositionDiffusionLightningModel(PositionDiffusionLightningModel): +class AnalyticalScorePositionDiffusionLightningModel(AXLDiffusionLightningModel): """Analytical Score Position Diffusion Lightning Model. Overload the base class so that we can properly feed in an analytical score network. This should not be in the main code as the analytical score is not a real model. """ - def __init__(self, hyper_params: PositionDiffusionParameters): + def __init__(self, hyper_params: AXLDiffusionParameters): """Init method. This initializes the class. @@ -79,8 +85,8 @@ def __init__(self, hyper_params: PositionDiffusionParameters): ) self.loss_calculator = create_loss_calculator(hyper_params.loss_parameters) - self.noisy_relative_coordinates_sampler = NoisyRelativeCoordinatesSampler() - self.variance_sampler = ExplodingVarianceSampler(hyper_params.noise_parameters) + self.relative_coordinates_noiser = RelativeCoordinatesNoiser() + self.variance_sampler = NoiseScheduler(hyper_params.noise_parameters, num_classes=2) def on_validation_start(self) -> None: """On validation start.""" @@ -164,12 +170,6 @@ def on_validation_start(self) -> None: record_samples=False, ) - diffusion_sampling_callback = PredictorCorrectorDiffusionSamplingCallback( - noise_parameters=noise_parameters, - sampling_parameters=sampling_parameters, - output_directory=output_dir / experiment_name, - ) - if use_equilibrium: exact_samples = einops.repeat( equilibrium_relative_coordinates, "n d -> b n d", b=dataset_size @@ -216,7 +216,7 @@ def on_validation_start(self) -> None: variance_parameter=model_variance_parameter, ) - diffusion_params = PositionDiffusionParameters( + diffusion_params = AXLDiffusionParameters( score_network_parameters=score_network_parameters, loss_parameters=MSELossParameters(), optimizer_parameters=dummy_optimizer_parameters, @@ -228,7 +228,7 @@ def on_validation_start(self) -> None: model = AnalyticalScorePositionDiffusionLightningModel(diffusion_params) trainer = pl.Trainer( - callbacks=[loss_monitoring_callback, diffusion_sampling_callback], + callbacks=[loss_monitoring_callback], max_epochs=1, log_every_n_steps=1, fast_dev_run=False, diff --git a/experiments/analysis/analytic_score/repaint/repaint_with_analytic_score.py b/experiments/analysis/analytic_score/repaint/repaint_with_analytic_score.py index 1ee22cc2..2ea37960 100644 --- a/experiments/analysis/analytic_score/repaint/repaint_with_analytic_score.py +++ b/experiments/analysis/analytic_score/repaint/repaint_with_analytic_score.py @@ -11,15 +11,14 @@ ConstrainedLangevinGenerator, ConstrainedLangevinGeneratorParameters) from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.analytical_score_network import ( AnalyticalScoreNetwork, AnalyticalScoreNetworkParameters) -from diffusion_for_multi_scale_molecular_dynamics.samplers.variance_sampler import \ +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ NoiseParameters from diffusion_for_multi_scale_molecular_dynamics.utils.logging_utils import \ setup_analysis_logger from diffusion_for_multi_scale_molecular_dynamics.utils.structure_utils import \ create_structure -from experiments.analysis.analytic_score import (get_samples_harmonic_energy, - get_silicon_supercell, - get_unit_cells) +from experiments.analysis.analytic_score.utils import ( + get_samples_harmonic_energy, get_silicon_supercell, get_unit_cells) logger = logging.getLogger(__name__) setup_analysis_logger() diff --git a/experiments/analysis/exploding_variance_analysis.py b/experiments/analysis/exploding_variance_analysis.py index 197df4d6..687d719a 100644 --- a/experiments/analysis/exploding_variance_analysis.py +++ b/experiments/analysis/exploding_variance_analysis.py @@ -10,8 +10,10 @@ from diffusion_for_multi_scale_molecular_dynamics import ANALYSIS_RESULTS_DIR from diffusion_for_multi_scale_molecular_dynamics.analysis import ( PLEASANT_FIG_SIZE, PLOT_STYLE_PATH) -from diffusion_for_multi_scale_molecular_dynamics.samplers.variance_sampler import ( - ExplodingVarianceSampler, NoiseParameters) +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ + NoiseParameters +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_scheduler import \ + ExplodingVarianceSampler from diffusion_for_multi_scale_molecular_dynamics.score.wrapped_gaussian_score import \ get_sigma_normalized_score diff --git a/experiments/analysis/plot_q_matrices.py b/experiments/analysis/plot_q_matrices.py new file mode 100644 index 00000000..253fa691 --- /dev/null +++ b/experiments/analysis/plot_q_matrices.py @@ -0,0 +1,52 @@ +from matplotlib import pyplot as plt + +from diffusion_for_multi_scale_molecular_dynamics.analysis import ( + PLEASANT_FIG_SIZE, PLOT_STYLE_PATH) +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ + NoiseParameters +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_scheduler import \ + NoiseScheduler +from diffusion_for_multi_scale_molecular_dynamics.utils.logging_utils import \ + setup_analysis_logger + +setup_analysis_logger() + +plt.style.use(PLOT_STYLE_PATH) + +num_classes = 3 + +if __name__ == '__main__': + + fig = plt.figure(figsize=PLEASANT_FIG_SIZE) + + fig.suptitle("Transition Probabilities") + ax1 = fig.add_subplot(131) + ax2 = fig.add_subplot(132) + ax3 = fig.add_subplot(133) + + for total_time_steps in [1000, 100, 10]: + noise_parameters = NoiseParameters(total_time_steps=total_time_steps) + sampler = NoiseScheduler(noise_parameters, num_classes=num_classes) + noise, _ = sampler.get_all_sampling_parameters() + times = noise.time + indices = noise.indices + q_matrices = noise.q_matrix + q_bar_matrices = noise.q_bar_matrix + + betas = q_matrices[:, 0, -1] + beta_bars = q_bar_matrices[:, 0, -1] + ratio = beta_bars[:-1] / beta_bars[1:] + ax1.plot(times, betas, label=f'T = {total_time_steps}') + ax2.plot(times, beta_bars, label=f'T = {total_time_steps}') + ax3.plot(times[1:], ratio, label=f'T = {total_time_steps}') + + ax1.set_ylabel(r'$\beta_t$') + ax2.set_ylabel(r'$\bar\beta_{t}$') + ax3.set_ylabel(r'$\frac{\bar\beta_{t-1}}{\bar\beta_{t}}$') + for ax in [ax1, ax2, ax3]: + ax.set_xlabel(r'$\frac{t}{T}$') + ax.legend(loc=0) + ax.set_xlim(times[-1] + 0.1, times[0] - 0.1) + + fig.tight_layout() + plt.show() diff --git a/experiments/atom_types_only_experiments/create_visualization.py b/experiments/atom_types_only_experiments/create_visualization.py new file mode 100644 index 00000000..23d8b98d --- /dev/null +++ b/experiments/atom_types_only_experiments/create_visualization.py @@ -0,0 +1,52 @@ +import numpy as np +import torch + +from diffusion_for_multi_scale_molecular_dynamics import ROOT_DIR +from diffusion_for_multi_scale_molecular_dynamics.analysis.sample_trajectory_analyser import \ + SampleTrajectoryAnalyser +from diffusion_for_multi_scale_molecular_dynamics.data.element_types import \ + ElementTypes +from diffusion_for_multi_scale_molecular_dynamics.namespace import AXL +from diffusion_for_multi_scale_molecular_dynamics.utils.logging_utils import \ + setup_analysis_logger +from diffusion_for_multi_scale_molecular_dynamics.utils.ovito_utils import \ + create_cif_files + +setup_analysis_logger() + +base_path = ROOT_DIR / "../experiments/atom_types_only_experiments/experiments" +data_path = base_path / "output/run1/trajectory_samples" +pickle_path = data_path / "trajectories_sample_epoch=999.pt" +visualization_artifacts_path = data_path / "trajectory_cif_files" + +elements = ["Si", "Ge"] +num_classes = len(elements) + 1 +element_types = ElementTypes(elements) + +trajectory_indices = np.arange(10) + + +if __name__ == "__main__": + + analyser = SampleTrajectoryAnalyser(pickle_path, num_classes) + time_indices, trajectory_axl = analyser.extract_axl("composition_i") + + reverse_order = np.argsort(time_indices)[::-1] + + # Torch can't deal with indices in reverse order + a = trajectory_axl.A + new_a = torch.from_numpy(a.numpy()[:, reverse_order]) + x = trajectory_axl.X + new_x = torch.from_numpy(x.numpy()[:, reverse_order]) + lattice = trajectory_axl.L + new_l = torch.from_numpy(lattice.numpy()[:, reverse_order]) + + reverse_time_order_trajectory_axl = AXL(A=new_a, X=new_x, L=new_l) + + for trajectory_index in trajectory_indices: + create_cif_files( + elements=elements, + visualization_artifacts_path=visualization_artifacts_path, + trajectory_index=trajectory_index, + trajectory_axl_compositions=reverse_time_order_trajectory_axl, + ) diff --git a/experiments/atom_types_only_experiments/experiments/config_mlp.yaml b/experiments/atom_types_only_experiments/experiments/config_mlp.yaml new file mode 100644 index 00000000..f71d921e --- /dev/null +++ b/experiments/atom_types_only_experiments/experiments/config_mlp.yaml @@ -0,0 +1,117 @@ +#================================================================================ +# Configuration file for a diffusion experiment where only atom-types change. +# =========================================================================== +# The data is inspired by SiGe 1x1x1. +# +# It is assumed that this config file will be used in a pseudo-experiment +# where the main code is patched so that only atom types will change. +# +#================================================================================ +exp_name: atom_types_only_PSEUDO +run_name: run1 +max_epoch: 1000 +log_every_n_steps: 1 +gradient_clipping: 0.0 +accumulate_grad_batches: 1 # make this number of forward passes before doing a backprop step + +elements: [Si, Ge] + +# set to null to avoid setting a seed (can speed up GPU computation, but +# results will not be reproducible) +seed: 1234 + +# Data: a fake dataloader will recreate the same example over and over. +data: + batch_size: 1024 # batch size for everyone + train_batch_size: 1024 # overloaded to mean 'size of training dataset' + valid_batch_size: 1024 # overloaded to mean 'size of validation dataset' + num_workers: 0 + max_atom: 8 + +# architecture +spatial_dimension: 3 + +model: + loss: + coordinates_algorithm: mse + atom_types_ce_weight: 1.0 + atom_types_lambda_weight: 1.0 + relative_coordinates_lambda_weight: 0.0 + lattice_lambda_weight: 0.0 + score_network: + architecture: mlp + num_atom_types: 2 + number_of_atoms: 8 + n_hidden_dimensions: 3 + noise_embedding_dimensions_size: 32 + time_embedding_dimensions_size: 32 + atom_type_embedding_dimensions_size: 8 + hidden_dimensions_size: 64 + conditional_prob: 0.0 + conditional_gamma: 2 + condition_embedding_size: 4 + noise: + total_time_steps: 100 + sigma_min: 0.0001 + sigma_max: 0.2 + +# optimizer and scheduler +optimizer: + name: adamw + learning_rate: 0.0001 + weight_decay: 5.0e-8 + + +scheduler: + name: CosineAnnealingLR + T_max: 1000 + eta_min: 0.0 + +# early stopping +early_stopping: + metric: validation_epoch_loss + mode: min + patience: 1000 + +model_checkpoint: + monitor: validation_epoch_loss + mode: min + + +# Sampling from the generative model +diffusion_sampling: + noise: + total_time_steps: 100 + sigma_min: 0.0001 + sigma_max: 0.2 + corrector_step_epsilon: 2.0e-7 + sampling: + algorithm: predictor_corrector + num_atom_types: 2 + number_of_atoms: 8 + sample_batchsize: 10 + spatial_dimension: 3 + number_of_corrector_steps: 0 + one_atom_type_transition_per_step: False + atom_type_greedy_sampling: False + atom_type_transition_in_corrector: False + number_of_samples: 10 + record_samples: True + cell_dimensions: [5.542, 5.542, 5.542] + metrics: + compute_energies: True + compute_structure_factor: False + +sampling_visualization: + record_every_n_epochs: 1 + first_record_epoch: 999 + record_trajectories: True + record_energies: False + record_structure: False + +oracle: + name: lammps + sw_coeff_filename: SiGe.sw + +logging: + - tensorboard diff --git a/experiments/atom_types_only_experiments/experiments/run_diffusion.sh b/experiments/atom_types_only_experiments/experiments/run_diffusion.sh new file mode 100755 index 00000000..0e59b5b7 --- /dev/null +++ b/experiments/atom_types_only_experiments/experiments/run_diffusion.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +export OMP_PATH="/opt/homebrew/opt/libomp/include/" +export PYTORCH_ENABLE_MPS_FALLBACK=1 + +# This example assumes that the dataset 'Si_diffusion_1x1x1' is present locally in the DATA folder. + + +CONFIG=config_mlp.yaml +DATA_DIR=./ +PROCESSED_DATA=${DATA_DIR} +DATA_WORK_DIR=${DATA_DIR} + +OUTPUT=./output/run1 + +python ../pseudo_train_diffusion.py \ + --accelerator "cpu" \ + --config $CONFIG \ + --data $DATA_DIR \ + --processed_datadir $PROCESSED_DATA \ + --dataset_working_dir $DATA_WORK_DIR \ + --output $OUTPUT # > log.txt 2>&1 diff --git a/experiments/atom_types_only_experiments/patches/equilibrium_structure.py b/experiments/atom_types_only_experiments/patches/equilibrium_structure.py new file mode 100644 index 00000000..a630641d --- /dev/null +++ b/experiments/atom_types_only_experiments/patches/equilibrium_structure.py @@ -0,0 +1,49 @@ +from pathlib import Path + +import numpy as np +from pymatgen.core import Lattice, Structure +from pymatgen.symmetry.analyzer import SpacegroupAnalyzer + + +def create_equilibrium_sige_structure(): + """Create the SiGe 1x1x1 equilibrium structure.""" + conventional_cell_a = 5.542 + primitive_cell_a = conventional_cell_a / np.sqrt(2.0) + lattice = Lattice.from_parameters( + a=primitive_cell_a, + b=primitive_cell_a, + c=primitive_cell_a, + alpha=60.0, + beta=60.0, + gamma=60.0, + ) + + species = ["Si", "Ge"] + coordinates = np.array([[0.0, 0.0, 0.0], [0.25, 0.25, 0.25]]) + + primitive_structure = Structure( + lattice=lattice, species=species, coords=coordinates, coords_are_cartesian=False + ) + conventional_structure = ( + SpacegroupAnalyzer(primitive_structure) + .get_symmetrized_structure() + .to_conventional() + ) + + # Shift the relative coordinates a bit for easier visualization + shift = np.array([0.375, 0.375, 0.375]) + new_coordinates = (conventional_structure.frac_coords + shift) % 1.0 + + structure = Structure( + lattice=conventional_structure.lattice, + species=conventional_structure.species, + coords=new_coordinates, + coords_are_cartesian=False, + ) + return structure + + +if __name__ == "__main__": + output_file_path = Path(__file__).parent / "equilibrium_sige.cif" + structure = create_equilibrium_sige_structure() + structure.to(output_file_path) diff --git a/experiments/atom_types_only_experiments/patches/fixed_position_data_loader.py b/experiments/atom_types_only_experiments/patches/fixed_position_data_loader.py new file mode 100644 index 00000000..007cea40 --- /dev/null +++ b/experiments/atom_types_only_experiments/patches/fixed_position_data_loader.py @@ -0,0 +1,117 @@ +import logging +from pathlib import Path +from typing import Optional + +import pytorch_lightning as pl +import torch +from equilibrium_structure import create_equilibrium_sige_structure +from torch_geometric.data import DataLoader + +from diffusion_for_multi_scale_molecular_dynamics.data.diffusion.data_loader import \ + LammpsLoaderParameters +from diffusion_for_multi_scale_molecular_dynamics.data.element_types import \ + ElementTypes +from diffusion_for_multi_scale_molecular_dynamics.namespace import ( + ATOM_TYPES, CARTESIAN_FORCES, RELATIVE_COORDINATES) + +logger = logging.getLogger(__name__) + + +class FixedPositionDataModule(pl.LightningDataModule): + """Data module class that is meant to imitate LammpsForDiffusionDataModule.""" + + def __init__( + self, + lammps_run_dir: str, # dummy + processed_dataset_dir: str, + hyper_params: LammpsLoaderParameters, + working_cache_dir: Optional[str] = None, # dummy + ): + """Init method.""" + logger.debug("FixedPositionDataModule!") + super().__init__() + + assert hyper_params.batch_size, "batch_size must be specified" + assert hyper_params.train_batch_size, "train_batch_size must be specified" + assert hyper_params.valid_batch_size, "valid_batch_size must be specified" + + self.batch_size = hyper_params.batch_size + self.train_size = hyper_params.train_batch_size + self.valid_size = hyper_params.valid_batch_size + + self.num_workers = hyper_params.num_workers + self.max_atom = hyper_params.max_atom # number of atoms to pad tensors + + self.element_types = ElementTypes(hyper_params.elements) + + def setup(self, stage: Optional[str] = None): + """Setup method.""" + structure = create_equilibrium_sige_structure() + + relative_coordinates = torch.from_numpy(structure.frac_coords).to(torch.float) + + atom_types = torch.tensor( + [self.element_types.get_element_id(a.name) for a in structure.species] + ) + box = torch.tensor(structure.lattice.abc) + + row = { + "natom": len(atom_types), + "box": box, + RELATIVE_COORDINATES: relative_coordinates, + ATOM_TYPES: atom_types, + CARTESIAN_FORCES: torch.zeros_like(relative_coordinates), + "potential_energy": 0.0, + } + + self.train_dataset = [row for _ in range(self.train_size)] + self.valid_dataset = [row for _ in range(self.valid_size)] + + def train_dataloader(self) -> DataLoader: + """Create the training dataloader using the training data parser.""" + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, + ) + + def val_dataloader(self): + """Create the validation dataloader using the validation data parser.""" + return DataLoader( + self.valid_dataset, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + + def test_dataloader(self): + """Creates the testing dataloader using the testing data parser.""" + raise NotImplementedError("Test set is not defined at the moment.") + + def clean_up(self): + """Nothing to clean.""" + pass + + +if __name__ == "__main__": + + elements = ["Si", "Ge"] + processed_dataset_dir = Path("/experiments/atom_types_only_experiments") + + hyper_params = LammpsLoaderParameters( + batch_size=64, + train_batch_size=1024, + valid_batch_size=1024, + num_workers=8, + max_atom=8, + elements=elements, + ) + + data_module = FixedPositionDataModule( + lammps_run_dir="dummy", + processed_dataset_dir=processed_dataset_dir, + hyper_params=hyper_params, + ) + + data_module.setup() diff --git a/experiments/atom_types_only_experiments/patches/identity_noiser.py b/experiments/atom_types_only_experiments/patches/identity_noiser.py new file mode 100644 index 00000000..d2e4b01b --- /dev/null +++ b/experiments/atom_types_only_experiments/patches/identity_noiser.py @@ -0,0 +1,23 @@ +import logging + +import torch + +from diffusion_for_multi_scale_molecular_dynamics.noisers.relative_coordinates_noiser import \ + RelativeCoordinatesNoiser + +logger = logging.getLogger(__name__) + + +class IdentityNoiser(RelativeCoordinatesNoiser): + """Identity Noiser. + + This class can be used as a stand-in that returns the identity (ie, no noising). + """ + + @staticmethod + def get_noisy_relative_coordinates_sample( + real_relative_coordinates: torch.Tensor, sigmas: torch.Tensor + ) -> torch.Tensor: + """Get noisy relative coordinates sample.""" + logger.debug("Identity Noiser! Return input as output.") + return real_relative_coordinates diff --git a/experiments/atom_types_only_experiments/patches/identity_relative_coordinates_langevin_generator.py b/experiments/atom_types_only_experiments/patches/identity_relative_coordinates_langevin_generator.py new file mode 100644 index 00000000..cdd5e063 --- /dev/null +++ b/experiments/atom_types_only_experiments/patches/identity_relative_coordinates_langevin_generator.py @@ -0,0 +1,64 @@ +import logging + +import einops +import torch +from equilibrium_structure import create_equilibrium_sige_structure + +from diffusion_for_multi_scale_molecular_dynamics.generators.langevin_generator import \ + LangevinGenerator +from diffusion_for_multi_scale_molecular_dynamics.generators.predictor_corrector_axl_generator import \ + PredictorCorrectorSamplingParameters +from diffusion_for_multi_scale_molecular_dynamics.models.score_networks import \ + ScoreNetwork +from diffusion_for_multi_scale_molecular_dynamics.namespace import AXL +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ + NoiseParameters + +logger = logging.getLogger(__name__) + + +class IdentityRelativeCoordinatesUpdateLangevinGenerator(LangevinGenerator): + """Identity Relative Coordinates Update Langevin Generator.""" + def __init__( + self, + noise_parameters: NoiseParameters, + sampling_parameters: PredictorCorrectorSamplingParameters, + axl_network: ScoreNetwork, + ): + """Init method.""" + super().__init__(noise_parameters, sampling_parameters, axl_network) + + structure = create_equilibrium_sige_structure() + self.fixed_relative_coordinates = torch.from_numpy(structure.frac_coords).to( + torch.float + ) + + def initialize( + self, number_of_samples: int, device: torch.device = torch.device("cpu") + ): + """Initialize method.""" + logger.debug("Initialize with fixed relative coordinates.") + init_composition = super().initialize(number_of_samples, device=device) + + fixed_x = einops.repeat( + self.fixed_relative_coordinates, + "natoms space -> nsamples natoms space", + nsamples=number_of_samples, + ).to(init_composition.X) + + fixed_init_composition = AXL( + A=init_composition.A, X=fixed_x, L=init_composition.L + ) + + return fixed_init_composition + + def _relative_coordinates_update( + self, + relative_coordinates: torch.Tensor, + sigma_normalized_scores: torch.Tensor, + sigma_i: torch.Tensor, + score_weight: torch.Tensor, + gaussian_noise_weight: torch.Tensor, + ) -> torch.Tensor: + """Relative coordinates update.""" + return relative_coordinates diff --git a/experiments/atom_types_only_experiments/plot_atom_type_probabilities.py b/experiments/atom_types_only_experiments/plot_atom_type_probabilities.py new file mode 100644 index 00000000..6692a737 --- /dev/null +++ b/experiments/atom_types_only_experiments/plot_atom_type_probabilities.py @@ -0,0 +1,165 @@ +import einops +from matplotlib import pyplot as plt +from tqdm import tqdm + +from diffusion_for_multi_scale_molecular_dynamics import ROOT_DIR +from diffusion_for_multi_scale_molecular_dynamics.analysis import \ + PLOT_STYLE_PATH +from diffusion_for_multi_scale_molecular_dynamics.analysis.sample_trajectory_analyser import \ + SampleTrajectoryAnalyser +from diffusion_for_multi_scale_molecular_dynamics.data.element_types import \ + ElementTypes +from diffusion_for_multi_scale_molecular_dynamics.utils.d3pm_utils import ( + class_index_to_onehot, get_probability_at_previous_time_step) +from diffusion_for_multi_scale_molecular_dynamics.utils.logging_utils import \ + setup_analysis_logger +from diffusion_for_multi_scale_molecular_dynamics.utils.tensor_utils import \ + broadcast_batch_matrix_tensor_to_all_dimensions + +setup_analysis_logger() + +plt.style.use(PLOT_STYLE_PATH) + +base_path = ROOT_DIR / "../experiments/atom_types_only_experiments/experiments" +data_path = base_path / "output/run1/trajectory_samples" +pickle_path = data_path / "trajectories_sample_epoch=999.pt" + +elements = ["Si", "Ge"] + +element_types = ElementTypes(elements) + +num_classes = len(elements) + 1 +if __name__ == "__main__": + + analyser = SampleTrajectoryAnalyser(pickle_path, num_classes=num_classes) + + time_indices, predictions_axl = analyser.extract_axl(axl_key="model_predictions_i") + _, composition_axl = analyser.extract_axl(axl_key="composition_i") + + nsamples, ntimes, natoms = composition_axl.A.shape + + batched_predictions = einops.rearrange( + predictions_axl.A, "samples time ... -> (samples time) ..." + ) + batched_at = einops.rearrange( + composition_axl.A, "samples time ... -> (samples time) ..." + ) + batched_at_onehot = class_index_to_onehot(batched_at, num_classes=num_classes) + + final_shape = (ntimes, nsamples, natoms) + + q_matrices = broadcast_batch_matrix_tensor_to_all_dimensions( + batch_values=analyser.noise.q_matrix, final_shape=final_shape + ) + batched_q_matrices = einops.rearrange( + q_matrices, "times samples ... -> (samples times) ..." + ) + + q_bar_matrices = broadcast_batch_matrix_tensor_to_all_dimensions( + batch_values=analyser.noise.q_bar_matrix, final_shape=final_shape + ) + batched_q_bar_matrices = einops.rearrange( + q_bar_matrices, "times samples ... -> (samples times) ..." + ) + + q_bar_tm1_matrices = broadcast_batch_matrix_tensor_to_all_dimensions( + batch_values=analyser.noise.q_bar_tm1_matrix, final_shape=final_shape + ) + batched_q_bar_tm1_matrices = einops.rearrange( + q_bar_tm1_matrices, "times samples ... -> (samples times) ..." + ) + + batched_probabilities = get_probability_at_previous_time_step( + batched_predictions, + batched_at_onehot, + batched_q_matrices, + batched_q_bar_matrices, + batched_q_bar_tm1_matrices, + small_epsilon=1.0e-12, + probability_at_zeroth_timestep_are_logits=True, + ) + + probabilities = einops.rearrange( + batched_probabilities, + "(samples times) ... -> samples times ...", + samples=nsamples, + times=ntimes, + ) + + raw_probabilities = einops.rearrange( + batched_predictions.softmax(dim=-1), + "(samples times) ... -> samples times ...", + samples=nsamples, + times=ntimes, + ) + + output_dir = base_path / "images" + output_dir.mkdir(parents=True, exist_ok=True) + + masked_atom_type = num_classes - 1 + + list_colors = ["green", "blue", "red"] + list_elements = [] + list_element_idx = [] + + for element_id in element_types.element_ids: + element_types.get_element(element_id) + list_elements.append(element_types.get_element(element_id)) + list_element_idx.append(element_id) + + list_elements.append("MASK") + list_element_idx.append(masked_atom_type) + + for traj_idx in tqdm(range(10), "TRAJ"): + + fig = plt.figure(figsize=(14.4, 6.6)) + + fig.suptitle("Prediction Probability") + ax1 = fig.add_subplot(241) + ax2 = fig.add_subplot(242) + ax3 = fig.add_subplot(243) + ax4 = fig.add_subplot(244) + ax5 = fig.add_subplot(245) + ax6 = fig.add_subplot(246) + ax7 = fig.add_subplot(247) + ax8 = fig.add_subplot(248) + list_ax = [ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8] + + for atom_idx, ax in enumerate(list_ax): + ax.set_title(f"Atom {atom_idx}") + + mask = composition_axl.A[traj_idx, :, atom_idx] == masked_atom_type + unmask_time = time_indices[mask].min() + + ax.vlines(unmask_time, -0.1, 1.1, lw=2, color="k", label="Unmasking Time") + list_elements.append("MASK") + list_element_idx.append(masked_atom_type) + + for element_idx, element, color in zip( + list_element_idx, list_elements, list_colors + ): + p = probabilities[traj_idx, :, atom_idx, element_idx] + ax.semilogy(time_indices, p, c=color, label=f"{element}", alpha=0.5) + + for element_idx, element, color in zip( + list_element_idx[:-1], list_elements[:-1], list_colors[:-1] + ): + raw_p = raw_probabilities[traj_idx, :, atom_idx, element_idx] + ax.semilogy( + time_indices, + raw_p, + "--", + lw=2, + c=color, + label=f"RAW {element}", + alpha=0.25, + ) + + ax.set_xlabel("Time Index") + ax.set_ylabel("Probability") + ax.set_xlim(time_indices[-1], time_indices[0]) + + ax1.legend(loc=0) + fig.tight_layout() + fig.savefig(output_dir / f"traj_{traj_idx}.png") + plt.close(fig) diff --git a/experiments/atom_types_only_experiments/pseudo_train_diffusion.py b/experiments/atom_types_only_experiments/pseudo_train_diffusion.py new file mode 100644 index 00000000..f08531bd --- /dev/null +++ b/experiments/atom_types_only_experiments/pseudo_train_diffusion.py @@ -0,0 +1,35 @@ +import sys # noqa +from unittest.mock import patch # noqa + +from diffusion_for_multi_scale_molecular_dynamics import ROOT_DIR # noqa + +sys.path.append(str(ROOT_DIR / "../experiments/atom_types_only_experiments/patches")) + +from patches.fixed_position_data_loader import FixedPositionDataModule # noqa +from patches.identity_noiser import IdentityNoiser # noqa +from patches.identity_relative_coordinates_langevin_generator import \ + IdentityRelativeCoordinatesUpdateLangevinGenerator # noqa + +from diffusion_for_multi_scale_molecular_dynamics.train_diffusion import \ + main as train_diffusion_main # noqa + +if __name__ == "__main__": + # We must patch 'where the class is looked up', not where it is defined. + # See: https://docs.python.org/3/library/unittest.mock.html#where-to-patch + + # Patch the dataloader to always use the same atomic relative coordinates. + target1 = "diffusion_for_multi_scale_molecular_dynamics.train_diffusion.LammpsForDiffusionDataModule" + + # Patch the noiser to never change the relative coordinates" + target2 = ("diffusion_for_multi_scale_molecular_dynamics.models." + "axl_diffusion_lightning_model.RelativeCoordinatesNoiser") + + # Patch the generator to never change the relative coordinates" + target3 = "diffusion_for_multi_scale_molecular_dynamics.generators.instantiate_generator.LangevinGenerator" + + with ( + patch(target=target1, new=FixedPositionDataModule), + patch(target=target2, new=IdentityNoiser), + patch(target=target3, new=IdentityRelativeCoordinatesUpdateLangevinGenerator), + ): + train_diffusion_main() diff --git a/experiments/dataset_analysis/energy_consistency_analysis.py b/experiments/dataset_analysis/energy_consistency_analysis.py index 35762c04..ec797b20 100644 --- a/experiments/dataset_analysis/energy_consistency_analysis.py +++ b/experiments/dataset_analysis/energy_consistency_analysis.py @@ -15,10 +15,10 @@ from tqdm import tqdm from diffusion_for_multi_scale_molecular_dynamics import DATA_DIR -from diffusion_for_multi_scale_molecular_dynamics.analysis import \ - PLOT_STYLE_PATH -from diffusion_for_multi_scale_molecular_dynamics.callbacks.sampling_visualization_callback import ( - LOGGER_FIGSIZE, SamplingVisualizationCallback) +from diffusion_for_multi_scale_molecular_dynamics.analysis import ( + PLEASANT_FIG_SIZE, PLOT_STYLE_PATH) +from diffusion_for_multi_scale_molecular_dynamics.callbacks.sampling_visualization_callback import \ + SamplingVisualizationCallback from diffusion_for_multi_scale_molecular_dynamics.data.diffusion.data_loader import ( LammpsForDiffusionDataModule, LammpsLoaderParameters) from diffusion_for_multi_scale_molecular_dynamics.oracle.lammps import \ @@ -92,7 +92,7 @@ ) plt.show() - fig2 = plt.figure(figsize=LOGGER_FIGSIZE) + fig2 = plt.figure(figsize=PLEASANT_FIG_SIZE) ax2 = fig2.add_subplot(111) errors = list_oracle_energies - list_dataset_potential_energies diff --git a/experiments/diffusion_mace_harmonic_data/ad_hoc_experiments_with_various_score_networks.py b/experiments/diffusion_mace_harmonic_data/ad_hoc_experiments_with_various_score_networks.py index 1403ffa4..69df67ca 100644 --- a/experiments/diffusion_mace_harmonic_data/ad_hoc_experiments_with_various_score_networks.py +++ b/experiments/diffusion_mace_harmonic_data/ad_hoc_experiments_with_various_score_networks.py @@ -29,10 +29,10 @@ MaceEquivariantScorePredictionHeadParameters from diffusion_for_multi_scale_molecular_dynamics.namespace import ( CARTESIAN_FORCES, RELATIVE_COORDINATES) -from diffusion_for_multi_scale_molecular_dynamics.samplers.variance_sampler import \ +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ NoiseParameters -from experiments.analysis.analytic_score import (get_exact_samples, - get_relative_harmonic_energy) +from experiments.analysis.analytic_score.utils import ( + get_exact_samples, get_relative_harmonic_energy) from experiments.diffusion_mace_harmonic_data.analysis_callbacks import \ HarmonicEnergyDiffusionSamplingCallback diff --git a/experiments/diffusion_mace_harmonic_data/analysis_callbacks.py b/experiments/diffusion_mace_harmonic_data/analysis_callbacks.py index a7576c66..9b007c14 100644 --- a/experiments/diffusion_mace_harmonic_data/analysis_callbacks.py +++ b/experiments/diffusion_mace_harmonic_data/analysis_callbacks.py @@ -16,9 +16,10 @@ SamplingVisualizationCallback from diffusion_for_multi_scale_molecular_dynamics.generators.position_generator import \ SamplingParameters -from diffusion_for_multi_scale_molecular_dynamics.samplers.variance_sampler import \ +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ NoiseParameters -from experiments.analysis.analytic_score import get_relative_harmonic_energy +from experiments.analysis.analytic_score.utils import \ + get_relative_harmonic_energy logger = logging.getLogger(__name__) diff --git a/experiments/diffusion_mace_harmonic_data/overfit_diffusion_mace.py b/experiments/diffusion_mace_harmonic_data/overfit_diffusion_mace.py index 03b2e519..de3a2ff1 100644 --- a/experiments/diffusion_mace_harmonic_data/overfit_diffusion_mace.py +++ b/experiments/diffusion_mace_harmonic_data/overfit_diffusion_mace.py @@ -20,16 +20,18 @@ DiffusionMACEScoreNetwork, DiffusionMACEScoreNetworkParameters) from diffusion_for_multi_scale_molecular_dynamics.namespace import ( CARTESIAN_FORCES, NOISE, NOISY_RELATIVE_COORDINATES, TIME, UNIT_CELL) -from diffusion_for_multi_scale_molecular_dynamics.samplers.noisy_relative_coordinates_sampler import \ - NoisyRelativeCoordinatesSampler -from diffusion_for_multi_scale_molecular_dynamics.samplers.variance_sampler import ( - ExplodingVarianceSampler, NoiseParameters) +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ + NoiseParameters +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_scheduler import \ + ExplodingVarianceSampler +from diffusion_for_multi_scale_molecular_dynamics.noisers.relative_coordinates_noiser import \ + RelativeCoordinatesNoiser from diffusion_for_multi_scale_molecular_dynamics.utils.basis_transformations import \ map_relative_coordinates_to_unit_cell from diffusion_for_multi_scale_molecular_dynamics.utils.tensor_utils import \ broadcast_batch_tensor_to_all_dimensions -from experiments.analysis.analytic_score import (get_exact_samples, - get_unit_cells) +from experiments.analysis.analytic_score.utils import (get_exact_samples, + get_unit_cells) torch.set_default_dtype(torch.float64) @@ -114,7 +116,7 @@ def training_step(self, batch, batch_idx): max_epochs = 1000 acell = 5.5 -noisy_relative_coordinates_sampler = NoisyRelativeCoordinatesSampler() +relative_coordinates_noiser = RelativeCoordinatesNoiser() noise_parameters = NoiseParameters(total_time_steps=100, sigma_min=0.001, sigma_max=0.5) variance_sampler = ExplodingVarianceSampler(noise_parameters) @@ -150,7 +152,7 @@ def training_step(self, batch, batch_idx): ) sigmas = torch.ones_like(sigmas) - xt = noisy_relative_coordinates_sampler.get_noisy_relative_coordinates_sample( + xt = relative_coordinates_noiser.get_noisy_relative_coordinates_sample( x0, sigmas ) diff --git a/experiments/generators/sde_generator_sanity_check.py b/experiments/generators/sde_generator_sanity_check.py index ff140d37..2a71ce50 100644 --- a/experiments/generators/sde_generator_sanity_check.py +++ b/experiments/generators/sde_generator_sanity_check.py @@ -14,7 +14,7 @@ ExplodingVarianceSDEPositionGenerator, SDESamplingParameters) from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.analytical_score_network import ( AnalyticalScoreNetworkParameters, TargetScoreBasedAnalyticalScoreNetwork) -from diffusion_for_multi_scale_molecular_dynamics.samplers.variance_sampler import \ +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ NoiseParameters from diffusion_for_multi_scale_molecular_dynamics.utils.basis_transformations import \ map_relative_coordinates_to_unit_cell diff --git a/experiments/sampling_sota_model/repaint_with_sota_score.py b/experiments/sampling_sota_model/repaint_with_sota_score.py index 935781b3..8e8d7dad 100644 --- a/experiments/sampling_sota_model/repaint_with_sota_score.py +++ b/experiments/sampling_sota_model/repaint_with_sota_score.py @@ -14,10 +14,10 @@ ConstrainedLangevinGenerator, ConstrainedLangevinGeneratorParameters) from diffusion_for_multi_scale_molecular_dynamics.models.instantiate_diffusion_model import \ load_diffusion_model +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ + NoiseParameters from diffusion_for_multi_scale_molecular_dynamics.oracle.lammps import \ get_energy_and_forces_from_lammps -from diffusion_for_multi_scale_molecular_dynamics.samplers.variance_sampler import \ - NoiseParameters from diffusion_for_multi_scale_molecular_dynamics.utils.logging_utils import \ setup_analysis_logger diff --git a/experiments/sampling_sota_model/sota_score_sampling_and_plotting.py b/experiments/sampling_sota_model/sota_score_sampling_and_plotting.py index be33daab..007ed48e 100644 --- a/experiments/sampling_sota_model/sota_score_sampling_and_plotting.py +++ b/experiments/sampling_sota_model/sota_score_sampling_and_plotting.py @@ -21,10 +21,12 @@ PredictorCorrectorSamplingParameters from diffusion_for_multi_scale_molecular_dynamics.models.instantiate_diffusion_model import \ load_diffusion_model +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.exploding_variance import \ + ExplodingVariance +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ + NoiseParameters from diffusion_for_multi_scale_molecular_dynamics.oracle.lammps import \ get_energy_and_forces_from_lammps -from diffusion_for_multi_scale_molecular_dynamics.samplers.variance_sampler import \ - NoiseParameters from diffusion_for_multi_scale_molecular_dynamics.utils.logging_utils import \ setup_analysis_logger @@ -82,6 +84,7 @@ noise_parameters = NoiseParameters( total_time_steps=total_time_steps, sigma_min=0.001, sigma_max=0.5 ) + exploding_variance = ExplodingVariance(noise_parameters) if sampling_algorithm == "ode": ode_sampling_parameters = ODESamplingParameters( @@ -147,7 +150,7 @@ # Plot the ODE parameters logger.info("Plotting ODE parameters") times = torch.linspace(0, 1, 1001) - sigmas = position_generator._get_exploding_variance_sigma(times) + sigmas = exploding_variance.get_sigma(times) ode_prefactor = position_generator._get_ode_prefactor(sigmas) fig0 = plt.figure(figsize=PLEASANT_FIG_SIZE) diff --git a/experiments/score_stability_analysis/draw_samples_from_equilibrium.py b/experiments/score_stability_analysis/draw_samples_from_equilibrium.py index aa54a67a..8ec1f134 100644 --- a/experiments/score_stability_analysis/draw_samples_from_equilibrium.py +++ b/experiments/score_stability_analysis/draw_samples_from_equilibrium.py @@ -19,7 +19,7 @@ PositionDiffusionLightningModel from diffusion_for_multi_scale_molecular_dynamics.models.score_networks import \ ScoreNetwork -from diffusion_for_multi_scale_molecular_dynamics.samplers.variance_sampler import \ +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ NoiseParameters from diffusion_for_multi_scale_molecular_dynamics.utils.logging_utils import \ setup_analysis_logger diff --git a/experiments/score_stability_analysis/plot_hessian_eigenvalues.py b/experiments/score_stability_analysis/plot_hessian_eigenvalues.py index 543a9deb..98c74a44 100644 --- a/experiments/score_stability_analysis/plot_hessian_eigenvalues.py +++ b/experiments/score_stability_analysis/plot_hessian_eigenvalues.py @@ -12,14 +12,15 @@ PLEASANT_FIG_SIZE, PLOT_STYLE_PATH) from diffusion_for_multi_scale_molecular_dynamics.models.position_diffusion_lightning_model import \ PositionDiffusionLightningModel -from diffusion_for_multi_scale_molecular_dynamics.samplers.exploding_variance import \ +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.exploding_variance import \ ExplodingVariance -from diffusion_for_multi_scale_molecular_dynamics.samplers.variance_sampler import \ +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ NoiseParameters from diffusion_for_multi_scale_molecular_dynamics.utils.logging_utils import \ setup_analysis_logger -from experiments import get_normalized_score_function from experiments.analysis.analytic_score.utils import get_silicon_supercell +from experiments.score_stability_analysis.util import \ + get_normalized_score_function plt.style.use(PLOT_STYLE_PATH) diff --git a/experiments/score_stability_analysis/plot_score_norm.py b/experiments/score_stability_analysis/plot_score_norm.py deleted file mode 100644 index 706e6bbd..00000000 --- a/experiments/score_stability_analysis/plot_score_norm.py +++ /dev/null @@ -1,119 +0,0 @@ -import logging - -import einops -import matplotlib.pyplot as plt -import numpy as np -import torch -from tqdm import tqdm - -from diffusion_for_multi_scale_molecular_dynamics.analysis import ( - PLEASANT_FIG_SIZE, PLOT_STYLE_PATH) -from diffusion_for_multi_scale_molecular_dynamics.models.position_diffusion_lightning_model import \ - PositionDiffusionLightningModel -from diffusion_for_multi_scale_molecular_dynamics.samplers.exploding_variance import \ - ExplodingVariance -from diffusion_for_multi_scale_molecular_dynamics.samplers.variance_sampler import \ - NoiseParameters -from diffusion_for_multi_scale_molecular_dynamics.utils.basis_transformations import \ - map_relative_coordinates_to_unit_cell -from diffusion_for_multi_scale_molecular_dynamics.utils.logging_utils import \ - setup_analysis_logger -from experiments import create_fixed_time_normalized_score_function -from experiments.analysis.analytic_score.utils import get_silicon_supercell - -plt.style.use(PLOT_STYLE_PATH) - -logger = logging.getLogger(__name__) -setup_analysis_logger() - - -checkpoint_path = ( - "/home/mila/r/rousseab/scratch/experiments/oct2_egnn_1x1x1/run1/" - "output/last_model/last_model-epoch=049-step=039100.ckpt" -) - -spatial_dimension = 3 -number_of_atoms = 8 -atom_types = np.ones(number_of_atoms, dtype=int) - -acell = 5.43 -basis_vectors = torch.diag(torch.tensor([acell, acell, acell])) - -total_time_steps = 1000 -noise_parameters = NoiseParameters( - total_time_steps=total_time_steps, - sigma_min=0.0001, - sigma_max=0.2, -) - -device = torch.device("cuda") -if __name__ == "__main__": - variance_calculator = ExplodingVariance(noise_parameters) - - logger.info("Loading checkpoint...") - pl_model = PositionDiffusionLightningModel.load_from_checkpoint(checkpoint_path) - pl_model.eval() - - sigma_normalized_score_network = pl_model.sigma_normalized_score_network - - for parameter in sigma_normalized_score_network.parameters(): - parameter.requires_grad_(False) - - equilibrium_relative_coordinates = torch.from_numpy( - get_silicon_supercell(supercell_factor=1) - ).to(torch.float32) - - direction = torch.zeros_like(equilibrium_relative_coordinates) - - # Move a single atom - # direction[0, 0] = 1.0 - # list_delta = torch.linspace(-0.5, 0.5, 101) - - # Put two particles on top of each other - dv = equilibrium_relative_coordinates[0] - equilibrium_relative_coordinates[1] - direction[0] = -0.5 * dv - direction[1] = 0.5 * dv - list_delta = torch.linspace(0.0, 2.0, 201) - - relative_coordinates = [] - for delta in list_delta: - relative_coordinates.append( - equilibrium_relative_coordinates + delta * direction - ) - relative_coordinates = map_relative_coordinates_to_unit_cell( - torch.stack(relative_coordinates) - ).to(device) - - list_t = torch.tensor([0.8, 0.7, 0.5, 0.3, 0.1, 0.01]) - list_sigmas = variance_calculator.get_sigma(list_t) - list_norms = [] - for t in tqdm(list_t, "norms"): - vector_field_fn = create_fixed_time_normalized_score_function( - sigma_normalized_score_network, - noise_parameters, - time=t, - basis_vectors=basis_vectors, - ) - - normalized_scores = vector_field_fn(relative_coordinates) - flat_normalized_scores = einops.rearrange( - normalized_scores, " b n s -> b (n s)" - ) - list_norms.append(flat_normalized_scores.norm(dim=-1).cpu()) - - fig = plt.figure(figsize=PLEASANT_FIG_SIZE) - fig.suptitle("Normalized Score Norm Along Specific Direction") - ax1 = fig.add_subplot(111) - ax1.set_xlabel(r"$\delta$") - ax1.set_ylabel(r"$|{\bf n}({\bf x}, t)|$") - - for t, sigma, norms in zip(list_t, list_sigmas, list_norms): - ax1.plot( - list_delta, norms, "-", label=f"t = {t: 3.2f}, $\\sigma$ = {sigma: 5.2e}" - ) - - ax1.legend(loc=0) - - fig.tight_layout() - - plt.show() diff --git a/experiments/score_stability_analysis/util.py b/experiments/score_stability_analysis/util.py index 373cc761..73871053 100644 --- a/experiments/score_stability_analysis/util.py +++ b/experiments/score_stability_analysis/util.py @@ -1,4 +1,3 @@ -import itertools from typing import Callable import einops @@ -8,10 +7,10 @@ ScoreNetwork from diffusion_for_multi_scale_molecular_dynamics.namespace import ( CARTESIAN_FORCES, NOISE, NOISY_RELATIVE_COORDINATES, TIME, UNIT_CELL) -from diffusion_for_multi_scale_molecular_dynamics.samplers.exploding_variance import \ - ExplodingVariance -from diffusion_for_multi_scale_molecular_dynamics.samplers.variance_sampler import \ +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ NoiseParameters +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_scheduler import \ + NoiseScheduler def get_normalized_score_function( @@ -20,7 +19,7 @@ def get_normalized_score_function( basis_vectors: torch.Tensor, ) -> Callable: """Get normalizd score function.""" - variance_calculator = ExplodingVariance(noise_parameters) + variance_calculator = NoiseScheduler(noise_parameters) def normalized_score_function( relative_coordinates: torch.Tensor, times: torch.Tensor @@ -48,20 +47,3 @@ def normalized_score_function( return sigma_normalized_scores return normalized_score_function - - -def get_cubic_point_group_symmetries(): - """Get cubic point group symmetries.""" - permutations = [ - torch.diag(torch.ones(3))[[idx]] for idx in itertools.permutations([0, 1, 2]) - ] - sign_changes = [ - torch.diag(torch.tensor(diag)) - for diag in itertools.product([-1.0, 1.0], repeat=3) - ] - symmetries = [] - for permutation in permutations: - for sign_change in sign_changes: - symmetries.append(permutation @ sign_change) - - return symmetries diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/analysis/generator_sample_analysis_utils.py b/src/diffusion_for_multi_scale_molecular_dynamics/analysis/generator_sample_analysis_utils.py index 14c98513..d5da9beb 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/analysis/generator_sample_analysis_utils.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/analysis/generator_sample_analysis_utils.py @@ -7,7 +7,7 @@ get_adj_matrix from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.score_network import \ ScoreNetwork -from diffusion_for_multi_scale_molecular_dynamics.samplers.variance_sampler import \ +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ NoiseParameters @@ -64,7 +64,7 @@ def get_interatomic_distances( Returns: distances : all distances up to cutoff. """ - shifted_adjacency_matrix, shifts, batch_indices = get_adj_matrix( + shifted_adjacency_matrix, shifts, _, _ = get_adj_matrix( positions=cartesian_positions, basis_vectors=basis_vectors, radial_cutoff=radial_cutoff, diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/analysis/sample_trajectory_analyser.py b/src/diffusion_for_multi_scale_molecular_dynamics/analysis/sample_trajectory_analyser.py new file mode 100644 index 00000000..039d960c --- /dev/null +++ b/src/diffusion_for_multi_scale_molecular_dynamics/analysis/sample_trajectory_analyser.py @@ -0,0 +1,85 @@ +import logging +from collections import defaultdict +from pathlib import Path +from typing import Tuple + +import einops +import numpy as np +import torch + +from diffusion_for_multi_scale_molecular_dynamics.namespace import AXL +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ + NoiseParameters +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_scheduler import \ + NoiseScheduler + +logger = logging.getLogger(__name__) + + +class SampleTrajectoryAnalyser: + """Sample Trajectory Analyser. + + This class reads in a trajectory recording pickle and processes the data to make it easy to analyse. + """ + def __init__(self, pickle_path: Path, num_classes: int): + """Init method. + + Args: + pickle_path: path to recording pickle. + num_classes: number of classes (including the MASK class). + """ + logger.info("Reading data from pickle file.") + data = torch.load(pickle_path, map_location=torch.device("cpu")) + logger.info("Done reading data.") + + noise_parameters = NoiseParameters(**data['noise_parameters']) + sampler = NoiseScheduler(noise_parameters, num_classes=num_classes) + self.noise, _ = sampler.get_all_sampling_parameters() + + self.time_index_key = 'time_step_index' + self.axl_keys = ['composition_i', 'composition_im1', 'model_predictions_i'] + + self._predictor_data = data["predictor_step"] + + del data + + def extract_axl(self, axl_key: str) -> Tuple[np.ndarray, AXL]: + """Extract AXL. + + Args: + axl_key: name of field to be extracted + + Returns: + time_indices: an array containing the time indices of the AXL. + axl: the axl described in the axl_key, where the fields have dimension [nsample, ntimes, ...] + """ + # The recording might have taken place over multiple batches. Combine corresponding compositions. + assert axl_key in self.axl_keys, f"Unknown axl key '{axl_key}'" + multiple_batch = defaultdict(list) + + logger.info("Iterating over entries") + list_time_indices = [] + for entry in self._predictor_data: + time_index = entry["time_step_index"] + list_time_indices.append(time_index) + axl = entry[axl_key] + multiple_batch[time_index].append(axl) + + time_indices = np.sort(np.unique(np.array(list_time_indices))) + + logger.info("Stacking multiple batch over time") + list_stacked_axl = [] + for time_index in time_indices: + list_axl = multiple_batch[time_index] + stacked_axl = AXL( + A=torch.vstack([axl.A for axl in list_axl]), + X=torch.vstack([axl.X for axl in list_axl]), + L=torch.vstack([axl.L for axl in list_axl]), + ) + list_stacked_axl.append(stacked_axl) + + logger.info("Rearrange dimensions") + a = einops.rearrange([axl.A for axl in list_stacked_axl], "time batch ... -> batch time ...") + x = einops.rearrange([axl.X for axl in list_stacked_axl], "time batch ... -> batch time ...") + lattice = einops.rearrange([axl.L for axl in list_stacked_axl], "time batch ... -> batch time ...") + return time_indices, AXL(A=a, X=x, L=lattice) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/callbacks/loss_monitoring_callback.py b/src/diffusion_for_multi_scale_molecular_dynamics/callbacks/loss_monitoring_callback.py index 897b2fa2..b58228f8 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/callbacks/loss_monitoring_callback.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/callbacks/loss_monitoring_callback.py @@ -67,8 +67,10 @@ def on_validation_batch_end( # Compute the square errors per atoms batched_squared_errors = ( ( - outputs["predicted_normalized_scores"] - - outputs["target_normalized_conditional_scores"] + outputs[ + "unreduced_loss" + ].X # prediction normalized scores for coordinates + - outputs["target_coordinates_normalized_conditional_scores"] ) ** 2 ).sum(dim=-1) @@ -76,7 +78,7 @@ def on_validation_batch_end( # Average over space dimensions, where the sigmas are the same. self.all_weighted_losses.append( - outputs["unreduced_loss"].mean(dim=-1).flatten() + outputs["unreduced_loss"].X.mean(dim=-1).flatten() ) def on_validation_epoch_end(self, trainer, pl_module): diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/data/diffusion/data_loader.py b/src/diffusion_for_multi_scale_molecular_dynamics/data/diffusion/data_loader.py index df6784c4..5fbc9d17 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/data/diffusion/data_loader.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/data/diffusion/data_loader.py @@ -14,8 +14,10 @@ from diffusion_for_multi_scale_molecular_dynamics.data.diffusion.data_preprocess import \ LammpsProcessorForDiffusion +from diffusion_for_multi_scale_molecular_dynamics.data.element_types import ( + NULL_ELEMENT, ElementTypes) from diffusion_for_multi_scale_molecular_dynamics.namespace import ( - CARTESIAN_FORCES, CARTESIAN_POSITIONS, RELATIVE_COORDINATES) + ATOM_TYPES, CARTESIAN_FORCES, CARTESIAN_POSITIONS, RELATIVE_COORDINATES) logger = logging.getLogger(__name__) @@ -31,6 +33,7 @@ class LammpsLoaderParameters: num_workers: int = 0 max_atom: int = 64 spatial_dimension: int = 3 # the dimension of Euclidean space where atoms live. + elements: list[str] # the elements that can exist. class LammpsForDiffusionDataModule(pl.LightningDataModule): @@ -63,6 +66,8 @@ def __init__( self.max_atom = hyper_params.max_atom # number of atoms to pad tensors self.spatial_dim = hyper_params.spatial_dimension + self.element_types = ElementTypes(hyper_params.elements) + if hyper_params.batch_size is None: assert ( hyper_params.valid_batch_size is not None @@ -86,7 +91,9 @@ def __init__( @staticmethod def dataset_transform( - x: Dict[typing.AnyStr, typing.Any], spatial_dim: int = 3 + x: Dict[typing.AnyStr, typing.Any], + element_types: ElementTypes, + spatial_dim: int = 3, ) -> Dict[str, torch.Tensor]: """Format the tensors for the Datasets library. @@ -96,6 +103,7 @@ def dataset_transform( Args: x: raw columns from the processed data files. Should contain natom, box, type, position and relative_positions. + element_types: object that knows the relationship between elements and their integer ids. spatial_dim (optional): number of spatial dimensions. Defaults to 3. Returns: @@ -111,9 +119,14 @@ def dataset_transform( ) # size: (batchsize, spatial dimension) for pos in [CARTESIAN_POSITIONS, RELATIVE_COORDINATES, CARTESIAN_FORCES]: transformed_x[pos] = torch.as_tensor(x[pos]).view(bsize, -1, spatial_dim) - transformed_x["type"] = torch.as_tensor( - x["type"] + + element_ids = [] + for row in x["element"]: + element_ids.append(list(map(element_types.get_element_id, row))) + transformed_x[ATOM_TYPES] = torch.as_tensor( + element_ids ).long() # size: (batchsize, max atom) + transformed_x["potential_energy"] = torch.as_tensor( x["potential_energy"] ) # size: (batchsize, ) @@ -139,9 +152,12 @@ def pad_samples( raise ValueError( f"Hyper-parameter max_atom is smaller than an example in the dataset with {natom} atoms." ) - x["type"] = F.pad( - torch.as_tensor(x["type"]).long(), (0, max_atom - natom), "constant", -1 - ) + + padded_elements = max_atom * [NULL_ELEMENT] + for idx, element in enumerate(x["element"]): + padded_elements[idx] = element + x["element"] = padded_elements + for pos in [CARTESIAN_POSITIONS, RELATIVE_COORDINATES, CARTESIAN_FORCES]: x[pos] = F.pad( torch.as_tensor(x[pos]).float(), @@ -180,6 +196,7 @@ def setup(self, stage: Optional[str] = None): ) # map() are applied once, not in-place. # The keyword argument "batched" can accelerate by working with batches, not useful for padding + self.train_dataset = self.train_dataset.map( partial( self.pad_samples, max_atom=self.max_atom, spatial_dim=self.spatial_dim @@ -195,10 +212,18 @@ def setup(self, stage: Optional[str] = None): # set_transform is applied on-the-fly and is less costly upfront. Works with batches, so we can't use it for # padding self.train_dataset.set_transform( - partial(self.dataset_transform, spatial_dim=self.spatial_dim) + partial( + self.dataset_transform, + element_types=self.element_types, + spatial_dim=self.spatial_dim, + ) ) self.valid_dataset.set_transform( - partial(self.dataset_transform, spatial_dim=self.spatial_dim) + partial( + self.dataset_transform, + element_types=self.element_types, + spatial_dim=self.spatial_dim, + ) ) def train_dataloader(self) -> DataLoader: diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/data/diffusion/data_preprocess.py b/src/diffusion_for_multi_scale_molecular_dynamics/data/diffusion/data_preprocess.py index 6db8195a..a80f64e1 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/data/diffusion/data_preprocess.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/data/diffusion/data_preprocess.py @@ -183,15 +183,18 @@ def parse_lammps_run(self, run_dir: str) -> Optional[pd.DataFrame]: warnings.warn("Skipping this run.", UserWarning) return None - # the dataframe contains the following columns: id (list of atom indices), type (list of int representing - # atom type, x (list of x cartesian coordinates for each atom), y, z, fx (list forces in direction x for each - # atom), potential_energy (1 float). + # the dataframe contains the following columns: + # - id : list of atom indices + # - element : list of strings representing atom element + # - x, y, z : lists of cartesian coordinates for each atom + # - fx, fy, fz : lists force components for each atom + # - potential_energy : 1 float. # Each row is a different MD step / usable example for diffusion model # TODO consider filtering out samples with large forces and MD steps that are too similar # TODO large force and similar are to be defined - df = df[["type", "x", "y", "z", "box", "potential_energy", "fx", "fy", "fz"]] + df = df[["element", "x", "y", "z", "box", "potential_energy", "fx", "fy", "fz"]] df = self.get_x_relative(df) # add relative coordinates - df["natom"] = df["type"].apply( + df["natom"] = df["element"].apply( lambda x: len(x) ) # count number of atoms in a structure @@ -201,11 +204,12 @@ def parse_lammps_run(self, run_dir: str) -> Optional[pd.DataFrame]: df[CARTESIAN_FORCES] = df.apply( partial(self._flatten_positions_in_row, keys=["fx", "fy", "fz"]), axis=1 ) + return df[ [ "natom", "box", - "type", + "element", "potential_energy", CARTESIAN_POSITIONS, RELATIVE_COORDINATES, diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/data/element_types.py b/src/diffusion_for_multi_scale_molecular_dynamics/data/element_types.py new file mode 100644 index 00000000..0a845739 --- /dev/null +++ b/src/diffusion_for_multi_scale_molecular_dynamics/data/element_types.py @@ -0,0 +1,75 @@ +from typing import Dict, List + +NULL_ELEMENT = "NULL_ELEMENT_FOR_PADDING" +NULL_ELEMENT_ID = -1 + + +class ElementTypes: + """Element Types. + + This class manages the relationship between strings that identify elements (Si, Ge, Li, etc...) + and their integer indices. + """ + + def __init__(self, elements: List[str]): + """Init method. + + Args: + elements: list all the elements that could be present in the data. + """ + self.validate_elements(elements) + self._elements = sorted(elements) + self._ids = list(range(len(self._elements))) + + self._element_to_id_map: Dict[str, int] = { + k: v for k, v in zip(self._elements, self._ids) + } + self._id_to_element_map: Dict[int, str] = { + k: v for k, v in zip(self._ids, self._elements) + } + + self._element_to_id_map[NULL_ELEMENT] = NULL_ELEMENT_ID + self._id_to_element_map[NULL_ELEMENT_ID] = NULL_ELEMENT + + @staticmethod + def validate_elements(elements: List[str]): + """Validate elements.""" + assert NULL_ELEMENT not in elements, f"The element '{NULL_ELEMENT}' is reserved and should not be used." + assert len(set(elements)) == len(elements), "Each entry in the elements list should be unique." + + @property + def number_of_atom_types(self) -> int: + """Number of atom types.""" + return len(self._elements) + + @property + def elements(self) -> List[str]: + """The sorted elements.""" + return self._elements + + @property + def element_ids(self) -> List[int]: + """The sorted elements.""" + return self._ids + + def get_element(self, element_id: int) -> str: + """Get element. + + Args: + element_id : integer index. + + Returns: + element: string representing the element + """ + return self._id_to_element_map[element_id] + + def get_element_id(self, element: str) -> int: + """Get element id. + + Args: + element: string representing the element + + Returns: + element_id : integer index. + """ + return self._element_to_id_map[element] diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/data/parse_lammps_outputs.py b/src/diffusion_for_multi_scale_molecular_dynamics/data/parse_lammps_outputs.py index a7f44196..9d1405d4 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/data/parse_lammps_outputs.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/data/parse_lammps_outputs.py @@ -62,8 +62,8 @@ def parse_lammps_dump(lammps_dump: str) -> Dict[str, Any]: Returns: data: a dictionary with all the relevant data. """ - expected_keywords = ["id", "type", "x", "y", "z", "fx", "fy", "fz"] - datatypes = 2 * [np.int64] + 6 * [np.float64] + expected_keywords = ["id", "element", "x", "y", "z", "fx", "fy", "fz"] + datatypes = [np.int64] + [str] + 6 * [np.float64] pd_data = defaultdict(list) with open(lammps_dump, "r") as stream: diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/generators/position_generator.py b/src/diffusion_for_multi_scale_molecular_dynamics/generators/axl_generator.py similarity index 67% rename from src/diffusion_for_multi_scale_molecular_dynamics/generators/position_generator.py rename to src/diffusion_for_multi_scale_molecular_dynamics/generators/axl_generator.py index b102227a..1c4c1989 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/generators/position_generator.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/generators/axl_generator.py @@ -4,6 +4,8 @@ import torch +from diffusion_for_multi_scale_molecular_dynamics.namespace import AXL + @dataclass(kw_only=True) class SamplingParameters: @@ -11,6 +13,7 @@ class SamplingParameters: algorithm: str spatial_dimension: int = 3 # the dimension of Euclidean space where atoms live. + num_atom_types: int # number of atom types excluding MASK number_of_atoms: ( int # the number of atoms that must be generated in a sampled configuration. ) @@ -20,22 +23,24 @@ class SamplingParameters: sample_batchsize: Optional[int] = None cell_dimensions: List[ float - ] # unit cell dimensions; the unit cell is assumed to be an orthogonal box. + ] # unit cell dimensions; the unit cell is assumed to be an orthogonal box. TODO replace with AXL-L record_samples: bool = ( False # should the predictor and corrector steps be recorded to a file ) + record_samples_corrector_steps: bool = False + record_atom_type_update: bool = False # record the information pertaining to generating atom types. -class PositionGenerator(ABC): - """This defines the interface for position generators.""" +class AXLGenerator(ABC): + """This defines the interface for AXL (atom types, reduced coordinates and lattice) generators.""" @abstractmethod def sample( self, number_of_samples: int, device: torch.device, unit_cell: torch.Tensor - ) -> torch.Tensor: + ) -> AXL: """Sample. - This method draws a position sample. + This method draws a configuration sample. Args: number_of_samples : number of samples to draw. @@ -44,11 +49,11 @@ def sample( Tensor of dimensions [number_of_samples, spatial_dimension, spatial_dimension] Returns: - samples: relative coordinates samples. + AXL samples: samples as AXL namedtuple with atom types, reduced coordinates and lattice vectors. """ pass @abstractmethod - def initialize(self, number_of_samples: int): + def initialize(self, number_of_samples: int, device: torch.device) -> AXL: """This method must initialize the samples from the fully noised distribution.""" pass diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py b/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py index 31ad891a..99db71ce 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/generators/constrained_langevin_generator.py @@ -6,16 +6,15 @@ from diffusion_for_multi_scale_molecular_dynamics.generators.langevin_generator import \ LangevinGenerator -from diffusion_for_multi_scale_molecular_dynamics.generators.predictor_corrector_position_generator import \ +from diffusion_for_multi_scale_molecular_dynamics.generators.predictor_corrector_axl_generator import \ PredictorCorrectorSamplingParameters from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.score_network import \ ScoreNetwork -from diffusion_for_multi_scale_molecular_dynamics.samplers.noisy_relative_coordinates_sampler import \ - NoisyRelativeCoordinatesSampler -from diffusion_for_multi_scale_molecular_dynamics.samplers.variance_sampler import \ +from diffusion_for_multi_scale_molecular_dynamics.namespace import AXL +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ NoiseParameters -from diffusion_for_multi_scale_molecular_dynamics.utils.basis_transformations import \ - map_relative_coordinates_to_unit_cell +from diffusion_for_multi_scale_molecular_dynamics.noisers.relative_coordinates_noiser import \ + RelativeCoordinatesNoiser @dataclass(kw_only=True) @@ -40,16 +39,14 @@ def __init__( self, noise_parameters: NoiseParameters, sampling_parameters: ConstrainedLangevinGeneratorParameters, - sigma_normalized_score_network: ScoreNetwork, + axl_network: ScoreNetwork, ): """Init method.""" - super().__init__( - noise_parameters, sampling_parameters, sigma_normalized_score_network - ) + super().__init__(noise_parameters, sampling_parameters, axl_network) self.constraint_relative_coordinates = torch.from_numpy( sampling_parameters.constrained_relative_coordinates - ) + ) # TODO constraint the atom type as well assert ( len(self.constraint_relative_coordinates.shape) == 2 @@ -70,15 +67,22 @@ def __init__( self.constraint_mask = torch.zeros(self.number_of_atoms, dtype=bool) self.constraint_mask[:number_of_constraints] = True - self.noisy_relative_coordinates_sampler = NoisyRelativeCoordinatesSampler() + self.relative_coordinates_noiser = RelativeCoordinatesNoiser() - def _apply_constraint(self, x: torch.Tensor, device: torch.device) -> None: - """This method applies the coordinate constraint in place on the input configuration.""" + def _apply_constraint(self, composition: AXL, device: torch.device) -> AXL: + """This method applies the coordinate constraint on the input configuration.""" + x = composition.X x[:, self.constraint_mask] = self.constraint_relative_coordinates.to(device) + updated_axl = AXL( + A=composition.A, + X=x, + L=composition.L, + ) + return updated_axl def sample( self, number_of_samples: int, device: torch.device, unit_cell: torch.Tensor - ) -> torch.Tensor: + ) -> AXL: """Sample. This method draws samples, imposing the satisfaction of positional constraints. @@ -90,7 +94,7 @@ def sample( Tensor of dimensions [number_of_samples, spatial_dimension, spatial_dimension] Returns: - samples: relative coordinates samples. + samples: composition samples as AXL namedtuple (atom types, reduced coordinates, lattice vectors) """ assert unit_cell.size() == ( number_of_samples, @@ -103,42 +107,43 @@ def sample( # Initialize a configuration that satisfy the constraint, but is otherwise random. # Since the noising process is 'atom-per-atom', the non-constrained position should have no impact. - x0_known = map_relative_coordinates_to_unit_cell( - self.initialize(number_of_samples) - ).to(device) - self._apply_constraint(x0_known, device) + composition0_known = self.initialize(number_of_samples, device) + # this is an AXL objet - x_ip1 = map_relative_coordinates_to_unit_cell( - self.initialize(number_of_samples) - ).to(device) - forces = torch.zeros_like(x_ip1) + composition0_known = self._apply_constraint(composition0_known, device) - broadcasting = torch.ones( + composition_ip1 = self.initialize(number_of_samples, device) + forces = torch.zeros_like(composition_ip1.X) + + coordinates_broadcasting = torch.ones( number_of_samples, self.number_of_atoms, self.spatial_dimension ).to(device) for i in tqdm(range(self.number_of_discretization_steps - 1, -1, -1)): sigma_i = self.noise.sigma[i] - broadcast_sigmas_i = sigma_i * broadcasting + broadcast_sigmas_i = sigma_i * coordinates_broadcasting # Noise an example satisfying the constraints from t_0 to t_i - x_i_known = self.noisy_relative_coordinates_sampler.get_noisy_relative_coordinates_sample( - x0_known, broadcast_sigmas_i + x_i_known = ( + self.relative_coordinates_noiser.get_noisy_relative_coordinates_sample( + composition0_known.X, broadcast_sigmas_i + ) ) # Denoise from t_{i+1} to t_i - x_i = map_relative_coordinates_to_unit_cell( - self.predictor_step(x_ip1, i + 1, unit_cell, forces) + composition_i = self.predictor_step( + composition_ip1, i + 1, unit_cell, forces ) # Combine the known and unknown + x_i = composition_i.X x_i[:, self.constraint_mask] = x_i_known[:, self.constraint_mask] + composition_i = AXL(A=composition_i.A, X=x_i, L=composition_i.L) for _ in range(self.number_of_corrector_steps): - x_i = map_relative_coordinates_to_unit_cell( - self.corrector_step(x_i, i, unit_cell, forces) - ) - x_ip1 = x_i + composition_i = self.corrector_step(composition_i, i, unit_cell, forces) + + composition_ip1 = composition_i # apply the constraint one last time - self._apply_constraint(x_i, device) + composition_i = self._apply_constraint(composition_i, device) - return x_i + return composition_i diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/generators/instantiate_generator.py b/src/diffusion_for_multi_scale_molecular_dynamics/generators/instantiate_generator.py index dfdaf083..af897328 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/generators/instantiate_generator.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/generators/instantiate_generator.py @@ -1,21 +1,21 @@ +from diffusion_for_multi_scale_molecular_dynamics.generators.axl_generator import \ + SamplingParameters from diffusion_for_multi_scale_molecular_dynamics.generators.langevin_generator import \ LangevinGenerator from diffusion_for_multi_scale_molecular_dynamics.generators.ode_position_generator import \ - ExplodingVarianceODEPositionGenerator -from diffusion_for_multi_scale_molecular_dynamics.generators.position_generator import \ - SamplingParameters + ExplodingVarianceODEAXLGenerator from diffusion_for_multi_scale_molecular_dynamics.generators.sde_position_generator import \ ExplodingVarianceSDEPositionGenerator from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.score_network import \ ScoreNetwork -from diffusion_for_multi_scale_molecular_dynamics.samplers.variance_sampler import \ +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ NoiseParameters def instantiate_generator( sampling_parameters: SamplingParameters, noise_parameters: NoiseParameters, - sigma_normalized_score_network: ScoreNetwork, + axl_network: ScoreNetwork, ): """Instantiate generator.""" assert sampling_parameters.algorithm in [ @@ -29,19 +29,19 @@ def instantiate_generator( generator = LangevinGenerator( sampling_parameters=sampling_parameters, noise_parameters=noise_parameters, - sigma_normalized_score_network=sigma_normalized_score_network, + axl_network=axl_network, ) case "ode": - generator = ExplodingVarianceODEPositionGenerator( + generator = ExplodingVarianceODEAXLGenerator( sampling_parameters=sampling_parameters, noise_parameters=noise_parameters, - sigma_normalized_score_network=sigma_normalized_score_network, + axl_network=axl_network, ) case "sde": generator = ExplodingVarianceSDEPositionGenerator( sampling_parameters=sampling_parameters, noise_parameters=noise_parameters, - sigma_normalized_score_network=sigma_normalized_score_network, + axl_network=axl_network, ) case _: raise NotImplementedError( diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/generators/langevin_generator.py b/src/diffusion_for_multi_scale_molecular_dynamics/generators/langevin_generator.py index e68dd2e2..2d68dc77 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/generators/langevin_generator.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/generators/langevin_generator.py @@ -1,18 +1,28 @@ +import dataclasses +from typing import Tuple + +import einops import torch -from diffusion_for_multi_scale_molecular_dynamics.generators.predictor_corrector_position_generator import ( - PredictorCorrectorPositionGenerator, PredictorCorrectorSamplingParameters) +from diffusion_for_multi_scale_molecular_dynamics.generators.predictor_corrector_axl_generator import ( + PredictorCorrectorAXLGenerator, PredictorCorrectorSamplingParameters) from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.score_network import \ ScoreNetwork from diffusion_for_multi_scale_molecular_dynamics.namespace import ( - CARTESIAN_FORCES, NOISE, NOISY_RELATIVE_COORDINATES, TIME, UNIT_CELL) -from diffusion_for_multi_scale_molecular_dynamics.samplers.variance_sampler import ( - ExplodingVarianceSampler, NoiseParameters) -from diffusion_for_multi_scale_molecular_dynamics.utils.sample_trajectory import ( - NoOpPredictorCorrectorSampleTrajectory, PredictorCorrectorSampleTrajectory) + AXL, CARTESIAN_FORCES, NOISE, NOISY_AXL_COMPOSITION, TIME, UNIT_CELL) +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ + NoiseParameters +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_scheduler import \ + NoiseScheduler +from diffusion_for_multi_scale_molecular_dynamics.utils.basis_transformations import \ + map_relative_coordinates_to_unit_cell +from diffusion_for_multi_scale_molecular_dynamics.utils.d3pm_utils import ( + class_index_to_onehot, get_probability_at_previous_time_step) +from diffusion_for_multi_scale_molecular_dynamics.utils.sample_trajectory import \ + SampleTrajectory -class LangevinGenerator(PredictorCorrectorPositionGenerator): +class LangevinGenerator(PredictorCorrectorAXLGenerator): """Annealed Langevin Dynamics Generator. This class implements the annealed Langevin Dynamics generation of position samples, following @@ -24,154 +34,508 @@ def __init__( self, noise_parameters: NoiseParameters, sampling_parameters: PredictorCorrectorSamplingParameters, - sigma_normalized_score_network: ScoreNetwork, + axl_network: ScoreNetwork, ): """Init method.""" super().__init__( number_of_discretization_steps=noise_parameters.total_time_steps, number_of_corrector_steps=sampling_parameters.number_of_corrector_steps, spatial_dimension=sampling_parameters.spatial_dimension, + num_atom_types=sampling_parameters.num_atom_types, ) - self.noise_parameters = noise_parameters - sampler = ExplodingVarianceSampler(noise_parameters) + sampler = NoiseScheduler(noise_parameters, num_classes=self.num_classes) self.noise, self.langevin_dynamics = sampler.get_all_sampling_parameters() self.number_of_atoms = sampling_parameters.number_of_atoms - self.sigma_normalized_score_network = sigma_normalized_score_network + self.masked_atom_type_index = self.num_classes - 1 + self.axl_network = axl_network + self.small_epsilon = sampling_parameters.small_epsilon - if sampling_parameters.record_samples: - self.sample_trajectory_recorder = PredictorCorrectorSampleTrajectory() - else: - self.sample_trajectory_recorder = NoOpPredictorCorrectorSampleTrajectory() + self.one_atom_type_transition_per_step = ( + sampling_parameters.one_atom_type_transition_per_step + ) + self.atom_type_greedy_sampling = sampling_parameters.atom_type_greedy_sampling + self.atom_type_transition_in_corrector = ( + sampling_parameters.atom_type_transition_in_corrector + ) + + self.record = sampling_parameters.record_samples + self.record_corrector = sampling_parameters.record_samples_corrector_steps + self.record_atom_type_update = sampling_parameters.record_atom_type_update + + if self.record_corrector or self.record_atom_type_update: + assert ( + self.record + ), "Corrector steps or atom_type_update can only be recorded if record_samples is True." - def initialize(self, number_of_samples: int): + if self.record: + self.sample_trajectory_recorder = SampleTrajectory() + self.sample_trajectory_recorder.record(key="noise", entry=self.noise) + self.sample_trajectory_recorder.record( + key="noise_parameters", entry=dataclasses.asdict(noise_parameters) + ) + self.sample_trajectory_recorder.record( + key="sampling_parameters", entry=dataclasses.asdict(sampling_parameters) + ) + + def initialize( + self, number_of_samples: int, device: torch.device = torch.device("cpu") + ): """This method must initialize the samples from the fully noised distribution.""" + # all atoms are initialized as masked + atom_types = ( + torch.ones(number_of_samples, self.number_of_atoms).long().to(device) + * self.masked_atom_type_index + ) + # relative coordinates are sampled from the uniform distribution relative_coordinates = torch.rand( number_of_samples, self.number_of_atoms, self.spatial_dimension - ) - return relative_coordinates + ).to(device) + lattice_vectors = torch.zeros_like(relative_coordinates).to( + device + ) # TODO placeholder + init_composition = AXL(A=atom_types, X=relative_coordinates, L=lattice_vectors) + return init_composition def _draw_gaussian_sample(self, number_of_samples): return torch.randn( number_of_samples, self.number_of_atoms, self.spatial_dimension ) - def _get_sigma_normalized_scores( + def _draw_gumbel_sample(self, number_of_samples): + return -torch.log( + -torch.log( + torch.rand( + number_of_samples, self.number_of_atoms, self.num_classes + ).clip(min=self.small_epsilon) + ) + ) + + def _draw_binary_sample(self, number_of_samples): + # this is used to determine if a MASK sample should be demasked or not in greedy sampling + return torch.rand(number_of_samples, self.number_of_atoms) + + def _get_model_predictions( self, - x: torch.Tensor, + composition: AXL, time: float, - noise: float, - unit_cell: torch.Tensor, + sigma_noise: float, + unit_cell: torch.Tensor, # TODO replace with AXL-L cartesian_forces: torch.Tensor, - ) -> torch.Tensor: - """Get sigma normalized scores. + ) -> AXL: + """Get the outputs of an axl-network. Args: - x : relative coordinates, of shape [number_of_samples, number_of_atoms, spatial_dimension] + composition : AXL composition with: + atom types, of shape [number of samples, number_of_atoms] + relative coordinates, of shape [number_of_samples, number_of_atoms, spatial_dimension] + lattice vectors, of shape [number_of_samples, spatial_dimension * (spatial_dimension - 1)] # TODO check time : time at which to evaluate the score - noise: the diffusion sigma parameter corresponding to the time at which to evaluate the score + sigma_noise: the diffusion sigma parameter corresponding to the time at which to evaluate the score unit_cell: unit cell definition in Angstrom of shape [number_of_samples, spatial_dimension, spatial_dimension] cartesian_forces: forces to condition the sampling from. Shape [number_of_samples, number_of_atoms, spatial_dimension] Returns: - sigma normalized score: sigma x Score(x, t). + axl network output: + atom type: logits of p(a_0 | a_t). + relative coordinates: sigma normalized score: sigma x Score(x, t). + lattice: TODO. """ - number_of_samples = x.shape[0] + number_of_samples = composition.X.shape[0] - time_tensor = time * torch.ones(number_of_samples, 1).to(x) - noise_tensor = noise * torch.ones(number_of_samples, 1).to(x) + time_tensor = time * torch.ones(number_of_samples, 1).to(composition.X) + sigma_noise_tensor = sigma_noise * torch.ones(number_of_samples, 1).to( + composition.X + ) augmented_batch = { - NOISY_RELATIVE_COORDINATES: x, + NOISY_AXL_COMPOSITION: composition, TIME: time_tensor, - NOISE: noise_tensor, - UNIT_CELL: unit_cell, + NOISE: sigma_noise_tensor, + UNIT_CELL: unit_cell, # TODO replace with AXL-L CARTESIAN_FORCES: cartesian_forces, } # TODO do not hard-code conditional to False - need to be able to condition sampling - predicted_normalized_scores = self.sigma_normalized_score_network( - augmented_batch, conditional=False + model_predictions = self.axl_network(augmented_batch, conditional=False) + return model_predictions + + def _relative_coordinates_update( + self, + relative_coordinates: torch.Tensor, + sigma_normalized_scores: torch.Tensor, + sigma_i: torch.Tensor, + score_weight: torch.Tensor, + gaussian_noise_weight: torch.Tensor, + ) -> torch.Tensor: + r"""Generic update for the relative coordinates. + + This is useful for both the predictor and the corrector step. The score weight and gaussian weight noise differs + in these two settings. + + Args: + relative_coordinates: starting coordinates. Dimension: [number_of_samples, number_of_atoms, + spatial_dimension] + + sigma_normalized_scores: output of the model - an estimate of the normalized + score :math:`\sigma \nabla log p(x)`. + Dimension: [number_of_samples, number_of_atoms, spatial_dimension] + sigma_i: noise parameter for variance exploding noise scheduler. Dimension: [number_of_samples] + score_weight: prefactor in front of the normalized score update. Should be g2_i in the predictor step and + eps_i in the corrector step. Dimension: [number_of_samples] + gaussian_noise_weight: prefactor in front of the random noise update. Should be g_i in the predictor step + and sqrt_2eps_i in the corrector step. Dimension: [number_of_samples] + + Returns: + updated_coordinates: relative coordinates after the update. Dimension: [number_of_samples, number_of_atoms, + spatial_dimension]. + """ + number_of_samples = relative_coordinates.shape[0] + z = self._draw_gaussian_sample(number_of_samples).to(relative_coordinates) + updated_coordinates = ( + relative_coordinates + + score_weight * sigma_normalized_scores / sigma_i + + gaussian_noise_weight * z + ) + # map back to the range [0, 1) + updated_coordinates = map_relative_coordinates_to_unit_cell(updated_coordinates) + return updated_coordinates + + def _atom_types_update( + self, + predicted_logits: torch.Tensor, + atom_types_i: torch.LongTensor, + q_matrices_i: torch.Tensor, + q_bar_matrices_i: torch.Tensor, + q_bar_tm1_matrices_i: torch.Tensor, + atom_type_greedy_sampling: bool, + one_atom_type_transition_per_step: bool, + ) -> torch.LongTensor: + """Generic update of the atom types. + + This should be used in the predictor step only. + + Args: + predicted_logits: output of the model - an estimate of p(a_0 | a_t). Dimension: + [number_of_samples, number_of_atoms, num_classes]. + atom_types_i: indices of the atom types at timestep i. Dimension: + [number_of_samples, number_of_atoms] + q_matrices_i: one-step transition matrix. Dimension: [number_of_samples, number_of_atoms, num_classes, + num_classes]. + q_bar_matrices_i: cumulative transition matrix at time step i. Dimension: [number_of_samples, + number_of_atoms, num_classes, num_classes]. + q_bar_tm1_matrices_i: cumulative transition matrix at time step 'i - 1'. Dimension: [number_of_samples, + number_of_atoms, num_classes, num_classes]. + atom_type_greedy_sampling: boolean flag that sets whether the atom types should be selected greedily. + one_atom_type_transition_per_step: boolean flag that sets whether a single atom type transition can + occur per time step. + + Returns: + atom_types_im1: updated atom type indices. Dimension: [number_of_samples, number_of_atoms] + """ + number_of_samples = predicted_logits.shape[0] + gumbel_random_variable = self._draw_gumbel_sample(number_of_samples).to( + predicted_logits.device + ) + one_hot_atom_types_i = class_index_to_onehot( + atom_types_i, num_classes=self.num_classes + ) + one_step_transition_probs = get_probability_at_previous_time_step( + probability_at_zeroth_timestep=predicted_logits, + one_hot_probability_at_current_timestep=one_hot_atom_types_i, + q_matrices=q_matrices_i, + q_bar_matrices=q_bar_matrices_i, + q_bar_tm1_matrices=q_bar_tm1_matrices_i, + small_epsilon=self.small_epsilon, + probability_at_zeroth_timestep_are_logits=True, + ) # p(a_{t-1} | a_t) as a [num_samples, num_atoms, num_classes] tensor + + if atom_type_greedy_sampling: + # if we use greedy sampling, we will update the transition probabilities for the MASK token. + # For a_i = MASK, we define "greedy sampling" as first determining if a_{i-1} should also be MASK based on + # p(a_{i-1} = MASK | a_i = MASK). If a_{i-1} should be unmasked, its atom type is selected as the one with + # the highest probability (i.e., no stochastic sampling). Stochasticity is removed by setting the relevant + # row of gumbel_random_variable to zero. + one_step_transition_probs, gumbel_random_variable = ( + self._adjust_atom_types_probabilities_for_greedy_sampling( + one_step_transition_probs, atom_types_i, gumbel_random_variable + ) + ) + + # Use the Gumbel-softmax trick to sample atomic types. + # We also keep the associated values in memory, so we can compare which transitions are the most likely. + # Dimensions: [num_samples, num_atoms]. + max_gumbel_values, sampled_atom_types = torch.max( + torch.log(one_step_transition_probs + self.small_epsilon) + + gumbel_random_variable, + dim=-1, ) - return predicted_normalized_scores + + if one_atom_type_transition_per_step: + # force a single transition for each sample + atom_types_im1 = self._get_updated_atom_types_for_one_transition_per_step( + atom_types_i, max_gumbel_values, sampled_atom_types + ) + else: + atom_types_im1 = sampled_atom_types + + if self.record_atom_type_update: + # Keep the record on the CPU + entry = dict( + predicted_logits=predicted_logits.detach().cpu(), + one_step_transition_probabilities=one_step_transition_probs.detach().cpu(), + gumbel_sample=gumbel_random_variable.cpu(), + a_i=atom_types_i.cpu(), + a_im1=atom_types_im1.cpu(), + ) + + self.sample_trajectory_recorder.record(key="atom_type_update", entry=entry) + + return atom_types_im1 + + def _get_updated_atom_types_for_one_transition_per_step( + self, + current_atom_types: torch.Tensor, + max_gumbel_values: torch.Tensor, + sampled_atom_types: torch.Tensor, + ): + """Get updated atom types for one transition per step. + + Assuming the Gumbel softmax trick was used to create a new sample of atom types, this method + restrict the transitions from the current atom types to only the most likely one per sample. + + Args: + current_atom_types: current indices of the atom types. Dimension: [number_of_samples, number_of_atoms] + max_gumbel_values: maximum Gumbel softmax values. Dimension: [number_of_samples, number_of_atoms] + sampled_atom_types: indices of the atom types resulting from the gumbel softmax sampling. + Dimension: [number_of_samples, number_of_atoms] + + Returns: + updated_atom_types: atom types resulting from only making one transition per sample on current_atom_types. + Dimension: [number_of_samples, number_of_atoms] + """ + number_of_samples = current_atom_types.shape[0] + sample_indices = torch.arange(number_of_samples) + + # Boolean mask of dimensions [number_of_samples, number_of_atoms] + atoms_have_changed_types = sampled_atom_types != current_atom_types + + # Identify the most likely transition amongst the proposed changes. + max_gumbel_values_restricted_to_proposed_changes = torch.where( + atoms_have_changed_types, max_gumbel_values, -torch.inf + ) + most_likely_transition_atom_indices = torch.argmax( + max_gumbel_values_restricted_to_proposed_changes, dim=-1 + ) + + # Restrict transitions to only the most likely ones. + updated_atom_types = current_atom_types.clone() + updated_atom_types[sample_indices, most_likely_transition_atom_indices] = ( + sampled_atom_types[sample_indices, most_likely_transition_atom_indices] + ) + + return updated_atom_types + + def _adjust_atom_types_probabilities_for_greedy_sampling( + self, + one_step_transition_probs: torch.Tensor, + atom_types_i: torch.LongTensor, + gumbel_random_variable: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Update the transition probabilities and the gumbel random variables to allow greedy sampling. + + At time step i, for every atom in a sample, we sample a random number. If it is larger than the probability of + that atom being in the MASK class, then we will sample greedily a new atom type (i.e. the most likely). To do + that, we simply replace the probability of the MASK class to zero and the gumbel noise u to zero. For non-MASK + atoms, we do nothing. For samples with only MASK atoms, we also do nothing. + + Args: + one_step_transition_probs: class distributions at time t-1 given distribution at time t. p(a_{t-1} | a_t) + atom_types_i: indices of atom types at time i. Dimension: [number_of_samples, number_of_atoms] + gumbel_random_variable: gumbel noise used for sampling. + Dimension: [number_of_samples, number_of_atoms, num_classes] + + Returns: + one_step_transition_probs: probabilities are updated so a MASK to non-MASK transition can happen + u: set to a constant for samples with at least 1 non-MASK atom + """ + # check which samples have at least 1 non-MASK atom + all_masked = torch.all( + atom_types_i == self.masked_atom_type_index, dim=-1 + ) # dim: number_of_samples, + + # we can only do greedy sampling for atoms that are masked + atom_is_masked = atom_types_i == self.masked_atom_type_index + + # we will first erase the probability of staying MASK for some atoms randomly by drawing from a binary + # distribution given by one_step_transition_probs[:, :, -1] i.e. the probabilities related to the MASK class. + # sample to override the MASK probability as the most likely + binary_sample = self._draw_binary_sample(atom_types_i.shape[0]).to( + device=atom_types_i.device + ) + unmask_this_atom = binary_sample > one_step_transition_probs[:, :, -1] + # if we override the MASK probability & there's already a non-MASK sample & that atom is masked, + # use a greedy sampling for that atom + do_greedy_sampling = torch.logical_and( + ~all_masked.view(-1, 1), + unmask_this_atom, + ) + do_greedy_sampling = torch.logical_and(do_greedy_sampling, atom_is_masked) + # replace the probability of getting a mask for those by 0 - so that state cannot be sampled + one_step_transition_probs[:, :, -1] = torch.where( + do_greedy_sampling, 0, one_step_transition_probs[:, :, -1] + ) + + # replace u with a constant for samples with a non-MASK token present - this ensures a greedy sampling + # In the current choice of \beta_t = 1 / (T-t+1), a greedy sampling will always select the MASK type if that + # probability is not set to zero - except at the last generation step. This might not hold if the \beta schedule + # is modified. + gumbel_random_variable = torch.where( + all_masked.view(-1, 1, 1), gumbel_random_variable, 0.0 + ) + return one_step_transition_probs, gumbel_random_variable def predictor_step( self, - x_i: torch.Tensor, + composition_i: AXL, index_i: int, - unit_cell: torch.Tensor, + unit_cell: torch.Tensor, # TODO replace with AXL-L cartesian_forces: torch.Tensor, - ) -> torch.Tensor: + ) -> AXL: """Predictor step. Args: - x_i : sampled relative coordinates, at time step i. + composition_i : sampled composition (atom types, relative coordinates, lattice vectors), at time step i. index_i : index of the time step. unit_cell: sampled unit cell at time step i. cartesian_forces: forces conditioning the sampling process Returns: - x_im1 : sampled relative coordinates, at time step i - 1. + composition_im1 : sampled composition, at time step i - 1. """ assert ( 1 <= index_i <= self.number_of_discretization_steps ), "The predictor step can only be invoked for index_i between 1 and the total number of discretization steps." - number_of_samples = x_i.shape[0] - z = self._draw_gaussian_sample(number_of_samples).to(x_i) + number_of_samples = composition_i.X.shape[0] + number_of_atoms = composition_i.X.shape[1] idx = index_i - 1 # python starts indices at zero - t_i = self.noise.time[idx].to(x_i) - g_i = self.noise.g[idx].to(x_i) - g2_i = self.noise.g_squared[idx].to(x_i) - sigma_i = self.noise.sigma[idx].to(x_i) - sigma_score_i = self._get_sigma_normalized_scores( - x_i, t_i, sigma_i, unit_cell, cartesian_forces + t_i = self.noise.time[idx].to(composition_i.X) + g_i = self.noise.g[idx].to(composition_i.X) + g2_i = self.noise.g_squared[idx].to(composition_i.X) + sigma_i = self.noise.sigma[idx].to(composition_i.X) + + # Broadcast the q matrices to the expected dimensions. + q_matrices_i = einops.repeat( + self.noise.q_matrix[idx].to(composition_i.X), + "n1 n2 -> nsamples natoms n1 n2", + nsamples=number_of_samples, + natoms=number_of_atoms, ) - x_im1 = x_i + g2_i / sigma_i * sigma_score_i + g_i * z - self.sample_trajectory_recorder.record_unit_cell(unit_cell=unit_cell) - self.sample_trajectory_recorder.record_predictor_step( - i_index=index_i, - time=t_i, - sigma=sigma_i, - x_i=x_i, - x_im1=x_im1, - scores=sigma_score_i, + q_bar_matrices_i = einops.repeat( + self.noise.q_bar_matrix[idx].to(composition_i.X), + "n1 n2 -> nsamples natoms n1 n2", + nsamples=number_of_samples, + natoms=number_of_atoms, ) - return x_im1 + q_bar_tm1_matrices_i = einops.repeat( + self.noise.q_bar_tm1_matrix[idx].to(composition_i.X), + "n1 n2 -> nsamples natoms n1 n2", + nsamples=number_of_samples, + natoms=number_of_atoms, + ) + + model_predictions_i = self._get_model_predictions( + composition_i, t_i, sigma_i, unit_cell, cartesian_forces + ) + + # Even if the global flag 'one_atom_type_transition_per_step' is set to True, a single atomic transition + # cannot be used at the last time step because it is necessary for all atoms to be unmasked at the end + # of the trajectory. Here, we use 'first' and 'last' with respect to a denoising trajectory, where + # the "first" time step is at index_i = T and the "last" time step is index_i = 1. + this_is_last_time_step = idx == 0 + one_atom_type_transition_per_step = ( + self.one_atom_type_transition_per_step and not this_is_last_time_step + ) + + a_im1 = self._atom_types_update( + model_predictions_i.A, + composition_i.A, + q_matrices_i, + q_bar_matrices_i, + q_bar_tm1_matrices_i, + atom_type_greedy_sampling=self.atom_type_greedy_sampling, + one_atom_type_transition_per_step=one_atom_type_transition_per_step, + ) + + if this_is_last_time_step: + assert (a_im1 != self.masked_atom_type_index).all(), \ + "There remains MASKED atoms at the last time step: review code, there must be a bug or invalid input." + + x_im1 = self._relative_coordinates_update( + composition_i.X, model_predictions_i.X, sigma_i, g2_i, g_i + ) + + composition_im1 = AXL( + A=a_im1, X=x_im1, L=unit_cell + ) # TODO : Deal with L correctly + + if self.record: + # TODO : Deal with L correctly + composition_i_for_recording = AXL( + A=composition_i.A, X=composition_i.X, L=unit_cell + ) + # Keep the record on the CPU + entry = dict(time_step_index=index_i) + list_keys = ["composition_i", "composition_im1", "model_predictions_i"] + list_axl = [ + composition_i_for_recording, + composition_im1, + model_predictions_i, + ] + + for key, axl in zip(list_keys, list_axl): + record_axl = AXL( + A=axl.A.detach().cpu(), + X=axl.X.detach().cpu(), + L=axl.L.detach().cpu(), + ) + entry[key] = record_axl + self.sample_trajectory_recorder.record(key="predictor_step", entry=entry) + + return composition_im1 def corrector_step( self, - x_i: torch.Tensor, + composition_i: AXL, index_i: int, - unit_cell: torch.Tensor, + unit_cell: torch.Tensor, # TODO replace with AXL-L cartesian_forces: torch.Tensor, - ) -> torch.Tensor: + ) -> AXL: """Corrector Step. + Note this is not affecting the atom types. Only the reduced coordinates and lattice vectors. + Args: - x_i : sampled relative coordinates, at time step i. + composition_i : sampled composition (atom types, relative coordinates, lattice vectors), at time step i. index_i : index of the time step. - unit_cell: sampled unit cell at time step i. + unit_cell: sampled unit cell at time step i. # TODO replace with AXL-L cartesian_forces: forces conditioning the sampling Returns: - corrected x_i : sampled relative coordinates, after corrector step. + corrected_composition_i : sampled composition, after corrector step. """ assert 0 <= index_i <= self.number_of_discretization_steps - 1, ( "The corrector step can only be invoked for index_i between 0 and " "the total number of discretization steps minus 1." ) - - number_of_samples = x_i.shape[0] - z = self._draw_gaussian_sample(number_of_samples).to(x_i) - # The Langevin dynamics array are indexed with [0,..., N-1] - eps_i = self.langevin_dynamics.epsilon[index_i].to(x_i) - sqrt_2eps_i = self.langevin_dynamics.sqrt_2_epsilon[index_i].to(x_i) + eps_i = self.langevin_dynamics.epsilon[index_i].to(composition_i.X) + sqrt_2eps_i = self.langevin_dynamics.sqrt_2_epsilon[index_i].to(composition_i.X) if index_i == 0: # TODO: we are extrapolating here; the score network will never have seen this time step... @@ -179,24 +543,69 @@ def corrector_step( self.noise_parameters.sigma_min ) # no need to change device, this is a float t_i = 0.0 # same for device - this is a float + idx = index_i else: idx = index_i - 1 # python starts indices at zero - sigma_i = self.noise.sigma[idx].to(x_i) - t_i = self.noise.time[idx].to(x_i) + sigma_i = self.noise.sigma[idx].to(composition_i.X) + t_i = self.noise.time[idx].to(composition_i.X) - sigma_score_i = self._get_sigma_normalized_scores( - x_i, t_i, sigma_i, unit_cell, cartesian_forces + model_predictions_i = self._get_model_predictions( + composition_i, t_i, sigma_i, unit_cell, cartesian_forces ) - corrected_x_i = x_i + eps_i / sigma_i * sigma_score_i + sqrt_2eps_i * z + corrected_x_i = self._relative_coordinates_update( + composition_i.X, model_predictions_i.X, sigma_i, eps_i, sqrt_2eps_i + ) + + if self.atom_type_transition_in_corrector: + q_matrices_i = self.noise.q_matrix[idx].to(composition_i.X) + q_bar_matrices_i = self.noise.q_bar_matrix[idx].to(composition_i.X) + q_bar_tm1_matrices_i = self.noise.q_bar_tm1_matrix[idx].to(composition_i.X) + # atom types update + corrected_a_i = self._atom_types_update( + model_predictions_i.A, + composition_i.A, + q_matrices_i, + q_bar_matrices_i, + q_bar_tm1_matrices_i, + atom_type_greedy_sampling=self.atom_type_greedy_sampling, + one_atom_type_transition_per_step=self.one_atom_type_transition_per_step, + ) + else: + corrected_a_i = composition_i.A - self.sample_trajectory_recorder.record_corrector_step( - i_index=index_i, - time=t_i, - sigma=sigma_i, - x_i=x_i, - corrected_x_i=corrected_x_i, - scores=sigma_score_i, + corrected_composition_i = AXL( + A=corrected_a_i, + X=corrected_x_i, + L=unit_cell, # TODO replace with AXL-L ) - return corrected_x_i + if self.record_corrector: + # TODO : Deal with L correctly + composition_i_for_recording = AXL( + A=composition_i.A, X=composition_i.X, L=unit_cell + ) + # Keep the record on the CPU + entry = dict(time_step_index=index_i) + list_keys = [ + "composition_i", + "corrected_composition_i", + "model_predictions_i", + ] + list_axl = [ + composition_i_for_recording, + corrected_composition_i, + model_predictions_i, + ] + + for key, axl in zip(list_keys, list_axl): + record_axl = AXL( + A=axl.A.detach().cpu(), + X=axl.X.detach().cpu(), + L=axl.L.detach().cpu(), + ) + entry[key] = record_axl + + self.sample_trajectory_recorder.record(key="corrector_step", entry=entry) + + return corrected_composition_i diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/generators/load_sampling_parameters.py b/src/diffusion_for_multi_scale_molecular_dynamics/generators/load_sampling_parameters.py index 57d841e8..99f1ccfe 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/generators/load_sampling_parameters.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/generators/load_sampling_parameters.py @@ -1,10 +1,10 @@ from typing import Any, AnyStr, Dict +from diffusion_for_multi_scale_molecular_dynamics.generators.axl_generator import \ + SamplingParameters from diffusion_for_multi_scale_molecular_dynamics.generators.ode_position_generator import \ ODESamplingParameters -from diffusion_for_multi_scale_molecular_dynamics.generators.position_generator import \ - SamplingParameters -from diffusion_for_multi_scale_molecular_dynamics.generators.predictor_corrector_position_generator import \ +from diffusion_for_multi_scale_molecular_dynamics.generators.predictor_corrector_axl_generator import \ PredictorCorrectorSamplingParameters from diffusion_for_multi_scale_molecular_dynamics.generators.sde_position_generator import \ SDESamplingParameters diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/generators/ode_position_generator.py b/src/diffusion_for_multi_scale_molecular_dynamics/generators/ode_position_generator.py index d8f1bc2c..8de850ba 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/generators/ode_position_generator.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/generators/ode_position_generator.py @@ -1,3 +1,4 @@ +import dataclasses import logging from dataclasses import dataclass from typing import Callable @@ -7,18 +8,20 @@ import torchode as to from torchode import Solution -from diffusion_for_multi_scale_molecular_dynamics.generators.position_generator import ( - PositionGenerator, SamplingParameters) +from diffusion_for_multi_scale_molecular_dynamics.generators.axl_generator import ( + AXLGenerator, SamplingParameters) from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.score_network import \ ScoreNetwork from diffusion_for_multi_scale_molecular_dynamics.namespace import ( - CARTESIAN_FORCES, NOISE, NOISY_RELATIVE_COORDINATES, TIME, UNIT_CELL) -from diffusion_for_multi_scale_molecular_dynamics.samplers.variance_sampler import \ + AXL, CARTESIAN_FORCES, NOISE, NOISY_AXL_COMPOSITION, TIME, UNIT_CELL) +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.exploding_variance import \ + VarianceScheduler +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ NoiseParameters -from diffusion_for_multi_scale_molecular_dynamics.utils.basis_transformations import \ - map_relative_coordinates_to_unit_cell -from diffusion_for_multi_scale_molecular_dynamics.utils.sample_trajectory import ( - NoOpODESampleTrajectory, ODESampleTrajectory) +from diffusion_for_multi_scale_molecular_dynamics.utils.basis_transformations import ( + map_axl_composition_to_unit_cell, map_relative_coordinates_to_unit_cell) +from diffusion_for_multi_scale_molecular_dynamics.utils.sample_trajectory import \ + SampleTrajectory logger = logging.getLogger(__name__) @@ -36,7 +39,7 @@ class ODESamplingParameters(SamplingParameters): ) -class ExplodingVarianceODEPositionGenerator(PositionGenerator): +class ExplodingVarianceODEAXLGenerator(AXLGenerator): """Exploding Variance ODE Position Generator. This class generates position samples by solving an ordinary differential equation (ODE). @@ -47,58 +50,49 @@ def __init__( self, noise_parameters: NoiseParameters, sampling_parameters: ODESamplingParameters, - sigma_normalized_score_network: ScoreNetwork, + axl_network: ScoreNetwork, ): """Init method. Args: noise_parameters : the diffusion noise parameters. sampling_parameters: the parameters needed for sampling. - sigma_normalized_score_network : the score network to use for drawing samples. + axl_network : the model to use for drawing samples that predicts an AXL: + atom types: predicts p(a_0 | a_t) + relative coordinates: predicts the sigma normalized score + lattice: placeholder # TODO """ self.t0 = 0.0 # The "initial diffusion time", corresponding to the physical distribution. self.tf = 1.0 # The "final diffusion time", corresponding to the uniform distribution. self.noise_parameters = noise_parameters - self.sigma_normalized_score_network = sigma_normalized_score_network + self.exploding_variance = VarianceScheduler(noise_parameters) + + self.axl_network = axl_network assert ( self.noise_parameters.total_time_steps >= 2 ), "There must at least be two time steps in the noise parameters to define the limits t0 and tf." self.number_of_atoms = sampling_parameters.number_of_atoms self.spatial_dimension = sampling_parameters.spatial_dimension + self.num_classes = ( + sampling_parameters.num_atom_types + 1 + ) # add 1 for the MASK class self.absolute_solver_tolerance = sampling_parameters.absolute_solver_tolerance self.relative_solver_tolerance = sampling_parameters.relative_solver_tolerance - self.record_samples = sampling_parameters.record_samples - - if self.record_samples: - self.sample_trajectory_recorder = ODESampleTrajectory() - else: - self.sample_trajectory_recorder = NoOpODESampleTrajectory() - - def _get_exploding_variance_sigma(self, times): - """Get Exploding Variance Sigma. - - In the 'exploding variance' scheme, the noise is defined by - - sigma(t) = sigma_min^{1- t} x sigma_max^{t} + self.record = sampling_parameters.record_samples - Args: - times : diffusion time - - Returns: - sigmas: value of the noise parameter. - """ - sigmas = ( - self.noise_parameters.sigma_min ** (1.0 - times) - * self.noise_parameters.sigma_max**times - ) - return sigmas + if self.record: + self.sample_trajectory_recorder = SampleTrajectory() + self.sample_trajectory_recorder.record(key="noise_parameters", + entry=dataclasses.asdict(noise_parameters)) + self.sample_trajectory_recorder.record(key="sampling_parameters", + entry=dataclasses.asdict(sampling_parameters)) - def _get_ode_prefactor(self, sigmas): + def _get_ode_prefactor(self, times): """Get ODE prefactor. - The ODE is given by + The ODE for the relative coordinates is given by dx = [-1/2 g(t)^2 x Score] dt with g(t)^2 = d sigma(t)^2 / dt @@ -114,24 +108,21 @@ def _get_ode_prefactor(self, sigmas): Prefactor = d sigma(t) / dt Args: - sigmas : the values of the noise parameters. + times: the values of the time. Returns: ode prefactor: the prefactor in the ODE. """ - log_ratio = torch.log( - torch.tensor( - self.noise_parameters.sigma_max / self.noise_parameters.sigma_min - ) - ) - ode_prefactor = log_ratio * sigmas - return ode_prefactor + return self.exploding_variance.get_sigma_time_derivative(times) - def generate_ode_term(self, unit_cell: torch.Tensor) -> Callable: + def generate_ode_term( + self, unit_cell: torch.Tensor, atom_types: torch.LongTensor + ) -> Callable: """Generate the ode_term needed to compute the ODE solution.""" def ode_term( - times: torch.Tensor, flat_relative_coordinates: torch.Tensor + times: torch.Tensor, + flat_relative_coordinates: torch.Tensor, ) -> torch.Tensor: """ODE term. @@ -141,13 +132,14 @@ def ode_term( Args: times : ODE times, dimension [batch_size] - flat_relative_coordinates : features for every time step, dimension [batch_size, number of features]. + flat_relative_coordinates : relative coordinates features for every time step, dimension + [batch_size, number of features]. Returns: rhs: the right-hand-side of the corresponding ODE. """ - sigmas = self._get_exploding_variance_sigma(times) - ode_prefactor = self._get_ode_prefactor(sigmas) + sigmas = self.exploding_variance.get_sigma(times) + ode_prefactor = self._get_ode_prefactor(times) relative_coordinates = einops.rearrange( flat_relative_coordinates, @@ -157,19 +149,21 @@ def ode_term( ) batch = { - NOISY_RELATIVE_COORDINATES: map_relative_coordinates_to_unit_cell( - relative_coordinates + NOISY_AXL_COMPOSITION: AXL( + A=atom_types, + X=map_relative_coordinates_to_unit_cell(relative_coordinates), + L=unit_cell, # TODO ), NOISE: sigmas.unsqueeze(-1), TIME: times.unsqueeze(-1), - UNIT_CELL: unit_cell, + UNIT_CELL: unit_cell, # TODO replace with AXL-L CARTESIAN_FORCES: torch.zeros_like( relative_coordinates ), # TODO: handle forces correctly. } # Shape [batch_size, number of atoms, spatial dimension] - sigma_normalized_scores = self.sigma_normalized_score_network(batch) + sigma_normalized_scores = self.axl_network(batch).X flat_sigma_normalized_scores = einops.rearrange( sigma_normalized_scores, "batch natom space -> batch (natom space)" ) @@ -180,28 +174,28 @@ def ode_term( def sample( self, number_of_samples: int, device: torch.device, unit_cell: torch.Tensor - ) -> torch.Tensor: + ) -> AXL: """Sample. - This method draws a position sample. + This method draws an AXL sample. Args: number_of_samples : number of samples to draw. device: device to use (cpu, cuda, etc.). Should match the PL model location. - unit_cell: unit cell definition in Angstrom. + unit_cell: unit cell definition in Angstrom. # TODO replace with AXL-L Tensor of dimensions [number_of_samples, spatial_dimension, spatial_dimension] Returns: - samples: relative coordinates samples. + samples: samples as AXL composition """ - ode_term = self.generate_ode_term(unit_cell) + initial_composition = map_axl_composition_to_unit_cell( + self.initialize(number_of_samples, device), device + ) - initial_relative_coordinates = map_relative_coordinates_to_unit_cell( - self.initialize(number_of_samples) - ).to(device) + ode_term = self.generate_ode_term(unit_cell, atom_types=initial_composition.A) y0 = einops.rearrange( - initial_relative_coordinates, "batch natom space -> batch (natom space)" + initial_composition.X, "batch natom space -> batch (natom space)" ) evaluation_times = torch.linspace( @@ -228,7 +222,7 @@ def sample( sol = jit_solver.solve(to.InitialValueProblem(y0=y0, t_eval=t_eval)) logger.info("ODE solver Finished.") - if self.record_samples: + if self.record: self.record_sample(ode_term, sol, evaluation_times, unit_cell) # sol.ys has dimensions [number of samples, number of times, number of features] @@ -242,7 +236,11 @@ def sample( space=self.spatial_dimension, ) - return map_relative_coordinates_to_unit_cell(relative_coordinates) + updated_composition = AXL( + A=initial_composition.A, X=relative_coordinates, L=initial_composition.L + ) + + return map_axl_composition_to_unit_cell(updated_composition, device) def record_sample( self, @@ -266,15 +264,14 @@ def record_sample( """ number_of_samples = sol.ys.shape[0] - self.sample_trajectory_recorder.record_unit_cell(unit_cell) record_relative_coordinates = einops.rearrange( sol.ys, "batch times (natom space) -> batch times natom space", natom=self.number_of_atoms, space=self.spatial_dimension, ) - sigmas = self._get_exploding_variance_sigma(evaluation_times) - ode_prefactor = self._get_ode_prefactor(sigmas) + sigmas = self.exploding_variance.get_sigma(evaluation_times) + ode_prefactor = self._get_ode_prefactor(evaluation_times) list_flat_normalized_scores = [] for time_idx, (time, gamma) in enumerate(zip(evaluation_times, ode_prefactor)): times = time * torch.ones(number_of_samples).to(sol.ys) @@ -290,18 +287,30 @@ def record_sample( natom=self.number_of_atoms, space=self.spatial_dimension, ) - self.sample_trajectory_recorder.record_ode_solution( - times=evaluation_times, - sigmas=sigmas, - relative_coordinates=record_relative_coordinates, - normalized_scores=record_normalized_scores, - stats=sol.stats, - status=sol.status, - ) - def initialize(self, number_of_samples: int): + entry = dict(times=evaluation_times, + sigmas=sigmas, + relative_coordinates=record_relative_coordinates, + normalized_scores=record_normalized_scores, + unit_cell=unit_cell, + stats=sol.stats, + status=sol.status) + self.sample_trajectory_recorder.record(key='ode', entry=entry) + + def initialize( + self, number_of_samples: int, device: torch.device = torch.device("cpu") + ): """This method must initialize the samples from the fully noised distribution.""" relative_coordinates = torch.rand( number_of_samples, self.number_of_atoms, self.spatial_dimension + ).to(device) + atom_types = ( + torch.zeros(number_of_samples, self.number_of_atoms).long().to(device) ) - return relative_coordinates + lattice_vectors = torch.zeros( + number_of_samples, self.spatial_dimension * (self.spatial_dimension - 1) + ).to( + device + ) # TODO placeholder + init_composition = AXL(A=atom_types, X=relative_coordinates, L=lattice_vectors) + return init_composition diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/generators/predictor_corrector_position_generator.py b/src/diffusion_for_multi_scale_molecular_dynamics/generators/predictor_corrector_axl_generator.py similarity index 56% rename from src/diffusion_for_multi_scale_molecular_dynamics/generators/predictor_corrector_position_generator.py rename to src/diffusion_for_multi_scale_molecular_dynamics/generators/predictor_corrector_axl_generator.py index f8c8a582..2757ac2b 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/generators/predictor_corrector_position_generator.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/generators/predictor_corrector_axl_generator.py @@ -5,10 +5,9 @@ import torch from tqdm import tqdm -from diffusion_for_multi_scale_molecular_dynamics.generators.position_generator import ( - PositionGenerator, SamplingParameters) -from diffusion_for_multi_scale_molecular_dynamics.utils.basis_transformations import \ - map_relative_coordinates_to_unit_cell +from diffusion_for_multi_scale_molecular_dynamics.generators.axl_generator import ( + AXLGenerator, SamplingParameters) +from diffusion_for_multi_scale_molecular_dynamics.namespace import AXL logger = logging.getLogger(__name__) @@ -19,22 +18,28 @@ class PredictorCorrectorSamplingParameters(SamplingParameters): algorithm: str = "predictor_corrector" number_of_corrector_steps: int = 1 + small_epsilon: float = 1e-8 + one_atom_type_transition_per_step: bool = True + atom_type_greedy_sampling: bool = True + atom_type_transition_in_corrector: bool = False -class PredictorCorrectorPositionGenerator(PositionGenerator): - """This defines the interface for predictor-corrector position generators.""" +class PredictorCorrectorAXLGenerator(AXLGenerator): + """Defines the interface for predictor-corrector AXL (atom types, relative coordinates and lattice) generators.""" def __init__( self, number_of_discretization_steps: int, number_of_corrector_steps: int, spatial_dimension: int, + num_atom_types: int, **kwargs, ): """Init method.""" + # T = 1 is a dangerous and meaningless edge case. assert ( - number_of_discretization_steps > 0 - ), "The number of discretization steps should be larger than zero" + number_of_discretization_steps > 1 + ), "The number of discretization steps should be larger than one" assert ( number_of_corrector_steps >= 0 ), "The number of corrector steps should be non-negative" @@ -42,10 +47,11 @@ def __init__( self.number_of_discretization_steps = number_of_discretization_steps self.number_of_corrector_steps = number_of_corrector_steps self.spatial_dimension = spatial_dimension + self.num_classes = num_atom_types + 1 # account for the MASK class def sample( self, number_of_samples: int, device: torch.device, unit_cell: torch.Tensor - ) -> torch.Tensor: + ) -> AXL: """Sample. This method draws a sample using the PC sampler algorithm. @@ -53,11 +59,11 @@ def sample( Args: number_of_samples : number of samples to draw. device: device to use (cpu, cuda, etc.). Should match the PL model location. - unit_cell: unit cell definition in Angstrom. + unit_cell: unit cell definition in Angstrom. # TODO replace with AXL-L Tensor of dimensions [number_of_samples, spatial_dimension, spatial_dimension] Returns: - samples: relative coordinates samples. + samples: AXL samples (atom types, relative coordinates, lattice vectors) """ assert unit_cell.size() == ( number_of_samples, @@ -66,66 +72,64 @@ def sample( ), ( "Unit cell passed to sample should be of size (number of sample, spatial dimension, spatial dimension" + f"Got {unit_cell.size()}" - ) + ) # TODO replace with AXL-L - x_ip1 = map_relative_coordinates_to_unit_cell( - self.initialize(number_of_samples) - ).to(device) - forces = torch.zeros_like(x_ip1) + composition_ip1 = self.initialize(number_of_samples, device) + + forces = torch.zeros_like(composition_ip1.X) for i in tqdm(range(self.number_of_discretization_steps - 1, -1, -1)): - x_i = map_relative_coordinates_to_unit_cell( - self.predictor_step(x_ip1, i + 1, unit_cell, forces) + composition_i = self.predictor_step( + composition_ip1, i + 1, unit_cell, forces ) for _ in range(self.number_of_corrector_steps): - x_i = map_relative_coordinates_to_unit_cell( - self.corrector_step(x_i, i, unit_cell, forces) - ) - x_ip1 = x_i - return x_i + composition_i = self.corrector_step(composition_i, i, unit_cell, forces) + composition_ip1 = composition_i + return composition_i @abstractmethod def predictor_step( self, - x_ip1: torch.Tensor, + composition_ip1: AXL, ip1: int, - unit_cell: torch.Tensor, + unit_cell: torch.Tensor, # TODO replace with AXL-L cartesian_forces: torch.Tensor, - ) -> torch.Tensor: + ) -> AXL: """Predictor step. It is assumed that there are N predictor steps, with index "i" running from N-1 to 0. Args: - x_ip1 : sampled relative coordinates at step "i + 1". + composition_ip1 : sampled AXL composition (atom types, relative coordinates and lattice vectors) at step + "i + 1". ip1 : index "i + 1" - unit_cell: sampled unit cell at time step "i + 1". + unit_cell: sampled unit cell at time step "i + 1". TODO replace with AXL-L cartesian_forces: forces conditioning the diffusion process Returns: - x_i : sampled relative coordinates after the predictor step. + composition_i : sampled AXL composition after the predictor step. """ pass @abstractmethod def corrector_step( self, - x_i: torch.Tensor, + composition_i: AXL, i: int, - unit_cell: torch.Tensor, + unit_cell: torch.Tensor, # TODO replace with AXL-L cartesian_forces: torch.Tensor, - ) -> torch.Tensor: + ) -> AXL: """Corrector step. It is assumed that there are N predictor steps, with index "i" running from N-1 to 0. For each value of "i", there are M corrector steps. Args: - x_i : sampled relative coordinates at step "i". + composition_i : sampled AXL composition (atom types, relative coordinates and lattice vectors) at step "i". i : index "i" OF THE PREDICTOR STEP. - unit_cell: sampled unit cell at time step i. + unit_cell: sampled unit cell at time step i. # TODO replace with AXL-L cartesian_forces: forces conditioning the diffusion process Returns: - x_i_out : sampled relative coordinates after the corrector step. + corrected_composition_i : sampled composition after the corrector step. """ pass diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/generators/sde_position_generator.py b/src/diffusion_for_multi_scale_molecular_dynamics/generators/sde_position_generator.py index 4b9dcab7..3531b9aa 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/generators/sde_position_generator.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/generators/sde_position_generator.py @@ -1,3 +1,4 @@ +import dataclasses import logging from dataclasses import dataclass @@ -5,18 +6,20 @@ import torch import torchsde -from diffusion_for_multi_scale_molecular_dynamics.generators.position_generator import ( - PositionGenerator, SamplingParameters) +from diffusion_for_multi_scale_molecular_dynamics.generators.axl_generator import ( + AXLGenerator, SamplingParameters) from diffusion_for_multi_scale_molecular_dynamics.models.score_networks import \ ScoreNetwork from diffusion_for_multi_scale_molecular_dynamics.namespace import ( - CARTESIAN_FORCES, NOISE, NOISY_RELATIVE_COORDINATES, TIME, UNIT_CELL) -from diffusion_for_multi_scale_molecular_dynamics.samplers.variance_sampler import \ + AXL, CARTESIAN_FORCES, NOISE, NOISY_AXL_COMPOSITION, TIME, UNIT_CELL) +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.exploding_variance import \ + VarianceScheduler +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ NoiseParameters -from diffusion_for_multi_scale_molecular_dynamics.utils.basis_transformations import \ - map_relative_coordinates_to_unit_cell +from diffusion_for_multi_scale_molecular_dynamics.utils.basis_transformations import ( + map_axl_composition_to_unit_cell, map_relative_coordinates_to_unit_cell) from diffusion_for_multi_scale_molecular_dynamics.utils.sample_trajectory import \ - SDESampleTrajectory + SampleTrajectory logger = logging.getLogger(__name__) @@ -51,8 +54,9 @@ def __init__( self, noise_parameters: NoiseParameters, sampling_parameters: SDESamplingParameters, - sigma_normalized_score_network: ScoreNetwork, - unit_cells: torch.Tensor, + axl_network: ScoreNetwork, + atom_types: torch.LongTensor, # TODO review formalism - this is treated as constant through the SDE solver + unit_cells: torch.Tensor, # TODO replace with AXL-L initial_diffusion_time: torch.Tensor, final_diffusion_time: torch.Tensor, ): @@ -64,7 +68,11 @@ def __init__( Args: noise_parameters: parameters defining the noise schedule. sampling_parameters : parameters defining the sampling procedure. - sigma_normalized_score_network : the score network to use for drawing samples. + axl_network : the model to use for drawing samples that predicts an AXL: + atom types: predicts p(a_0 | a_t) + relative coordinates: predicts the sigma normalized score + lattice: placeholder # TODO + atom_types: atom type indices. Tensor of dimensions [number_of_samples, natoms] unit_cells: unit cell definition in Angstrom. Tensor of dimensions [number_of_samples, spatial_dimension, spatial_dimension] initial_diffusion_time : initial diffusion time. Dimensionless tensor. @@ -73,34 +81,15 @@ def __init__( super().__init__() self.sde_type = sampling_parameters.sde_type self.noise_parameters = noise_parameters - self.sigma_normalized_score_network = sigma_normalized_score_network - self.unit_cells = unit_cells + self.exploding_variance = VarianceScheduler(noise_parameters) + self.axl_network = axl_network + self.atom_types = atom_types + self.unit_cells = unit_cells # TODO replace with AXL-L self.number_of_atoms = sampling_parameters.number_of_atoms self.spatial_dimension = sampling_parameters.spatial_dimension self.initial_diffusion_time = initial_diffusion_time self.final_diffusion_time = final_diffusion_time - def _get_exploding_variance_sigma( - self, diffusion_time: torch.Tensor - ) -> torch.Tensor: - """Get Exploding Variance Sigma. - - In the 'exploding variance' scheme, the noise is defined by - - sigma(t) = sigma_min^{1- t} x sigma_max^{t} - - Args: - diffusion_time : diffusion time - - Returns: - sigma: value of the noise parameter. - """ - sigma = ( - self.noise_parameters.sigma_min ** (1.0 - diffusion_time) - * self.noise_parameters.sigma_max**diffusion_time - ) - return sigma - def _get_diffusion_coefficient_g_squared( self, diffusion_time: torch.Tensor ) -> torch.Tensor: @@ -115,13 +104,7 @@ def _get_diffusion_coefficient_g_squared( Returns: coefficient_g : the coefficient g(t) """ - s_min = torch.tensor(self.noise_parameters.sigma_min) - ratio = torch.tensor( - self.noise_parameters.sigma_max / self.noise_parameters.sigma_min - ) - - g_squared = 2.0 * (s_min * ratio**diffusion_time) ** 2 * torch.log(ratio) - return g_squared + return self.exploding_variance.get_g_squared(diffusion_time) def _get_diffusion_time(self, sde_time: torch.Tensor) -> torch.Tensor: """Get diffusion time. @@ -149,15 +132,15 @@ def f( """ diffusion_time = self._get_diffusion_time(sde_time) - sigma_normalized_scores = self.get_sigma_normalized_score( - diffusion_time, flat_relative_coordinates - ) + sigma_normalized_scores = self.get_model_predictions( + diffusion_time, flat_relative_coordinates, self.atom_types + ).X # we are only using the sigma normalized score for the relative coordinates diffusion flat_sigma_normalized_scores = einops.rearrange( sigma_normalized_scores, "batch natom space -> batch (natom space)" ) g_squared = self._get_diffusion_coefficient_g_squared(diffusion_time) - sigma = self._get_exploding_variance_sigma(diffusion_time) + sigma = self.exploding_variance.get_sigma(diffusion_time) # Careful! The prefactor must account for the following facts: # - the SDE time is NEGATIVE the diffusion time; this introduces a minus sign dt_{diff} = -dt_{sde} # - what our model calculates is the NORMALIZED score (ie, Score x sigma). We must thus divide by sigma. @@ -165,9 +148,12 @@ def f( return prefactor * flat_sigma_normalized_scores - def get_sigma_normalized_score( - self, diffusion_time: torch.Tensor, flat_relative_coordinates: torch.Tensor - ) -> torch.Tensor: + def get_model_predictions( + self, + diffusion_time: torch.Tensor, + flat_relative_coordinates: torch.Tensor, + atom_types: torch.Tensor, + ) -> AXL: """Get sigma normalized score. This is a utility method to wrap around the computation of the sigma normalized score in this context, @@ -177,13 +163,16 @@ def get_sigma_normalized_score( diffusion_time : the diffusion time. Dimensionless tensor. flat_relative_coordinates : the flat relative coordinates. Dimension [batch_size, natoms x spatial_dimensions] + atom_types: indices for the atom types. Dimension [batch_size, natoms] Returns: - sigma_normalized_score: the sigma normalized score. - Dimension [batch_size, natoms, spatial_dimensions] + model predictions: AXL with + A: estimate of p(a_0|a_t). Dimension [batch_size, natoms, num_classes] + X: sigma normalized score. Dimension [batch_size, natoms, spatial_dimensions] + L: placeholder # TODO """ batch_size = flat_relative_coordinates.shape[0] - sigma = self._get_exploding_variance_sigma(diffusion_time) + sigma = self.exploding_variance.get_sigma(diffusion_time) sigmas = einops.repeat(sigma.unsqueeze(0), "1 -> batch 1", batch=batch_size) times = einops.repeat( diffusion_time.unsqueeze(0), "1 -> batch 1", batch=batch_size @@ -196,8 +185,10 @@ def get_sigma_normalized_score( space=self.spatial_dimension, ) batch = { - NOISY_RELATIVE_COORDINATES: map_relative_coordinates_to_unit_cell( - relative_coordinates + NOISY_AXL_COMPOSITION: AXL( + A=atom_types, + X=map_relative_coordinates_to_unit_cell(relative_coordinates), + L=self.unit_cells, # TODO ), NOISE: sigmas, TIME: times, @@ -206,9 +197,9 @@ def get_sigma_normalized_score( relative_coordinates ), # TODO: handle forces correctly. } - # Shape [batch_size, number of atoms, spatial dimension] - sigma_normalized_scores = self.sigma_normalized_score_network(batch) - return sigma_normalized_scores + # Shape for the coordinates scores [batch_size, number of atoms, spatial dimension] + model_predictions = self.axl_network(batch) + return model_predictions def g(self, sde_time, y): """Diffusion function.""" @@ -219,7 +210,7 @@ def g(self, sde_time, y): return g_of_t * torch.ones_like(y) -class ExplodingVarianceSDEPositionGenerator(PositionGenerator): +class ExplodingVarianceSDEPositionGenerator(AXLGenerator): """Exploding Variance SDE Position Generator. This class generates position samples by solving a stochastic differential equation (SDE). @@ -230,54 +221,70 @@ def __init__( self, noise_parameters: NoiseParameters, sampling_parameters: SDESamplingParameters, - sigma_normalized_score_network: ScoreNetwork, + axl_network: ScoreNetwork, ): """Init method. Args: noise_parameters : the diffusion noise parameters. sampling_parameters: the parameters needed for sampling. - sigma_normalized_score_network : the score network to use for drawing samples. + axl_network: the score network to use for drawing samples. """ self.initial_diffusion_time = torch.tensor(0.0) self.final_diffusion_time = torch.tensor(1.0) self.noise_parameters = noise_parameters - self.sigma_normalized_score_network = sigma_normalized_score_network + self.axl_network = axl_network self.sampling_parameters = sampling_parameters self.number_of_atoms = sampling_parameters.number_of_atoms self.spatial_dimension = sampling_parameters.spatial_dimension self.absolute_solver_tolerance = sampling_parameters.absolute_solver_tolerance self.relative_solver_tolerance = sampling_parameters.relative_solver_tolerance - self.record_samples = sampling_parameters.record_samples - if self.record_samples: - self.sample_trajectory_recorder = SDESampleTrajectory() - - def get_sde(self, unit_cells: torch.Tensor) -> SDE: + self.record = sampling_parameters.record_samples + if self.record: + self.sample_trajectory_recorder = SampleTrajectory() + self.sample_trajectory_recorder.record(key="noise_parameters", + entry=dataclasses.asdict(noise_parameters)) + self.sample_trajectory_recorder.record(key="sampling_parameters", + entry=dataclasses.asdict(sampling_parameters)) + + def get_sde(self, unit_cells: torch.Tensor, atom_types: torch.LongTensor) -> SDE: """Get SDE.""" return SDE( noise_parameters=self.noise_parameters, sampling_parameters=self.sampling_parameters, - sigma_normalized_score_network=self.sigma_normalized_score_network, + axl_network=self.axl_network, + atom_types=atom_types, unit_cells=unit_cells, initial_diffusion_time=self.initial_diffusion_time, final_diffusion_time=self.final_diffusion_time, ) - def initialize(self, number_of_samples: int): + def initialize( + self, number_of_samples: int, device: torch.device = torch.device("cpu") + ): """This method must initialize the samples from the fully noised distribution.""" relative_coordinates = torch.rand( number_of_samples, self.number_of_atoms, self.spatial_dimension + ).to(device) + atom_types = ( + torch.zeros(number_of_samples, self.number_of_atoms).long().to(device) ) - return relative_coordinates + lattice_vectors = torch.zeros( + number_of_samples, self.spatial_dimension * (self.spatial_dimension - 1) + ).to( + device + ) # TODO placeholder + init_composition = AXL(A=atom_types, X=relative_coordinates, L=lattice_vectors) + return init_composition def sample( self, number_of_samples: int, device: torch.device, unit_cell: torch.Tensor - ) -> torch.Tensor: + ) -> AXL: """Sample. - This method draws a position sample. + This method draws an AXL sample. Args: number_of_samples : number of samples to draw. @@ -286,16 +293,17 @@ def sample( Tensor of dimensions [number_of_samples, spatial_dimension, spatial_dimension] Returns: - samples: relative coordinates samples. + samples: samples as AXL composition. """ - sde = self.get_sde(unit_cell) + initial_composition = map_axl_composition_to_unit_cell( + self.initialize(number_of_samples), device + ) + + sde = self.get_sde(unit_cell, atom_types=initial_composition.A) sde.to(device) - initial_relative_coordinates = map_relative_coordinates_to_unit_cell( - self.initialize(number_of_samples) - ).to(device) y0 = einops.rearrange( - initial_relative_coordinates, "batch natom space -> batch (natom space)" + initial_composition.X, "batch natom space -> batch (natom space)" ) sde_times = torch.linspace( @@ -323,7 +331,7 @@ def sample( ) logger.info("SDE solver Finished.") - if self.record_samples: + if self.record: self.record_sample(sde, ys, sde_times) # only the final sde time (ie, diffusion time t0) is the real sample. @@ -352,8 +360,6 @@ def record_sample(self, sde: SDE, ys: torch.Tensor, sde_times: torch.Tensor): Returns: None """ - self.sample_trajectory_recorder.record_unit_cell(sde.unit_cells) - list_normalized_scores = [] sigmas = [] evaluation_times = [] @@ -362,14 +368,16 @@ def record_sample(self, sde: SDE, ys: torch.Tensor, sde_times: torch.Tensor): sde_times.flip(dims=(0,)), ys.flip(dims=(0,)) ): diffusion_time = sde._get_diffusion_time(sde_time) - sigma = sde._get_exploding_variance_sigma(diffusion_time) + sigma = sde.exploding_variance.get_sigma(diffusion_time) sigmas.append(sigma) evaluation_times.append(diffusion_time) with torch.no_grad(): - normalized_scores = sde.get_sigma_normalized_score( - diffusion_time, flat_relative_coordinates - ) + normalized_scores = sde.get_model_predictions( + diffusion_time, + flat_relative_coordinates, + sde.atom_types, + ).X list_normalized_scores.append(normalized_scores) sigmas = torch.tensor(sigmas) @@ -387,9 +395,11 @@ def record_sample(self, sde: SDE, ys: torch.Tensor, sde_times: torch.Tensor): space=self.spatial_dimension, ) - self.sample_trajectory_recorder.record_sde_solution( - times=evaluation_times, - sigmas=sigmas, - relative_coordinates=record_relative_coordinates, - normalized_scores=record_normalized_scores, - ) + entry = dict(unit_cell=sde.unit_cells, + times=evaluation_times, + sigmas=sigmas, + relative_coordinates=record_relative_coordinates, + normalized_scores=record_normalized_scores + ) + + self.sample_trajectory_recorder.record(key='sde', entry=entry) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/loss/__init__.py b/src/diffusion_for_multi_scale_molecular_dynamics/loss/__init__.py new file mode 100644 index 00000000..ad5b35ac --- /dev/null +++ b/src/diffusion_for_multi_scale_molecular_dynamics/loss/__init__.py @@ -0,0 +1,38 @@ +from diffusion_for_multi_scale_molecular_dynamics.loss.atom_type_loss_calculator import \ + D3PMLossCalculator +from diffusion_for_multi_scale_molecular_dynamics.loss.coordinates_loss_calculator import ( + MSELossCalculator, WeightedMSELossCalculator) +from diffusion_for_multi_scale_molecular_dynamics.loss.lattice_loss_calculator import \ + LatticeLossCalculator +from diffusion_for_multi_scale_molecular_dynamics.loss.loss_parameters import \ + LossParameters +from diffusion_for_multi_scale_molecular_dynamics.namespace import AXL + +LOSS_BY_ALGO = dict(mse=MSELossCalculator, weighted_mse=WeightedMSELossCalculator) + + +def create_loss_calculator(loss_parameters: LossParameters) -> AXL: + """Create Loss Calculator. + + This is a factory method to create the loss calculator. + + Args: + loss_parameters : parameters defining the loss. + + Returns: + loss_calculator : the loss calculator for atom types, coordinates, lattice in an AXL namedtuple. + """ + algorithm = loss_parameters.coordinates_algorithm + assert ( + algorithm in LOSS_BY_ALGO.keys() + ), f"Algorithm {algorithm} is not implemented. Possible choices are {LOSS_BY_ALGO.keys()}" + + coordinates_loss = LOSS_BY_ALGO[algorithm](loss_parameters) + lattice_loss = LatticeLossCalculator # TODO placeholder + atom_loss = D3PMLossCalculator(loss_parameters) + + return AXL( + A=atom_loss, + X=coordinates_loss, + L=lattice_loss, + ) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/loss/atom_type_loss_calculator.py b/src/diffusion_for_multi_scale_molecular_dynamics/loss/atom_type_loss_calculator.py new file mode 100644 index 00000000..d4687de5 --- /dev/null +++ b/src/diffusion_for_multi_scale_molecular_dynamics/loss/atom_type_loss_calculator.py @@ -0,0 +1,263 @@ +import torch + +from diffusion_for_multi_scale_molecular_dynamics.loss.loss_parameters import \ + LossParameters +from diffusion_for_multi_scale_molecular_dynamics.utils.d3pm_utils import \ + get_probability_at_previous_time_step + + +class D3PMLossCalculator(torch.nn.Module): + """Class to calculate the discrete diffusion loss.""" + + def __init__(self, loss_parameters: LossParameters): + """Initialize method.""" + super().__init__() + # weight of the cross-entropy component + self.ce_weight = loss_parameters.atom_types_ce_weight + self.eps = loss_parameters.atom_types_eps + + def cross_entropy_loss_term(self, + predicted_logits: torch.Tensor, + one_hot_real_atom_types: torch.Tensor) -> torch.Tensor: + r"""Compute the cross entropy component of the loss. + + This corresponds to this: + + .. math:: + + -\log \tilde p_\theta(a_{0} | a_{t}) + + Args: + predicted_logits: output of the score network estimating class logits + :math:`\tilde p(a_0 | a_t)` of dimension [batch_size, number_of_atoms, num_classes] where num_classes + includes the MASK token + one_hot_real_atom_types: real atom types :math:`a_0` in one-hot format of dimension + [batch_size, number_of_atoms, num_type_atoms, num_classes] + + Returns: + cross_entropy: the negative log-likelihood of the predictions for the actual class, of dimension + [batch_size, number_of_atoms, num_classes]. + """ + nll_term = -torch.nn.functional.log_softmax(predicted_logits, dim=-1) + # The last logit is -inf, which leads to p(a_{0} = MASK) = 0. This diverges and must be squashed. + nll_term[..., -1] = 0.0 + + # We must restrict the value of a0 to its actual value, which is done by multiplying by delta_{a0, actual_a0} + cross_entropy = one_hot_real_atom_types * nll_term + + return cross_entropy + + def variational_bound_loss_term( + self, + predicted_logits: torch.Tensor, + one_hot_real_atom_types: torch.Tensor, + one_hot_noisy_atom_types: torch.Tensor, + q_matrices: torch.Tensor, + q_bar_matrices: torch.Tensor, + q_bar_tm1_matrices: torch.Tensor, + time_indices: torch.Tensor + ) -> torch.Tensor: + r"""Compute the variational bound part of the loss. + + This corresponds to this: + + .. math:: + + t == 1 : -log(p_\theta(a_{0} | a_{1}) + t != 1 : D_{KL}[q(a_{t-1} | a_t, a_0) || p_\theta(a_{t-1} | a_{t})] + + Args: + predicted_logits: output of the score network estimating class logits + :math:`\tilde p(a_0 | a_t)` of dimension [batch_size, number_of_atoms, num_classes] where num_classes + includes the MASK token + one_hot_real_atom_types: real atom types :math:`a_0` in one-hot format of dimension + [batch_size, number_of_atoms, num_type_atoms, num_classes] + one_hot_noisy_atom_types: noisy atom types :math:`a_t` in one-hot format of dimension + [batch_size, number_of_atoms, num_type_atoms, num_classes] + q_matrices: one-step transition matrices :math:`Q_t` of dimension + [batch_size, number_of_atoms, num_type_atoms, num_classes] + q_bar_matrices: one-shot transition matrices :math:`\bar{Q}_t` of dimension + [batch_size, number_of_atoms, num_type_atoms, num_classes] + q_bar_tm1_matrices: one-shot transition matrices at previous step :math:`\bar{Q}_{t-1}` of dimension + [batch_size, number_of_atoms, num_type_atoms, num_classes]. An identity matrix is used for t=0. + time_indices: time indices sampled of dimension [batch_size] + + Returns: + torch.Tensor: unreduced variational bound loss of dimension [batch_size, number_of_atoms, num_classes] + """ + # The posterior probabilities + q_atm1_given_at_and_a0 = self.get_q_atm1_given_at_and_a0( + one_hot_a0=one_hot_real_atom_types, + one_hot_at=one_hot_noisy_atom_types, + q_matrices=q_matrices, + q_bar_matrices=q_bar_matrices, + q_bar_tm1_matrices=q_bar_tm1_matrices, + small_epsilon=self.eps, + ) + + # The predicted probabilities + p_atm1_given_at = self.get_p_atm1_given_at( + predicted_logits=predicted_logits, + one_hot_at=one_hot_noisy_atom_types, + q_matrices=q_matrices, + q_bar_matrices=q_bar_matrices, + q_bar_tm1_matrices=q_bar_tm1_matrices, + small_epsilon=self.eps, + ) + + # get the KL divergence between posterior and predicted probabilities + # do not reduce (average) yet as we will replace the samples with t=1 with a NLL loss + # input of kl_div should be log-probabilities. + # time_indices.view(-1, 1, 1) == 0, + + log_p = torch.log(p_atm1_given_at.clip(min=self.eps)) + kl_loss = torch.nn.functional.kl_div( + log_p, q_atm1_given_at_and_a0, reduction="none" + ) + + variational_bound_loss = kl_loss + + first_time_step_mask = time_indices == 0 + # We must restrict the value of a0 to its actual value, which is done by multiplying by delta_{a0, actual_a0} + variational_bound_loss[first_time_step_mask] = (-log_p[first_time_step_mask] + * one_hot_real_atom_types[first_time_step_mask]) + + return variational_bound_loss + + @classmethod + def get_q_atm1_given_at_and_a0( + cls, + one_hot_a0: torch.Tensor, + one_hot_at: torch.Tensor, + q_matrices: torch.Tensor, + q_bar_matrices: torch.Tensor, + q_bar_tm1_matrices: torch.Tensor, + small_epsilon: float, + ) -> torch.Tensor: + r"""Compute q(a_{t-1} | a_t, a_0). + + Args: + one_hot_a0: a one-hot representation of a class type at time step zero, as a tensor with dimension + [batch_size, number_of_atoms, num_classes] + one_hot_at: a one-hot representation of a class type at current time step, as a tensor with dimension + [batch_size, number_of_atoms, num_classes] + q_matrices: transition matrices at current time step :math:`{Q}_{t}` of dimension + [batch_size, number_of_atoms, num_classes, num_classes]. + q_bar_matrices: one-shot transition matrices at current time step :math:`\bar{Q}_{t}` of dimension + [batch_size, number_of_atoms, num_classes, num_classes]. + q_bar_tm1_matrices: one-shot transition matrices at previous time step :math:`\bar{Q}_{t-1}` of dimension + [batch_size, number_of_atoms, num_classes, num_classes]. + small_epsilon: minimum value for the denominator, to avoid division by zero. + + Returns: + probabilities over classes, of dimension [batch_size, num_classes, num_classes] + """ + q_atm1_given_at_and_0 = get_probability_at_previous_time_step( + probability_at_zeroth_timestep=one_hot_a0, + one_hot_probability_at_current_timestep=one_hot_at, + q_matrices=q_matrices, + q_bar_matrices=q_bar_matrices, + q_bar_tm1_matrices=q_bar_tm1_matrices, + small_epsilon=small_epsilon, + probability_at_zeroth_timestep_are_logits=False, + ) + return q_atm1_given_at_and_0 + + @classmethod + def get_p_atm1_given_at( + cls, + predicted_logits: torch.Tensor, + one_hot_at: torch.Tensor, + q_matrices: torch.Tensor, + q_bar_matrices: torch.Tensor, + q_bar_tm1_matrices: torch.Tensor, + small_epsilon: float, + ) -> torch.Tensor: + r"""Compute p(a_{t-1} | a_t). + + .. math:: + p_\theta(a_{t-1} | a_t) \propto \sum_{\tilde{a}_0} q(a_{t-1}, a_t | \tilde{a}_0)p_\theta(\tilde{a}_0, a_t) + + Args: + predicted_logits: output of the score network estimating an unnormalized + :math:`p(a_0 | a_t)` of dimension [batch_size, number_of_atoms, num_type_atoms] where num_type_atoms + includes the MASK token + one_hot_at: a one-hot representation of a class type at current time step, as a tensor with dimension + [batch_size, number_of_atoms, num_classes] + q_matrices: transition matrices at current time step :math:`{Q}_{t}` of dimension + [batch_size, number_of_atoms, num_classes, num_classes]. + q_bar_matrices: one-shot transition matrices at current time step :math:`\bar{Q}_{t}` of dimension + [batch_size, number_of_atoms, num_classes, num_classes]. + q_bar_tm1_matrices: one-shot transition matrices at previous time step :math:`\bar{Q}_{t-1}` of dimension + [batch_size, number_of_atoms, num_classes, num_classes]. + small_epsilon: minimum value for the denominator, to avoid division by zero. + + Returns: + one-step transition normalized probabilities of dimension [batch_size, num_classes, num_classes] + """ + p_atm1_at = get_probability_at_previous_time_step( + probability_at_zeroth_timestep=predicted_logits, + one_hot_probability_at_current_timestep=one_hot_at, + q_matrices=q_matrices, + q_bar_matrices=q_bar_matrices, + q_bar_tm1_matrices=q_bar_tm1_matrices, + small_epsilon=small_epsilon, + probability_at_zeroth_timestep_are_logits=True, + ) + return p_atm1_at + + def calculate_unreduced_loss( + self, + predicted_logits: torch.Tensor, + one_hot_real_atom_types: torch.Tensor, + one_hot_noisy_atom_types: torch.Tensor, + time_indices: torch.Tensor, + q_matrices: torch.Tensor, + q_bar_matrices: torch.Tensor, + q_bar_tm1_matrices: torch.Tensor, + ) -> torch.Tensor: + r"""Calculate unreduced loss. + + The loss is given by: + + .. math:: + + L_a = E_{a_0 ~ p_\textrm{data}} [ - E_{a_1 ~ p_{t=1| 0}} log p_\theta(a_0 | a_1) + + \sum_{t=2}^T E_{a_t ~ p_{t|0}} [ D_{KL}[q(a_{t-1} | a_t, a_0) || p_theta(a_{t-1} | a_{t}] ] + + \lambda_CE \sum_{t=1}^T -log p_\theta(a_0 | a_t)] + + Args: + predicted_logits: output of the score network logits for :math:`p(a_0 | a_t)` + of dimension [batch_size, number_of_atoms, num_classes] where num_classes includes the MASK token. + one_hot_real_atom_types: real atom types :math:`a_0` as one-hot vectors + of dimension [batch_size, number_of_atoms, num_type_atoms] + one_hot_noisy_atom_types: noisy atom types :math:`a_t` as one-hot vectors + of dimension [batch_size, number_of_atoms, num_type_atoms] + time_indices: time indices sampled of dimension [batch_size] + q_matrices: one-step transition matrices :math:`Q_t` of dimension + [batch_size, number_of_atoms, num_type_atoms, num_type_atoms] + q_bar_matrices: one-shot transition matrices :math:`\bar{Q}_t` of dimension + [batch_size, number_of_atoms, num_type_atoms, num_type_atoms] + q_bar_tm1_matrices: one-shot transition matrices at previous step :math:`\bar{Q}_{t-1}` of dimension + [batch_size, number_of_atoms, num_type_atoms, num_type_atoms]. An identity matrix is used for t=0 + + Returns: + unreduced_loss: a tensor of shape [batch_size, number_of_atoms, num_type_atoms]. It's mean is the loss. + """ + # if t == 1 (0 for python indexing convention), use the NLL term, otherwise use the KL term + vb_term = self.variational_bound_loss_term( + predicted_logits, + one_hot_real_atom_types, + one_hot_noisy_atom_types, + q_matrices, + q_bar_matrices, + q_bar_tm1_matrices, + time_indices + ) + + # -log tilde_p_\theta(a_0 | a_t) + ce_term = self.cross_entropy_loss_term(predicted_logits, one_hot_real_atom_types) + + d3pm_loss = vb_term + self.ce_weight * ce_term + + return d3pm_loss diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/models/loss.py b/src/diffusion_for_multi_scale_molecular_dynamics/loss/coordinates_loss_calculator.py similarity index 62% rename from src/diffusion_for_multi_scale_molecular_dynamics/models/loss.py rename to src/diffusion_for_multi_scale_molecular_dynamics/loss/coordinates_loss_calculator.py index 90daaf11..c0cdebbb 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/models/loss.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/loss/coordinates_loss_calculator.py @@ -1,39 +1,10 @@ -from dataclasses import dataclass -from typing import Any, Dict - import torch -from diffusion_for_multi_scale_molecular_dynamics.utils.configuration_parsing import \ - create_parameters_from_configuration_dictionary - - -@dataclass(kw_only=True) -class LossParameters: - """Specific Hyper-parameters for the loss function.""" - - algorithm: str - - -@dataclass(kw_only=True) -class MSELossParameters(LossParameters): - """Specific Hyper-parameters for the MSE loss function.""" +from diffusion_for_multi_scale_molecular_dynamics.loss.loss_parameters import ( + LossParameters, MSELossParameters, WeightedMSELossParameters) - algorithm: str = "mse" - -@dataclass(kw_only=True) -class WeightedMSELossParameters(LossParameters): - """Specific Hyper-parameters for the weighted MSE loss function.""" - - algorithm: str = "weighted_mse" - # The default values are chosen to lead to a flat loss curve vs. sigma, based on preliminary experiments. - # These parameters have no effect if the algorithm is 'mse'. - # The default parameters are chosen such that weights(sigma=0.5) \sim 10^3 - sigma0: float = 0.2 - exponent: float = 23.0259 # ~ 10 ln(10) - - -class LossCalculator(torch.nn.Module): +class CoordinatesLossCalculator(torch.nn.Module): """Class to calculate the loss.""" def __init__(self, loss_parameters: LossParameters): @@ -63,7 +34,7 @@ def calculate_unreduced_loss( raise NotImplementedError -class MSELossCalculator(LossCalculator): +class MSELossCalculator(CoordinatesLossCalculator): """Class to calculate the MSE loss.""" def __init__(self, loss_parameters: MSELossParameters): @@ -147,50 +118,3 @@ def calculate_unreduced_loss( unreduced_loss = unreduced_mse_loss * weights return unreduced_loss - - -LOSS_PARAMETERS_BY_ALGO = dict( - mse=MSELossParameters, weighted_mse=WeightedMSELossParameters -) -LOSS_BY_ALGO = dict(mse=MSELossCalculator, weighted_mse=WeightedMSELossCalculator) - - -def create_loss_parameters(model_dictionary: Dict[str, Any]) -> LossParameters: - """Create loss parameters. - - Extract the relevant information from the general configuration dictionary. - - Args: - model_dictionary : model configuration dictionary. - - Returns: - loss_parameters: the loss parameters. - """ - default_dict = dict(algorithm="mse") - loss_config_dictionary = model_dictionary.get("loss", default_dict) - - loss_parameters = create_parameters_from_configuration_dictionary( - configuration=loss_config_dictionary, - identifier="algorithm", - options=LOSS_PARAMETERS_BY_ALGO, - ) - return loss_parameters - - -def create_loss_calculator(loss_parameters: LossParameters) -> LossCalculator: - """Create Loss Calculator. - - This is a factory method to create the loss calculator. - - Args: - loss_parameters : parameters defining the loss. - - Returns: - loss_calculator : the loss calculator. - """ - algorithm = loss_parameters.algorithm - assert ( - algorithm in LOSS_BY_ALGO.keys() - ), f"Algorithm {algorithm} is not implemented. Possible choices are {LOSS_BY_ALGO.keys()}" - - return LOSS_BY_ALGO[algorithm](loss_parameters) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/loss/lattice_loss_calculator.py b/src/diffusion_for_multi_scale_molecular_dynamics/loss/lattice_loss_calculator.py new file mode 100644 index 00000000..6e24984f --- /dev/null +++ b/src/diffusion_for_multi_scale_molecular_dynamics/loss/lattice_loss_calculator.py @@ -0,0 +1,16 @@ +import torch + + +class LatticeLossCalculator(torch.nn.Module): + """Class to calculate the loss for the lattice vectors. + + Placeholder for now. + """ + + def __init__(self): + """Placeholder for now.""" + super().__init__() + + def calculate_unreduced_loss(self, *args): + """Placeholder for now.""" + return 0 diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/loss/loss_parameters.py b/src/diffusion_for_multi_scale_molecular_dynamics/loss/loss_parameters.py new file mode 100644 index 00000000..1224aa21 --- /dev/null +++ b/src/diffusion_for_multi_scale_molecular_dynamics/loss/loss_parameters.py @@ -0,0 +1,64 @@ +from dataclasses import dataclass +from typing import Any, Dict + +from diffusion_for_multi_scale_molecular_dynamics.utils.configuration_parsing import \ + create_parameters_from_configuration_dictionary + + +@dataclass(kw_only=True) +class LossParameters: + """Specific Hyper-parameters for the loss function.""" + atom_types_lambda_weight: float = 1.0 # weighting prefactor for atom-type loss. + relative_coordinates_lambda_weight: float = 1.0 # weighting prefactor for the coordinates loss. + lattice_lambda_weight: float = 1.0 # weighting prefactor for the lattice loss. + + coordinates_algorithm: str + atom_types_ce_weight: float = 0.001 # default value in google D3PM repo + atom_types_eps: float = 1e-8 # avoid divisions by zero + # https://github.com/google-research/google-research/blob/master/d3pm/images/config.py + + +@dataclass(kw_only=True) +class MSELossParameters(LossParameters): + """Specific Hyper-parameters for the MSE loss function.""" + + coordinates_algorithm: str = "mse" + + +@dataclass(kw_only=True) +class WeightedMSELossParameters(LossParameters): + """Specific Hyper-parameters for the weighted MSE loss function.""" + + coordinates_algorithm: str = "weighted_mse" + # The default values are chosen to lead to a flat loss curve vs. sigma, based on preliminary experiments. + # These parameters have no effect if the algorithm is 'mse'. + # The default parameters are chosen such that weights(sigma=0.5) \sim 10^3 + sigma0: float = 0.2 + exponent: float = 23.0259 # ~ 10 ln(10) + + +def create_loss_parameters(model_dictionary: Dict[str, Any]) -> LossParameters: + """Create loss parameters. + + Extract the relevant information from the general configuration dictionary. + + Args: + model_dictionary : model configuration dictionary. + + Returns: + loss_parameters: the loss parameters. + """ + default_dict = dict(algorithm="mse") + loss_config_dictionary = model_dictionary.get("loss", default_dict) + + loss_parameters = create_parameters_from_configuration_dictionary( + configuration=loss_config_dictionary, + identifier="coordinates_algorithm", + options=LOSS_PARAMETERS_BY_ALGO, + ) + return loss_parameters + + +LOSS_PARAMETERS_BY_ALGO = dict( + mse=MSELossParameters, weighted_mse=WeightedMSELossParameters +) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/models/position_diffusion_lightning_model.py b/src/diffusion_for_multi_scale_molecular_dynamics/models/axl_diffusion_lightning_model.py similarity index 55% rename from src/diffusion_for_multi_scale_molecular_dynamics/models/position_diffusion_lightning_model.py rename to src/diffusion_for_multi_scale_molecular_dynamics/models/axl_diffusion_lightning_model.py index 143b2999..9a5cece5 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/models/position_diffusion_lightning_model.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/models/axl_diffusion_lightning_model.py @@ -7,10 +7,12 @@ from diffusion_for_multi_scale_molecular_dynamics.generators.instantiate_generator import \ instantiate_generator +from diffusion_for_multi_scale_molecular_dynamics.loss import \ + create_loss_calculator +from diffusion_for_multi_scale_molecular_dynamics.loss.loss_parameters import \ + LossParameters from diffusion_for_multi_scale_molecular_dynamics.metrics.kolmogorov_smirnov_metrics import \ KolmogorovSmirnovMetrics -from diffusion_for_multi_scale_molecular_dynamics.models.loss import ( - LossParameters, create_loss_calculator) from diffusion_for_multi_scale_molecular_dynamics.models.optimizer import ( OptimizerParameters, load_optimizer) from diffusion_for_multi_scale_molecular_dynamics.models.scheduler import ( @@ -20,51 +22,65 @@ from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.score_network_factory import \ create_score_network from diffusion_for_multi_scale_molecular_dynamics.namespace import ( - CARTESIAN_FORCES, CARTESIAN_POSITIONS, NOISE, NOISY_RELATIVE_COORDINATES, - RELATIVE_COORDINATES, TIME, UNIT_CELL) -from diffusion_for_multi_scale_molecular_dynamics.oracle.energies import \ - compute_oracle_energies -from diffusion_for_multi_scale_molecular_dynamics.samplers.noisy_relative_coordinates_sampler import \ - NoisyRelativeCoordinatesSampler -from diffusion_for_multi_scale_molecular_dynamics.samplers.variance_sampler import ( - ExplodingVarianceSampler, NoiseParameters) -from diffusion_for_multi_scale_molecular_dynamics.samples.diffusion_sampling_parameters import \ - DiffusionSamplingParameters -from diffusion_for_multi_scale_molecular_dynamics.samples.sampling import \ + ATOM_TYPES, AXL, AXL_COMPOSITION, AXL_NAME_DICT, CARTESIAN_FORCES, + CARTESIAN_POSITIONS, NOISE, NOISY_AXL_COMPOSITION, RELATIVE_COORDINATES, + TIME, UNIT_CELL) +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ + NoiseParameters +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_scheduler import \ + NoiseScheduler +from diffusion_for_multi_scale_molecular_dynamics.noisers.atom_types_noiser import \ + AtomTypesNoiser +from diffusion_for_multi_scale_molecular_dynamics.noisers.lattice_noiser import \ + LatticeNoiser +from diffusion_for_multi_scale_molecular_dynamics.noisers.relative_coordinates_noiser import \ + RelativeCoordinatesNoiser +from diffusion_for_multi_scale_molecular_dynamics.oracle.energy_oracle import \ + OracleParameters +from diffusion_for_multi_scale_molecular_dynamics.oracle.energy_oracle_factory import \ + create_energy_oracle +from diffusion_for_multi_scale_molecular_dynamics.sampling.diffusion_sampling import \ create_batch_of_samples +from diffusion_for_multi_scale_molecular_dynamics.sampling.diffusion_sampling_parameters import \ + DiffusionSamplingParameters from diffusion_for_multi_scale_molecular_dynamics.score.wrapped_gaussian_score import \ get_sigma_normalized_score from diffusion_for_multi_scale_molecular_dynamics.utils.basis_transformations import ( get_positions_from_coordinates, map_relative_coordinates_to_unit_cell) +from diffusion_for_multi_scale_molecular_dynamics.utils.d3pm_utils import \ + class_index_to_onehot from diffusion_for_multi_scale_molecular_dynamics.utils.structure_utils import \ compute_distances_in_batch -from diffusion_for_multi_scale_molecular_dynamics.utils.tensor_utils import \ - broadcast_batch_tensor_to_all_dimensions +from diffusion_for_multi_scale_molecular_dynamics.utils.tensor_utils import ( + broadcast_batch_matrix_tensor_to_all_dimensions, + broadcast_batch_tensor_to_all_dimensions) logger = logging.getLogger(__name__) @dataclass(kw_only=True) -class PositionDiffusionParameters: - """Position Diffusion parameters.""" +class AXLDiffusionParameters: + """AXL (atom, relative coordinates, lattice) Diffusion parameters.""" score_network_parameters: ScoreNetworkParameters loss_parameters: LossParameters optimizer_parameters: OptimizerParameters scheduler_parameters: Optional[SchedulerParameters] = None noise_parameters: NoiseParameters - # convergence parameter for the Ewald-like sum of the perturbation kernel. + # convergence parameter for the Ewald-like sum of the perturbation kernel for coordinates. kmax_target_score: int = 4 diffusion_sampling_parameters: Optional[DiffusionSamplingParameters] = None + oracle_parameters: Optional[OracleParameters] = None -class PositionDiffusionLightningModel(pl.LightningModule): - """Position Diffusion Lightning Model. +class AXLDiffusionLightningModel(pl.LightningModule): + """AXL Diffusion Lightning Model. - This lightning model can train a score network predict the noise for relative coordinates. + This lightning model can train a score network to predict the noise for relative coordinates, atom types and lattice + vectors. """ - def __init__(self, hyper_params: PositionDiffusionParameters): + def __init__(self, hyper_params: AXLDiffusionParameters): """Init method. This initializes the class. @@ -72,23 +88,40 @@ def __init__(self, hyper_params: PositionDiffusionParameters): super().__init__() self.hyper_params = hyper_params + self.num_atom_types = hyper_params.score_network_parameters.num_atom_types self.save_hyperparameters( logger=False ) # It is not the responsibility of this class to log its parameters. - # we will model sigma x score - self.sigma_normalized_score_network = create_score_network( - hyper_params.score_network_parameters - ) + # the score network is expected to produce an output as an AXL namedtuple: + # atom: unnormalized estimate of p(a_0 | a_t) + # relative coordinates: estimate of \sigma \nabla_{x_t} p_{t|0}(x_t | x_0) + # lattices: TODO + self.axl_network = create_score_network(hyper_params.score_network_parameters) + # loss is an AXL object with one loss for each element (atom type, coordinate, lattice) self.loss_calculator = create_loss_calculator(hyper_params.loss_parameters) - self.noisy_relative_coordinates_sampler = NoisyRelativeCoordinatesSampler() - self.variance_sampler = ExplodingVarianceSampler(hyper_params.noise_parameters) + self.loss_weights = AXL(A=hyper_params.loss_parameters.atom_types_lambda_weight, + X=hyper_params.loss_parameters.relative_coordinates_lambda_weight, + L=hyper_params.loss_parameters.lattice_lambda_weight) + + # noisy samplers for atom types, coordinates and lattice vectors + self.noisers = AXL( + A=AtomTypesNoiser(), + X=RelativeCoordinatesNoiser(), + L=LatticeNoiser(), + ) + + self.noise_scheduler = NoiseScheduler( + hyper_params.noise_parameters, + num_classes=self.num_atom_types + 1, # add 1 for the MASK class + ) self.generator = None self.structure_ks_metric = None self.energy_ks_metric = None + self.oracle = None self.draw_samples = hyper_params.diffusion_sampling_parameters is not None if self.draw_samples: @@ -99,6 +132,9 @@ def __init__(self, hyper_params: PositionDiffusionParameters): self.structure_ks_metric = KolmogorovSmirnovMetrics() if self.metrics_parameters.compute_energies: self.energy_ks_metric = KolmogorovSmirnovMetrics() + assert self.hyper_params.oracle_parameters is not None, \ + "Energies cannot be computed without a configured energy oracle." + self.oracle = create_energy_oracle(self.hyper_params.oracle_parameters) def configure_optimizers(self): """Returns the combination of optimizer(s) and learning rate scheduler(s) to train with. @@ -145,32 +181,50 @@ def _generic_step( batch_idx: int, no_conditional: bool = False, ) -> Any: - """Generic step. + r"""Generic step. This "generic step" computes the loss for any of the possible lightning "steps". - The loss is defined as: - L = 1 / T int_0^T dt lambda(t) E_{x0 ~ p_data} E_{xt~ p_{t| 0}} - [|S_theta(xt, t) - nabla_{xt} log p_{t | 0} (xt | x0)|^2] + The loss is defined as a sum of 3 components: - Where - T : time range of the noising process - S_theta : score network - p_{t| 0} : perturbation kernel - nabla log p : the target score - lambda(t) : is arbitrary, but chosen for convenience. + .. math:: + L = L_x + L_a + L_L - In this implementation, we choose lambda(t) = sigma(t)^2 ( a standard choice from the literature), such + where :math:`L_x` is the loss for the coordinate diffusion, :math:`L_a` for the atom type diffusion and + :math:`L_L` for the lattice. + + The loss for the coordinate diffusion is defined as: + + .. math:: + L_x = 1 / T \int_0^T dt \lambda(t) E_{x0 ~ p_data} E_{xt~ p_{t| 0}} + [|S_\theta(xt, t) - \nabla_{xt} \log p_{t | 0} (xt | x0)|^2] + + Where + :math:`T` : time range of the noising process + :math:`S_\theta` : score network + :math:`p_{t|0}` : perturbation kernel + :math:`\nabla \log p` : the target score + :math:`\lambda(t)` : is arbitrary, but chosen for convenience. + + In this implementation, we choose :math:`\lambda(t) = \sigma(t)^2` (a standard choice from the literature), such that the score network and the target scores that are used are actually "sigma normalized" versions, ie, pre-multiplied by sigma. + For the atom type diffusion, the loss is defined as: + + .. math:: + L_a = E_{a_0 ~ p_\textrm{data}} [ \sum_{t=2}^T E_{a_t ~ p_{t|0} + [D_{KL}[q(a_{t-1} | a_t, a_0) || p_theta(a_{t-1} | a_{t}) - \lambda_CE log p_\theta(a_0 | a_t)] + - E_{a_1 ~ p_{t=1|0}} log p_\theta(a_0 | a_1) ] + The loss that is computed is a Monte Carlo estimate of L, where we sample a mini-batch of relative coordinates - configurations {x0}; each of these configurations is noised with a random t value, with corresponding - {sigma(t)} and {xt}. + configurations {x0} and atom types {a_0}; each of these configurations is noised with a random t value, + with corresponding {sigma(t)}, {xt}, {beta(t)} and {a(t)}. Note the :math:`beta(t)` is used to compute the true + posterior :math:`q(a_{t-1} | a_t, a_0)` and :math:`p_\theta(a_{t-1} | a_t)` in the atom type loss. Args: batch : a dictionary that should contain a data sample. - batch_idx : index of the batch + batch_idx : index of the batch no_conditional (optional): if True, do not use the conditional option of the forward. Used for validation. Returns: @@ -180,32 +234,78 @@ def _generic_step( assert ( RELATIVE_COORDINATES in batch ), f"The field '{RELATIVE_COORDINATES}' is missing from the input." + + assert ( + ATOM_TYPES in batch + ), f"The field '{ATOM_TYPES}' is missing from the input." + x0 = batch[RELATIVE_COORDINATES] shape = x0.shape assert len(shape) == 3, ( f"the shape of the RELATIVE_COORDINATES array should be [batch_size, number_of_atoms, spatial_dimensions]. " f"Got shape = {shape}." ) + + a0 = batch[ATOM_TYPES] batch_size = self._get_batch_size(batch) + atom_shape = a0.shape + assert len(atom_shape) == 2, ( + f"the shape of the ATOM_TYPES array should be [batch_size, number_of_atoms]. " + f"Got shape = {atom_shape}" + ) - noise_sample = self.variance_sampler.get_random_noise_sample(batch_size) + l0 = batch[ + "box" + ] # should be batch[UNIT_CELL] - see later comment with batch['box'] + # TODO assert on shape - # noise_sample.sigma has dimension [batch_size]. Broadcast these sigma values to be - # of shape [batch_size, number_of_atoms, spatial_dimension], which can be interpreted - # as [batch_size, (configuration)]. All the sigma values must be the same for a given configuration. + noise_sample = self.noise_scheduler.get_random_noise_sample(batch_size) + + # noise_sample.sigma has dimension [batch_size]. Broadcast these values to be of shape + # [batch_size, number_of_atoms, spatial_dimension] , which can be interpreted as + # [batch_size, (configuration)]. All the sigma values must be the same for a given configuration. sigmas = broadcast_batch_tensor_to_all_dimensions( batch_values=noise_sample.sigma, final_shape=shape ) + # we can now get noisy coordinates + xt = self.noisers.X.get_noisy_relative_coordinates_sample(x0, sigmas) + + # to get noisy atom types, we need to broadcast the transition matrices q, q_bar and q_bar_tm1 from size + # [batch_size, num_atom_types, num_atom_types] to [batch_size, number_of_atoms, num_atom_types, num_atom_types]. + # All the matrices must be the same for all atoms in a given configuration. + q_matrices = broadcast_batch_matrix_tensor_to_all_dimensions( + batch_values=noise_sample.q_matrix, final_shape=atom_shape + ) + q_bar_matrices = broadcast_batch_matrix_tensor_to_all_dimensions( + batch_values=noise_sample.q_bar_matrix, final_shape=atom_shape + ) - xt = self.noisy_relative_coordinates_sampler.get_noisy_relative_coordinates_sample( - x0, sigmas + q_bar_tm1_matrices = broadcast_batch_matrix_tensor_to_all_dimensions( + batch_values=noise_sample.q_bar_tm1_matrix, final_shape=atom_shape ) - # The target is nabla log p_{t|0} (xt | x0): it is NOT the "score", but rather a "conditional" (on x0) score. - target_normalized_conditional_scores = self._get_target_normalized_score( - xt, x0, sigmas + # we also need the atom types to be one-hot vector and not a class index + a0_onehot = class_index_to_onehot(a0, self.num_atom_types + 1) + + at = self.noisers.A.get_noisy_atom_types_sample(a0_onehot, q_bar_matrices) + at_onehot = class_index_to_onehot(at, self.num_atom_types + 1) + + # TODO do the same for the lattice vectors + lt = self.noisers.L.get_noisy_lattice_vectors(l0) + + noisy_composition = AXL(A=at, X=xt, L=lt) # not one-hot + + original_composition = AXL(A=a0, X=x0, L=l0) + + # Get the loss targets + # Coordinates: The target is :math:`sigma(t) \nabla log p_{t|0} (xt | x0)` + # it is NOT the "score", but rather a "conditional" (on x0) score. + target_coordinates_normalized_conditional_scores = ( + self._get_coordinates_target_normalized_score(xt, x0, sigmas) ) + # for the atom types, the loss is constructed from the Q and Qbar matrices + # TODO get unit_cell from the noisy version and not a kwarg in batch (at least replace with namespace name) unit_cell = torch.diag_embed( batch["box"] ) # from (batch, spatial_dim) to (batch, spatial_dim, spatial_dim) @@ -213,40 +313,84 @@ def _generic_step( forces = batch[CARTESIAN_FORCES] augmented_batch = { - NOISY_RELATIVE_COORDINATES: xt, + NOISY_AXL_COMPOSITION: noisy_composition, TIME: noise_sample.time.reshape(-1, 1), NOISE: noise_sample.sigma.reshape(-1, 1), - UNIT_CELL: unit_cell, + UNIT_CELL: unit_cell, # TODO remove and take from AXL instead CARTESIAN_FORCES: forces, } use_conditional = None if no_conditional is False else False - predicted_normalized_scores = self.sigma_normalized_score_network( + model_predictions = self.axl_network( augmented_batch, conditional=use_conditional ) - - unreduced_loss = self.loss_calculator.calculate_unreduced_loss( - predicted_normalized_scores, - target_normalized_conditional_scores, + # this output is expected to be an AXL object + # X score network output: an estimate of the sigma normalized score for the coordinates, + # A score network output: an unnormalized estimate of p(a_0 | a_t) for the atom types + # TODO something for the lattice + + unreduced_loss_coordinates = self.loss_calculator.X.calculate_unreduced_loss( + model_predictions.X, + target_coordinates_normalized_conditional_scores, sigmas, ) - loss = torch.mean(unreduced_loss) + + unreduced_loss_atom_types = self.loss_calculator.A.calculate_unreduced_loss( + predicted_logits=model_predictions.A, + one_hot_real_atom_types=a0_onehot, + one_hot_noisy_atom_types=at_onehot, + time_indices=noise_sample.indices, + q_matrices=q_matrices, + q_bar_matrices=q_bar_matrices, + q_bar_tm1_matrices=q_bar_tm1_matrices, + ) + + # TODO placeholder - returns zero + unreduced_loss_lattice = self.loss_calculator.L.calculate_unreduced_loss( + model_predictions.L + ) + + aggregated_weighted_loss = ( + self.loss_weights.X * unreduced_loss_coordinates.mean( + dim=-1 + ) # batch, num_atoms, spatial_dimension + + self.loss_weights.L * unreduced_loss_lattice + + self.loss_weights.A * unreduced_loss_atom_types.mean(dim=-1) # batch, num_atoms, num_atom_types + ) + + weighted_loss = torch.mean(aggregated_weighted_loss) + + unreduced_loss = AXL( + A=unreduced_loss_atom_types.detach(), + X=unreduced_loss_coordinates.detach(), + L=torch.zeros_like( + unreduced_loss_coordinates + ).detach(), # TODO use unreduced_loss_lattice.detach(), + ) + + model_predictions_detached = AXL( + A=model_predictions.A.detach(), + X=model_predictions.X.detach(), + L=model_predictions.L.detach(), + ) output = dict( - unreduced_loss=unreduced_loss.detach(), - loss=loss, + unreduced_loss=unreduced_loss, + loss=weighted_loss, sigmas=sigmas, - predicted_normalized_scores=predicted_normalized_scores.detach(), - target_normalized_conditional_scores=target_normalized_conditional_scores, + model_predictions=model_predictions_detached, + target_coordinates_normalized_conditional_scores=target_coordinates_normalized_conditional_scores, ) - output[RELATIVE_COORDINATES] = x0 - output[NOISY_RELATIVE_COORDINATES] = augmented_batch[NOISY_RELATIVE_COORDINATES] + output[AXL_COMPOSITION] = original_composition + output[NOISY_AXL_COMPOSITION] = noisy_composition output[TIME] = augmented_batch[TIME] - output[UNIT_CELL] = augmented_batch[UNIT_CELL] + output[UNIT_CELL] = augmented_batch[ + UNIT_CELL + ] # TODO remove and use AXL instead return output - def _get_target_normalized_score( + def _get_coordinates_target_normalized_score( self, noisy_relative_coordinates: torch.Tensor, real_relative_coordinates: torch.Tensor, @@ -296,6 +440,15 @@ def training_step(self, batch, batch_idx): on_step=False, on_epoch=True, ) + + for axl_field, axl_name in AXL_NAME_DICT.items(): + self.log( + f"train_epoch_{axl_name}_loss", + getattr(output["unreduced_loss"], axl_field).mean(), + batch_size=batch_size, + on_step=False, + on_epoch=True, + ) return output def validation_step(self, batch, batch_idx): @@ -314,6 +467,15 @@ def validation_step(self, batch, batch_idx): prog_bar=True, ) + for axl_field, axl_name in AXL_NAME_DICT.items(): + self.log( + f"validation_epoch_{axl_name}_loss", + getattr(output["unreduced_loss"], axl_field).mean(), + batch_size=batch_size, + on_step=False, + on_epoch=True, + ) + if not self.draw_samples: return output @@ -322,9 +484,9 @@ def validation_step(self, batch, batch_idx): self.energy_ks_metric.register_reference_samples(reference_energies.cpu()) if self.draw_samples and self.metrics_parameters.compute_structure_factor: - basis_vectors = torch.diag_embed(batch["box"]) + basis_vectors = torch.diag_embed(batch["box"]) # TODO replace with AXL L cartesian_positions = get_positions_from_coordinates( - relative_coordinates=batch[RELATIVE_COORDINATES], + relative_coordinates=output[AXL_COMPOSITION].X, basis_vectors=basis_vectors, ) @@ -350,10 +512,20 @@ def test_step(self, batch, batch_idx): "test_epoch_loss", loss, batch_size=batch_size, on_step=False, on_epoch=True ) + for axl_field, axl_name in AXL_NAME_DICT.items(): + self.log( + f"test_epoch_{axl_name}_loss", + getattr(output["unreduced_loss"], axl_field).mean(), + batch_size=batch_size, + on_step=False, + on_epoch=True, + ) + return output def generate_samples(self): """Generate a batch of samples.""" + # TODO add atom types generation assert ( self.hyper_params.diffusion_sampling_parameters is not None ), "sampling parameters must be provided to create a generator." @@ -362,7 +534,7 @@ def generate_samples(self): self.generator = instantiate_generator( sampling_parameters=self.hyper_params.diffusion_sampling_parameters.sampling_parameters, noise_parameters=self.hyper_params.diffusion_sampling_parameters.noise_parameters, - sigma_normalized_score_network=self.sigma_normalized_score_network, + axl_network=self.axl_network, # TODO use A and L too ) logger.info(f"Generator type : {type(self.generator)}") @@ -382,11 +554,11 @@ def on_validation_epoch_end(self) -> None: return logger.info(" - Drawing samples at the end of the validation epoch.") - samples_batch = self.generate_samples() + samples_batch = self.generate_samples() # TODO generate atom types too if self.draw_samples and self.metrics_parameters.compute_energies: logger.info(" * Computing sample energies") - sample_energies = compute_oracle_energies(samples_batch) + sample_energies = self.oracle.compute_oracle_energies(samples_batch) logger.info(" * Registering sample energies") self.energy_ks_metric.register_predicted_samples(sample_energies.cpu()) @@ -409,7 +581,9 @@ def on_validation_epoch_end(self) -> None: if self.draw_samples and self.metrics_parameters.compute_structure_factor: logger.info(" * Computing sample distances") sample_distances = compute_distances_in_batch( - cartesian_positions=samples_batch[CARTESIAN_POSITIONS], + cartesian_positions=samples_batch[ + CARTESIAN_POSITIONS + ], # TODO replace with AXL unit_cell=samples_batch[UNIT_CELL], max_distance=self.metrics_parameters.structure_factor_max_distance, ) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/models/diffusion_mace.py b/src/diffusion_for_multi_scale_molecular_dynamics/models/diffusion_mace.py index 2af13346..3c4814d7 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/models/diffusion_mace.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/models/diffusion_mace.py @@ -11,7 +11,10 @@ from diffusion_for_multi_scale_molecular_dynamics.models.mace_utils import ( get_adj_matrix, reshape_from_e3nn_to_mace, reshape_from_mace_to_e3nn) from diffusion_for_multi_scale_molecular_dynamics.namespace import ( - CARTESIAN_FORCES, NOISE, NOISY_CARTESIAN_POSITIONS, UNIT_CELL) + AXL, CARTESIAN_FORCES, NOISE, NOISY_AXL_COMPOSITION, + NOISY_CARTESIAN_POSITIONS, UNIT_CELL) +from diffusion_for_multi_scale_molecular_dynamics.utils.d3pm_utils import \ + class_index_to_onehot class LinearVectorReadoutBlock(torch.nn.Module): @@ -27,14 +30,32 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.linear(x) +class LinearClassificationReadoutBlock(torch.nn.Module): + """Linear readout for scalar representation.""" + + def __init__(self, irreps_in: o3.Irreps, num_classes: int): + """Init method.""" + super().__init__() + self.linear = o3.Linear( + irreps_in=irreps_in, irreps_out=o3.Irreps(f"{num_classes}x0e") + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward.""" + return self.linear(x) + + def input_to_diffusion_mace( - batch: Dict[AnyStr, torch.Tensor], radial_cutoff: float + batch: Dict[AnyStr, torch.Tensor], + radial_cutoff: float, + num_classes: int, ) -> Data: """Convert score network input to Diffusion MACE input. Args: batch: score network input dictionary radial_cutoff : largest distance between neighbors. + num_classes: number of atomic species, including the MASK class Returns: pytorch-geometric graph data compatible with MACE forward @@ -43,6 +64,7 @@ def input_to_diffusion_mace( batch_size, n_atom_per_graph, spatial_dimension = cartesian_positions.shape device = cartesian_positions.device + # TODO replace with AXL L basis_vectors = batch[UNIT_CELL] # batch, spatial_dimension, spatial_dimension adj_matrix, shift_matrix, batch_tensor, num_edges = get_adj_matrix( @@ -52,11 +74,10 @@ def input_to_diffusion_mace( ) # node features are int corresponding to atom type - # TODO handle different atom types - atom_types = torch.zeros(batch_size * n_atom_per_graph) - node_attrs = torch.nn.functional.one_hot(atom_types.long(), num_classes=1).to( - atom_types - ) + atom_types = batch[NOISY_AXL_COMPOSITION].A + node_attrs = class_index_to_onehot(atom_types, num_classes=num_classes) + node_attrs = node_attrs.view(-1, num_classes) + # atom type as 1-hot - should be (batch_size * n_atom, num_classes) # The node diffusion scalars will be the diffusion noise sigma, which is constant for each structure in the batch. # We broadcast to each node to avoid complex broadcasting logic within the model itself. # TODO: it might be better to define the noise as a 'global' graph attribute, and find 'the right way' of @@ -122,12 +143,11 @@ def __init__( interaction_cls: Type[InteractionBlock], interaction_cls_first: Type[InteractionBlock], num_interactions: int, - num_elements: int, + num_classes: int, hidden_irreps: o3.Irreps, mlp_irreps: o3.Irreps, number_of_mlp_layers: int, avg_num_neighbors: float, - atomic_numbers: List[int], correlation: Union[int, List[int]], gate: Optional[Callable], radial_MLP: List[int], @@ -137,16 +157,7 @@ def __init__( tanh_after_interaction: bool = True, ): """Init method.""" - assert ( - num_elements == 1 - ), "only a single element can be used at this time. Set 'num_elements' to 1." - assert ( - len(atomic_numbers) == 1 - ), "only a single element can be used at this time. Set 'atomic_numbers' to length 1." super().__init__() - self.register_buffer( - "atomic_numbers", torch.tensor(atomic_numbers, dtype=torch.int64) - ) self.register_buffer( "r_max", torch.tensor(r_max, dtype=torch.get_default_dtype()) ) @@ -166,7 +177,8 @@ def __init__( # define the "0e" representation as a constant to avoid "magic numbers" below. scalar_irrep = o3.Irrep(0, 1) - # Apply an MLP with a bias on the scalar diffusion time-like input. + # An MLP will be used to mix the diffusion time-like input (the 'diffusion scalar', a global quantity) and + # the 1-hot atom type (the 'node scalars') number_of_node_scalar_dimensions = 1 number_of_hidden_diffusion_scalar_dimensions = mlp_irreps.count(scalar_irrep) @@ -196,7 +208,7 @@ def __init__( self.diffusion_scalar_embedding.append(linear) # The node_attr is the one-hot version of the atom types. - node_attr_irreps = o3.Irreps([(num_elements, scalar_irrep)]) + node_attr_irreps = o3.Irreps([(num_classes, scalar_irrep)]) # Perform a tensor product to mix the diffusion scalar and node attributes self.attribute_mixing = o3.FullyConnectedTensorProduct( @@ -298,7 +310,7 @@ def __init__( node_feats_irreps=node_feats_irreps_out, target_irreps=hidden_irreps, correlation=correlation[0], - num_elements=num_elements, + num_elements=num_classes, use_sc=use_sc_first, ) self.products = torch.nn.ModuleList([prod]) @@ -333,7 +345,7 @@ def __init__( node_feats_irreps=interaction_irreps, target_irreps=hidden_irreps_out, correlation=correlation[i + 1], - num_elements=num_elements, + num_elements=num_classes, use_sc=True, ) self.products.append(prod) @@ -345,6 +357,11 @@ def __init__( # the output is a single vector. self.vector_readout = LinearVectorReadoutBlock(irreps_in=hidden_irreps_out) + # and an output for atom classification + self.classification_readout = LinearClassificationReadoutBlock( + irreps_in=hidden_irreps_out, num_classes=num_classes + ) + # Apply a MLP with a bias on the forces as a conditional feature. This would be a 1o irrep forces_irreps_in = o3.Irreps("1x1o") forces_irreps_embedding = o3.Irreps(f"{condition_embedding_size}x1o") @@ -362,12 +379,9 @@ def __init__( ) self.conditional_layers.append(cond_layer) - def forward( - self, data: Dict[str, torch.Tensor], conditional: bool = False - ) -> torch.Tensor: + def forward(self, data: Dict[str, torch.Tensor], conditional: bool = False) -> AXL: """Forward method.""" # Setup - # Augment the node attributes with information from the diffusion scalar. diffusion_scalar_embeddings = self.diffusion_scalar_embedding( data["node_diffusion_scalars"] @@ -438,4 +452,10 @@ def forward( # Outputs vectors_output = self.vector_readout(node_feats) - return vectors_output + classification_output = self.classification_readout(node_feats) + axl_output = AXL( + A=classification_output, + X=vectors_output, + L=torch.zeros_like(classification_output), + ) + return axl_output diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/models/egnn.py b/src/diffusion_for_multi_scale_molecular_dynamics/models/egnn.py index 86befaa9..ffb77df2 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/models/egnn.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/models/egnn.py @@ -15,6 +15,7 @@ from diffusion_for_multi_scale_molecular_dynamics.models.egnn_utils import ( unsorted_segment_mean, unsorted_segment_sum) +from diffusion_for_multi_scale_molecular_dynamics.namespace import AXL class E_GCL(nn.Module): @@ -269,6 +270,7 @@ class EGNN(nn.Module): def __init__( self, input_size: int, + num_classes: int, message_n_hidden_dimensions: int, message_hidden_dimensions_size: int, node_n_hidden_dimensions: int, @@ -288,6 +290,7 @@ def __init__( Args: input_size: number of node features in the input + num_classes: number of atom types uses for the final node embedding - including the MASK class. message_n_hidden_dimensions: number of hidden layers of the message (edge) MLP message_hidden_dimensions_size: size of the hidden layers of the message (edge) MLP node_n_hidden_dimensions: number of hidden layers of the node update MLP @@ -308,6 +311,9 @@ def __init__( self.n_layers = n_layers self.embedding_in = nn.Linear(input_size, node_hidden_dimensions_size) self.graph_layers = nn.ModuleList([]) + self.node_classification_layer = nn.Linear( + node_hidden_dimensions_size, num_classes + ) for _ in range(0, n_layers): self.graph_layers.append( E_GCL( @@ -329,9 +335,7 @@ def __init__( ) ) - def forward( - self, h: torch.Tensor, edges: torch.Tensor, x: torch.Tensor - ) -> torch.Tensor: + def forward(self, h: torch.Tensor, edges: torch.Tensor, x: torch.Tensor) -> AXL: """Forward instructions for the model. Args: @@ -340,9 +344,18 @@ def forward( x: node coordinates. size is number of nodes, spatial dimension Returns: - estimated score. size is number of nodes, spatial dimension + estimated score in an AXL namedtuple. + coordinates: size is number of nodes, spatial dimension + atom types: number of nodes, number of atomic species + 1 (for MASK) + lattice: number of nodes, spatial dimension * (spatial dimension - 1) TODO """ h = self.embedding_in(h) for graph_layer in self.graph_layers: h, x = graph_layer(h, edges, x) - return x + node_classification_logits = self.node_classification_layer(h) + model_outputs = AXL( + A=node_classification_logits, + X=x, + L=torch.zeros_like(x), + ) + return model_outputs diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/models/instantiate_diffusion_model.py b/src/diffusion_for_multi_scale_molecular_dynamics/models/instantiate_diffusion_model.py index d6ff405f..695468ce 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/models/instantiate_diffusion_model.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/models/instantiate_diffusion_model.py @@ -3,27 +3,27 @@ import logging from typing import Any, AnyStr, Dict -from diffusion_for_multi_scale_molecular_dynamics.models.loss import \ +from diffusion_for_multi_scale_molecular_dynamics.loss.loss_parameters import \ create_loss_parameters +from diffusion_for_multi_scale_molecular_dynamics.models.axl_diffusion_lightning_model import ( + AXLDiffusionLightningModel, AXLDiffusionParameters) from diffusion_for_multi_scale_molecular_dynamics.models.optimizer import \ create_optimizer_parameters -from diffusion_for_multi_scale_molecular_dynamics.models.position_diffusion_lightning_model import ( - PositionDiffusionLightningModel, PositionDiffusionParameters) from diffusion_for_multi_scale_molecular_dynamics.models.scheduler import \ create_scheduler_parameters from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.score_network_factory import \ create_score_network_parameters -from diffusion_for_multi_scale_molecular_dynamics.samplers.variance_sampler import \ +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ NoiseParameters -from diffusion_for_multi_scale_molecular_dynamics.samples.diffusion_sampling_parameters import \ +from diffusion_for_multi_scale_molecular_dynamics.oracle.energy_oracle_factory import \ + create_energy_oracle_parameters +from diffusion_for_multi_scale_molecular_dynamics.sampling.diffusion_sampling_parameters import \ load_diffusion_sampling_parameters logger = logging.getLogger(__name__) -def load_diffusion_model( - hyper_params: Dict[AnyStr, Any] -) -> PositionDiffusionLightningModel: +def load_diffusion_model(hyper_params: Dict[AnyStr, Any]) -> AXLDiffusionLightningModel: """Load a position diffusion model from the hyperparameters. Args: @@ -32,9 +32,11 @@ def load_diffusion_model( Returns: Diffusion model randomly initialized """ + elements = hyper_params["elements"] globals_dict = dict( max_atom=hyper_params["data"]["max_atom"], spatial_dimension=hyper_params.get("spatial_dimension", 3), + elements=elements ) score_network_dict = hyper_params["model"]["score_network"] @@ -55,16 +57,21 @@ def load_diffusion_model( diffusion_sampling_parameters = load_diffusion_sampling_parameters(hyper_params) - diffusion_params = PositionDiffusionParameters( + oracle_parameters = None + if "oracle" in hyper_params: + oracle_parameters = create_energy_oracle_parameters(hyper_params["oracle"], elements) + + diffusion_params = AXLDiffusionParameters( score_network_parameters=score_network_parameters, loss_parameters=loss_parameters, optimizer_parameters=optimizer_parameters, scheduler_parameters=scheduler_parameters, noise_parameters=noise_parameters, diffusion_sampling_parameters=diffusion_sampling_parameters, + oracle_parameters=oracle_parameters ) - model = PositionDiffusionLightningModel(diffusion_params) + model = AXLDiffusionLightningModel(diffusion_params) logger.info("model info:\n" + str(model) + "\n") return model diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/models/mace_utils.py b/src/diffusion_for_multi_scale_molecular_dynamics/models/mace_utils.py index 1a7e1595..1b0f59d9 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/models/mace_utils.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/models/mace_utils.py @@ -36,7 +36,7 @@ def input_to_mace(x: Dict[AnyStr, torch.Tensor], radial_cutoff: float) -> Data: # TODO handle different atom types node_attrs = torch.nn.functional.one_hot( (torch.ones(batch_size * n_atom_per_graph) * 14).long(), num_classes=89 - ).float() + ).to(noisy_cartesian_positions) flat_positions = noisy_cartesian_positions.view( -1, spatial_dimension ) # [batchsize * natoms, spatial dimension] diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/models/normalized_score_fokker_planck_error.py b/src/diffusion_for_multi_scale_molecular_dynamics/models/normalized_score_fokker_planck_error.py deleted file mode 100644 index a3f2949b..00000000 --- a/src/diffusion_for_multi_scale_molecular_dynamics/models/normalized_score_fokker_planck_error.py +++ /dev/null @@ -1,264 +0,0 @@ -from typing import Callable - -import einops -import torch -from torch.func import jacrev - -from diffusion_for_multi_scale_molecular_dynamics.models.score_networks import \ - ScoreNetwork -from diffusion_for_multi_scale_molecular_dynamics.namespace import ( - CARTESIAN_FORCES, NOISE, NOISY_RELATIVE_COORDINATES, TIME, UNIT_CELL) -from diffusion_for_multi_scale_molecular_dynamics.samplers.exploding_variance import \ - ExplodingVariance -from diffusion_for_multi_scale_molecular_dynamics.samplers.variance_sampler import \ - NoiseParameters - - -class NormalizedScoreFokkerPlanckError(torch.nn.Module): - """Class to calculate the Normalized Score Fokker Planck Error. - - This concept is defined in the paper: - "FP-Diffusion: Improving Score-based Diffusion Models by Enforcing the Underlying Score Fokker-Planck Equation" - - The Fokker-Planck equation, which is applicable to the time-dependent probability distribution, is generalized - to an ODE that the score should satisfy. The departure from satisfying this equation thus defines the FP error. - - The score Fokker-Planck equation is defined as: - - d S(x, t) / dt = 1/2 g(t)^2 nabla [ nabla.S(x,t) + |S(x,t)|^2] - - where S(x, t) is the score. Define the Normalized Score as N(x, t) == sigma(t) S(x, t), the equation above - becomes - - d N(x, t) / dt = sigma_dot(t) / sigma(t) N(x, t) + sigma_dot(t) nabla [ sigma(t) nabla. N(x,t) + |N(x,t)|^2] - - where is it assumed that g(t)^2 == 2 sigma(t) sigma_dot(t). - - The great advantage of this approach is that it only requires knowledge of the normalized score - (and its derivative), which is the quantity we seek to learn. - """ - - def __init__( - self, - sigma_normalized_score_network: ScoreNetwork, - noise_parameters: NoiseParameters, - ): - """Init method.""" - super().__init__() - - self.exploding_variance = ExplodingVariance(noise_parameters) - self.sigma_normalized_score_network = sigma_normalized_score_network - - def _normalized_scores_function( - self, - relative_coordinates: torch.Tensor, - times: torch.Tensor, - unit_cells: torch.Tensor, - ) -> torch.Tensor: - """Normalized Scores Function. - - This method computes the normalized score, as defined by the sigma_normalized_score_network. - - Args: - relative_coordinates : relative coordinates. Dimensions : [batch_size, number_of_atoms, spatial_dimension]. - times : diffusion times. Dimensions : [batch_size, 1]. - unit_cells : unit cells. Dimensions : [batch_size, spatial_dimension, spatial_dimension]. - - Returns: - normalized scores: the scores for given input. - Dimensions : [batch_size, number_of_atoms, spatial_dimension]. - """ - forces = torch.zeros_like(relative_coordinates) - sigmas = self.exploding_variance.get_sigma(times) - - augmented_batch = { - NOISY_RELATIVE_COORDINATES: relative_coordinates, - TIME: times, - NOISE: sigmas, - UNIT_CELL: unit_cells, - CARTESIAN_FORCES: forces, - } - - sigma_normalized_scores = self.sigma_normalized_score_network( - augmented_batch, conditional=False - ) - - return sigma_normalized_scores - - def _normalized_scores_square_norm_function( - self, - relative_coordinates: torch.Tensor, - times: torch.Tensor, - unit_cells: torch.Tensor, - ) -> torch.Tensor: - """Normalized Scores Square Norm Function. - - This method computes the square norm of the normalized score, as defined - by the sigma_normalized_score_network. - - Args: - relative_coordinates : relative coordinates. Dimensions : [batch_size, number_of_atoms, spatial_dimension]. - times : diffusion times. Dimensions : [batch_size, 1]. - unit_cells : unit cells. Dimensions : [batch_size, spatial_dimension, spatial_dimension]. - - Returns: - normalized_scores_square_norm: |normalized scores|^2. Dimension: [batch_size]. - """ - normalized_scores = self._normalized_scores_function( - relative_coordinates, times, unit_cells - ) - - flat_scores = einops.rearrange( - normalized_scores, - "batch natoms spatial_dimension -> batch (natoms spatial_dimension)", - ) - square_norms = (flat_scores**2).sum(dim=1) - return square_norms - - def _get_dn_dt( - self, - relative_coordinates: torch.Tensor, - times: torch.Tensor, - unit_cells: torch.Tensor, - ) -> torch.Tensor: - """Compute the time derivative of the normalized score.""" - # "_normalized_scores_function" is a Callable, with time as its second argument (index = 1) - time_jacobian_function = jacrev(self._normalized_scores_function, argnums=1) - - # Computing the Jacobian returns an array of dimension [batch_size, natoms, space, batch_size, 1] - time_jacobian = time_jacobian_function(relative_coordinates, times, unit_cells) - - # Only the "diagonal" along the batch dimensions is meaningful. - # Also, squeeze out the needless last 'time' dimension. - batch_diagonal = torch.diagonal(time_jacobian.squeeze(-1), dim1=0, dim2=3) - - # torch.diagonal puts the diagonal dimension (here, the batch index) at the end. Bring it back to the front. - dn_dt = einops.rearrange( - batch_diagonal, "natoms space batch -> batch natoms space" - ) - - return dn_dt - - def _get_gradient( - self, - scalar_function: Callable, - relative_coordinates: torch.Tensor, - times: torch.Tensor, - unit_cells: torch.Tensor, - ) -> torch.Tensor: - """Compute the gradient of the provided scalar function.""" - # We cannot use the "grad" function because our "scalar" function actually returns one value per batch entry. - grad_function = jacrev(scalar_function, argnums=0) - - # Gradients have dimension [batch_size, batch_size, natoms, spatial_dimension] - overbatched_gradients = grad_function(relative_coordinates, times, unit_cells) - - batch_diagonal = torch.diagonal(overbatched_gradients, dim1=0, dim2=1) - - # torch.diagonal puts the diagonal dimension (here, the batch index) at the end. Bring it back to the front. - gradients = einops.rearrange( - batch_diagonal, "natoms space batch -> batch natoms space" - ) - return gradients - - def _divergence_function( - self, - relative_coordinates: torch.Tensor, - times: torch.Tensor, - unit_cells: torch.Tensor, - ) -> torch.Tensor: - """Compute the divergence of the normalized score.""" - # "_normalized_scores_function" is a Callable, with space as its zeroth argument - space_jacobian_function = jacrev(self._normalized_scores_function, argnums=0) - - # Computing the Jacobian returns an array of dimension [batch_size, natoms, space, batch_size, natoms, space] - space_jacobian = space_jacobian_function( - relative_coordinates, times, unit_cells - ) - - # Take only the diagonal batch term. "torch.diagonal" puts the batch index at the end... - batch_diagonal = torch.diagonal(space_jacobian, dim1=0, dim2=3) - - flat_jacobian = einops.rearrange( - batch_diagonal, - "natoms1 space1 natoms2 space2 batch " - "-> batch (natoms1 space1) (natoms2 space2)", - ) - - # take the trace of the Jacobian to get the divergence. - divergence = torch.vmap(torch.trace)(flat_jacobian) - return divergence - - def get_normalized_score_fokker_planck_error( - self, - relative_coordinates: torch.Tensor, - times: torch.Tensor, - unit_cells: torch.Tensor, - ) -> torch.Tensor: - """Get Normalized Score Fokker-Planck Error. - - Args: - relative_coordinates : relative coordinates. Dimensions : [batch_size, number_of_atoms, spatial_dimension]. - times : diffusion times. Dimensions : [batch_size, 1]. - unit_cells : unit cells. Dimensions : [batch_size, spatial_dimension, spatial_dimension]. - - Returns: - FP_error: how much the normalized score Fokker-Planck equation is violated. - Dimensions : [batch_size, spatial_dimension, spatial_dimension]. - """ - batch_size, natoms, spatial_dimension = relative_coordinates.shape - - sigmas = einops.repeat( - self.exploding_variance.get_sigma(times), - "batch 1 -> batch natoms space", - natoms=natoms, - space=spatial_dimension, - ) - - dot_sigmas = einops.repeat( - self.exploding_variance.get_sigma_time_derivative(times), - "batch 1 -> batch natoms space", - natoms=natoms, - space=spatial_dimension, - ) - - n = self._normalized_scores_function(relative_coordinates, times, unit_cells) - - dn_dt = self._get_dn_dt(relative_coordinates, times, unit_cells) - - grad_n2 = self._get_gradient( - self._normalized_scores_square_norm_function, - relative_coordinates, - times, - unit_cells, - ) - - grad_div_n = self._get_gradient( - self._divergence_function, relative_coordinates, times, unit_cells - ) - - fp_errors = ( - dn_dt - - dot_sigmas / sigmas * n - - sigmas * dot_sigmas * grad_div_n - - dot_sigmas * grad_n2 - ) - - return fp_errors - - def get_normalized_score_fokker_planck_error_by_iterating_over_batch( - self, - relative_coordinates: torch.Tensor, - times: torch.Tensor, - unit_cells: torch.Tensor, - ) -> torch.Tensor: - """Get the error by iterating over the elements of the batch.""" - list_errors = [] - for x, t, c in zip(relative_coordinates, times, unit_cells): - # Iterate over the elements of the batch. In effect, compute over "batch_size = 1" tensors. - errors = self.get_normalized_score_fokker_planck_error( - x.unsqueeze(0), t.unsqueeze(0), c.unsqueeze(0) - ).squeeze(0) - list_errors.append(errors) - - return torch.stack(list_errors) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/analytical_score_network.py b/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/analytical_score_network.py index fa7540c6..67a65814 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/analytical_score_network.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/analytical_score_network.py @@ -21,7 +21,7 @@ from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.score_network import ( ScoreNetwork, ScoreNetworkParameters) from diffusion_for_multi_scale_molecular_dynamics.namespace import ( - NOISE, NOISY_RELATIVE_COORDINATES) + AXL, NOISE, NOISY_AXL_COMPOSITION) from diffusion_for_multi_scale_molecular_dynamics.score.wrapped_gaussian_score import \ get_sigma_normalized_score from diffusion_for_multi_scale_molecular_dynamics.utils.basis_transformations import \ @@ -63,6 +63,10 @@ def __init__(self, hyper_params: AnalyticalScoreNetworkParameters): """ super(AnalyticalScoreNetwork, self).__init__(hyper_params) + assert hyper_params.num_atom_types == 1, \ + "The analytical score network is only appropriate for a single atom type." + + self.number_of_atomic_classes = hyper_params.num_atom_types + 1 # account for the MASK class. self.natoms = hyper_params.number_of_atoms self.spatial_dimension = hyper_params.spatial_dimension self.nd = self.natoms * self.spatial_dimension @@ -123,7 +127,7 @@ def _get_all_equilibrium_permutations( def _forward_unchecked( self, batch: Dict[AnyStr, Any], conditional: bool = False - ) -> torch.Tensor: + ) -> AXL: """Forward unchecked. This method assumes that the input data has already been checked with respect to expectations @@ -134,10 +138,13 @@ def _forward_unchecked( conditional (optional): CURRENTLY DOES NOTHING. Returns: - output : the scores computed by the model as a [batch_size, n_atom, spatial_dimension] tensor. + output : an AXL namedtuple with the coordinates scores computed by the model as a + [batch_size, n_atom, spatial_dimension] tensor. Empty tensors are returned for the atom types and + lattice. """ sigmas = batch[NOISE] # dimension: [batch_size, 1] - xt = batch[NOISY_RELATIVE_COORDINATES] + xt = batch[NOISY_AXL_COMPOSITION].X + batch_size = xt.shape[0] xt.requires_grad_(True) list_unnormalized_log_prob = [] @@ -162,7 +169,17 @@ def _forward_unchecked( ) sigma_normalized_scores = broadcast_sigmas * scores - return sigma_normalized_scores + # Mimic perfect predictions of single possible atomic type. + atomic_logits = torch.zeros(batch_size, self.natoms, self.number_of_atomic_classes) + atomic_logits[..., -1] = -torch.inf + + axl_scores = AXL( + A=atomic_logits, + X=sigma_normalized_scores, + L=torch.zeros_like(sigma_normalized_scores), + ) + + return axl_scores def _compute_unnormalized_log_probability( self, sigmas: torch.Tensor, xt: torch.Tensor, x_eq: torch.Tensor @@ -246,7 +263,8 @@ def _forward_unchecked( output : the scores computed by the model as a [batch_size, n_atom, spatial_dimension] tensor. """ sigmas = batch[NOISE] # dimension: [batch_size, 1] - xt = batch[NOISY_RELATIVE_COORDINATES] + xt = batch[NOISY_AXL_COMPOSITION].X + batch_size = xt.shape[0] broadcast_sigmas = einops.repeat( sigmas, @@ -266,4 +284,14 @@ def _forward_unchecked( broadcast_sigmas / broadcast_effective_sigmas * misnormalized_scores ) - return sigma_normalized_scores + # Mimic perfect predictions of single possible atomic type. + atomic_logits = torch.zeros(batch_size, self.natoms, self.number_of_atomic_classes) + atomic_logits[..., -1] = -torch.inf + + axl_scores = AXL( + A=atomic_logits, + X=sigma_normalized_scores, + L=torch.zeros_like(sigma_normalized_scores), + ) + + return axl_scores diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/diffusion_mace_score_network.py b/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/diffusion_mace_score_network.py index 3120e58a..6012e614 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/diffusion_mace_score_network.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/diffusion_mace_score_network.py @@ -1,6 +1,7 @@ from dataclasses import dataclass, field from typing import AnyStr, Dict, List +import einops import torch from e3nn import o3 from mace.modules import gate_dict, interaction_classes @@ -11,9 +12,9 @@ from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.score_network import ( ScoreNetwork, ScoreNetworkParameters) from diffusion_for_multi_scale_molecular_dynamics.namespace import ( - NOISY_CARTESIAN_POSITIONS, NOISY_RELATIVE_COORDINATES, UNIT_CELL) -from diffusion_for_multi_scale_molecular_dynamics.utils.basis_transformations import ( - get_positions_from_coordinates, get_reciprocal_basis_vectors) + AXL, NOISY_AXL_COMPOSITION, NOISY_CARTESIAN_POSITIONS, UNIT_CELL) +from diffusion_for_multi_scale_molecular_dynamics.utils.basis_transformations import \ + get_positions_from_coordinates @dataclass(kw_only=True) @@ -22,7 +23,6 @@ class DiffusionMACEScoreNetworkParameters(ScoreNetworkParameters): architecture: str = "diffusion_mace" number_of_atoms: int # the number of atoms in a configuration. - number_of_elements: int = 1 # The number of distinct elements present r_max: float = 5.0 num_bessel: int = 8 num_polynomial_cutoff: int = 5 @@ -76,6 +76,8 @@ def __init__(self, hyper_params: DiffusionMACEScoreNetworkParameters): self.r_max = hyper_params.r_max self.collate_fn = Collater(follow_batch=[None], exclude_keys=[None]) + # we removed atomic_numbers from the mace_config which breaks the compatibility with pre-trained MACE + # this is necessary for the diffusion with masked atoms diffusion_mace_config = dict( r_max=hyper_params.r_max, num_bessel=hyper_params.num_bessel, @@ -88,12 +90,12 @@ def __init__(self, hyper_params: DiffusionMACEScoreNetworkParameters): hyper_params.interaction_cls_first ], num_interactions=hyper_params.num_interactions, - num_elements=hyper_params.number_of_elements, + num_classes=hyper_params.num_atom_types + + 1, # we need the model to work with the MASK token as well hidden_irreps=o3.Irreps(hyper_params.hidden_irreps), mlp_irreps=o3.Irreps(hyper_params.mlp_irreps), number_of_mlp_layers=hyper_params.number_of_mlp_layers, avg_num_neighbors=hyper_params.avg_num_neighbors, - atomic_numbers=[14], # TODO: revisit this when we have multi-atom types correlation=hyper_params.correlation, gate=gate_dict[hyper_params.gate], radial_MLP=hyper_params.radial_MLP, @@ -104,20 +106,19 @@ def __init__(self, hyper_params: DiffusionMACEScoreNetworkParameters): ) self._natoms = hyper_params.number_of_atoms - self._number_of_elements = hyper_params.number_of_elements self.diffusion_mace_network = DiffusionMACE(**diffusion_mace_config) def _check_batch(self, batch: Dict[AnyStr, torch.Tensor]): super(DiffusionMACEScoreNetwork, self)._check_batch(batch) - number_of_atoms = batch[NOISY_RELATIVE_COORDINATES].shape[1] + number_of_atoms = batch[NOISY_AXL_COMPOSITION].X.shape[1] assert ( number_of_atoms == self._natoms ), "The dimension corresponding to the number of atoms is not consistent with the configuration." def _forward_unchecked( self, batch: Dict[AnyStr, torch.Tensor], conditional: bool = False - ) -> torch.Tensor: + ) -> AXL: """Forward unchecked. This method assumes that the input data has already been checked with respect to expectations @@ -129,25 +130,43 @@ def _forward_unchecked( Defaults to False. Returns: - output : the scores computed by the model as a [batch_size, n_atom, spatial_dimension] tensor. + output : the scores computed by the model as a AXL + coordinates: [batch_size, n_atom, spatial_dimension] tensor. + atom types: [batch_size, n_atom, num_atom_types + 1] tensor. + lattice: [batch_size, n_atom, spatial_dimension * (spatial_dimension -1)] tensor. """ - relative_coordinates = batch[NOISY_RELATIVE_COORDINATES] + relative_coordinates = batch[NOISY_AXL_COMPOSITION].X batch_size, number_of_atoms, spatial_dimension = relative_coordinates.shape - basis_vectors = batch[UNIT_CELL] + basis_vectors = batch[UNIT_CELL] # TODO replace with AXL L batch[NOISY_CARTESIAN_POSITIONS] = get_positions_from_coordinates( relative_coordinates, basis_vectors ) - graph_input = input_to_diffusion_mace(batch, radial_cutoff=self.r_max) + graph_input = input_to_diffusion_mace( + batch, radial_cutoff=self.r_max, num_classes=self.num_atom_types + 1 + ) - flat_cartesian_scores = self.diffusion_mace_network(graph_input, conditional) + mace_axl_scores = self.diffusion_mace_network(graph_input, conditional) + flat_cartesian_scores = mace_axl_scores.X cartesian_scores = flat_cartesian_scores.reshape( batch_size, number_of_atoms, spatial_dimension ) - reciprocal_basis_vectors_as_columns = get_reciprocal_basis_vectors( - basis_vectors + # basis_vectors is composed of ROWS of basis vectors + coordinates_scores = einops.einsum( + basis_vectors, + cartesian_scores, + "batch i alpha, batch natoms alpha -> batch natoms i", + ) + + atom_types_scores = mace_axl_scores.A.reshape( + batch_size, number_of_atoms, self.num_atom_types + 1 + ) + + axl_scores = AXL( + A=atom_types_scores, + X=coordinates_scores, + L=torch.zeros_like(atom_types_scores), ) - scores = torch.bmm(cartesian_scores, reciprocal_basis_vectors_as_columns) - return scores + return axl_scores diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/egnn_score_network.py b/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/egnn_score_network.py index dd49c531..1068ba5c 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/egnn_score_network.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/egnn_score_network.py @@ -12,7 +12,9 @@ from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.score_network import \ ScoreNetwork from diffusion_for_multi_scale_molecular_dynamics.namespace import ( - NOISE, NOISY_RELATIVE_COORDINATES, UNIT_CELL) + AXL, NOISE, NOISY_AXL_COMPOSITION, UNIT_CELL) +from diffusion_for_multi_scale_molecular_dynamics.utils.d3pm_utils import \ + class_index_to_onehot @dataclass(kw_only=True) @@ -53,8 +55,11 @@ def __init__(self, hyper_params: EGNNScoreNetworkParameters): """ super(EGNNScoreNetwork, self).__init__(hyper_params) - self.number_of_features_per_node = 1 self.spatial_dimension = hyper_params.spatial_dimension + self.num_atom_types = hyper_params.num_atom_types + self.number_of_features_per_node = ( + self.num_atom_types + 2 + ) # +1 for MASK class, + 1 for sigma projection_matrices = self._create_block_diagonal_projection_matrices( self.spatial_dimension @@ -98,6 +103,7 @@ def __init__(self, hyper_params: EGNNScoreNetworkParameters): coords_agg=hyper_params.coords_agg, message_agg=hyper_params.message_agg, n_layers=hyper_params.n_layers, + num_classes=self.num_atom_types + 1, ) @staticmethod @@ -137,24 +143,36 @@ def _create_block_diagonal_projection_matrices( return torch.stack(projection_matrices) @staticmethod - def _get_node_attributes(batch: Dict[AnyStr, torch.Tensor]) -> torch.Tensor: + def _get_node_attributes( + batch: Dict[AnyStr, torch.Tensor], num_atom_types: int + ) -> torch.Tensor: """Get node attributes. - This method extracts the node atttributes, "h", to be fed as input to the EGNN network. + This method extracts the node attributes, "h", to be fed as input to the EGNN network. Args: batch : the batch dictionary + num_atom_types: number of atom types excluding the MASK token Returns: - node_attributes: a tensor of dimension [number_of_nodes, number_for_features_per_node] + node_attributes: a tensor of dimension [batch, natoms, num_atom_types + 2] """ - relative_coordinates = batch[NOISY_RELATIVE_COORDINATES] + relative_coordinates = batch[NOISY_AXL_COMPOSITION].X batch_size, number_of_atoms, spatial_dimension = relative_coordinates.shape sigmas = batch[NOISE].to(relative_coordinates.device) repeated_sigmas = einops.repeat( sigmas, "batch 1 -> (batch natoms) 1", natoms=number_of_atoms ) - return repeated_sigmas + + atom_types = batch[NOISY_AXL_COMPOSITION].A + atom_types_one_hot = class_index_to_onehot( + atom_types, num_classes=num_atom_types + 1 + ) + + node_attributes = torch.concatenate( + (repeated_sigmas, atom_types_one_hot.view(-1, num_atom_types + 1)), dim=1 + ) + return node_attributes @staticmethod def _get_euclidean_positions( @@ -184,8 +202,8 @@ def _get_euclidean_positions( def _forward_unchecked( self, batch: Dict[AnyStr, torch.Tensor], conditional: bool = False - ) -> torch.Tensor: - relative_coordinates = batch[NOISY_RELATIVE_COORDINATES] + ) -> AXL: + relative_coordinates = batch[NOISY_AXL_COMPOSITION].X batch_size, number_of_atoms, spatial_dimension = relative_coordinates.shape if self.edges == "fully_connected": @@ -209,7 +227,9 @@ def _forward_unchecked( # Dimensions [number_of_nodes, 2 x spatial_dimension] euclidean_positions = self._get_euclidean_positions(flat_relative_coordinates) - node_attributes_h = self._get_node_attributes(batch) + node_attributes_h = self._get_node_attributes( + batch, num_atom_types=self.num_atom_types + ) # The raw normalized score has dimensions [number_of_nodes, 2 x spatial_dimension] # CAREFUL! It is important to pass a clone of the euclidian positions because EGNN will modify its input! raw_normalized_score = self.egnn( @@ -226,7 +246,7 @@ def _forward_unchecked( flat_normalized_scores = einops.einsum( euclidean_positions, self.projection_matrices, - raw_normalized_score, + raw_normalized_score.X, "nodes i, alpha i j, nodes j-> nodes alpha", ) @@ -236,4 +256,18 @@ def _forward_unchecked( batch=batch_size, natoms=number_of_atoms, ) - return normalized_scores + + atom_reshaped_scores = einops.rearrange( + raw_normalized_score.A, + "(batch natoms) num_classes -> batch natoms num_classes", + batch=batch_size, + natoms=number_of_atoms, + ) + + axl_scores = AXL( + A=atom_reshaped_scores, + X=normalized_scores, + L=raw_normalized_score.L, + ) + + return axl_scores diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/force_field_augmented_score_network.py b/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/force_field_augmented_score_network.py index 4008e4d9..382b10a0 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/force_field_augmented_score_network.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/force_field_augmented_score_network.py @@ -7,7 +7,7 @@ from diffusion_for_multi_scale_molecular_dynamics.models.score_networks import \ ScoreNetwork from diffusion_for_multi_scale_molecular_dynamics.namespace import ( - NOISY_RELATIVE_COORDINATES, UNIT_CELL) + AXL, NOISY_AXL_COMPOSITION, UNIT_CELL) from diffusion_for_multi_scale_molecular_dynamics.utils.basis_transformations import ( get_positions_from_coordinates, get_reciprocal_basis_vectors, get_relative_coordinates_from_cartesian_positions) @@ -57,7 +57,7 @@ def __init__( def forward( self, batch: Dict[AnyStr, torch.Tensor], conditional: Optional[bool] = None - ) -> torch.Tensor: + ) -> AXL: """Model forward. Args: @@ -70,7 +70,8 @@ def forward( """ raw_scores = self._score_network(batch, conditional) forces = self.get_relative_coordinates_pseudo_force(batch) - return raw_scores + forces + updated_scores = AXL(A=raw_scores.A, X=raw_scores.X + forces, L=raw_scores.L) + return updated_scores def _get_cartesian_pseudo_forces_contributions( self, cartesian_displacements: torch.Tensor @@ -109,7 +110,7 @@ def _get_adjacency_information( self, batch: Dict[AnyStr, torch.Tensor] ) -> AdjacencyInfo: basis_vectors = batch[UNIT_CELL] - relative_coordinates = batch[NOISY_RELATIVE_COORDINATES] + relative_coordinates = batch[NOISY_AXL_COMPOSITION].X cartesian_positions = get_positions_from_coordinates( relative_coordinates, basis_vectors ) @@ -132,8 +133,8 @@ def _get_cartesian_displacements( bch = adj_info.edge_batch_indices src, dst = adj_info.adjacency_matrix - relative_coordinates = batch[NOISY_RELATIVE_COORDINATES] - basis_vectors = batch[UNIT_CELL] + relative_coordinates = batch[NOISY_AXL_COMPOSITION].X + basis_vectors = batch[UNIT_CELL] # TODO replace with AXL L cartesian_positions = get_positions_from_coordinates( relative_coordinates, basis_vectors ) @@ -159,7 +160,7 @@ def _get_cartesian_pseudo_forces( bch = adj_info.edge_batch_indices src, dst = adj_info.adjacency_matrix - batch_size, natoms, spatial_dimension = batch[NOISY_RELATIVE_COORDINATES].shape + batch_size, natoms, spatial_dimension = batch[NOISY_AXL_COMPOSITION].X.shape # Combine the bch and src index into a single global index node_idx = natoms * bch + src @@ -207,7 +208,7 @@ def get_relative_coordinates_pseudo_force( cartesian_pseudo_force_contributions, adj_info, batch ) - basis_vectors = batch[UNIT_CELL] + basis_vectors = batch[UNIT_CELL] # TODO replace with AXL L reciprocal_basis_vectors = get_reciprocal_basis_vectors(basis_vectors) relative_pseudo_forces = get_relative_coordinates_from_cartesian_positions( cartesian_pseudo_forces, reciprocal_basis_vectors diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/mace_score_network.py b/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/mace_score_network.py index 998e428e..73c38298 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/mace_score_network.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/mace_score_network.py @@ -1,6 +1,7 @@ from dataclasses import dataclass, field from typing import AnyStr, Dict, List, Optional +import einops import numpy as np import torch from e3nn import o3 @@ -14,9 +15,10 @@ from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.score_network import ( ScoreNetwork, ScoreNetworkParameters) from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.score_prediction_head import ( - MaceScorePredictionHeadParameters, instantiate_mace_prediction_head) + MaceMLPScorePredictionHeadParameters, MaceScorePredictionHeadParameters, + instantiate_mace_prediction_head) from diffusion_for_multi_scale_molecular_dynamics.namespace import ( - NOISY_CARTESIAN_POSITIONS, NOISY_RELATIVE_COORDINATES, TIME, UNIT_CELL) + AXL, NOISY_AXL_COMPOSITION, NOISY_CARTESIAN_POSITIONS, TIME, UNIT_CELL) @dataclass(kw_only=True) @@ -50,6 +52,8 @@ class MACEScoreNetworkParameters(ScoreNetworkParameters): radial_type: str = ( "bessel" # type of radial basis functions - choices=["bessel", "gaussian", "chebyshev"] ) + atom_type_head_hidden_size: int = 64 + atom_type_head_n_hidden_layers: int = 2 prediction_head_parameters: MaceScorePredictionHeadParameters @@ -119,20 +123,31 @@ def __init__(self, hyper_params: MACEScoreNetworkParameters): ), "Something is wrong with pretrained dimensions." self.mace_output_size = output_node_features_irreps.dim - self.prediction_head = instantiate_mace_prediction_head( + self.coordinates_prediction_head = instantiate_mace_prediction_head( output_node_features_irreps, hyper_params.prediction_head_parameters ) + atom_type_prediction_head_parameters = MaceMLPScorePredictionHeadParameters( + name="mlp", + hidden_dimensions_size=hyper_params.atom_type_head_hidden_size, + n_hidden_dimensions=hyper_params.atom_type_head_n_hidden_layers, + spatial_dimension=self.num_atom_types + + 1, # spatial_dimension acts as the output size + # TODO will not work because MASK is not a valid atom type + ) + self.atom_types_prediction_head = instantiate_mace_prediction_head( + output_node_features_irreps, atom_type_prediction_head_parameters + ) def _check_batch(self, batch: Dict[AnyStr, torch.Tensor]): super(MACEScoreNetwork, self)._check_batch(batch) - number_of_atoms = batch[NOISY_RELATIVE_COORDINATES].shape[1] + number_of_atoms = batch[NOISY_AXL_COMPOSITION].X.shape[1] assert ( number_of_atoms == self._natoms ), "The dimension corresponding to the number of atoms is not consistent with the configuration." def _forward_unchecked( self, batch: Dict[AnyStr, torch.Tensor], conditional: bool = False - ) -> torch.Tensor: + ) -> AXL: """Forward unchecked. This method assumes that the input data has already been checked with respect to expectations @@ -147,7 +162,7 @@ def _forward_unchecked( output : the scores computed by the model as a [batch_size, n_atom, spatial_dimension] tensor. """ del conditional # TODO implement conditional - relative_coordinates = batch[NOISY_RELATIVE_COORDINATES] + relative_coordinates = batch[NOISY_AXL_COMPOSITION].X batch[NOISY_CARTESIAN_POSITIONS] = torch.bmm( relative_coordinates, batch[UNIT_CELL] ) # positions in Angstrom @@ -164,11 +179,38 @@ def _forward_unchecked( # with this value the same for all atoms belonging to the same graph. times = batch[TIME].to(relative_coordinates.device) # shape [batch_size, 1] flat_times = times[graph_input.batch] # shape [batch_size * natoms, 1] - flat_scores = self.prediction_head( + + # The output of the prediction head is a 'cartesian score'; ie it is similar to nabla_r ln P. + flat_cartesian_scores = self.coordinates_prediction_head( flat_node_features, flat_times ) # shape [batch_size * natoms, spatial_dim] - # Reshape the scores to have an explicit batch dimension - scores = flat_scores.reshape(-1, self._natoms, self.spatial_dimension) + # Reshape the cartesian scores to have an explicit batch dimension + cartesian_scores = flat_cartesian_scores.reshape( + -1, self._natoms, self.spatial_dimension + ) + + # The expected output of the score network is a COORDINATE SCORE, i.e. something like nabla_x ln P. + # Note that the basis_vectors is composed of ROWS of basis vectors + basis_vectors = batch[UNIT_CELL] + coordinates_scores = einops.einsum( + basis_vectors, + cartesian_scores, + "batch i alpha, batch natoms alpha -> batch natoms i", + ) + + flat_atom_type_scores = self.atom_types_prediction_head( + flat_node_features, flat_times + ) # shape [batch_size * natoms, num_atom_types] + + atom_type_scores = flat_atom_type_scores.reshape( + -1, self._natoms, self.num_atom_types + 1 + ) + + scores = AXL( + A=atom_type_scores, + X=coordinates_scores, + L=torch.zeros_like(atom_type_scores), # TODO replace with real output + ) return scores diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/mlp_score_network.py b/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/mlp_score_network.py index 5606b04f..0b622dd7 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/mlp_score_network.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/mlp_score_network.py @@ -7,7 +7,9 @@ from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.score_network import ( ScoreNetwork, ScoreNetworkParameters) from diffusion_for_multi_scale_molecular_dynamics.namespace import ( - CARTESIAN_FORCES, NOISE, NOISY_RELATIVE_COORDINATES) + AXL, CARTESIAN_FORCES, NOISE, NOISY_AXL_COMPOSITION, TIME) +from diffusion_for_multi_scale_molecular_dynamics.utils.d3pm_utils import \ + class_index_to_onehot @dataclass(kw_only=True) @@ -18,9 +20,15 @@ class MLPScoreNetworkParameters(ScoreNetworkParameters): number_of_atoms: int # the number of atoms in a configuration. n_hidden_dimensions: int # the number of hidden layers. hidden_dimensions_size: int # the dimensions of the hidden layers. - embedding_dimensions_size: ( + noise_embedding_dimensions_size: ( int # the dimension of the embedding of the noise parameter. ) + time_embedding_dimensions_size: ( + int # the dimension of the embedding of the time parameter. + ) + atom_type_embedding_dimensions_size: ( + int # the dimension of the embedding of the atom types + ) condition_embedding_size: int = ( 64 # dimension of the conditional variable embedding ) @@ -39,27 +47,44 @@ def __init__(self, hyper_params: MLPScoreNetworkParameters): hyper_params : hyper parameters from the config file. """ super(MLPScoreNetwork, self).__init__(hyper_params) - hidden_dimensions = [ - hyper_params.hidden_dimensions_size - ] * hyper_params.n_hidden_dimensions + hidden_dimensions = [hyper_params.hidden_dimensions_size] * ( + hyper_params.n_hidden_dimensions + ) self._natoms = hyper_params.number_of_atoms + self.num_atom_types = hyper_params.num_atom_types + self.num_classes = self.num_atom_types + 1 # add 1 for the MASK class + + coordinate_output_dimension = self.spatial_dimension * self._natoms + atom_type_output_dimension = self._natoms * self.num_classes - output_dimension = self.spatial_dimension * self._natoms - input_dimension = output_dimension + hyper_params.embedding_dimensions_size + input_dimension = ( + coordinate_output_dimension + + hyper_params.noise_embedding_dimensions_size + + hyper_params.time_embedding_dimensions_size + + self._natoms * hyper_params.atom_type_embedding_dimensions_size + ) self.noise_embedding_layer = nn.Linear( - 1, hyper_params.embedding_dimensions_size + 1, hyper_params.noise_embedding_dimensions_size + ) + + self.time_embedding_layer = nn.Linear( + 1, hyper_params.time_embedding_dimensions_size + ) + + self.atom_type_embedding_layer = nn.Linear( + self.num_classes, hyper_params.atom_type_embedding_dimensions_size ) self.condition_embedding_layer = nn.Linear( - output_dimension, hyper_params.condition_embedding_size + coordinate_output_dimension, hyper_params.condition_embedding_size ) self.flatten = nn.Flatten() self.mlp_layers = nn.ModuleList() self.conditional_layers = nn.ModuleList() - input_dimensions = [input_dimension] + hidden_dimensions - output_dimensions = hidden_dimensions + [output_dimension] + input_dimensions = [input_dimension] + hidden_dimensions[:-1] + output_dimensions = hidden_dimensions for input_dimension, output_dimension in zip( input_dimensions, output_dimensions @@ -70,16 +95,24 @@ def __init__(self, hyper_params: MLPScoreNetworkParameters): ) self.non_linearity = nn.ReLU() + # Create a self nn object to be discoverable to be placed on the correct device + self.output_A_layer = nn.Linear(hyper_params.hidden_dimensions_size, atom_type_output_dimension) + self.output_X_layer = nn.Linear(hyper_params.hidden_dimensions_size, coordinate_output_dimension) + self.output_L_layer = nn.Identity() + self.output_layers = AXL(A=self.output_A_layer, + X=self.output_X_layer, + L=self.output_L_layer) # TODO placeholder + def _check_batch(self, batch: Dict[AnyStr, torch.Tensor]): super(MLPScoreNetwork, self)._check_batch(batch) - number_of_atoms = batch[NOISY_RELATIVE_COORDINATES].shape[1] + number_of_atoms = batch[NOISY_AXL_COMPOSITION].X.shape[1] assert ( number_of_atoms == self._natoms ), "The dimension corresponding to the number of atoms is not consistent with the configuration." def _forward_unchecked( self, batch: Dict[AnyStr, torch.Tensor], conditional: bool = False - ) -> torch.Tensor: + ) -> AXL: """Forward unchecked. This method assumes that the input data has already been checked with respect to expectations @@ -91,17 +124,38 @@ def _forward_unchecked( Defaults to False. Returns: - computed_scores : the scores computed by the model. + computed_scores : the scores computed by the model in an AXL namedtuple. """ - relative_coordinates = batch[NOISY_RELATIVE_COORDINATES] + relative_coordinates = batch[NOISY_AXL_COMPOSITION].X # shape [batch_size, number_of_atoms, spatial_dimension] sigmas = batch[NOISE].to(relative_coordinates.device) # shape [batch_size, 1] noise_embedding = self.noise_embedding_layer( sigmas - ) # shape [batch_size, embedding_dimension] + ) # shape [batch_size, noise_embedding_dimension] - input = torch.cat([self.flatten(relative_coordinates), noise_embedding], dim=1) + times = batch[TIME].to(relative_coordinates.device) # shape [batch_size, 1] + time_embedding = self.time_embedding_layer( + times + ) # shape [batch_size, time_embedding_dimension] + + atom_types = batch[NOISY_AXL_COMPOSITION].A + atom_types_one_hot = class_index_to_onehot( + atom_types, num_classes=self.num_classes + ) + atom_type_embedding = self.atom_type_embedding_layer( + atom_types_one_hot + ) # shape [batch_size, atom_type_embedding_dimension] + + input = torch.cat( + [ + self.flatten(relative_coordinates), + noise_embedding, + time_embedding, + self.flatten(atom_type_embedding), + ], + dim=1, + ) forces_input = self.condition_embedding_layer( self.flatten(batch[CARTESIAN_FORCES]) @@ -117,5 +171,13 @@ def _forward_unchecked( if conditional: output += condition_layer(forces_input) - output = output.reshape(relative_coordinates.shape) - return output + coordinates_output = self.output_layers.X(output).reshape( + relative_coordinates.shape + ) + atom_types_output = self.output_layers.A(output).reshape( + atom_types_one_hot.shape + ) + lattice_output = torch.zeros_like(atom_types_output) # TODO placeholder + + axl_output = AXL(A=atom_types_output, X=coordinates_output, L=lattice_output) + return axl_output diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/score_network.py b/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/score_network.py index ad9b6722..20b067c0 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/score_network.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/score_network.py @@ -1,24 +1,24 @@ -"""Score Network. +r"""Score Network. This module implements score networks for positions in relative coordinates. Relative coordinates are with respect to lattice vectors which define the periodic unit cell. + +The coordinates part of the output aims to calculate + +.. math:: + output.X \propto nabla_X \ln P(x,t) + +where X is relative coordinates. """ -import os from dataclasses import dataclass from typing import AnyStr, Dict, Optional import torch from diffusion_for_multi_scale_molecular_dynamics.namespace import ( - CARTESIAN_FORCES, NOISE, NOISY_RELATIVE_COORDINATES, TIME, UNIT_CELL) - -# mac fun time -# for mace, conflict with mac -# https://stackoverflow.com/questions/53014306/error-15-initializing-libiomp5-dylib-but-found-libiomp5-dylib-already- \ -# initial -os.environ["KMP_DUPLICATE_LIB_OK"] = "True" + AXL, CARTESIAN_FORCES, NOISE, NOISY_AXL_COMPOSITION, TIME, UNIT_CELL) @dataclass(kw_only=True) @@ -27,9 +27,8 @@ class ScoreNetworkParameters: architecture: str spatial_dimension: int = 3 # the dimension of Euclidean space where atoms live. - conditional_prob: float = ( - 0.0 # probability of making a conditional forward - else, do a unconditional forward - ) + num_atom_types: int # number of possible atomic species - not counting the MASK class used in the diffusion + conditional_prob: float = 0.0 # probability of making a conditional forward - else, do an unconditional forward conditional_gamma: float = ( 2.0 # conditional score weighting - see eq. B45 in MatterGen ) @@ -52,6 +51,7 @@ def __init__(self, hyper_params: ScoreNetworkParameters): super(ScoreNetwork, self).__init__() self._hyper_params = hyper_params self.spatial_dimension = hyper_params.spatial_dimension + self.num_atom_types = hyper_params.num_atom_types self.conditional_prob = hyper_params.conditional_prob self.conditional_gamma = hyper_params.conditional_gamma @@ -62,11 +62,16 @@ def _check_batch(self, batch: Dict[AnyStr, torch.Tensor]): those inputs have the expected dimensions. It is expected that: - - the relative coordinates are present and of shape [batch_size, number of atoms, spatial_dimension] + - an AXL namedtuple is present with + - the relative coordinates of shape [batch_size, number of atoms, spatial_dimension] + - the atom types of shape [batch_size, number of atoms] + - the unit cell vectors TODO shape - all the components of relative coordinates will be in [0, 1) + - all the components of atom types are integers between [0, number of atomic species + 1) + the + 1 accounts for the MASK class - the time steps are present and of shape [batch_size, 1] - the time steps are in range [0, 1]. - - the 'noise' parameter is present and has the same shape as time. + - the 'noise' parameter sigma is present and has the same shape as time. An assert will fail if the batch does not conform with expectation. @@ -76,12 +81,12 @@ def _check_batch(self, batch: Dict[AnyStr, torch.Tensor]): Returns: None. """ - assert NOISY_RELATIVE_COORDINATES in batch, ( - f"The relative coordinates should be present in " - f"the batch dictionary with key '{NOISY_RELATIVE_COORDINATES}'" + assert NOISY_AXL_COMPOSITION in batch, ( + f"The noisy coordinates, atomic types and lattice vectors should be present in " + f"the batch dictionary with key '{NOISY_AXL_COMPOSITION}'" ) - relative_coordinates = batch[NOISY_RELATIVE_COORDINATES] + relative_coordinates = batch[NOISY_AXL_COMPOSITION].X relative_coordinates_shape = relative_coordinates.shape batch_size = relative_coordinates_shape[0] assert ( @@ -119,6 +124,7 @@ def _check_batch(self, batch: Dict[AnyStr, torch.Tensor]): batch[NOISE].shape == times.shape ), "the 'noise' parameter should have the same shape as the 'time'." + # TODO replace UNIT_CELL with AXL unit cell assert ( UNIT_CELL in batch ), f"The unit cell should be present in the batch dictionary with key '{UNIT_CELL}'" @@ -132,7 +138,22 @@ def _check_batch(self, batch: Dict[AnyStr, torch.Tensor]): len(unit_cell_shape) == 3 and unit_cell_shape[1] == self.spatial_dimension and unit_cell_shape[2] == self.spatial_dimension - ), "The unit cell is expected to be in a tensor of shape [batch_size, spatial_dimension, spatial_dimension]." + ), "The unit cell is expected to be in a tensor of shape [batch_size, spatial_dimension, spatial_dimension].}" + + atom_types = batch[NOISY_AXL_COMPOSITION].A + atom_types_shape = atom_types.shape + assert ( + atom_types_shape[0] == batch_size + ), "the batch size dimension is inconsistent between positions and atom types." + assert ( + len(atom_types_shape) == 2 + ), "The atoms type are expected to be in a tensor of shape [batch_size, number of atoms]." + + assert torch.logical_and( + atom_types >= 0, + atom_types + < self.num_atom_types + 1, # MASK is a possible type in a noised sample + ).all(), f"All atom types are expected to be in [0, {self.num_atom_types}]." if self.conditional_prob > 0: assert CARTESIAN_FORCES in batch, ( @@ -150,9 +171,13 @@ def _check_batch(self, batch: Dict[AnyStr, torch.Tensor]): f"{self.spatial_dimension}]" ) + def _impose_non_mask_atomic_type_prediction(self, output: AXL): + # Force the last logit to be -infinity, making it impossible for the model to predict MASK. + output.A[..., self.num_atom_types] = -torch.inf + def forward( self, batch: Dict[AnyStr, torch.Tensor], conditional: Optional[bool] = None - ) -> torch.Tensor: + ) -> AXL: """Model forward. Args: @@ -161,7 +186,7 @@ def forward( randomly with probability conditional_prob Returns: - computed_scores : the scores computed by the model. + computed_scores : the scores computed by the model in an AXL namedtuple. """ self._check_batch(batch) if conditional is None: @@ -171,10 +196,12 @@ def forward( ) < self.conditional_prob ) + if not conditional: - return self._forward_unchecked(batch, conditional=False) + output = self._forward_unchecked(batch, conditional=False) else: - return self._forward_unchecked( + # TODO this is not going to work + output = self._forward_unchecked( batch, conditional=True ) * self.conditional_gamma + self._forward_unchecked( batch, conditional=False @@ -182,9 +209,13 @@ def forward( 1 - self.conditional_gamma ) + self._impose_non_mask_atomic_type_prediction(output) + + return output + def _forward_unchecked( self, batch: Dict[AnyStr, torch.Tensor], conditional: bool = False - ) -> torch.Tensor: + ) -> AXL: """Forward unchecked. This method assumes that the input data has already been checked with respect to expectations diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/score_network_factory.py b/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/score_network_factory.py index f161236b..02adb3bb 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/score_network_factory.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/models/score_networks/score_network_factory.py @@ -65,6 +65,9 @@ def create_score_network_parameters( Returns: score_network_parameters: the dataclass configuration object describing the score network. """ + assert len(global_parameters_dictionary["elements"]) == score_network_dictionary["num_atom_types"], \ + "There should be 'num_atom_types' entries in the 'elements' list." + assert ( "architecture" in score_network_dictionary ), "The architecture of the score network must be specified." diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/namespace.py b/src/diffusion_for_multi_scale_molecular_dynamics/namespace.py index 2fb6ebeb..fbb50a1e 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/namespace.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/namespace.py @@ -5,6 +5,8 @@ represent these concepts. """ +from collections import namedtuple + # r^alpha <- cartesian position, alpha \in (x,y,z) # x_i <- relative coordinates i \in (1,2,3) # @@ -23,3 +25,12 @@ TIME = "time" # diffusion time NOISE = "noise_parameter" # the exploding variance sigma parameter UNIT_CELL = "unit_cell" # unit cell definition + +ATOM_TYPES = "atom_types" +NOISY_ATOM_TYPES = "noisy_atom_types" + +AXL = namedtuple("AXL", ["A", "X", "L"]) +AXL_NAME_DICT = {"A": ATOM_TYPES, "X": RELATIVE_COORDINATES, "L": UNIT_CELL} + +NOISY_AXL_COMPOSITION = "noisy_axl" +AXL_COMPOSITION = "original_axl" diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/samplers/__init__.py b/src/diffusion_for_multi_scale_molecular_dynamics/noise_schedulers/__init__.py similarity index 100% rename from src/diffusion_for_multi_scale_molecular_dynamics/samplers/__init__.py rename to src/diffusion_for_multi_scale_molecular_dynamics/noise_schedulers/__init__.py diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/samplers/exploding_variance.py b/src/diffusion_for_multi_scale_molecular_dynamics/noise_schedulers/exploding_variance.py similarity index 91% rename from src/diffusion_for_multi_scale_molecular_dynamics/samplers/exploding_variance.py rename to src/diffusion_for_multi_scale_molecular_dynamics/noise_schedulers/exploding_variance.py index d37a5ae0..3b1ac2fc 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/samplers/exploding_variance.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/noise_schedulers/exploding_variance.py @@ -1,13 +1,13 @@ import torch -from diffusion_for_multi_scale_molecular_dynamics.samplers.variance_sampler import \ +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ NoiseParameters -class ExplodingVariance(torch.nn.Module): +class VarianceScheduler(torch.nn.Module): """Exploding Variance. - This class is responsible for calculating the various quantities related to the diffusion variance. + This class is responsible for calculating the various quantities related to the diffusion variance. This implementation will use "exploding variance" scheme. """ diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/noise_schedulers/noise_parameters.py b/src/diffusion_for_multi_scale_molecular_dynamics/noise_schedulers/noise_parameters.py new file mode 100644 index 00000000..ae34bb85 --- /dev/null +++ b/src/diffusion_for_multi_scale_molecular_dynamics/noise_schedulers/noise_parameters.py @@ -0,0 +1,23 @@ +from dataclasses import dataclass + + +@dataclass +class NoiseParameters: + """Noise schedule parameters.""" + + total_time_steps: int + time_delta: float = 1e-5 # the time schedule will cover the range [time_delta, 1] + # As discussed in Appendix C of "SCORE-BASED GENERATIVE MODELING THROUGH STOCHASTIC DIFFERENTIAL EQUATIONS", + # the time t = 0 is problematic. + + # Default values come from the paper: + # "Torsional Diffusion for Molecular Conformer Generation", + # The original values in the paper are + # sigma_min = 0.01 pi , sigma_σmax = pi + # However, they consider angles from 0 to 2pi as their coordinates: + # here we divide by 2pi because our space is in the range [0, 1). + sigma_min: float = 0.005 + sigma_max: float = 0.5 + + # Default value comes from "Generative Modeling by Estimating Gradients of the Data Distribution" + corrector_step_epsilon: float = 2e-5 diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/samplers/variance_sampler.py b/src/diffusion_for_multi_scale_molecular_dynamics/noise_schedulers/noise_scheduler.py similarity index 52% rename from src/diffusion_for_multi_scale_molecular_dynamics/samplers/variance_sampler.py rename to src/diffusion_for_multi_scale_molecular_dynamics/noise_schedulers/noise_scheduler.py index a5ed687e..fb58935c 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/samplers/variance_sampler.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/noise_schedulers/noise_scheduler.py @@ -1,46 +1,45 @@ from collections import namedtuple -from dataclasses import dataclass from typing import Tuple import torch -Noise = namedtuple("Noise", ["time", "sigma", "sigma_squared", "g", "g_squared"]) +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.exploding_variance import \ + VarianceScheduler +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ + NoiseParameters + +Noise = namedtuple( + "Noise", + [ + "time", + "sigma", + "sigma_squared", + "g", + "g_squared", + "beta", + "alpha_bar", + "q_matrix", + "q_bar_matrix", + "q_bar_tm1_matrix", + "indices", + ], +) LangevinDynamics = namedtuple("LangevinDynamics", ["epsilon", "sqrt_2_epsilon"]) -@dataclass -class NoiseParameters: - """Noise schedule parameters.""" - - total_time_steps: int - time_delta: float = 1e-5 # the time schedule will cover the range [time_delta, 1] - # As discussed in Appendix C of "SCORE-BASED GENERATIVE MODELING THROUGH STOCHASTIC DIFFERENTIAL EQUATIONS", - # the time t = 0 is problematic. - - # Default values come from the paper: - # "Torsional Diffusion for Molecular Conformer Generation", - # The original values in the paper are - # sigma_min = 0.01 pi , sigma_σmax = pi - # However, they consider angles from 0 to 2pi as their coordinates: - # here we divide by 2pi because our space is in the range [0, 1). - sigma_min: float = 0.005 - sigma_max: float = 0.5 - - # Default value comes from "Generative Modeling by Estimating Gradients of the Data Distribution" - corrector_step_epsilon: float = 2e-5 - - -class ExplodingVarianceSampler(torch.nn.Module): - """Exploding Variance Sampler. +class NoiseScheduler(torch.nn.Module): + r"""Noise Scheduler. This class is responsible for creating all the quantities needed for noise generation for training and sampling. - This implementation will use "exponential diffusion" as discussed in + This implementation will use "exponential diffusion" and a "variance-preserving" diffusion as discussed in the following papers (no one paper presents everything clearly) - [1] "Torsional Diffusion for Molecular Conformer Generation". - [2] "SCORE-BASED GENERATIVE MODELING THROUGH STOCHASTIC DIFFERENTIAL EQUATIONS" - [3] "Generative Modeling by Estimating Gradients of the Data Distribution" + - [4] "Denoising diffusion probabilistic models" + - [5] "Deep unsupervised learning using nonequilibrium thermodynamics" The following quantities are defined: - total number of times steps, N @@ -65,26 +64,49 @@ class ExplodingVarianceSampler(torch.nn.Module): - eps and sqrt_2_eps: This is for Langevin dynamics within a corrector step. Following [3], we define - eps_i = 0.5 epsilon_step * sigma^2_i / sigma^2_1 for i = 0, ..., N-1. + .. math:: + eps_i = 0.5 epsilon_step * sigma^2_i / sigma^2_1 for i = 0, ..., N-1. + + --> Careful! eps_0 is needed for the corrector steps. - --> Careful! eps_0 is needed for the corrector steps. + - beta and alpha_bar: + noise schedule following the "variance-preserving scheme", + + .. math:: + beta(t) = 1 / (t_{max} - t + 1) + + .. math:: + \bar{\alpha}(t) = \prod_{i=t}^t (1 - beta(i)) + + - q_matrix, q_bar_matrix: + transition matrix for D3PM - Q_t - and cumulative transition matrix :math:`\bar{Q}_t` + + .. math:: + Q_t = (1 - beta(t)) I + beta(t) 1 e^T_m + + .. math:: + \bar{Q}_t = \prod_{i=i}^t Q_t """ - def __init__(self, noise_parameters: NoiseParameters): + def __init__(self, noise_parameters: NoiseParameters, num_classes: int): """Init method. Args: noise_parameters: parameters that define the noise schedule. + num_classes: number of discrete classes for the discrete diffusion """ super().__init__() self.noise_parameters = noise_parameters + self.num_classes = num_classes - self._time_array = torch.nn.Parameter( - self._get_time_array(noise_parameters), requires_grad=False - ) + self._exploding_variance = VarianceScheduler(noise_parameters) + + times = self._get_time_array(noise_parameters) + + self._time_array = torch.nn.Parameter(times, requires_grad=False) self._sigma_array = torch.nn.Parameter( - self._create_sigma_array(noise_parameters, self._time_array), + self._exploding_variance.get_sigma(times), requires_grad=False, ) self._sigma_squared_array = torch.nn.Parameter( @@ -92,7 +114,9 @@ def __init__(self, noise_parameters: NoiseParameters): ) self._g_squared_array = torch.nn.Parameter( - self._create_g_squared_array(noise_parameters, self._sigma_squared_array), + self._create_discretized_g_squared_array( + self._sigma_squared_array, noise_parameters.sigma_min + ), requires_grad=False, ) self._g_array = torch.nn.Parameter( @@ -114,6 +138,29 @@ def __init__(self, noise_parameters: NoiseParameters): torch.tensor(0), requires_grad=False ) + self._beta_array = torch.nn.Parameter( + self._create_beta_array(noise_parameters.total_time_steps), + requires_grad=False, + ) + + self._alpha_bar_array = torch.nn.Parameter( + self._create_alpha_bar_array(self._beta_array), requires_grad=False + ) + + self._q_matrix_array = torch.nn.Parameter( + self._create_q_matrix_array(self._beta_array, num_classes), + requires_grad=False, + ) + + self._q_bar_matrix_array = torch.nn.Parameter( + self._create_q_bar_matrix_array(self._q_matrix_array), requires_grad=False + ) + + self._q_bar_tm1_matrix_array = torch.nn.Parameter( + self._create_q_bar_tm1_matrix_array(self._q_bar_matrix_array), + requires_grad=False, + ) + @staticmethod def _get_time_array(noise_parameters: NoiseParameters) -> torch.Tensor: return torch.linspace( @@ -121,21 +168,10 @@ def _get_time_array(noise_parameters: NoiseParameters) -> torch.Tensor: ) @staticmethod - def _create_sigma_array( - noise_parameters: NoiseParameters, time_array: torch.Tensor - ) -> torch.Tensor: - sigma_min = noise_parameters.sigma_min - sigma_max = noise_parameters.sigma_max - - sigma = sigma_min ** (1.0 - time_array) * sigma_max**time_array - return sigma - - @staticmethod - def _create_g_squared_array( - noise_parameters: NoiseParameters, sigma_squared_array: torch.Tensor + def _create_discretized_g_squared_array( + sigma_squared_array: torch.Tensor, sigma_min: float ) -> torch.Tensor: # g^2_{i} = sigma^2_{i} - sigma^2_{i-1}. For the first element (i=1), we set sigma_{0} = sigma_min. - sigma_min = noise_parameters.sigma_min zeroth_value_tensor = torch.tensor([sigma_squared_array[0] - sigma_min**2]) return torch.cat( [zeroth_value_tensor, sigma_squared_array[1:] - sigma_squared_array[:-1]] @@ -160,6 +196,55 @@ def _create_epsilon_array( ] ) + @staticmethod + def _create_beta_array(num_time_steps: int) -> torch.Tensor: + return 1.0 / (num_time_steps - torch.arange(1, num_time_steps + 1) + 1) + + @staticmethod + def _create_alpha_bar_array(beta_array: torch.Tensor) -> torch.Tensor: + return torch.cumprod(1 - beta_array, 0) + + @staticmethod + def _create_q_matrix_array( + beta_array: torch.Tensor, num_classes: torch.Tensor + ) -> torch.Tensor: + beta_array_ = beta_array.unsqueeze(-1).unsqueeze(-1) + qt = (1 - beta_array_) * torch.eye( + num_classes + ) # time step, num_classes, num_classes + qt += beta_array_ * torch.outer( + torch.ones(num_classes), + torch.nn.functional.one_hot( + torch.LongTensor([num_classes - 1]), num_classes=num_classes + ).squeeze(0), + ) + return qt + + @staticmethod + def _create_q_bar_matrix_array(q_matrix_array: torch.Tensor) -> torch.Tensor: + q_bar_matrix_array = torch.empty_like(q_matrix_array) + q_bar_matrix_array[0] = q_matrix_array[0] + for i in range(1, q_matrix_array.size(0)): + q_bar_matrix_array[i] = torch.matmul( + q_bar_matrix_array[i - 1], q_matrix_array[i] + ) + return q_bar_matrix_array + + @staticmethod + def _create_q_bar_tm1_matrix_array( + q_bar_matrix_array: torch.Tensor, + ) -> torch.Tensor: + # we need the q_bar matrices for the previous time index (t-1) to compute the loss. We will use Q_{t-1}=1 + # for the case t=1 (special case in the loss or the last step of the sampling process + q_bar_tm1_matrices = torch.cat( + ( + torch.eye(q_bar_matrix_array.size(-1)).unsqueeze(0), + q_bar_matrix_array[:-1], + ), + dim=0, + ) + return q_bar_tm1_matrices + def _get_random_time_step_indices(self, shape: Tuple[int]) -> torch.Tensor: """Random time step indices. @@ -181,20 +266,22 @@ def _get_random_time_step_indices(self, shape: Tuple[int]) -> torch.Tensor: return random_indices def get_random_noise_sample(self, batch_size: int) -> Noise: - """Get random noise sample. + r"""Get random noise sample. It is assumed that a batch is of the form [batch_size, (dimensions of a configuration)]. - In order to train a diffusion model, a configuration must be "noised" to a time t with a parameter sigma(t). + In order to train a diffusion model, a configuration must be "noised" to a time t with a parameter sigma(t) for + the relative coordinates, beta(t) and associated transition matrices Q(t), \bar{Q}(t), \bar{Q}(t-1) for the atom + types. Different values can be used for different configurations: correspondingly, this method returns one random time per element in the batch. - Args: batch_size : number of configurations in a batch, Returns: - noise_sample: a collection of all the noise parameters (t, sigma, sigma^2, g, g^2) - for some random indices. All the arrays are of dimension [batch_size]. + noise_sample: a collection of all the noise parameters (t, sigma, sigma^2, g, g^2, beta, alpha_bar, + Q, Qbar, Qbar at time t-1 and indices) for some random indices. All the arrays are of dimension + [batch_size] expect Q, Qbar, Qbar t-1 which are [batch_size, num_classes, num_classes]. """ indices = self._get_random_time_step_indices((batch_size,)) times = self._time_array.take(indices) @@ -202,6 +289,13 @@ def get_random_noise_sample(self, batch_size: int) -> Noise: sigmas_squared = self._sigma_squared_array.take(indices) gs = self._g_array.take(indices) gs_squared = self._g_squared_array.take(indices) + betas = self._beta_array.take(indices) + alpha_bars = self._alpha_bar_array.take(indices) + q_matrices = self._q_matrix_array.index_select(dim=0, index=indices) + q_bar_matrices = self._q_bar_matrix_array.index_select(dim=0, index=indices) + q_bar_tm1_matrices = self._q_bar_tm1_matrix_array.index_select( + dim=0, index=indices + ) return Noise( time=times, @@ -209,6 +303,12 @@ def get_random_noise_sample(self, batch_size: int) -> Noise: sigma_squared=sigmas_squared, g=gs, g_squared=gs_squared, + beta=betas, + alpha_bar=alpha_bars, + q_matrix=q_matrices, + q_bar_matrix=q_bar_matrices, + q_bar_tm1_matrix=q_bar_tm1_matrices, + indices=indices, ) def get_all_sampling_parameters(self) -> Tuple[Noise, LangevinDynamics]: @@ -228,6 +328,14 @@ def get_all_sampling_parameters(self) -> Tuple[Noise, LangevinDynamics]: sigma_squared=self._sigma_squared_array, g=self._g_array, g_squared=self._g_squared_array, + beta=self._beta_array, + alpha_bar=self._alpha_bar_array, + q_matrix=self._q_matrix_array, + q_bar_matrix=self._q_bar_matrix_array, + q_bar_tm1_matrix=self._q_bar_tm1_matrix_array, + indices=torch.arange( + self._minimum_random_index, self._maximum_random_index + 1 + ), ) langevin_dynamics = LangevinDynamics( epsilon=self._epsilon_array, sqrt_2_epsilon=self._sqrt_two_epsilon_array diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/samples/__init__.py b/src/diffusion_for_multi_scale_molecular_dynamics/noisers/__init__.py similarity index 100% rename from src/diffusion_for_multi_scale_molecular_dynamics/samples/__init__.py rename to src/diffusion_for_multi_scale_molecular_dynamics/noisers/__init__.py diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/noisers/atom_types_noiser.py b/src/diffusion_for_multi_scale_molecular_dynamics/noisers/atom_types_noiser.py new file mode 100644 index 00000000..368be9d0 --- /dev/null +++ b/src/diffusion_for_multi_scale_molecular_dynamics/noisers/atom_types_noiser.py @@ -0,0 +1,60 @@ +from typing import Tuple + +import torch + +from diffusion_for_multi_scale_molecular_dynamics.utils.d3pm_utils import \ + compute_q_at_given_a0 + + +class AtomTypesNoiser: + """Atom types noiser. + + This class provides methods to generate noisy atom types. + """ + + @staticmethod + def _get_uniform_noise(shape: Tuple[int]) -> torch.Tensor: + """Get uniform noise. + + Get a sample from U(0, 1) of dimensions shape. + + Args: + shape : the shape of the sample. + + Returns: + gaussian_noise: a sample from U(0, 1) of dimensions shape. + """ + return torch.rand(shape) + + @staticmethod + def get_noisy_atom_types_sample( + real_onehot_atom_types: torch.Tensor, q_bar: torch.Tensor + ) -> torch.Tensor: + r"""Get noisy atom types sample. + + This method generates a sample using the transition probabilities defined by the q_bar matrices. + + Args: + real_onehot_atom_types : atom types of the real sample. Assumed to be a one-hot vector. The size is assumed + to be (..., num_classes + 1) where num_classes is the number of atoms. + q_bar : cumulative transition matrices i.e. the q_bar in q(a_t | a_0) = a_0 \bar{Q}_t. Assumed to be of size + (..., num_classes + 1, num_classes + 1) + + Returns: + noisy_atom_types: a sample of noised atom types as classes, not 1-hot, of the same shape as + real_onehot_atom_types except for the last dimension that is removed. + """ + assert ( + real_onehot_atom_types.shape == q_bar.shape[:-1] + ), "q_bar array first dimensions should match real_atom_types array" + + u = AtomTypesNoiser._get_uniform_noise(real_onehot_atom_types.shape).to(q_bar) + # we need to sample from q(x_t | x_0) + posterior_at_probabilities = compute_q_at_given_a0( + real_onehot_atom_types, q_bar + ) + # gumbel trick to sample from a distribution + noise = -torch.log(-torch.log(u)).to(real_onehot_atom_types.device) + noisy_atom_types = torch.log(posterior_at_probabilities) + noise + noisy_atom_types = torch.argmax(noisy_atom_types, dim=-1) + return noisy_atom_types diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/noisers/lattice_noiser.py b/src/diffusion_for_multi_scale_molecular_dynamics/noisers/lattice_noiser.py new file mode 100644 index 00000000..93809ac1 --- /dev/null +++ b/src/diffusion_for_multi_scale_molecular_dynamics/noisers/lattice_noiser.py @@ -0,0 +1,23 @@ +import torch + + +class LatticeNoiser: + """Lattice noiser. + + This class provides methods to generate noisy lattices. + TODO this is a placeholder + """ + + @staticmethod + def get_noisy_lattice_vectors(real_lattice_vectors: torch.Tensor) -> torch.Tensor: + """Get noisy lattice vectors. + + TODO this is a placeholder + + Args: + real_lattice_vectors: lattice vectors from the sampled data + + Returns: + real_lattice_vectors: a sample of noised lattice vectors. Placeholder for now. + """ + return real_lattice_vectors diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/samplers/noisy_relative_coordinates_sampler.py b/src/diffusion_for_multi_scale_molecular_dynamics/noisers/relative_coordinates_noiser.py similarity index 90% rename from src/diffusion_for_multi_scale_molecular_dynamics/samplers/noisy_relative_coordinates_sampler.py rename to src/diffusion_for_multi_scale_molecular_dynamics/noisers/relative_coordinates_noiser.py index c57e4fb5..d821b8d5 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/samplers/noisy_relative_coordinates_sampler.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/noisers/relative_coordinates_noiser.py @@ -1,8 +1,3 @@ -"""Noisy Position Sampler. - -This module is responsible for sampling relative positions from the perturbation kernel. -""" - from typing import Tuple import torch @@ -11,8 +6,8 @@ map_relative_coordinates_to_unit_cell -class NoisyRelativeCoordinatesSampler: - """Noisy Relative Coordinates Sampler. +class RelativeCoordinatesNoiser: + """Relative Coordinates Noiser. This class provides methods to generate noisy relative coordinates, given real relative coordinates and a sigma parameter. @@ -62,7 +57,7 @@ def get_noisy_relative_coordinates_sample( real_relative_coordinates.shape == sigmas.shape ), "sigmas array is expected to be of the same shape as the real_relative_coordinates array" - z_scores = NoisyRelativeCoordinatesSampler._get_gaussian_noise( + z_scores = RelativeCoordinatesNoiser._get_gaussian_noise( real_relative_coordinates.shape ).to(sigmas) noise = (sigmas * z_scores).to(real_relative_coordinates) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/oracle/__init__.py b/src/diffusion_for_multi_scale_molecular_dynamics/oracle/__init__.py index e69de29b..cee6a5d5 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/oracle/__init__.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/oracle/__init__.py @@ -0,0 +1,3 @@ +from diffusion_for_multi_scale_molecular_dynamics import DATA_DIR + +SW_COEFFICIENTS_DIR = DATA_DIR / "stillinger_weber_coefficients" diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/oracle/energies.py b/src/diffusion_for_multi_scale_molecular_dynamics/oracle/energies.py deleted file mode 100644 index b49a25c5..00000000 --- a/src/diffusion_for_multi_scale_molecular_dynamics/oracle/energies.py +++ /dev/null @@ -1,56 +0,0 @@ -import logging -import tempfile -from typing import AnyStr, Dict - -import numpy as np -import torch - -from diffusion_for_multi_scale_molecular_dynamics.namespace import ( - CARTESIAN_POSITIONS, UNIT_CELL) -from diffusion_for_multi_scale_molecular_dynamics.oracle.lammps import \ - get_energy_and_forces_from_lammps - -logger = logging.getLogger(__name__) - - -def compute_oracle_energies(samples: Dict[AnyStr, torch.Tensor]) -> torch.Tensor: - """Compute oracle energies. - - Method to call the oracle for samples expressed in a standardized format. - - Args: - samples: a dictionary assumed to contain the fields - - CARTESIAN_POSITIONS - - UNIT_CELL - - Returns: - energies: a numpy array with the computed energies. - """ - assert ( - CARTESIAN_POSITIONS in samples - ), f"the field '{CARTESIAN_POSITIONS}' must be present in the sample dictionary" - - assert ( - UNIT_CELL in samples - ), f"the field '{UNIT_CELL}' must be present in the sample dictionary" - - # Dimension [batch_size, space_dimension, space_dimension] - basis_vectors = samples[UNIT_CELL].detach().cpu().numpy() - - # Dimension [batch_size, number_of_atoms, space_dimension] - cartesian_positions = samples[CARTESIAN_POSITIONS].detach().cpu().numpy() - - number_of_atoms = cartesian_positions.shape[1] - atom_types = np.ones(number_of_atoms, dtype=int) - - logger.info("Compute energy from Oracle") - - list_energy = [] - with tempfile.TemporaryDirectory() as tmp_work_dir: - for positions, box in zip(cartesian_positions, basis_vectors): - energy, forces = get_energy_and_forces_from_lammps( - positions, box, atom_types, tmp_work_dir=tmp_work_dir - ) - list_energy.append(energy) - logger.info("Done computing energies from Oracle") - return torch.tensor(list_energy) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/oracle/energy_oracle.py b/src/diffusion_for_multi_scale_molecular_dynamics/oracle/energy_oracle.py new file mode 100644 index 00000000..6bd4c5bb --- /dev/null +++ b/src/diffusion_for_multi_scale_molecular_dynamics/oracle/energy_oracle.py @@ -0,0 +1,84 @@ +import logging +from dataclasses import dataclass +from typing import AnyStr, Dict, List + +import numpy as np +import torch + +from diffusion_for_multi_scale_molecular_dynamics.data.element_types import \ + ElementTypes +from diffusion_for_multi_scale_molecular_dynamics.namespace import ( + AXL_COMPOSITION, CARTESIAN_POSITIONS, UNIT_CELL) + +logger = logging.getLogger(__name__) + + +@dataclass(kw_only=True) +class OracleParameters: + """Lammps Oracle Parameters.""" + name: str # what kind of Oracle + elements: List[str] # unique elements + + +class EnergyOracle: + """Energy oracle base class.""" + def __init__( + self, oracle_parameters: OracleParameters, **kwargs + ): + """Init method.""" + self._oracle_parameters = oracle_parameters + self._element_types = ElementTypes(oracle_parameters.elements) + + def _compute_one_configuration_energy( + self, + cartesian_positions: np.ndarray, + basis_vectors: np.ndarray, + atom_types: np.ndarray, + ) -> float: + raise NotImplementedError("This method must be implemented") + + def compute_oracle_energies( + self, samples: Dict[AnyStr, torch.Tensor] + ) -> torch.Tensor: + """Compute oracle energies. + + Method to call the oracle for samples expressed in a standardized format. + + Args: + samples: a dictionary assumed to contain the fields + - CARTESIAN_POSITIONS + - UNIT_CELL + + Returns: + energies: a numpy array with the computed energies. + """ + assert ( + CARTESIAN_POSITIONS in samples + ), f"the field '{CARTESIAN_POSITIONS}' must be present in the sample dictionary" + + assert ( + UNIT_CELL in samples + ), f"the field '{UNIT_CELL}' must be present in the sample dictionary" + + # Dimension [batch_size, space_dimension, space_dimension] + batched_basis_vectors = samples[UNIT_CELL].detach().cpu().numpy() # TODO: use the AXL_COMPOSITION + + # Dimension [batch_size, number_of_atoms, space_dimension] + batched_cartesian_positions = ( + samples[CARTESIAN_POSITIONS].detach().cpu().numpy() + ) + + # Dimension [batch_size, number_of_atoms] + batched_atom_types = samples[AXL_COMPOSITION].A.detach().cpu().numpy() + + logger.info("Compute energy from Oracle") + list_energy = [] + for cartesian_positions, basis_vectors, atom_types in zip( + batched_cartesian_positions, batched_basis_vectors, batched_atom_types + ): + energy = self._compute_one_configuration_energy( + cartesian_positions, basis_vectors, atom_types + ) + list_energy.append(energy) + logger.info("Done computing energies from Oracle") + return torch.tensor(list_energy) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/oracle/energy_oracle_factory.py b/src/diffusion_for_multi_scale_molecular_dynamics/oracle/energy_oracle_factory.py new file mode 100644 index 00000000..ac376ab9 --- /dev/null +++ b/src/diffusion_for_multi_scale_molecular_dynamics/oracle/energy_oracle_factory.py @@ -0,0 +1,48 @@ +from typing import Any, AnyStr, Dict, List + +from diffusion_for_multi_scale_molecular_dynamics.oracle.energy_oracle import ( + EnergyOracle, OracleParameters) +from diffusion_for_multi_scale_molecular_dynamics.oracle.lammps_energy_oracle import ( + LammpsEnergyOracle, LammpsOracleParameters) + +ORACLE_PARAMETERS_BY_NAME = dict(lammps=LammpsOracleParameters) +ENERGY_ORACLE_BY_NAME = dict(lammps=LammpsEnergyOracle) + + +def create_energy_oracle_parameters( + energy_oracle_dictionary: Dict[AnyStr, Any], elements: List[str] +) -> OracleParameters: + """Create energy oracle parameters. + + Args: + energy_oracle_dictionary : parsed configuration for the energy oracle. + elements : list of unique elements. + + Returns: + oracle_parameters: a configuration object for an energy oracle object. + """ + name = energy_oracle_dictionary["name"] + + assert ( + name in ORACLE_PARAMETERS_BY_NAME.keys() + ), f"Energy Oracle {name} is not implemented. Possible choices are {ORACLE_PARAMETERS_BY_NAME.keys()}" + + oracle_parameters = ORACLE_PARAMETERS_BY_NAME[name]( + **energy_oracle_dictionary, elements=elements + ) + return oracle_parameters + + +def create_energy_oracle(oracle_parameters: OracleParameters) -> EnergyOracle: + """Create an energy oracle. + + This is a factory method responsible for instantiating the energy oracle. + """ + name = oracle_parameters.name + assert ( + name in ENERGY_ORACLE_BY_NAME.keys() + ), f"Energy Oracle {name} is not implemented. Possible choices are {ENERGY_ORACLE_BY_NAME.keys()}" + + oracle = ENERGY_ORACLE_BY_NAME[name](oracle_parameters) + + return oracle diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/oracle/lammps.py b/src/diffusion_for_multi_scale_molecular_dynamics/oracle/lammps.py index 6cd96fbf..df3c9650 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/oracle/lammps.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/oracle/lammps.py @@ -1,6 +1,7 @@ """Call LAMMPS to get the forces and energy in a given configuration.""" import os +import warnings from pathlib import Path from typing import Dict, Tuple @@ -10,16 +11,18 @@ import yaml from pymatgen.core import Element -from diffusion_for_multi_scale_molecular_dynamics import DATA_DIR +from diffusion_for_multi_scale_molecular_dynamics.oracle import \ + SW_COEFFICIENTS_DIR +@warnings.deprecated("DO NOT USE THIS METHOD. It will be refactored away and replaced by LammpsEnergyOracle.") def get_energy_and_forces_from_lammps( cartesian_positions: np.ndarray, box: np.ndarray, atom_types: np.ndarray, atom_type_map: Dict[int, str] = {1: "Si"}, tmp_work_dir: str = "./", - pair_coeff_dir: Path = DATA_DIR, + pair_coeff_dir: Path = SW_COEFFICIENTS_DIR, ) -> Tuple[float, pd.DataFrame]: """Call LAMMPS to compute the forces on all atoms in a configuration. diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/oracle/lammps_energy_oracle.py b/src/diffusion_for_multi_scale_molecular_dynamics/oracle/lammps_energy_oracle.py new file mode 100644 index 00000000..6da21b3c --- /dev/null +++ b/src/diffusion_for_multi_scale_molecular_dynamics/oracle/lammps_energy_oracle.py @@ -0,0 +1,153 @@ +"""Call LAMMPS to get the forces and energy in a given configuration.""" + +import os +import tempfile +from dataclasses import dataclass +from pathlib import Path +from typing import List + +import lammps +import numpy as np +import pandas as pd +import yaml +from pymatgen.core import Element + +from diffusion_for_multi_scale_molecular_dynamics.oracle import \ + SW_COEFFICIENTS_DIR +from diffusion_for_multi_scale_molecular_dynamics.oracle.energy_oracle import ( + EnergyOracle, OracleParameters) + + +@dataclass(kw_only=True) +class LammpsOracleParameters(OracleParameters): + """Lammps Oracle Parameters.""" + name: str = 'lammps' + sw_coeff_filename: str # Stillinger-Weber potential filename + + +class LammpsEnergyOracle(EnergyOracle): + """Lammps energy oracle. + + This class invokes LAMMPS to get the forces and energy in a given configuration. + """ + def __init__( + self, + lammps_oracle_parameters: LammpsOracleParameters, + sw_coefficients_dir: Path = SW_COEFFICIENTS_DIR, + ): + """Init method. + + Args: + lammps_oracle_parameters : parameters for the LAMMPS Oracle. + sw_coefficients_dir : the directory where the sw cofficient files can be found. + """ + super().__init__(lammps_oracle_parameters) + self.sw_coefficients_file_path = str( + sw_coefficients_dir / lammps_oracle_parameters.sw_coeff_filename + ) + + assert os.path.isfile( + self.sw_coefficients_file_path + ), f"The SW file '{self.sw_coefficients_file_path}' does not exist." + + def _create_lammps_commands( + self, + cartesian_positions: np.ndarray, + box: np.ndarray, + atom_types: np.ndarray, + dump_file_path: Path, + ) -> List[str]: + commands = [] + commands.append("units metal") + commands.append("atom_style atomic") + commands.append( + f"region simbox block 0 {box[0, 0]} 0 {box[1, 1]} 0 {box[2, 2]}" + ) + commands.append(f"create_box {self._element_types.number_of_atom_types} simbox") + commands.append("pair_style sw") + + elements_string = "" + for element_id in self._element_types.element_ids: + group_id = element_id + 1 # don't start the groups at zero + element_name = self._element_types.get_element(element_id) + elements_string += f" {element_name}" + element_mass = Element(element_name).atomic_mass.real + commands.append(f"group {element_name} type {group_id}") + commands.append(f"mass {group_id} {element_mass}") + + commands.append( + f"pair_coeff * * {self.sw_coefficients_file_path}{elements_string}" + ) + + for idx, cartesian_position in enumerate(cartesian_positions): + element_id = atom_types[idx] + group_id = element_id + 1 # don't start the groups at zero + positions_string = " ".join(map(str, cartesian_position)) + commands.append(f"create_atoms {group_id} single {positions_string}") + + commands.append( + "fix 1 all nvt temp 300 300 0.01" + ) # selections here do not matter because we only do 1 step + commands.append(f"dump 1 all yaml 1 {dump_file_path} id element x y z fx fy fz") + commands.append(f"dump_modify 1 element {elements_string}") + commands.append( + "run 0" + ) # 0 is the last step index - so run 0 means no MD update - just get the initial forces + return commands + + def _compute_energy_and_forces( + self, cartesian_positions: np.ndarray, box: np.ndarray, atom_types: np.ndarray, dump_file_path: Path + ): + """Call LAMMPS to compute the energy and forces on all atoms in a configuration. + + Args: + cartesian_positions: atomic positions in Euclidean space as a n_atom x spatial dimension array + box: spatial dimension x spatial dimension array representing the periodic box. Assumed to be orthogonal. + atom_types: n_atom array with an index representing the type of each atom + dump_file_path: a temporary file where lammps will dump results. + + Returns: + energy: energy of configuration + forces: forces on each atom in the configuration + """ + assert np.allclose(box, np.diag(np.diag(box))), "only orthogonal LAMMPS box are valid" + + # create a lammps run, turning off logging + lmp = lammps.lammps(cmdargs=["-log", "none", "-echo", "none", "-screen", "none"]) + + commands = self._create_lammps_commands(cartesian_positions, box, atom_types, dump_file_path) + for command in commands: + lmp.command(command) + + # read information from lammps output + with open(dump_file_path, "r") as f: + dump_yaml = yaml.safe_load_all(f) + doc = next(iter(dump_yaml)) + + forces = pd.DataFrame(doc["data"], columns=doc["keywords"]).sort_values( + "id" + ) # organize in a dataframe + + # get the energy + ke = lmp.get_thermo( + "ke" + ) # kinetic energy - should be 0 as atoms are created with 0 velocity + pe = lmp.get_thermo("pe") # potential energy + energy = ke + pe + + return energy, forces + + def _compute_one_configuration_energy(self, cartesian_positions: np.ndarray, + basis_vectors: np.ndarray, + atom_types: np.ndarray) -> float: + + with tempfile.TemporaryDirectory() as tmp_work_dir: + dump_file_path = Path(tmp_work_dir) / "dump.yaml" + energy, _ = self._compute_energy_and_forces(cartesian_positions, + basis_vectors, + atom_types, + dump_file_path) + # clean up! + dump_file_path.unlink() + + return energy diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py b/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py index d6fe7fea..26260536 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/sample_diffusion.py @@ -12,26 +12,30 @@ import torch +from diffusion_for_multi_scale_molecular_dynamics.data.element_types import \ + ElementTypes +from diffusion_for_multi_scale_molecular_dynamics.generators.axl_generator import \ + SamplingParameters from diffusion_for_multi_scale_molecular_dynamics.generators.instantiate_generator import \ instantiate_generator from diffusion_for_multi_scale_molecular_dynamics.generators.load_sampling_parameters import \ load_sampling_parameters -from diffusion_for_multi_scale_molecular_dynamics.generators.position_generator import \ - SamplingParameters -from diffusion_for_multi_scale_molecular_dynamics.main_utils import \ - load_and_backup_hyperparameters -from diffusion_for_multi_scale_molecular_dynamics.models.position_diffusion_lightning_model import \ - PositionDiffusionLightningModel +from diffusion_for_multi_scale_molecular_dynamics.models.axl_diffusion_lightning_model import \ + AXLDiffusionLightningModel from diffusion_for_multi_scale_molecular_dynamics.models.score_networks import \ ScoreNetwork -from diffusion_for_multi_scale_molecular_dynamics.oracle.energies import \ - compute_oracle_energies -from diffusion_for_multi_scale_molecular_dynamics.samplers.variance_sampler import \ +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ NoiseParameters -from diffusion_for_multi_scale_molecular_dynamics.samples.sampling import \ +from diffusion_for_multi_scale_molecular_dynamics.oracle.energy_oracle import \ + OracleParameters +from diffusion_for_multi_scale_molecular_dynamics.oracle.energy_oracle_factory import ( + create_energy_oracle, create_energy_oracle_parameters) +from diffusion_for_multi_scale_molecular_dynamics.sampling.diffusion_sampling import \ create_batch_of_samples from diffusion_for_multi_scale_molecular_dynamics.utils.logging_utils import ( get_git_hash, setup_console_logger) +from diffusion_for_multi_scale_molecular_dynamics.utils.main_utils import \ + load_and_backup_hyperparameters logger = logging.getLogger(__name__) @@ -87,9 +91,20 @@ def main(args: Optional[Any] = None): hyper_params ) + if "elements" in hyper_params: + ElementTypes.validate_elements(hyper_params["elements"]) + + oracle_parameters = None + if "oracle" in hyper_params: + assert "elements" in hyper_params, \ + "elements are needed to define the energy oracle." + elements = hyper_params["elements"] + oracle_parameters = create_energy_oracle_parameters(hyper_params["oracle"], elements) + create_samples_and_write_to_disk( noise_parameters=noise_parameters, sampling_parameters=sampling_parameters, + oracle_parameters=oracle_parameters, device=device, checkpoint_path=args.checkpoint, output_path=args.output, @@ -119,28 +134,27 @@ def extract_and_validate_parameters(hyper_params: Dict[AnyStr, Any]): return noise_parameters, sampling_parameters -def get_sigma_normalized_score_network( - checkpoint_path: Union[str, Path] -) -> ScoreNetwork: - """Get sigma-normalized score network. +def get_axl_network(checkpoint_path: Union[str, Path]) -> ScoreNetwork: + """Get AXL network. Args: checkpoint_path : path where the checkpoint is written. Returns: - sigma_normalized score network: read from the checkpoint. + axl network network: read from the checkpoint. """ logger.info("Loading checkpoint...") - pl_model = PositionDiffusionLightningModel.load_from_checkpoint(checkpoint_path) + pl_model = AXLDiffusionLightningModel.load_from_checkpoint(checkpoint_path) pl_model.eval() - sigma_normalized_score_network = pl_model.sigma_normalized_score_network - return sigma_normalized_score_network + axl_network = pl_model.axl_network + return axl_network def create_samples_and_write_to_disk( noise_parameters: NoiseParameters, sampling_parameters: SamplingParameters, + oracle_parameters: Union[OracleParameters, None], device: torch.device, checkpoint_path: Union[str, Path], output_path: Union[str, Path], @@ -159,13 +173,13 @@ def create_samples_and_write_to_disk( Returns: None """ - sigma_normalized_score_network = get_sigma_normalized_score_network(checkpoint_path) + axl_network = get_axl_network(checkpoint_path) logger.info("Instantiate generator...") position_generator = instantiate_generator( sampling_parameters=sampling_parameters, noise_parameters=noise_parameters, - sigma_normalized_score_network=sigma_normalized_score_network, + axl_network=axl_network, ) logger.info("Generating samples...") @@ -182,12 +196,14 @@ def create_samples_and_write_to_disk( with open(output_directory / "samples.pt", "wb") as fd: torch.save(samples_batch, fd) - logger.info("Compute energy from Oracle...") - sample_energies = compute_oracle_energies(samples_batch) + if oracle_parameters: + logger.info("Compute energy from Oracle...") + oracle = create_energy_oracle(oracle_parameters) + sample_energies = oracle.compute_oracle_energies(samples_batch) - logger.info("Writing energies to disk...") - with open(output_directory / "energies.pt", "wb") as fd: - torch.save(sample_energies, fd) + logger.info("Writing energies to disk...") + with open(output_directory / "energies.pt", "wb") as fd: + torch.save(sample_energies, fd) if sampling_parameters.record_samples: logger.info("Writing sampling trajectories to disk...") diff --git a/tests/samplers/__init__.py b/src/diffusion_for_multi_scale_molecular_dynamics/sampling/__init__.py similarity index 100% rename from tests/samplers/__init__.py rename to src/diffusion_for_multi_scale_molecular_dynamics/sampling/__init__.py diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/samples/sampling.py b/src/diffusion_for_multi_scale_molecular_dynamics/sampling/diffusion_sampling.py similarity index 68% rename from src/diffusion_for_multi_scale_molecular_dynamics/samples/sampling.py rename to src/diffusion_for_multi_scale_molecular_dynamics/sampling/diffusion_sampling.py index 34d2b3db..1e196157 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/samples/sampling.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/sampling/diffusion_sampling.py @@ -2,10 +2,10 @@ import torch -from diffusion_for_multi_scale_molecular_dynamics.generators.position_generator import ( - PositionGenerator, SamplingParameters) +from diffusion_for_multi_scale_molecular_dynamics.generators.axl_generator import ( + AXLGenerator, SamplingParameters) from diffusion_for_multi_scale_molecular_dynamics.namespace import ( - CARTESIAN_POSITIONS, RELATIVE_COORDINATES, UNIT_CELL) + AXL, AXL_COMPOSITION, CARTESIAN_POSITIONS, UNIT_CELL) from diffusion_for_multi_scale_molecular_dynamics.utils.basis_transformations import \ get_positions_from_coordinates from diffusion_for_multi_scale_molecular_dynamics.utils.structure_utils import \ @@ -15,7 +15,7 @@ def create_batch_of_samples( - generator: PositionGenerator, + generator: AXLGenerator, sampling_parameters: SamplingParameters, device: torch.device, ): @@ -24,7 +24,7 @@ def create_batch_of_samples( Utility function to drive the generation of samples. Args: - generator : position generator. + generator : AXL generator. sampling_parameters : parameters defining how to sample. device: device where the generator is located. @@ -44,24 +44,36 @@ def create_batch_of_samples( sample_batch_size = sampling_parameters.sample_batchsize list_sampled_relative_coordinates = [] + list_sampled_atom_types = [] + list_sampled_lattice_vectors = [] for sampling_batch_indices in torch.split( torch.arange(number_of_samples), sample_batch_size ): basis_vectors_ = basis_vectors[sampling_batch_indices] - sampled_relative_coordinates = generator.sample( + sampled_axl = generator.sample( len(sampling_batch_indices), unit_cell=basis_vectors_, device=device ) - list_sampled_relative_coordinates.append(sampled_relative_coordinates) + list_sampled_atom_types.append(sampled_axl.A) + list_sampled_relative_coordinates.append(sampled_axl.X) + list_sampled_lattice_vectors.append(sampled_axl.L) + atom_types = torch.concat(list_sampled_atom_types) relative_coordinates = torch.concat(list_sampled_relative_coordinates) + lattice_vectors = torch.concat(list_sampled_lattice_vectors) + axl_composition = AXL( + A=atom_types, + X=relative_coordinates, + L=lattice_vectors, + ) + cartesian_positions = get_positions_from_coordinates( relative_coordinates, basis_vectors ) batch = { CARTESIAN_POSITIONS: cartesian_positions, - RELATIVE_COORDINATES: relative_coordinates, - UNIT_CELL: basis_vectors, + AXL_COMPOSITION: axl_composition, + UNIT_CELL: basis_vectors, # TODO remove } return batch diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/samples/diffusion_sampling_parameters.py b/src/diffusion_for_multi_scale_molecular_dynamics/sampling/diffusion_sampling_parameters.py similarity index 92% rename from src/diffusion_for_multi_scale_molecular_dynamics/samples/diffusion_sampling_parameters.py rename to src/diffusion_for_multi_scale_molecular_dynamics/sampling/diffusion_sampling_parameters.py index 5a4f5fdb..541d4b27 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/samples/diffusion_sampling_parameters.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/sampling/diffusion_sampling_parameters.py @@ -1,13 +1,13 @@ from dataclasses import dataclass from typing import Any, AnyStr, Dict, Union +from diffusion_for_multi_scale_molecular_dynamics.generators.axl_generator import \ + SamplingParameters from diffusion_for_multi_scale_molecular_dynamics.generators.load_sampling_parameters import \ load_sampling_parameters -from diffusion_for_multi_scale_molecular_dynamics.generators.position_generator import \ - SamplingParameters from diffusion_for_multi_scale_molecular_dynamics.metrics.sampling_metrics_parameters import \ SamplingMetricsParameters -from diffusion_for_multi_scale_molecular_dynamics.samplers.variance_sampler import \ +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ NoiseParameters diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/train_diffusion.py b/src/diffusion_for_multi_scale_molecular_dynamics/train_diffusion.py index fb5c34b3..e95987c6 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/train_diffusion.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/train_diffusion.py @@ -14,17 +14,19 @@ create_all_callbacks from diffusion_for_multi_scale_molecular_dynamics.data.diffusion.data_loader import ( LammpsForDiffusionDataModule, LammpsLoaderParameters) +from diffusion_for_multi_scale_molecular_dynamics.data.element_types import \ + ElementTypes from diffusion_for_multi_scale_molecular_dynamics.loggers.logger_loader import \ create_all_loggers -from diffusion_for_multi_scale_molecular_dynamics.main_utils import ( - MetricResult, get_crash_metric_result, get_optimized_metric_name_and_mode, - load_and_backup_hyperparameters, report_to_orion_if_on) from diffusion_for_multi_scale_molecular_dynamics.models.instantiate_diffusion_model import \ load_diffusion_model from diffusion_for_multi_scale_molecular_dynamics.utils.hp_utils import \ check_and_log_hp from diffusion_for_multi_scale_molecular_dynamics.utils.logging_utils import ( log_exp_details, setup_console_logger) +from diffusion_for_multi_scale_molecular_dynamics.utils.main_utils import ( + MetricResult, get_crash_metric_result, get_optimized_metric_name_and_mode, + load_and_backup_hyperparameters, report_to_orion_if_on) logger = logging.getLogger(__name__) @@ -119,7 +121,9 @@ def run(args, output_dir, hyper_params): if hyper_params["seed"] is not None: pytorch_lightning.seed_everything(hyper_params["seed"]) - data_params = LammpsLoaderParameters(**hyper_params["data"]) + ElementTypes.validate_elements(hyper_params["elements"]) + + data_params = LammpsLoaderParameters(**hyper_params["data"], elements=hyper_params["elements"]) datamodule = LammpsForDiffusionDataModule( lammps_run_dir=args.data, diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/utils/basis_transformations.py b/src/diffusion_for_multi_scale_molecular_dynamics/utils/basis_transformations.py index 3eaeabf3..9507549a 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/utils/basis_transformations.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/utils/basis_transformations.py @@ -1,5 +1,7 @@ import torch +from diffusion_for_multi_scale_molecular_dynamics.namespace import AXL + def get_reciprocal_basis_vectors(basis_vectors: torch.Tensor) -> torch.Tensor: """Get reciprocal basis vectors. @@ -112,3 +114,22 @@ def map_relative_coordinates_to_unit_cell( normalized_relative_coordinates = torch.remainder(relative_coordinates, 1.0) normalized_relative_coordinates[normalized_relative_coordinates == 1.0] = 0.0 return normalized_relative_coordinates + + +def map_axl_composition_to_unit_cell(composition: AXL, device: torch.device) -> AXL: + """Map relative coordinates in an AXL namedtuple back to unit cell and update the namedtuple. + + Args: + composition: AXL namedtuple with atom types, relative coordinates and lattice as tensors of arbitrary shapes. + device: device where to map the updated relative coordinates tensor + + Returns: + normalized_composition: AXL namedtuple with relative coordinates in the unit cell i.e. in the range [0, 1). + """ + normalized_relative_coordinates = map_relative_coordinates_to_unit_cell( + composition.X + ).to(device) + normalized_composition = AXL( + A=composition.A, X=normalized_relative_coordinates, L=composition.L + ) + return normalized_composition diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/utils/d3pm_utils.py b/src/diffusion_for_multi_scale_molecular_dynamics/utils/d3pm_utils.py new file mode 100644 index 00000000..8e36755b --- /dev/null +++ b/src/diffusion_for_multi_scale_molecular_dynamics/utils/d3pm_utils.py @@ -0,0 +1,149 @@ +"""Common operations used for Discrete Diffusion.""" + +import einops +import torch + + +def class_index_to_onehot(index: torch.Tensor, num_classes: int) -> torch.Tensor: + """Convert a tensor of class indices to a one-hot representation. + + Args: + index: index tensor to encode + num_classes: total number of classes + + Returns: + float tensor of 0s and 1s. The size is x.size() + (num_classes) + """ + # the last .to() acts on the tensor type to avoid longs + return torch.nn.functional.one_hot(index.long(), num_classes=num_classes).to( + device=index.device, dtype=torch.float + ) + + +def compute_q_at_given_a0( + one_hot_a0: torch.Tensor, q_bar_t: torch.Tensor +) -> torch.Tensor: + r"""Compute :math:`q(a_t | a_0)`. + + This is done by the vector-matrix product: :math:`a_0 \bar{Q}_t` assuming a_0 is a one-hot vector or a distribution + over different classes. + + Args: + one_hot_x0: initial state (:math:`a_0`). The last dimension should be the number of classes. + q_bar_t: cumulative Markov transition matrix (:math:`\bar{Q}_t`). The last 2 dimensions should be the number of + classes. + + Returns: + matrix-vector product between one_hot_x0 and q_bar_t that defines :math:`q(a_t | a_0)` + """ + return einops.einsum(one_hot_a0.to(q_bar_t), q_bar_t, "... j, ... j i -> ... i") + + +def compute_q_at_given_atm1( + one_hot_atm1: torch.Tensor, q_tm1: torch.Tensor +) -> torch.Tensor: + r"""Compute :math:`q(a_t | a_{t-1})`. + + This is done by the vector-matrix product: :math:`a_{t-1} Q_{t-1}^T` assuming :math:`a_{t-1}` is a one-hot vector or + a distribution over different classes. The transition matrix Q is a 1-step transition matrix. + + Args: + one_hot_atm1: state (:math:`a_{t-1}`). The last dimension should be the number of classes. + q_tm1: Markov transition matrix (:math:`Q_{t-1}`). The last 2 dimensions should be the number of classes. + + Returns: + matrix-vector product between one_hot_atm1 and :math:`Q_{t-1}^T` that defines :math:`q(a_t | a_{t-1})` + """ + return einops.einsum( + one_hot_atm1.to(q_tm1), + torch.transpose(q_tm1, -2, -1), + "... j, ... i j -> ... i", + ) + + +def get_probability_at_previous_time_step( + probability_at_zeroth_timestep: torch.Tensor, + one_hot_probability_at_current_timestep: torch.Tensor, + q_matrices: torch.Tensor, + q_bar_matrices: torch.Tensor, + q_bar_tm1_matrices: torch.Tensor, + small_epsilon: float, + probability_at_zeroth_timestep_are_logits: bool = False, +) -> torch.Tensor: + r"""Compute :math:`P(a_{t-1} | a_t, \gamma_0)`. + + For given probability distribution :math:`\gamma_0` and a one-hot distribution :math:`a_t`. + + .. math:: + P(a_{t-1} | a_t, \gamma_0) = (\gamma_0^T \cdot \bar{Q}_{t-1} \cdot a_{t-1}) (a_{t-1}^T \cdot Q_t \cdot a_t) / + (\gamma_0^T \cdot \bar{Q}_{t} \cdot a_t) + + Args: + probability_at_zeroth_timestep: :math:`\gamma_0` a probability representation of a class type (one-hot + distribution or normalized distribution), as a tensor with dimension + [batch_size, number_of_atoms, num_classes] + one_hot_probability_at_current_timestep: :math:`a_t` a one-hot representation of a class type at current time + step, as a tensor with dimension [batch_size, number_of_atoms, num_classes] + q_matrices: :math:`{Q}_{t}` transition matrices at current time step of dimension + [batch_size, number_of_atoms, num_classes, num_classes]. + q_bar_matrices: :math:`\bar{Q}_{t}` one-shot transition matrices at current time step of dimension + [batch_size, number_of_atoms, num_classes, num_classes]. + q_bar_tm1_matrices: :math:`\bar{Q}_{t-1}` one-shot transition matrices at previous time step of dimension + [batch_size, number_of_atoms, num_classes, num_classes]. + small_epsilon: minimum value for the denominator, to avoid division by zero. + probability_at_zeroth_timestep_are_logits: if True, assume the probability_at_zeroth_timestep do not sum to 1 + and use a softmax on the last dimension to normalize. If False, assume the probabilities are normalized. + Defaults to False. + + Returns: + one-step transition normalized probabilities of dimension [batch_size, number_of_atoms, num_type_atoms] + """ + if probability_at_zeroth_timestep_are_logits: + probability_at_zeroth_timestep = get_probability_from_logits(probability_at_zeroth_timestep, + lowest_probability_value=small_epsilon) + + numerator1 = einops.einsum( + probability_at_zeroth_timestep, q_bar_tm1_matrices, "... j, ... j i -> ... i" + ) + numerator2 = einops.einsum( + q_matrices, one_hot_probability_at_current_timestep, "... i j, ... j -> ... i" + ) + numerator = numerator1 * numerator2 + + den1 = einops.einsum( + q_bar_matrices, + one_hot_probability_at_current_timestep, + "... i j, ... j -> ... i", + ) + den2 = einops.einsum(probability_at_zeroth_timestep, den1, "... j, ... j -> ...") + + denominator = einops.repeat( + den2, "... -> ... num_classes", num_classes=numerator.shape[-1] + ) + + return numerator / denominator + + +def get_probability_from_logits(logits: torch.Tensor, lowest_probability_value: float) -> torch.Tensor: + """Get probability from logits. + + Compute the probabilities from the logit, imposing that no class probablility can be lower than + lowest_probability_value. + + Args: + logits: Unormalized values that can be turned into probabilities. Dimensions [..., num_classes] + lowest_probability_value: imposed lowest probability value for any class. + + Returns: + probabilities: derived from the logits, with minimal clipped values. Dimensions [..., num_classes]. + + """ + raw_probabilities = torch.nn.functional.softmax(logits, dim=-1) + probability_sum = raw_probabilities.sum(dim=-1) + torch.testing.assert_close(probability_sum, torch.ones_like(probability_sum), + msg="Logits are pathological: the probabilities do not sum to one.") + + clipped_probabilities = raw_probabilities.clip(min=lowest_probability_value) + + probabilities = clipped_probabilities / clipped_probabilities.sum(dim=-1).unsqueeze(-1) + return probabilities diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/utils/geometric_utils.py b/src/diffusion_for_multi_scale_molecular_dynamics/utils/geometric_utils.py new file mode 100644 index 00000000..08297e29 --- /dev/null +++ b/src/diffusion_for_multi_scale_molecular_dynamics/utils/geometric_utils.py @@ -0,0 +1,20 @@ +import itertools + +import torch + + +def get_cubic_point_group_symmetries(): + """Get cubic point group symmetries.""" + permutations = [ + torch.diag(torch.ones(3))[[idx]] for idx in itertools.permutations([0, 1, 2]) + ] + sign_changes = [ + torch.diag(torch.tensor(diag)) + for diag in itertools.product([-1.0, 1.0], repeat=3) + ] + symmetries = [] + for permutation in permutations: + for sign_change in sign_changes: + symmetries.append(permutation @ sign_change) + + return torch.stack(symmetries) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/main_utils.py b/src/diffusion_for_multi_scale_molecular_dynamics/utils/main_utils.py similarity index 100% rename from src/diffusion_for_multi_scale_molecular_dynamics/main_utils.py rename to src/diffusion_for_multi_scale_molecular_dynamics/utils/main_utils.py diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/utils/ovito_utils.py b/src/diffusion_for_multi_scale_molecular_dynamics/utils/ovito_utils.py index df72a18e..b58722c6 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/utils/ovito_utils.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/utils/ovito_utils.py @@ -10,50 +10,64 @@ import numpy as np import ovito -import torch from ovito.io import import_file from ovito.modifiers import (AffineTransformationModifier, CombineDatasetsModifier, CreateBondsModifier) from pymatgen.core import Lattice, Structure from tqdm import tqdm +from diffusion_for_multi_scale_molecular_dynamics.data.element_types import \ + ElementTypes +from diffusion_for_multi_scale_molecular_dynamics.namespace import AXL + +UNKNOWN_ATOM_TYPE = "X" + _cif_directory_template = "cif_files_trajectory_{trajectory_index}" _cif_file_name_template = "diffusion_positions_step_{time_index}.cif" def create_cif_files( + elements: list[str], visualization_artifacts_path: Path, trajectory_index: int, - ode_trajectory_pickle: Path, + trajectory_axl_compositions: AXL, ): """Create cif files. Args: + elements: list of unique elements present in the samples visualization_artifacts_path : where the various visualization artifacts should be written to disk. trajectory_index : the index of the trajectory to be loaded. - ode_trajectory_pickle : Path to the data pickle written by ODESampleTrajectory. + trajectory_axl_compositions: AXL that contains the trajectories, where each field + has dimension [samples, time, ...] Returns: None """ - data = torch.load(ode_trajectory_pickle, map_location=torch.device("cpu")) + element_types = ElementTypes(elements) + atom_type_map = dict() + for element in elements: + id = element_types.get_element_id(element) + atom_type_map[id] = element + + mask_id = np.max(element_types.element_ids) + 1 + atom_type_map[mask_id] = UNKNOWN_ATOM_TYPE cif_directory = visualization_artifacts_path / _cif_directory_template.format( trajectory_index=trajectory_index ) cif_directory.mkdir(exist_ok=True, parents=True) - basis_vectors = data["unit_cell"][trajectory_index].numpy() - lattice = Lattice(matrix=basis_vectors, pbc=(True, True, True)) - trajectory_relative_coordinates = data["relative_coordinates"][ - trajectory_index - ].numpy() + trajectory_atom_types = trajectory_axl_compositions.A[trajectory_index].numpy() + trajectory_relative_coordinates = trajectory_axl_compositions.X[trajectory_index].numpy() + trajectory_lattices = trajectory_axl_compositions.L[trajectory_index].numpy() - for time_idx, relative_coordinates in tqdm( - enumerate(trajectory_relative_coordinates), "Write CIFs" + for time_idx, (atom_types, relative_coordinates, basis_vectors) in tqdm( + enumerate(zip(trajectory_atom_types, trajectory_relative_coordinates, trajectory_lattices)), "Write CIFs" ): - number_of_atoms = relative_coordinates.shape[0] - species = number_of_atoms * ["Si"] + + lattice = Lattice(matrix=basis_vectors, pbc=(True, True, True)) + species = list(map(atom_type_map.get, atom_types)) structure = Structure( lattice=lattice, diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/utils/sample_trajectory.py b/src/diffusion_for_multi_scale_molecular_dynamics/utils/sample_trajectory.py index 82407554..089b9ef2 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/utils/sample_trajectory.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/utils/sample_trajectory.py @@ -1,238 +1,44 @@ from collections import defaultdict -from typing import Any, AnyStr, Dict +from typing import Any, Dict, NamedTuple, Union -import einops import torch class SampleTrajectory: """Sample Trajectory. - This class aims to record all details of the diffusion sampling process. The goal is to produce + This class aims to record the diffusion sampling process. The goal is to produce an artifact that can then be analyzed off-line. """ def __init__(self): """Init method.""" - self.data = defaultdict(list) + self._internal_data = defaultdict(list) def reset(self): """Reset data structure.""" - self.data = defaultdict(list) + self._internal_data = defaultdict(list) - def record_unit_cell(self, unit_cell: torch.Tensor): - """Record unit cell.""" - self.data["unit_cell"] = unit_cell.detach().cpu() + def record(self, key: str, entry: Union[Dict[str, Any], NamedTuple]): + """Record. - def standardize_data(self, data: Dict[AnyStr, Any]) -> Dict[AnyStr, Any]: - """Method to transform the recorded data to a standard form.""" - raise NotImplementedError("Must be implemented in child class.") + Record data from a trajectory. - def write_to_pickle(self, path_to_pickle: str): - """Write standardized data to pickle file.""" - standard_data = self.standardize_data(self.data) - with open(path_to_pickle, "wb") as fd: - torch.save(standard_data, fd) - - -class ODESampleTrajectory(SampleTrajectory): - """ODE Sample Trajectory. - - This class aims to record all details of the ODE diffusion sampling process. The goal is to produce - an artifact that can then be analyzed off-line. - """ - - def record_ode_solution( - self, - times: torch.Tensor, - sigmas: torch.Tensor, - relative_coordinates: torch.Tensor, - normalized_scores: torch.Tensor, - stats: Dict, - status: torch.Tensor, - ): - """Record ODE solution information.""" - self.data["time"].append(times) - self.data["sigma"].append(sigmas) - self.data["relative_coordinates"].append(relative_coordinates) - self.data["normalized_scores"].append(normalized_scores) - self.data["stats"].append(stats) - self.data["status"].append(status) - - def standardize_data(self, data: Dict[AnyStr, Any]) -> Dict[AnyStr, Any]: - """Method to transform the recorded data to a standard form.""" - extra_fields = ["stats", "status"] - standardized_data = dict( - unit_cell=data["unit_cell"], - time=data["time"][0], - sigma=data["sigma"][0], - relative_coordinates=data["relative_coordinates"][0], - normalized_scores=data["normalized_scores"][0], - extra={key: data[key][0] for key in extra_fields}, - ) - return standardized_data - - -class SDESampleTrajectory(SampleTrajectory): - """SDE Sample Trajectory. - - This class aims to record all details of the SDE diffusion sampling process. The goal is to produce - an artifact that can then be analyzed off-line. - """ - - def record_sde_solution( - self, - times: torch.Tensor, - sigmas: torch.Tensor, - relative_coordinates: torch.Tensor, - normalized_scores: torch.Tensor, - ): - """Record ODE solution information.""" - self.data["time"].append(times) - self.data["sigma"].append(sigmas) - self.data["relative_coordinates"].append(relative_coordinates) - self.data["normalized_scores"].append(normalized_scores) - - def standardize_data(self, data: Dict[AnyStr, Any]) -> Dict[AnyStr, Any]: - """Method to transform the recorded data to a standard form.""" - standardized_data = dict( - unit_cell=data["unit_cell"], - time=data["time"][0], - sigma=data["sigma"][0], - relative_coordinates=data["relative_coordinates"][0], - normalized_scores=data["normalized_scores"][0], - ) - return standardized_data - - -class NoOpODESampleTrajectory(ODESampleTrajectory): - """A sample trajectory object that performs no operation.""" + Args: + key: name of internal list to which the entry will be added. + entry: dictionary-like data to be recorded. - def record_unit_cell(self, unit_cell: torch.Tensor): - """No Op.""" - return - - def record_ode_solution( - self, - times: torch.Tensor, - sigmas: torch.Tensor, - relative_coordinates: torch.Tensor, - normalized_scores: torch.Tensor, - stats: Dict, - status: torch.Tensor, - ): - """No Op.""" - return + Returns: + None. + """ + self._internal_data[key].append(entry) def write_to_pickle(self, path_to_pickle: str): - """No Op.""" - return - - -class PredictorCorrectorSampleTrajectory(SampleTrajectory): - """Predictor Corrector Sample Trajectory. - - This class aims to record all details of the predictor-corrector diffusion sampling process. The goal is to produce - an artifact that can then be analyzed off-line. - """ - - def record_predictor_step( - self, - i_index: int, - time: float, - sigma: float, - x_i: torch.Tensor, - x_im1: torch.Tensor, - scores: torch.Tensor, - ): - """Record predictor step.""" - self.data["predictor_i_index"].append(i_index) - self.data["predictor_time"].append(time) - self.data["predictor_sigma"].append(sigma) - self.data["predictor_x_i"].append(x_i.detach().cpu()) - self.data["predictor_x_im1"].append(x_im1.detach().cpu()) - self.data["predictor_scores"].append(scores.detach().cpu()) - - def record_corrector_step( - self, - i_index: int, - time: float, - sigma: float, - x_i: torch.Tensor, - corrected_x_i: torch.Tensor, - scores: torch.Tensor, - ): - """Record corrector step.""" - self.data["corrector_i_index"].append(i_index) - self.data["corrector_time"].append(time) - self.data["corrector_sigma"].append(sigma) - self.data["corrector_x_i"].append(x_i.detach().cpu()) - self.data["corrector_corrected_x_i"].append(corrected_x_i.detach().cpu()) - self.data["corrector_scores"].append(scores.detach().cpu()) - - def standardize_data(self, data: Dict[AnyStr, Any]) -> Dict[AnyStr, Any]: - """Method to transform the recorded data to a standard form.""" - predictor_relative_coordinates = einops.rearrange( - torch.stack(data["predictor_x_i"]), "t b n d -> b t n d" - ) - predictor_normalized_scores = einops.rearrange( - torch.stack(data["predictor_scores"]), "t b n d -> b t n d" - ) + """Write data to pickle file.""" + self._internal_data = dict(self._internal_data) + for key, value in self._internal_data.items(): + if len(value) == 1: + self._internal_data[key] = value[0] - extra_fields = [ - "predictor_i_index", - "predictor_x_i", - "predictor_x_im1", - "corrector_i_index", - "corrector_time", - "corrector_sigma", - "corrector_x_i", - "corrector_corrected_x_i", - "corrector_scores", - ] - - standardized_data = dict( - unit_cell=data["unit_cell"], - time=torch.tensor(data["predictor_time"]), - sigma=torch.tensor(data["predictor_sigma"]), - relative_coordinates=predictor_relative_coordinates, - normalized_scores=predictor_normalized_scores, - extra={key: data[key] for key in extra_fields}, - ) - return standardized_data - - -class NoOpPredictorCorrectorSampleTrajectory(PredictorCorrectorSampleTrajectory): - """A sample trajectory object that performs no operation.""" - - def record_unit_cell(self, unit_cell: torch.Tensor): - """No Op.""" - return - - def record_predictor_step( - self, - i_index: int, - time: float, - sigma: float, - x_i: torch.Tensor, - x_im1: torch.Tensor, - scores: torch.Tensor, - ): - """No Op.""" - return - - def record_corrector_step( - self, - i_index: int, - time: float, - sigma: float, - x_i: torch.Tensor, - corrected_x_i: torch.Tensor, - scores: torch.Tensor, - ): - """No Op.""" - return - - def write_to_pickle(self, path_to_pickle: str): - """No Op.""" - return + with open(path_to_pickle, "wb") as fd: + torch.save(self._internal_data, fd) diff --git a/src/diffusion_for_multi_scale_molecular_dynamics/utils/tensor_utils.py b/src/diffusion_for_multi_scale_molecular_dynamics/utils/tensor_utils.py index 92b89e67..bb10f2c2 100644 --- a/src/diffusion_for_multi_scale_molecular_dynamics/utils/tensor_utils.py +++ b/src/diffusion_for_multi_scale_molecular_dynamics/utils/tensor_utils.py @@ -16,7 +16,7 @@ def broadcast_batch_tensor_to_all_dimensions( This is useful when we want to multiply every value in the data example by the same number. Args: - batch_values : values to be braodcasted, of shape [batch_size] + batch_values : values to be broadcasted, of shape [batch_size] final_shape : shape of the final tensor, [batch_size, n1, n2, ...] Returns: @@ -38,3 +38,46 @@ def broadcast_batch_tensor_to_all_dimensions( reshape_dimension = [-1] + (number_of_dimensions - 1) * [1] broadcast_values = batch_values.reshape(reshape_dimension).expand(final_shape) return broadcast_values + + +def broadcast_batch_matrix_tensor_to_all_dimensions( + batch_values: torch.Tensor, final_shape: Tuple[int, ...] +) -> torch.Tensor: + """Broadcast batch tensor to all dimensions. + + A data matrix batch is typically a tensor of shape [batch_size, n1, n2, ..., m1, m2] where n1, n2, etc constitute + one example of the data and m1 and m2 are the matrix dimensions. This method broadcasts a tensor of shape + [batch_size, m1, m2] to a tensor of shape + [batch_size, n1, n2, ..., m1, m2] where all the values for the non-batch and matrix dimensions are equal to the + value for the given batch index and matrix element. + + This is useful when we want to multiply every value in the data example by the same matrix. + + Args: + batch_values : values to be broadcasted, of shape [batch_size, m1, m2] + final_shape : shape of the final tensor, excluding the matrix dimensions [batch_size, n1, n2, ...,] + + Returns: + broadcast_values : tensor of shape [batch_size, n1, n2, ..., m1, m2], where all entries are identical + along non-batch and non-matrix dimensions. + """ + assert ( + len(batch_values.shape) == 3 + ), "The batch values should be a three-dimensional tensor." + batch_size = batch_values.shape[0] + matrix_size = batch_values.shape[-2:] + + assert ( + final_shape[0] == batch_size + ), "The final shape should have the batch_size as its first dimension." + + # reshape the batch_values array to have the same dimension as final_shape, with all values identical + # for a given batch index. + number_of_dimensions = len(final_shape) + reshape_dimension = ( + torch.Size([batch_size] + (number_of_dimensions - 1) * [1]) + matrix_size + ) + broadcast_values = batch_values.reshape(reshape_dimension).expand( + torch.Size(final_shape) + matrix_size + ) + return broadcast_values diff --git a/tests/conftest.py b/tests/conftest.py index 5fb60e8b..115ae5d8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,6 +4,7 @@ from tests.fake_data_utils import (create_dump_yaml_documents, create_thermo_yaml_documents, + generate_random_string, get_configuration_runs, write_to_yaml) @@ -36,9 +37,17 @@ def pytest_collection_modifyitems(config, items): _available_devices = [torch.device("cpu")] + if torch.cuda.is_available(): _available_devices.append(torch.device("cuda")) +if torch.backends.mps.is_available(): + # MPS is an Apple-specific device. Its connections to pytorch are still incomplete at this time. + # The environment variable + # PYTORCH_ENABLE_MPS_FALLBACK=1 + # should be set to use this device so that a cpu fallback can be used for missing operations. + _available_devices.append(torch.device("mps")) + @pytest.fixture(params=_available_devices) def device(request): @@ -51,6 +60,8 @@ def accelerator(device): return "cpu" elif str(device) == "cuda": return "gpu" + elif str(device) == "mps": + return "mps" else: raise ValueError("Wrong device") @@ -89,6 +100,15 @@ def number_of_atoms(self): """Number of atoms in fake data.""" return 8 + @pytest.fixture() + def num_atom_types(self): + """Number of types of atoms in fake data.""" + return 5 + + @pytest.fixture + def unique_elements(self, num_atom_types): + return [generate_random_string(size=3) for _ in range(num_atom_types)] + @pytest.fixture() def spatial_dimension(self): """Spatial dimension of fake data.""" @@ -96,11 +116,11 @@ def spatial_dimension(self): @pytest.fixture def train_configuration_runs( - self, number_of_train_runs, spatial_dimension, number_of_atoms + self, number_of_train_runs, spatial_dimension, number_of_atoms, unique_elements ): """Generate multiple fake 'data' runs and return their configurations.""" return get_configuration_runs( - number_of_train_runs, spatial_dimension, number_of_atoms + number_of_train_runs, spatial_dimension, number_of_atoms, unique_elements ) @pytest.fixture @@ -113,11 +133,11 @@ def all_train_configurations(self, train_configuration_runs): @pytest.fixture def valid_configuration_runs( - self, number_of_valid_runs, spatial_dimension, number_of_atoms + self, number_of_valid_runs, spatial_dimension, number_of_atoms, unique_elements ): """Generate multiple fake 'data' runs and return their configurations.""" return get_configuration_runs( - number_of_valid_runs, spatial_dimension, number_of_atoms + number_of_valid_runs, spatial_dimension, number_of_atoms, unique_elements ) @pytest.fixture diff --git a/tests/data/diffusion/test_data_loader.py b/tests/data/diffusion/test_data_loader.py index 7b01a3c9..a01297e0 100644 --- a/tests/data/diffusion/test_data_loader.py +++ b/tests/data/diffusion/test_data_loader.py @@ -1,19 +1,24 @@ from collections import defaultdict from typing import Dict, List +import numpy as np import pytest import torch from diffusion_for_multi_scale_molecular_dynamics.data.diffusion.data_loader import ( LammpsForDiffusionDataModule, LammpsLoaderParameters) +from diffusion_for_multi_scale_molecular_dynamics.data.element_types import ( + NULL_ELEMENT, ElementTypes) from diffusion_for_multi_scale_molecular_dynamics.namespace import ( - CARTESIAN_FORCES, CARTESIAN_POSITIONS, RELATIVE_COORDINATES) + ATOM_TYPES, CARTESIAN_FORCES, CARTESIAN_POSITIONS, RELATIVE_COORDINATES) from tests.conftest import TestDiffusionDataBase -from tests.fake_data_utils import Configuration, find_aligning_permutation +from tests.fake_data_utils import (Configuration, find_aligning_permutation, + generate_fake_configuration) def convert_configurations_to_dataset( configurations: List[Configuration], + element_types: ElementTypes, ) -> Dict[str, torch.Tensor]: """Convert the input configuration into a dict of torch tensors comparable to a pytorch dataset.""" # The expected dataset keys are {'natom', 'box', 'cartesian_positions', 'relative_positions', 'type', @@ -25,7 +30,7 @@ def convert_configurations_to_dataset( data[CARTESIAN_FORCES].append(configuration.cartesian_forces) data[CARTESIAN_POSITIONS].append(configuration.cartesian_positions) data[RELATIVE_COORDINATES].append(configuration.relative_coordinates) - data["type"].append(configuration.types) + data[ATOM_TYPES].append([element_types.get_element_id(element) for element in configuration.elements]) data["potential_energy"].append(configuration.potential_energy) configuration_dataset = dict() @@ -36,51 +41,73 @@ def convert_configurations_to_dataset( class TestDiffusionDataLoader(TestDiffusionDataBase): + @pytest.fixture - def input_data_to_transform(self): - return { - "natom": [2], # batch size of 1 - "box": [[1.0, 1.0, 1.0]], - CARTESIAN_POSITIONS: [ - [1.0, 2.0, 3, 4.0, 5, 6] - ], # for one batch, two atoms, 3D positions - CARTESIAN_FORCES: [ - [11.0, 12.0, 13, 14.0, 15, 16] - ], # for one batch, two atoms, 3D forces - RELATIVE_COORDINATES: [[1.0, 2.0, 3, 4.0, 5, 6]], - "type": [[1, 2]], - "potential_energy": [23.233], - } + def element_types(self, unique_elements): + return ElementTypes(unique_elements) + + @pytest.fixture() + def batch_size(self): + return 4 + + @pytest.fixture + def batch_of_configurations(self, spatial_dimension, number_of_atoms, unique_elements, batch_size): + return [generate_fake_configuration(spatial_dimension, number_of_atoms, unique_elements) + for _ in range(batch_size)] + + @pytest.fixture + def batched_input_data(self, batch_of_configurations): + data = defaultdict(list) + for configuration in batch_of_configurations: + data["natom"].append(len(configuration.ids)) + data["box"].append(configuration.cell_dimensions.astype(np.float32)) + data[CARTESIAN_FORCES].append(configuration.cartesian_forces.flatten().astype(np.float32)) + data[CARTESIAN_POSITIONS].append(configuration.cartesian_positions.flatten().astype(np.float32)) + data[RELATIVE_COORDINATES].append(configuration.relative_coordinates.flatten().astype(np.float32)) + data['element'].append(configuration.elements) + data["potential_energy"].append(configuration.potential_energy) + + return data - def test_dataset_transform(self, input_data_to_transform): - result = LammpsForDiffusionDataModule.dataset_transform(input_data_to_transform) + @pytest.fixture + def input_data_for_padding(self, batched_input_data): + row = dict() + for key, list_of_values in batched_input_data.items(): + row[key] = list_of_values[0] + return row + + def test_dataset_transform(self, batched_input_data, element_types, batch_size, number_of_atoms, spatial_dimension): + result = LammpsForDiffusionDataModule.dataset_transform(batched_input_data, element_types) # Check keys in result assert set(result.keys()) == { "natom", + ATOM_TYPES, CARTESIAN_FORCES, CARTESIAN_POSITIONS, RELATIVE_COORDINATES, "box", - "type", "potential_energy", } # Check tensor types and shapes assert torch.equal( - result["natom"], torch.tensor(input_data_to_transform["natom"]).long() + result["natom"], torch.tensor(batched_input_data["natom"]).long() ) assert result[CARTESIAN_POSITIONS].shape == ( - 1, - 2, - 3, - ) # (batchsize, natom, 3 [since it's 3D]) - assert result["box"].shape == (1, 3) - assert torch.equal( - result["type"], torch.tensor(input_data_to_transform["type"]).long() + batch_size, + number_of_atoms, + spatial_dimension, ) + assert result["box"].shape == (batch_size, spatial_dimension) + + element_ids = list(result[ATOM_TYPES].flatten().numpy()) + computed_element_names = [element_types.get_element(id) for id in element_ids] + expected_element_names = list(np.array(batched_input_data['element']).flatten()) + assert computed_element_names == expected_element_names + assert torch.equal( result["potential_energy"], - torch.tensor(input_data_to_transform["potential_energy"]), + torch.tensor(batched_input_data["potential_energy"]), ) # Check tensor types explicitly @@ -89,55 +116,37 @@ def test_dataset_transform(self, input_data_to_transform): result[CARTESIAN_POSITIONS].dtype == torch.float32 ) # default dtype for torch.as_tensor with float inputs assert result["box"].dtype == torch.float32 - assert result["type"].dtype == torch.long + assert result[ATOM_TYPES].dtype == torch.long assert result["potential_energy"].dtype == torch.float32 - @pytest.fixture - def input_data_to_pad(self): - return { - "natom": 2, # batch size of 1 - "box": [1.0, 1.0, 1.0], - CARTESIAN_POSITIONS: [ - 1.0, - 2.0, - 3, - 4.0, - 5, - 6, - ], # for one batch, two atoms, 3D positions - CARTESIAN_FORCES: [11.0, 12.0, 13, 14.0, 15, 16], - RELATIVE_COORDINATES: [1.0, 2.0, 3, 4.0, 5, 6], - "type": [1, 2], - "potential_energy": 23.233, - } + @pytest.fixture() + def max_atom_for_padding(self, number_of_atoms): + return number_of_atoms + 4 - def test_pad_dataset(self, input_data_to_pad): - max_atom = 5 # Assume we want to pad to a max of 5 atoms - padded_sample = LammpsForDiffusionDataModule.pad_samples( - input_data_to_pad, max_atom - ) + def test_pad_dataset(self, input_data_for_padding, number_of_atoms, max_atom_for_padding): + padded_sample = LammpsForDiffusionDataModule.pad_samples(input_data_for_padding, max_atom_for_padding) # Check if the type and position have been padded correctly - assert len(padded_sample["type"]) == max_atom - assert padded_sample[CARTESIAN_POSITIONS].shape == torch.Size([max_atom * 3]) + assert len(padded_sample["element"]) == max_atom_for_padding + assert padded_sample[CARTESIAN_POSITIONS].shape == torch.Size([max_atom_for_padding * 3]) - # Check that the padding uses -1 for type - # 2 atoms in the input_data - last 3 atoms should be type -1 - for k in range(max_atom - 2): - assert padded_sample["type"].tolist()[-(k + 1)] == -1 + # Check that the padding is correct + for k in range(number_of_atoms, max_atom_for_padding): + assert padded_sample["element"][k] == NULL_ELEMENT # Check that the padding uses nan for position assert torch.isnan( - padded_sample[CARTESIAN_POSITIONS][-(max_atom - 2) * 3:] + padded_sample[CARTESIAN_POSITIONS][3 * number_of_atoms:] ).all() @pytest.fixture - def data_module_hyperparameters(self, number_of_atoms, spatial_dimension): + def data_module_hyperparameters(self, number_of_atoms, spatial_dimension, unique_elements): return LammpsLoaderParameters( batch_size=2, num_workers=0, max_atom=number_of_atoms, spatial_dimension=spatial_dimension, + elements=unique_elements ) @pytest.fixture() @@ -155,19 +164,19 @@ def data_module(self, paths, data_module_hyperparameters, tmpdir): @pytest.fixture() def real_and_test_datasets( - self, mode, data_module, all_train_configurations, all_valid_configurations + self, mode, data_module, all_train_configurations, all_valid_configurations, element_types ): match mode: case "train": data_module_dataset = data_module.train_dataset[:] configuration_dataset = convert_configurations_to_dataset( - all_train_configurations + all_train_configurations, element_types ) case "valid": data_module_dataset = data_module.valid_dataset[:] configuration_dataset = convert_configurations_to_dataset( - all_valid_configurations + all_valid_configurations, element_types ) case _: raise ValueError(f"Unknown mode {mode}") @@ -178,7 +187,7 @@ def test_dataset_feature_names(self, data_module): expected_feature_names = { "natom", "box", - "type", + 'element', "potential_energy", CARTESIAN_FORCES, CARTESIAN_POSITIONS, diff --git a/tests/data/diffusion/test_data_preprocess.py b/tests/data/diffusion/test_data_preprocess.py index 3d3e262f..6684448f 100644 --- a/tests/data/diffusion/test_data_preprocess.py +++ b/tests/data/diffusion/test_data_preprocess.py @@ -56,7 +56,7 @@ def test_parse_lammps_run( expected_columns = [ "natom", "box", - "type", + "element", CARTESIAN_POSITIONS, CARTESIAN_FORCES, RELATIVE_COORDINATES, diff --git a/tests/data/diffusion/test_element_types.py b/tests/data/diffusion/test_element_types.py new file mode 100644 index 00000000..14dbfbeb --- /dev/null +++ b/tests/data/diffusion/test_element_types.py @@ -0,0 +1,57 @@ +import numpy as np +import pytest + +from diffusion_for_multi_scale_molecular_dynamics.data.element_types import ( + NULL_ELEMENT, NULL_ELEMENT_ID, ElementTypes) +from tests.fake_data_utils import generate_random_string + + +class TestElementTypes: + + @pytest.fixture() + def num_atom_types(self): + return 4 + + @pytest.fixture + def unique_elements(self, num_atom_types): + return [generate_random_string(size=3) for _ in range(num_atom_types)] + + @pytest.fixture + def bad_element(self): + return "this_is_a_bad_element" + + @pytest.fixture + def bad_element_id(self): + return 9999 + + @pytest.fixture + def element_types(self, unique_elements): + return ElementTypes(unique_elements) + + def test_number_of_atom_types(self, element_types, num_atom_types): + assert element_types.number_of_atom_types == num_atom_types + + def test_get_element_id(self, element_types, unique_elements): + assert element_types.get_element_id(NULL_ELEMENT) == NULL_ELEMENT_ID + + computed_element_ids = [element_types.get_element_id(element) for element in unique_elements] + assert len(np.unique(computed_element_ids)) == len(unique_elements) + + def test_get_element_id_bad_element(self, element_types, bad_element): + with pytest.raises(KeyError): + element_types.get_element_id(bad_element) + + def test_get_element(self, element_types, unique_elements): + assert element_types.get_element(NULL_ELEMENT_ID) == NULL_ELEMENT + + for element in unique_elements: + computed_element_id = element_types.get_element_id(element) + assert element == element_types.get_element(computed_element_id) + + def test_get_element_bad_element_id(self, element_types, bad_element_id): + with pytest.raises(KeyError): + element_types.get_element(bad_element_id) + + def test_validate_elements(self): + with pytest.raises(AssertionError): + ElementTypes.validate_elements(["A", "A", "B"]) diff --git a/tests/data/test_parse_lammps_output.py b/tests/data/test_parse_lammps_output.py index 8e337a4f..fbce855f 100644 --- a/tests/data/test_parse_lammps_output.py +++ b/tests/data/test_parse_lammps_output.py @@ -9,6 +9,7 @@ parse_lammps_dump, parse_lammps_output, parse_lammps_thermo_log) from tests.fake_data_utils import (create_dump_yaml_documents, generate_fake_configuration, + generate_fake_unique_elements, generate_parse_dump_output_dataframe, write_to_yaml) @@ -27,17 +28,17 @@ def fake_yaml_content(): # fake LAMMPS output file with 4 MD steps in 1D for 3 atoms np.random.seed(23423) box = [[0, 0.6], [0, 1.6], [0, 2.6]] - keywords = ["id", "type", "x", "y", "z", "fx", "fy", "fz"] + keywords = ["id", "element", "x", "y", "z", "fx", "fy", "fz"] number_of_documents = 4 - list_atom_types = [1, 2, 1] + list_elements = ['Ab', 'Cd', 'Ab'] yaml_content = [] for doc_idx in range(number_of_documents): data = [] - for id, atom_type in enumerate(list_atom_types): - row = [id, atom_type] + list(np.random.rand(6)) + for id, element in enumerate(list_elements): + row = [id, element] + list(np.random.rand(6)) data.append(row) doc = dict(keywords=keywords, box=box, data=data) @@ -143,13 +144,21 @@ def number_of_configurations(): return 16 +@pytest.fixture() +def num_unique_elements(): + return 5 + + @pytest.fixture -def configurations(number_of_configurations, spatial_dimension, number_of_atoms): +def configurations(number_of_configurations, spatial_dimension, number_of_atoms, num_unique_elements): """Generate multiple fake configurations.""" np.random.seed(23423423) + + unique_elements = generate_fake_unique_elements(num_unique_elements) + configurations = [ generate_fake_configuration( - spatial_dimension=spatial_dimension, number_of_atoms=number_of_atoms + spatial_dimension=spatial_dimension, number_of_atoms=number_of_atoms, unique_elements=unique_elements ) for _ in range(number_of_configurations) ] diff --git a/tests/fake_data_utils.py b/tests/fake_data_utils.py index 442fa5f1..909c19ca 100644 --- a/tests/fake_data_utils.py +++ b/tests/fake_data_utils.py @@ -1,3 +1,5 @@ +import random +import string from collections import namedtuple from typing import Any, Dict, List @@ -16,7 +18,7 @@ CARTESIAN_POSITIONS, CARTESIAN_FORCES, RELATIVE_COORDINATES, - "types", + "elements", "ids", "cell_dimensions", "potential_energy", @@ -26,12 +28,17 @@ ) -def generate_fake_configuration(spatial_dimension: int, number_of_atoms: int): +def generate_fake_unique_elements(num_elements: int): + return [generate_random_string(size=4) for _ in range(num_elements)] + + +def generate_fake_configuration(spatial_dimension: int, number_of_atoms: int, unique_elements: List[str]): """Generate fake configuration. Args: spatial_dimension : dimension of space. Should be 1, 2 or 3. number_of_atoms : how many atoms to generate. + unique_elements: distinct element types Returns: configuration: a configuration object with all the data describing a configuration. @@ -53,7 +60,7 @@ def generate_fake_configuration(spatial_dimension: int, number_of_atoms: int): relative_coordinates=relative_coordinates, cartesian_positions=positions, cartesian_forces=np.random.rand(number_of_atoms, spatial_dimension), - types=np.random.randint(1, 10, number_of_atoms), + elements=np.random.choice(unique_elements, number_of_atoms), ids=np.arange(1, number_of_atoms + 1), cell_dimensions=cell_dimensions, potential_energy=potential_energy, @@ -62,14 +69,14 @@ def generate_fake_configuration(spatial_dimension: int, number_of_atoms: int): ) -def get_configuration_runs(number_of_runs, spatial_dimension, number_of_atoms): +def get_configuration_runs(number_of_runs, spatial_dimension, number_of_atoms, unique_elements): """Generate multiple random configuration runs, each composed of many different configurations.""" list_configurations = [] for _ in range(number_of_runs): number_of_configs = np.random.randint(1, 16) configurations = [ generate_fake_configuration( - spatial_dimension=spatial_dimension, number_of_atoms=number_of_atoms + spatial_dimension=spatial_dimension, number_of_atoms=number_of_atoms, unique_elements=unique_elements ) for _ in range(number_of_configs) ] @@ -94,7 +101,7 @@ def generate_parse_dump_output_dataframe( row = dict( box=configuration.cell_dimensions, id=list(configuration.ids), - type=list(configuration.types), + element=list(configuration.elements), ) for coordinates, name in zip( configuration.cartesian_positions.transpose(), ["x", "y", "z"] @@ -118,8 +125,8 @@ def create_dump_single_record( box = [[0, float(dimension)] for dimension in configuration.cell_dimensions] - # keywords should be of the form : [id, type, x, y, z, fx, fy, fz, ] - keywords = ["id", "type"] + # keywords should be of the form : [id, element, x, y, z, fx, fy, fz, ] + keywords = ["id", "element"] for direction, _ in zip(["x", "y", "z"], range(spatial_dimension)): keywords.append(direction) @@ -130,14 +137,14 @@ def create_dump_single_record( # Each row of data should be a list in the same order as the keywords data = [] - for id, type, position, force in zip( + for id, element, position, force in zip( configuration.ids, - configuration.types, + configuration.elements, configuration.cartesian_positions, configuration.cartesian_forces, ): row = ( - [int(id), int(type)] + [int(id), element] + [float(p) for p in position] + [float(f) for f in force] ) @@ -227,7 +234,7 @@ def generate_parquet_dataframe(configurations: List[Configuration]) -> pd.DataFr row = dict( natom=number_of_atoms, box=box, - type=configuration.types, + element=configuration.elements, potential_energy=configuration.potential_energy, cartesian_positions=positions, relative_coordinates=relative_positions, @@ -264,3 +271,8 @@ def find_aligning_permutation( permutation_indices = matching_indices[:, 1] return permutation_indices + + +def generate_random_string(size: int): + chars = string.ascii_uppercase + string.ascii_lowercase + return ''.join(random.choice(chars) for _ in range(size)) diff --git a/tests/generators/conftest.py b/tests/generators/conftest.py index 7699d60d..572e27c4 100644 --- a/tests/generators/conftest.py +++ b/tests/generators/conftest.py @@ -5,17 +5,25 @@ from diffusion_for_multi_scale_molecular_dynamics.models.score_networks import ( ScoreNetwork, ScoreNetworkParameters) -from diffusion_for_multi_scale_molecular_dynamics.namespace import \ - NOISY_RELATIVE_COORDINATES +from diffusion_for_multi_scale_molecular_dynamics.namespace import ( + AXL, NOISY_AXL_COMPOSITION) +from diffusion_for_multi_scale_molecular_dynamics.utils.d3pm_utils import \ + class_index_to_onehot -class FakeScoreNetwork(ScoreNetwork): - """A fake, smooth score network for the ODE solver.""" +class FakeAXLNetwork(ScoreNetwork): + """A fake score network for tests.""" def _forward_unchecked( self, batch: Dict[AnyStr, torch.Tensor], conditional: bool = False - ) -> torch.Tensor: - return batch[NOISY_RELATIVE_COORDINATES] + ) -> AXL: + return AXL( + A=class_index_to_onehot( + batch[NOISY_AXL_COMPOSITION].A, num_classes=self.num_atom_types + 1 + ), + X=batch[NOISY_AXL_COMPOSITION].X, + L=None, + ) class BaseTestGenerator: @@ -38,19 +46,25 @@ def spatial_dimension(self, request): return request.param @pytest.fixture() - def unit_cell_sample(self, unit_cell_size, spatial_dimension, number_of_samples): + def num_atom_types(self): + return 6 + + @pytest.fixture() + def unit_cell_sample(self, unit_cell_size, spatial_dimension, number_of_samples, device): return torch.diag(torch.Tensor([unit_cell_size] * spatial_dimension)).repeat( number_of_samples, 1, 1 - ) + ).to(device) @pytest.fixture() def cell_dimensions(self, unit_cell_size, spatial_dimension): return spatial_dimension * [unit_cell_size] @pytest.fixture() - def sigma_normalized_score_network(self, spatial_dimension): - return FakeScoreNetwork( + def axl_network(self, spatial_dimension, num_atom_types): + return FakeAXLNetwork( ScoreNetworkParameters( - architecture="dummy", spatial_dimension=spatial_dimension + architecture="dummy", + spatial_dimension=spatial_dimension, + num_atom_types=num_atom_types, ) ) diff --git a/tests/generators/test_constrained_langevin_generator.py b/tests/generators/test_constrained_langevin_generator.py index d1aa431a..59f2bb6d 100644 --- a/tests/generators/test_constrained_langevin_generator.py +++ b/tests/generators/test_constrained_langevin_generator.py @@ -4,6 +4,7 @@ from diffusion_for_multi_scale_molecular_dynamics.generators.constrained_langevin_generator import ( ConstrainedLangevinGenerator, ConstrainedLangevinGeneratorParameters) +from diffusion_for_multi_scale_molecular_dynamics.namespace import AXL from tests.generators.test_langevin_generator import TestLangevinGenerator @@ -24,6 +25,7 @@ def sampling_parameters( number_of_corrector_steps, unit_cell_size, constrained_relative_coordinates, + num_atom_types, ): sampling_parameters = ConstrainedLangevinGeneratorParameters( number_of_corrector_steps=number_of_corrector_steps, @@ -32,34 +34,50 @@ def sampling_parameters( cell_dimensions=cell_dimensions, spatial_dimension=spatial_dimension, constrained_relative_coordinates=constrained_relative_coordinates, + num_atom_types=num_atom_types, ) return sampling_parameters @pytest.fixture() - def pc_generator( - self, noise_parameters, sampling_parameters, sigma_normalized_score_network - ): + def pc_generator(self, noise_parameters, sampling_parameters, axl_network): generator = ConstrainedLangevinGenerator( noise_parameters=noise_parameters, sampling_parameters=sampling_parameters, - sigma_normalized_score_network=sigma_normalized_score_network, + axl_network=axl_network, ) return generator @pytest.fixture() - def x(self, number_of_samples, number_of_atoms, spatial_dimension, device): - return torch.rand(number_of_samples, number_of_atoms, spatial_dimension).to( - device + def axl( + self, + number_of_samples, + number_of_atoms, + spatial_dimension, + num_atom_types, + device, + ): + return AXL( + A=torch.randint( + 0, num_atom_types + 1, (number_of_samples, number_of_atoms) + ).to(device), + X=torch.rand(number_of_samples, number_of_atoms, spatial_dimension).to( + device + ), + L=torch.rand( + number_of_samples, spatial_dimension * (spatial_dimension - 1) + ).to( + device + ), # TODO placeholder ) def test_apply_constraint( - self, pc_generator, x, constrained_relative_coordinates, device + self, pc_generator, axl, constrained_relative_coordinates, device ): - batch_size = x.shape[0] - original_x = torch.clone(x) - pc_generator._apply_constraint(x, device) + batch_size = axl.X.shape[0] + original_x = torch.clone(axl.X) + pc_generator._apply_constraint(axl, device) number_of_constraints = len(constrained_relative_coordinates) @@ -69,7 +87,7 @@ def test_apply_constraint( b=batch_size, ) - torch.testing.assert_close(x[:, :number_of_constraints], constrained_x) + torch.testing.assert_close(axl.X[:, :number_of_constraints], constrained_x) torch.testing.assert_close( - x[:, number_of_constraints:], original_x[:, number_of_constraints:] + axl.X[:, number_of_constraints:], original_x[:, number_of_constraints:] ) diff --git a/tests/generators/test_langevin_generator.py b/tests/generators/test_langevin_generator.py index c1a4c3d0..0f66cb3c 100644 --- a/tests/generators/test_langevin_generator.py +++ b/tests/generators/test_langevin_generator.py @@ -1,37 +1,69 @@ +import einops import pytest import torch from diffusion_for_multi_scale_molecular_dynamics.generators.langevin_generator import \ LangevinGenerator -from diffusion_for_multi_scale_molecular_dynamics.generators.predictor_corrector_position_generator import \ +from diffusion_for_multi_scale_molecular_dynamics.generators.predictor_corrector_axl_generator import \ PredictorCorrectorSamplingParameters +from diffusion_for_multi_scale_molecular_dynamics.namespace import AXL +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ + NoiseParameters from diffusion_for_multi_scale_molecular_dynamics.utils.basis_transformations import \ map_relative_coordinates_to_unit_cell -from src.diffusion_for_multi_scale_molecular_dynamics.samplers.variance_sampler import ( - ExplodingVarianceSampler, NoiseParameters) +from src.diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_scheduler import \ + NoiseScheduler from tests.generators.conftest import BaseTestGenerator class TestLangevinGenerator(BaseTestGenerator): - @pytest.fixture(params=[0, 1, 2]) + @pytest.fixture() + def num_atom_types(self): + return 4 + + @pytest.fixture() + def num_atomic_classes(self, num_atom_types): + return num_atom_types + 1 + + @pytest.fixture(params=[0, 2]) def number_of_corrector_steps(self, request): return request.param - @pytest.fixture(params=[1, 5, 10]) + @pytest.fixture(params=[2, 5, 10]) def total_time_steps(self, request): return request.param @pytest.fixture() - def noise_parameters(self, total_time_steps): + def sigma_min(self): + return 0.15 + + @pytest.fixture() + def noise_parameters(self, total_time_steps, sigma_min): noise_parameters = NoiseParameters( total_time_steps=total_time_steps, time_delta=0.1, - sigma_min=0.15, + sigma_min=sigma_min, corrector_step_epsilon=0.25, ) return noise_parameters + @pytest.fixture() + def small_epsilon(self): + return 1e-6 + + @pytest.fixture(params=[True, False]) + def one_atom_type_transition_per_step(self, request): + return request.param + + @pytest.fixture(params=[True, False]) + def atom_type_greedy_sampling(self, request): + return request.param + + @pytest.fixture(params=[True, False]) + def atom_type_transition_in_corrector(self, request): + return request.param + @pytest.fixture() def sampling_parameters( self, @@ -41,6 +73,11 @@ def sampling_parameters( number_of_samples, number_of_corrector_steps, unit_cell_size, + num_atom_types, + one_atom_type_transition_per_step, + atom_type_greedy_sampling, + atom_type_transition_in_corrector, + small_epsilon, ): sampling_parameters = PredictorCorrectorSamplingParameters( number_of_corrector_steps=number_of_corrector_steps, @@ -48,18 +85,29 @@ def sampling_parameters( number_of_samples=number_of_samples, cell_dimensions=cell_dimensions, spatial_dimension=spatial_dimension, + num_atom_types=num_atom_types, + one_atom_type_transition_per_step=one_atom_type_transition_per_step, + atom_type_greedy_sampling=atom_type_greedy_sampling, + atom_type_transition_in_corrector=atom_type_transition_in_corrector, + small_epsilon=small_epsilon, ) return sampling_parameters @pytest.fixture() - def pc_generator( - self, noise_parameters, sampling_parameters, sigma_normalized_score_network - ): + def noise(self, noise_parameters, num_atomic_classes, device): + sampler = NoiseScheduler(noise_parameters, num_classes=num_atomic_classes).to( + device + ) + noise, _ = sampler.get_all_sampling_parameters() + return noise + + @pytest.fixture() + def pc_generator(self, noise_parameters, sampling_parameters, axl_network): generator = LangevinGenerator( noise_parameters=noise_parameters, sampling_parameters=sampling_parameters, - sigma_normalized_score_network=sigma_normalized_score_network, + axl_network=axl_network, ) return generator @@ -71,35 +119,49 @@ def test_smoke_sample( pc_generator.sample(number_of_samples, device, unit_cell_sample) @pytest.fixture() - def x_i(self, number_of_samples, number_of_atoms, spatial_dimension, device): - return map_relative_coordinates_to_unit_cell( - torch.rand(number_of_samples, number_of_atoms, spatial_dimension) - ).to(device) + def axl_i( + self, + number_of_samples, + number_of_atoms, + spatial_dimension, + num_atomic_classes, + device, + ): + return AXL( + A=torch.randint( + 0, num_atomic_classes, (number_of_samples, number_of_atoms) + ).to(device), + X=map_relative_coordinates_to_unit_cell( + torch.rand(number_of_samples, number_of_atoms, spatial_dimension) + ).to(device), + L=torch.zeros( + number_of_samples, spatial_dimension * (spatial_dimension - 1) + ).to( + device + ), # TODO placeholder + ) - def test_predictor_step( + def test_predictor_step_relative_coordinates( self, mocker, pc_generator, - noise_parameters, - x_i, + noise, + sigma_min, + axl_i, total_time_steps, number_of_samples, unit_cell_sample, ): - - sampler = ExplodingVarianceSampler(noise_parameters) - noise, _ = sampler.get_all_sampling_parameters() - sigma_min = noise_parameters.sigma_min list_sigma = noise.sigma list_time = noise.time - forces = torch.zeros_like(x_i) + forces = torch.zeros_like(axl_i.X) - z = pc_generator._draw_gaussian_sample(number_of_samples).to(x_i) + z = pc_generator._draw_gaussian_sample(number_of_samples).to(axl_i.X) mocker.patch.object(pc_generator, "_draw_gaussian_sample", return_value=z) for index_i in range(1, total_time_steps + 1): computed_sample = pc_generator.predictor_step( - x_i, index_i, unit_cell_sample, forces + axl_i, index_i, unit_cell_sample, forces ) sigma_i = list_sigma[index_i - 1] @@ -112,42 +174,339 @@ def test_predictor_step( g2 = sigma_i**2 - sigma_im1**2 s_i = ( - pc_generator._get_sigma_normalized_scores( - x_i, t_i, sigma_i, unit_cell_sample, forces - ) + pc_generator._get_model_predictions( + axl_i, t_i, sigma_i, unit_cell_sample, forces + ).X / sigma_i ) - expected_sample = x_i + g2 * s_i + torch.sqrt(g2) * z + expected_coordinates = axl_i.X + g2 * s_i + torch.sqrt(g2) * z + expected_coordinates = map_relative_coordinates_to_unit_cell( + expected_coordinates + ) + + torch.testing.assert_close(computed_sample.X, expected_coordinates) + + def test_adjust_atom_types_probabilities_for_greedy_sampling( + self, pc_generator, number_of_atoms, num_atomic_classes + ): + # Test that all_masked atom types are unaffected. + fully_masked_row = pc_generator.masked_atom_type_index * torch.ones( + number_of_atoms, dtype=torch.int64 + ) + + partially_unmasked_row = fully_masked_row.clone() + partially_unmasked_row[0] = 0 + + atom_types_i = torch.stack([fully_masked_row, partially_unmasked_row]) + + number_of_samples = atom_types_i.shape[0] + u = pc_generator._draw_gumbel_sample(number_of_samples) + + one_step_transition_probs = torch.rand( + number_of_samples, number_of_atoms, num_atomic_classes + ).softmax(dim=-1) + # Use cloned values because the method overrides the inputs. + updated_one_step_transition_probs, updated_u = ( + pc_generator._adjust_atom_types_probabilities_for_greedy_sampling( + one_step_transition_probs.clone(), atom_types_i, u.clone() + ) + ) + + # Test that the fully masked row is unaffected + torch.testing.assert_close( + updated_one_step_transition_probs[0], one_step_transition_probs[0] + ) + torch.testing.assert_close(u[0], updated_u[0]) + + # Test that when an atom is unmasked, the probabilities are set up for greedy sampling: + # - the probabilities for the real atomic classes are unchanged. + # - the probability for the MASK class (last index) is either unchanged or set to zero. + # - the Gumbel sample is set to zero so that the unmasking is greedy. + + torch.testing.assert_close( + updated_one_step_transition_probs[1, :, :-1], + one_step_transition_probs[1, :, :-1], + ) + + m1 = ( + updated_one_step_transition_probs[1, :, -1] + == one_step_transition_probs[1, :, -1] + ) + m2 = updated_one_step_transition_probs[1, :, -1] == 0.0 + assert torch.logical_or(m1, m2).all() + torch.testing.assert_close(updated_u[1], torch.zeros_like(updated_u[1])) + + def test_get_updated_atom_types_for_one_transition_per_step_is_idempotent( + self, + pc_generator, + number_of_samples, + number_of_atoms, + num_atomic_classes, + device, + ): + # Test that the method returns the current atom types if there is no proposed changes. + current_atom_types = torch.randint( + 0, num_atomic_classes, (number_of_samples, number_of_atoms) + ).to(device) + sampled_atom_types = current_atom_types.clone() + max_gumbel_values = torch.rand(number_of_samples, number_of_atoms).to(device) + + updated_atom_types = ( + pc_generator._get_updated_atom_types_for_one_transition_per_step( + current_atom_types, max_gumbel_values, sampled_atom_types + ) + ) + + torch.testing.assert_close(updated_atom_types, current_atom_types) + + def test_get_updated_atom_types_for_one_transition_per_step( + self, + pc_generator, + number_of_samples, + number_of_atoms, + num_atomic_classes, + device, + ): + assert ( + num_atomic_classes > 0 + ), "Cannot run this test with a single atomic class." + current_atom_types = torch.randint( + 0, num_atomic_classes, (number_of_samples, number_of_atoms) + ).to(device) + sampled_atom_types = torch.randint( + 0, num_atomic_classes, (number_of_samples, number_of_atoms) + ).to(device) + # Make sure at least one atom is different in every sample. + while not (current_atom_types != sampled_atom_types).any(dim=-1).all(): + sampled_atom_types = torch.randint( + 0, num_atomic_classes, (number_of_samples, number_of_atoms) + ).to(device) + + proposed_difference_mask = current_atom_types != sampled_atom_types + + max_gumbel_values = torch.rand(number_of_samples, number_of_atoms).to(device) + + updated_atom_types = ( + pc_generator._get_updated_atom_types_for_one_transition_per_step( + current_atom_types, max_gumbel_values, sampled_atom_types + ) + ) + + difference_mask = updated_atom_types != current_atom_types + + # Check that there is a single difference per sample + number_of_changes = difference_mask.sum(dim=-1) + torch.testing.assert_close( + number_of_changes, torch.ones(number_of_samples).to(number_of_changes) + ) + + # Check that the difference is at the location of the maximum value of the Gumbel random variable over the + # possible changes. + computed_changed_atom_indices = torch.where(difference_mask)[1] + + expected_changed_atom_indices = [] + for sample_idx in range(number_of_samples): + sample_gumbel_values = max_gumbel_values[sample_idx].clone() + sample_proposed_difference_mask = proposed_difference_mask[sample_idx] + sample_gumbel_values[~sample_proposed_difference_mask] = -torch.inf + max_index = torch.argmax(sample_gumbel_values) + expected_changed_atom_indices.append(max_index) + expected_changed_atom_indices = torch.tensor(expected_changed_atom_indices).to( + computed_changed_atom_indices + ) + + torch.testing.assert_close( + computed_changed_atom_indices, expected_changed_atom_indices + ) + + def test_atom_types_update( + self, + pc_generator, + noise, + total_time_steps, + num_atomic_classes, + number_of_samples, + number_of_atoms, + device, + ): + + # Initialize to fully masked + a_i = pc_generator.masked_atom_type_index * torch.ones( + number_of_samples, number_of_atoms, dtype=torch.int64 + ).to(device) + + for time_index_i in range(total_time_steps, 0, -1): + this_is_last_time_step = time_index_i == 1 + idx = time_index_i - 1 + q_matrices_i = einops.repeat( + noise.q_matrix[idx], + "n1 n2 -> nsamples natoms n1 n2", + nsamples=number_of_samples, + natoms=number_of_atoms, + ) + + q_bar_matrices_i = einops.repeat( + noise.q_bar_matrix[idx], + "n1 n2 -> nsamples natoms n1 n2", + nsamples=number_of_samples, + natoms=number_of_atoms, + ) + + q_bar_tm1_matrices_i = einops.repeat( + noise.q_bar_tm1_matrix[idx], + "n1 n2 -> nsamples natoms n1 n2", + nsamples=number_of_samples, + natoms=number_of_atoms, + ) + + random_logits = torch.rand( + number_of_samples, number_of_atoms, num_atomic_classes + ).to(device) + random_logits[:, :, -1] = -torch.inf + + one_atom_type_transition_per_step = ( + pc_generator.one_atom_type_transition_per_step + and not this_is_last_time_step + ) + + a_im1 = pc_generator._atom_types_update( + random_logits, + a_i, + q_matrices_i, + q_bar_matrices_i, + q_bar_tm1_matrices_i, + atom_type_greedy_sampling=pc_generator.atom_type_greedy_sampling, + one_atom_type_transition_per_step=one_atom_type_transition_per_step, + ) + + difference_mask = a_im1 != a_i + + # Test that the changes are from MASK to not-MASK + assert (a_i[difference_mask] == pc_generator.masked_atom_type_index).all() + assert (a_im1[difference_mask] != pc_generator.masked_atom_type_index).all() + + if one_atom_type_transition_per_step: + # Test that there is at most one change + assert torch.all(difference_mask.sum(dim=-1) <= 1.0) + + if pc_generator.atom_type_greedy_sampling: + # Test that the changes are the most probable (greedy) + sample_indices, atom_indices = torch.where(difference_mask) + for sample_idx, atom_idx in zip(sample_indices, atom_indices): + # Greedy sampling only applies if at least one atom was already unmasked. + if (a_i[sample_idx] == pc_generator.masked_atom_type_index).all(): + continue + computed_atom_type = a_im1[sample_idx, atom_idx] + expected_atom_type = random_logits[sample_idx, atom_idx].argmax() + assert computed_atom_type == expected_atom_type + + a_i = a_im1 + + # Test that no MASKED states remain + assert not (a_i == pc_generator.masked_atom_type_index).any() + + def test_predictor_step_atom_types( + self, + mocker, + pc_generator, + total_time_steps, + number_of_samples, + number_of_atoms, + num_atomic_classes, + spatial_dimension, + unit_cell_sample, + device, + ): + zeros = torch.zeros(number_of_samples, number_of_atoms, spatial_dimension).to( + device + ) + forces = zeros + + random_x = map_relative_coordinates_to_unit_cell( + torch.rand(number_of_samples, number_of_atoms, spatial_dimension) + ).to(device) + + random_l = torch.zeros( + number_of_samples, spatial_dimension, spatial_dimension + ).to(device) + + # Initialize to fully masked + a_ip1 = pc_generator.masked_atom_type_index * torch.ones( + number_of_samples, number_of_atoms, dtype=torch.int64 + ).to(device) + axl_ip1 = AXL(A=a_ip1, X=random_x, L=random_l) + + for idx in range(total_time_steps - 1, -1, -1): + + # Inject reasonable logits + logits = torch.rand( + number_of_samples, number_of_atoms, num_atomic_classes + ).to(device) + logits[:, :, -1] = -torch.inf + fake_model_predictions = AXL(A=logits, X=zeros, L=zeros) + mocker.patch.object( + pc_generator, + "_get_model_predictions", + return_value=fake_model_predictions, + ) + + axl_i = pc_generator.predictor_step( + axl_ip1, idx + 1, unit_cell_sample, forces + ) + + this_is_last_time_step = idx == 0 + a_i = axl_i.A + a_ip1 = axl_ip1.A + + difference_mask = a_ip1 != a_i + + # Test that the changes are from MASK to not-MASK + assert (a_ip1[difference_mask] == pc_generator.masked_atom_type_index).all() + assert (a_i[difference_mask] != pc_generator.masked_atom_type_index).all() + + one_atom_type_transition_per_step = ( + pc_generator.one_atom_type_transition_per_step + and not this_is_last_time_step + ) + + if one_atom_type_transition_per_step: + # Test that there is at most one change + assert torch.all(difference_mask.sum(dim=-1) <= 1.0) + + axl_ip1 = AXL(A=a_i, X=random_x, L=random_l) - torch.testing.assert_close(computed_sample, expected_sample) + # Test that no MASKED states remain + a_i = axl_i.A + assert not (a_i == pc_generator.masked_atom_type_index).any() def test_corrector_step( self, mocker, pc_generator, noise_parameters, - x_i, + axl_i, total_time_steps, number_of_samples, unit_cell_sample, + num_atomic_classes, ): - sampler = ExplodingVarianceSampler(noise_parameters) + sampler = NoiseScheduler(noise_parameters, num_classes=num_atomic_classes) noise, _ = sampler.get_all_sampling_parameters() sigma_min = noise_parameters.sigma_min epsilon = noise_parameters.corrector_step_epsilon list_sigma = noise.sigma list_time = noise.time sigma_1 = list_sigma[0] - forces = torch.zeros_like(x_i) + forces = torch.zeros_like(axl_i.X) - z = pc_generator._draw_gaussian_sample(number_of_samples).to(x_i) + z = pc_generator._draw_gaussian_sample(number_of_samples).to(axl_i.X) mocker.patch.object(pc_generator, "_draw_gaussian_sample", return_value=z) for index_i in range(0, total_time_steps): computed_sample = pc_generator.corrector_step( - x_i, index_i, unit_cell_sample, forces + axl_i, index_i, unit_cell_sample, forces ) if index_i == 0: @@ -160,12 +519,37 @@ def test_corrector_step( eps_i = 0.5 * epsilon * sigma_i**2 / sigma_1**2 s_i = ( - pc_generator._get_sigma_normalized_scores( - x_i, t_i, sigma_i, unit_cell_sample, forces - ) + pc_generator._get_model_predictions( + axl_i, t_i, sigma_i, unit_cell_sample, forces + ).X / sigma_i ) - expected_sample = x_i + eps_i * s_i + torch.sqrt(2.0 * eps_i) * z + expected_coordinates = axl_i.X + eps_i * s_i + torch.sqrt(2.0 * eps_i) * z + expected_coordinates = map_relative_coordinates_to_unit_cell( + expected_coordinates + ) + + torch.testing.assert_close(computed_sample.X, expected_coordinates) + + if pc_generator.atom_type_transition_in_corrector: + a_i = axl_i.A + corrected_a_i = computed_sample.A - torch.testing.assert_close(computed_sample, expected_sample) + difference_mask = corrected_a_i != a_i + + # Test that the changes are from MASK to not-MASK + assert ( + a_i[difference_mask] == pc_generator.masked_atom_type_index + ).all() + assert ( + corrected_a_i[difference_mask] + != pc_generator.masked_atom_type_index + ).all() + + if pc_generator.one_atom_type_transition_per_step: + # Test that there is at most one change + assert torch.all(difference_mask.sum(dim=-1) <= 1.0) + + else: + assert torch.all(computed_sample.A == axl_i.A) diff --git a/tests/generators/test_ode_position_generator.py b/tests/generators/test_ode_position_generator.py index 711ad09e..49899eae 100644 --- a/tests/generators/test_ode_position_generator.py +++ b/tests/generators/test_ode_position_generator.py @@ -2,9 +2,11 @@ import torch from diffusion_for_multi_scale_molecular_dynamics.generators.ode_position_generator import ( - ExplodingVarianceODEPositionGenerator, ODESamplingParameters) -from src.diffusion_for_multi_scale_molecular_dynamics.samplers.variance_sampler import ( - ExplodingVarianceSampler, NoiseParameters) + ExplodingVarianceODEAXLGenerator, ODESamplingParameters) +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ + NoiseParameters +from src.diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_scheduler import \ + NoiseScheduler from tests.generators.conftest import BaseTestGenerator @@ -12,7 +14,7 @@ @pytest.mark.parametrize("sigma_min", [0.15]) @pytest.mark.parametrize("record_samples", [False, True]) @pytest.mark.parametrize("number_of_samples", [8]) -class TestExplodingVarianceODEPositionGenerator(BaseTestGenerator): +class TestExplodingVarianceODEAXLGenerator(BaseTestGenerator): @pytest.fixture() def noise_parameters(self, total_time_steps, sigma_min): @@ -28,6 +30,7 @@ def sampling_parameters( cell_dimensions, number_of_samples, record_samples, + num_atom_types, ): sampling_parameters = ODESamplingParameters( number_of_atoms=number_of_atoms, @@ -35,38 +38,37 @@ def sampling_parameters( number_of_samples=number_of_samples, cell_dimensions=cell_dimensions, record_samples=record_samples, + num_atom_types=num_atom_types, ) return sampling_parameters @pytest.fixture() def ode_generator( - self, noise_parameters, sampling_parameters, sigma_normalized_score_network + self, + noise_parameters, + sampling_parameters, + axl_network, ): - generator = ExplodingVarianceODEPositionGenerator( + generator = ExplodingVarianceODEAXLGenerator( noise_parameters=noise_parameters, sampling_parameters=sampling_parameters, - sigma_normalized_score_network=sigma_normalized_score_network, + axl_network=axl_network, ) return generator - def test_get_exploding_variance_sigma(self, ode_generator, noise_parameters): - times = ExplodingVarianceSampler._get_time_array(noise_parameters) - expected_sigmas = ExplodingVarianceSampler._create_sigma_array( - noise_parameters, times - ) - computed_sigmas = ode_generator._get_exploding_variance_sigma(times) - torch.testing.assert_close(expected_sigmas, computed_sigmas) - def test_get_ode_prefactor(self, ode_generator, noise_parameters): - times = ExplodingVarianceSampler._get_time_array(noise_parameters) - sigmas = ode_generator._get_exploding_variance_sigma(times) + times = NoiseScheduler._get_time_array(noise_parameters) + sigmas = ( + noise_parameters.sigma_min ** (1.0 - times) + * noise_parameters.sigma_max**times + ) sig_ratio = torch.tensor( noise_parameters.sigma_max / noise_parameters.sigma_min ) expected_ode_prefactor = torch.log(sig_ratio) * sigmas - computed_ode_prefactor = ode_generator._get_ode_prefactor(sigmas) + computed_ode_prefactor = ode_generator._get_ode_prefactor(times) torch.testing.assert_close(expected_ode_prefactor, computed_ode_prefactor) def test_smoke_sample( @@ -79,15 +81,13 @@ def test_smoke_sample( unit_cell_sample, ): # Just a smoke test that we can sample without crashing. - relative_coordinates = ode_generator.sample( - number_of_samples, device, unit_cell_sample - ) + sampled_axl = ode_generator.sample(number_of_samples, device, unit_cell_sample) - assert relative_coordinates.shape == ( + assert sampled_axl.X.shape == ( number_of_samples, number_of_atoms, spatial_dimension, ) - assert relative_coordinates.min() >= 0.0 - assert relative_coordinates.max() < 1.0 + assert sampled_axl.X.min() >= 0.0 + assert sampled_axl.X.max() < 1.0 diff --git a/tests/generators/test_predictor_corrector_position_generator.py b/tests/generators/test_predictor_corrector_position_generator.py index 49dd5f0f..8b9a0a34 100644 --- a/tests/generators/test_predictor_corrector_position_generator.py +++ b/tests/generators/test_predictor_corrector_position_generator.py @@ -1,14 +1,15 @@ import pytest import torch -from diffusion_for_multi_scale_molecular_dynamics.generators.predictor_corrector_position_generator import \ - PredictorCorrectorPositionGenerator -from diffusion_for_multi_scale_molecular_dynamics.utils.basis_transformations import \ - map_relative_coordinates_to_unit_cell +from diffusion_for_multi_scale_molecular_dynamics.generators.predictor_corrector_axl_generator import \ + PredictorCorrectorAXLGenerator +from diffusion_for_multi_scale_molecular_dynamics.namespace import AXL +from diffusion_for_multi_scale_molecular_dynamics.utils.basis_transformations import ( + map_axl_composition_to_unit_cell, map_relative_coordinates_to_unit_cell) from tests.generators.conftest import BaseTestGenerator -class FakePCGenerator(PredictorCorrectorPositionGenerator): +class FakePCGenerator(PredictorCorrectorAXLGenerator): """A dummy PC generator for the purpose of testing.""" def __init__( @@ -16,32 +17,50 @@ def __init__( number_of_discretization_steps: int, number_of_corrector_steps: int, spatial_dimension: int, + num_atom_types: int, initial_sample: torch.Tensor, ): super().__init__( - number_of_discretization_steps, number_of_corrector_steps, spatial_dimension + number_of_discretization_steps, + number_of_corrector_steps, + spatial_dimension, + num_atom_types, ) self.initial_sample = initial_sample - def initialize(self, number_of_samples: int): + def initialize( + self, number_of_samples: int, device: torch.device = torch.device("cpu") + ): return self.initial_sample def predictor_step( self, - x_ip1: torch.Tensor, + axl_ip1: AXL, ip1: int, unit_cell: torch.Tensor, forces: torch.Tensor, ) -> torch.Tensor: - return 1.2 * x_ip1 + 3.4 + ip1 / 111.0 + updated_axl = AXL( + A=axl_ip1.A, + X=map_relative_coordinates_to_unit_cell( + 1.2 * axl_ip1.X + 3.4 + ip1 / 111.0 + ), + L=axl_ip1.L, + ) + return updated_axl def corrector_step( - self, x_i: torch.Tensor, i: int, unit_cell: torch.Tensor, forces: torch.Tensor + self, axl_i: torch.Tensor, i: int, unit_cell: torch.Tensor, forces: torch.Tensor ) -> torch.Tensor: - return 0.56 * x_i + 7.89 + i / 117.0 + updated_axl = AXL( + A=axl_i.A, + X=map_relative_coordinates_to_unit_cell(0.56 * axl_i.X + 7.89 + i / 117.0), + L=axl_i.L, + ) + return updated_axl -@pytest.mark.parametrize("number_of_discretization_steps", [1, 5, 10]) +@pytest.mark.parametrize("number_of_discretization_steps", [2, 5, 10]) @pytest.mark.parametrize("number_of_corrector_steps", [0, 1, 2]) class TestPredictorCorrectorPositionGenerator(BaseTestGenerator): @pytest.fixture(scope="class", autouse=True) @@ -49,8 +68,18 @@ def set_random_seed(self): torch.manual_seed(1234567) @pytest.fixture - def initial_sample(self, number_of_samples, number_of_atoms, spatial_dimension): - return torch.rand(number_of_samples, number_of_atoms, spatial_dimension) + def initial_sample( + self, number_of_samples, number_of_atoms, spatial_dimension, num_atom_types + ): + return AXL( + A=torch.randint( + 0, num_atom_types + 1, (number_of_samples, number_of_atoms) + ), + X=torch.rand(number_of_samples, number_of_atoms, spatial_dimension), + L=torch.rand( + number_of_samples, spatial_dimension * (spatial_dimension - 1) + ), # TODO placeholder + ) @pytest.fixture def generator( @@ -58,12 +87,14 @@ def generator( number_of_discretization_steps, number_of_corrector_steps, spatial_dimension, + num_atom_types, initial_sample, ): generator = FakePCGenerator( number_of_discretization_steps, number_of_corrector_steps, spatial_dimension, + num_atom_types, initial_sample, ) return generator @@ -81,22 +112,32 @@ def expected_samples( list_i.reverse() list_j = list(range(number_of_corrector_steps)) - noisy_sample = map_relative_coordinates_to_unit_cell(initial_sample) - x_ip1 = noisy_sample + noisy_sample = map_axl_composition_to_unit_cell( + initial_sample, torch.device("cpu") + ) + composition_ip1 = noisy_sample for i in list_i: - xi = map_relative_coordinates_to_unit_cell( + composition_i = map_axl_composition_to_unit_cell( generator.predictor_step( - x_ip1, i + 1, unit_cell_sample, torch.zeros_like(x_ip1) - ) + composition_ip1, + i + 1, + unit_cell_sample, + torch.zeros_like(composition_ip1.X), + ), + torch.device("cpu"), ) for _ in list_j: - xi = map_relative_coordinates_to_unit_cell( + composition_i = map_axl_composition_to_unit_cell( generator.corrector_step( - xi, i, unit_cell_sample, torch.zeros_like(xi) - ) + composition_i, + i, + unit_cell_sample, + torch.zeros_like(composition_i.X), + ), + torch.device("cpu"), ) - x_ip1 = xi - return xi + composition_ip1 = composition_i + return composition_i def test_sample( self, generator, number_of_samples, expected_samples, unit_cell_sample @@ -104,4 +145,5 @@ def test_sample( computed_samples = generator.sample( number_of_samples, torch.device("cpu"), unit_cell_sample ) + torch.testing.assert_close(expected_samples, computed_samples) diff --git a/tests/generators/test_sde_position_generator.py b/tests/generators/test_sde_position_generator.py index 9cb36372..79b28f3e 100644 --- a/tests/generators/test_sde_position_generator.py +++ b/tests/generators/test_sde_position_generator.py @@ -3,8 +3,10 @@ from diffusion_for_multi_scale_molecular_dynamics.generators.sde_position_generator import ( SDE, ExplodingVarianceSDEPositionGenerator, SDESamplingParameters) -from src.diffusion_for_multi_scale_molecular_dynamics.samplers.variance_sampler import ( - ExplodingVarianceSampler, NoiseParameters) +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.exploding_variance import \ + VarianceScheduler +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ + NoiseParameters from tests.generators.conftest import BaseTestGenerator @@ -34,6 +36,7 @@ def sampling_parameters( cell_dimensions, number_of_samples, record_samples, + num_atom_types, ): sampling_parameters = SDESamplingParameters( number_of_atoms=number_of_atoms, @@ -41,15 +44,21 @@ def sampling_parameters( number_of_samples=number_of_samples, cell_dimensions=cell_dimensions, record_samples=record_samples, + num_atom_types=num_atom_types, ) return sampling_parameters + @pytest.fixture() + def atom_types(self, number_of_samples, number_of_atoms): + return torch.zeros(number_of_samples, number_of_atoms).long() + @pytest.fixture() def sde( self, noise_parameters, sampling_parameters, - sigma_normalized_score_network, + axl_network, + atom_types, unit_cell_sample, initial_diffusion_time, final_diffusion_time, @@ -57,7 +66,8 @@ def sde( sde = SDE( noise_parameters=noise_parameters, sampling_parameters=sampling_parameters, - sigma_normalized_score_network=sigma_normalized_score_network, + axl_network=axl_network, + atom_types=atom_types, unit_cells=unit_cell_sample, initial_diffusion_time=initial_diffusion_time, final_diffusion_time=final_diffusion_time, @@ -87,9 +97,7 @@ def test_sde_g_squared( final_diffusion_time - initial_diffusion_time ) - sigma = ExplodingVarianceSampler._create_sigma_array( - noise_parameters=noise_parameters, time_array=time_array - )[0] + sigma = VarianceScheduler(noise_parameters).get_sigma(time_array)[0] expected_g_squared = ( 2.0 @@ -106,13 +114,11 @@ def test_sde_g_squared( torch.testing.assert_close(computed_g_squared, expected_g_squared) @pytest.fixture() - def sde_generator( - self, noise_parameters, sampling_parameters, sigma_normalized_score_network - ): + def sde_generator(self, noise_parameters, sampling_parameters, axl_network): generator = ExplodingVarianceSDEPositionGenerator( noise_parameters=noise_parameters, sampling_parameters=sampling_parameters, - sigma_normalized_score_network=sigma_normalized_score_network, + axl_network=axl_network, ) return generator diff --git a/tests/samples_and_metrics/__init__.py b/tests/loss/__init__.py similarity index 100% rename from tests/samples_and_metrics/__init__.py rename to tests/loss/__init__.py diff --git a/tests/loss/test_atom_type_loss_calculator.py b/tests/loss/test_atom_type_loss_calculator.py new file mode 100644 index 00000000..33d70702 --- /dev/null +++ b/tests/loss/test_atom_type_loss_calculator.py @@ -0,0 +1,444 @@ +from unittest.mock import patch + +import pytest +import torch +from torch.nn import KLDivLoss + +from diffusion_for_multi_scale_molecular_dynamics.loss import ( + D3PMLossCalculator, LossParameters) +from diffusion_for_multi_scale_molecular_dynamics.utils.d3pm_utils import \ + class_index_to_onehot +from diffusion_for_multi_scale_molecular_dynamics.utils.tensor_utils import \ + broadcast_batch_matrix_tensor_to_all_dimensions + + +class TestD3PMLossCalculator: + + @pytest.fixture(scope="class", autouse=True) + def set_seed(self): + """Set the random seed.""" + torch.manual_seed(3423423) + + @pytest.fixture + def batch_size(self): + return 64 + + @pytest.fixture + def number_of_atoms(self): + return 8 + + @pytest.fixture + def num_atom_types(self): + return 5 + + @pytest.fixture + def total_number_of_times_steps(self): + return 8 + + @pytest.fixture + def time_indices(self, batch_size, total_number_of_times_steps): + return torch.randint(0, total_number_of_times_steps, (batch_size,)) + + @pytest.fixture + def num_classes(self, num_atom_types): + return num_atom_types + 1 + + @pytest.fixture + def predicted_logits(self, batch_size, number_of_atoms, num_classes): + logits = 10 * (torch.randn(batch_size, number_of_atoms, num_classes) - 0.5) + logits[:, :, -1] = -torch.inf # force the model to never predict MASK + return logits + + @pytest.fixture + def predicted_p_a0_given_at(self, predicted_logits): + return torch.nn.functional.softmax(predicted_logits, dim=-1) + + @pytest.fixture + def one_hot_a0(self, batch_size, number_of_atoms, num_atom_types, num_classes): + # a0 CANNOT be MASK. + one_hot_indices = torch.randint( + 0, + num_atom_types, + ( + batch_size, + number_of_atoms, + ), + ) + one_hots = class_index_to_onehot(one_hot_indices, num_classes=num_classes) + return one_hots + + @pytest.fixture + def one_hot_at(self, batch_size, number_of_atoms, num_atom_types, num_classes): + # at CAN be MASK. + one_hot_indices = torch.randint( + 0, + num_classes, + ( + batch_size, + number_of_atoms, + ), + ) + one_hots = class_index_to_onehot(one_hot_indices, num_classes=num_classes) + return one_hots + + @pytest.fixture + def one_hot_different_noisy_atom_types( + self, batch_size, number_of_atoms, num_classes + ): + one_hot_noisy_atom_types = torch.zeros(batch_size, number_of_atoms, num_classes) + for i in range(number_of_atoms): + one_hot_noisy_atom_types[:, i, i + 1] = 1 + return one_hot_noisy_atom_types + + @pytest.fixture + def one_hot_similar_noisy_atom_types( + self, batch_size, number_of_atoms, num_classes + ): + one_hot_noisy_atom_types = torch.zeros(batch_size, number_of_atoms, num_classes) + for i in range(1, number_of_atoms): + one_hot_noisy_atom_types[:, i, i + 1] = 1 + one_hot_noisy_atom_types[:, 0, 0] = 1 + return one_hot_noisy_atom_types + + @pytest.fixture + def q_matrices(self, batch_size, number_of_atoms, num_classes): + random_q_matrices = torch.rand(batch_size, num_classes, num_classes) + final_shape = (batch_size, number_of_atoms) + broadcast_q_matrices = broadcast_batch_matrix_tensor_to_all_dimensions( + random_q_matrices, final_shape=final_shape + ) + return broadcast_q_matrices + + @pytest.fixture + def q_bar_matrices(self, batch_size, number_of_atoms, num_classes): + random_q_bar_matrices = torch.rand(batch_size, num_classes, num_classes) + final_shape = (batch_size, number_of_atoms) + broadcast_q_bar_matrices = broadcast_batch_matrix_tensor_to_all_dimensions( + random_q_bar_matrices, final_shape=final_shape + ) + return broadcast_q_bar_matrices + + @pytest.fixture + def q_bar_tm1_matrices(self, batch_size, number_of_atoms, num_classes): + random_q_bar_tm1_matrices = torch.rand(batch_size, num_classes, num_classes) + final_shape = (batch_size, number_of_atoms) + broadcast_q_bar_tm1_matrices = broadcast_batch_matrix_tensor_to_all_dimensions( + random_q_bar_tm1_matrices, final_shape=final_shape + ) + return broadcast_q_bar_tm1_matrices + + @pytest.fixture + def loss_eps(self): + return 1.0e-12 + + @pytest.fixture + def atom_types_ce_weight(self): + return 0.1 + + @pytest.fixture + def loss_parameters(self, loss_eps, atom_types_ce_weight): + return LossParameters( + coordinates_algorithm=None, + atom_types_eps=loss_eps, + atom_types_ce_weight=atom_types_ce_weight, + ) + + @pytest.fixture + def d3pm_calculator(self, loss_parameters): + return D3PMLossCalculator(loss_parameters) + + @pytest.fixture + def expected_p_atm1_given_at( + self, + predicted_p_a0_given_at, + one_hot_at, + q_matrices, + q_bar_matrices, + q_bar_tm1_matrices, + ): + batch_size, natoms, num_classes = predicted_p_a0_given_at.shape + + denominator = torch.zeros(batch_size, natoms) + numerator1 = torch.zeros(batch_size, natoms, num_classes) + numerator2 = torch.zeros(batch_size, natoms, num_classes) + + for i in range(num_classes): + for j in range(num_classes): + denominator[:, :] += ( + predicted_p_a0_given_at[:, :, i] + * q_bar_matrices[:, :, i, j] + * one_hot_at[:, :, j] + ) + numerator1[:, :, i] += ( + predicted_p_a0_given_at[:, :, j] * q_bar_tm1_matrices[:, :, j, i] + ) + numerator2[:, :, i] += q_matrices[:, :, i, j] * one_hot_at[:, :, j] + + numerator = numerator1 * numerator2 + + expected_p = torch.zeros(batch_size, natoms, num_classes) + for i in range(num_classes): + expected_p[:, :, i] = numerator[:, :, i] / denominator[:, :] + + # Note that the expected_p_atm1_given_at is not really a probability (and thus does not sum to 1) because + # the Q matrices are random. + return expected_p + + @pytest.fixture + def expected_q_atm1_given_at_and_a0( + self, one_hot_a0, one_hot_at, q_matrices, q_bar_matrices, q_bar_tm1_matrices + ): + batch_size, natoms, num_classes = one_hot_a0.shape + + denominator = torch.zeros(batch_size, natoms) + numerator1 = torch.zeros(batch_size, natoms, num_classes) + numerator2 = torch.zeros(batch_size, natoms, num_classes) + + for i in range(num_classes): + for j in range(num_classes): + denominator[:, :] += ( + one_hot_a0[:, :, i] + * q_bar_matrices[:, :, i, j] + * one_hot_at[:, :, j] + ) + numerator1[:, :, i] += ( + one_hot_a0[:, :, j] * q_bar_tm1_matrices[:, :, j, i] + ) + numerator2[:, :, i] += q_matrices[:, :, i, j] * one_hot_at[:, :, j] + + numerator = numerator1 * numerator2 + + expected_q = torch.zeros(batch_size, natoms, num_classes) + for i in range(num_classes): + expected_q[:, :, i] = numerator[:, :, i] / denominator[:, :] + + return expected_q + + @pytest.fixture + def expected_vb_loss( + self, time_indices, one_hot_a0, expected_p_atm1_given_at, expected_q_atm1_given_at_and_a0 + ): + assert ( + 0 in time_indices + ), "For a good test, the index 0 should appear in the time indices!" + + kl_loss = KLDivLoss(reduction="none") + log_p = torch.log(expected_p_atm1_given_at) + vb_loss = kl_loss(input=log_p, target=expected_q_atm1_given_at_and_a0) + + for batch_idx, time_index in enumerate(time_indices): + if time_index == 0: + vb_loss[batch_idx] = -log_p[batch_idx] * one_hot_a0[batch_idx] + + return vb_loss + + def test_get_p_atm1_at( + self, + predicted_logits, + one_hot_at, + q_matrices, + q_bar_matrices, + q_bar_tm1_matrices, + loss_eps, + d3pm_calculator, + expected_p_atm1_given_at, + ): + computed_p_atm1_given_at = d3pm_calculator.get_p_atm1_given_at( + predicted_logits, + one_hot_at, + q_matrices, + q_bar_matrices, + q_bar_tm1_matrices, + small_epsilon=loss_eps, + ) + + assert torch.allclose(computed_p_atm1_given_at, expected_p_atm1_given_at) + + def test_get_q_atm1_given_at_and_a0( + self, + one_hot_a0, + one_hot_at, + q_matrices, + q_bar_matrices, + q_bar_tm1_matrices, + loss_eps, + d3pm_calculator, + expected_q_atm1_given_at_and_a0, + ): + computed_q_atm1_given_at_and_a0 = d3pm_calculator.get_q_atm1_given_at_and_a0( + one_hot_a0, + one_hot_at, + q_matrices, + q_bar_matrices, + q_bar_tm1_matrices, + small_epsilon=loss_eps, + ) + + assert torch.allclose( + computed_q_atm1_given_at_and_a0, expected_q_atm1_given_at_and_a0 + ) + + def test_variational_bound_loss( + self, + predicted_logits, + one_hot_a0, + one_hot_at, + q_matrices, + q_bar_matrices, + q_bar_tm1_matrices, + time_indices, + d3pm_calculator, + loss_eps, + expected_vb_loss, + ): + computed_vb_loss = d3pm_calculator.variational_bound_loss_term( + predicted_logits, + one_hot_a0, + one_hot_at, + q_matrices, + q_bar_matrices, + q_bar_tm1_matrices, + time_indices, + ) + + torch.testing.assert_close(computed_vb_loss, expected_vb_loss) + + def test_vb_loss_predicting_a0( + self, + one_hot_a0, + one_hot_at, + q_matrices, + q_bar_matrices, + q_bar_tm1_matrices, + time_indices, + d3pm_calculator, + ): + # The KL should vanish when p_\theta(. | a_t) predicts a0 with probability 1. + + predicted_logits = torch.log(one_hot_a0) + + computed_vb_loss = d3pm_calculator.variational_bound_loss_term( + predicted_logits, + one_hot_a0, + one_hot_at, + q_matrices, + q_bar_matrices, + q_bar_tm1_matrices, + time_indices, + ) + + non_zero_time_step_mask = time_indices != 0 + computed_kl_loss = computed_vb_loss[non_zero_time_step_mask] + + torch.testing.assert_close(computed_kl_loss, torch.zeros_like(computed_kl_loss)) + + def test_cross_entropy_loss_term(self, predicted_logits, one_hot_a0, d3pm_calculator): + computed_ce_loss = d3pm_calculator.cross_entropy_loss_term(predicted_logits, one_hot_a0) + + p = torch.softmax(predicted_logits, dim=-1) + log_p = torch.log(p) + log_p[..., -1] = 0.0 + expected_ce_loss = -log_p * one_hot_a0 + + torch.testing.assert_close(computed_ce_loss, expected_ce_loss) + + def test_calculate_unreduced_loss( + self, + predicted_logits, + one_hot_a0, + one_hot_at, + q_matrices, + q_bar_matrices, + q_bar_tm1_matrices, + time_indices, + d3pm_calculator, + atom_types_ce_weight, + ): + vb_loss = d3pm_calculator.variational_bound_loss_term( + predicted_logits, + one_hot_a0, + one_hot_at, + q_matrices, + q_bar_matrices, + q_bar_tm1_matrices, + time_indices, + ) + + ce_loss = d3pm_calculator.cross_entropy_loss_term(predicted_logits, one_hot_a0) + expected_losss = vb_loss + atom_types_ce_weight * ce_loss + + computed_loss = d3pm_calculator.calculate_unreduced_loss( + predicted_logits, + one_hot_a0, + one_hot_at, + time_indices, + q_matrices, + q_bar_matrices, + q_bar_tm1_matrices, + ) + + torch.testing.assert_close(computed_loss, expected_losss) + + @pytest.mark.parametrize("time_index_zero", [True, False]) + def test_variational_bound_call( + self, + time_index_zero, + d3pm_calculator, + batch_size, + number_of_atoms, + num_classes, + ): + predicted_logits = torch.randn(batch_size, number_of_atoms, num_classes) + predicted_logits[..., -1] = -torch.inf + + real_atom_types = torch.randint(0, num_classes, (batch_size, number_of_atoms)) + real_atom_types = class_index_to_onehot(real_atom_types, num_classes=num_classes) + + noisy_atom_types = torch.randint(0, num_classes, (batch_size, number_of_atoms)) + noisy_atom_types = class_index_to_onehot(noisy_atom_types, num_classes=num_classes) + + q_matrices = torch.randn(batch_size, number_of_atoms, num_classes, num_classes) + q_bar_matrices = torch.randn( + batch_size, number_of_atoms, num_classes, num_classes + ) + q_bar_tm1_matrices = torch.randn( + batch_size, number_of_atoms, num_classes, num_classes + ) + + # Mock the KL loss term output + mock_vb_loss_output = torch.randn(batch_size, number_of_atoms, num_classes) + + # Define time_indices: 0 for NLL and 1 for KL + NLL (depending on parametrize input) + if time_index_zero: + time_indices = torch.zeros( + batch_size, dtype=torch.long + ) # t == 1 case (index 0) + else: + time_indices = torch.ones(batch_size, dtype=torch.long) # t > 1 case + + # Patch the kl_loss_term method + with patch.object( + d3pm_calculator, + "variational_bound_loss_term", + return_value=mock_vb_loss_output, + ) as mock_vb_loss: + # Call the function under test + _ = d3pm_calculator.calculate_unreduced_loss( + predicted_logits, + real_atom_types, + noisy_atom_types, + time_indices, + q_matrices, + q_bar_matrices, + q_bar_tm1_matrices, + ) + + mock_vb_loss.assert_called_once_with( + predicted_logits, + real_atom_types, + noisy_atom_types, + q_matrices, + q_bar_matrices, + q_bar_tm1_matrices, + time_indices, + ) diff --git a/tests/models/test_loss.py b/tests/loss/test_loss.py similarity index 88% rename from tests/models/test_loss.py rename to tests/loss/test_loss.py index d130e616..0bc962ff 100644 --- a/tests/models/test_loss.py +++ b/tests/loss/test_loss.py @@ -1,9 +1,11 @@ import pytest import torch -from diffusion_for_multi_scale_molecular_dynamics.models.loss import ( - MSELossParameters, WeightedMSELossParameters, create_loss_calculator) -from diffusion_for_multi_scale_molecular_dynamics.utils.tensor_utils import \ +from diffusion_for_multi_scale_molecular_dynamics.loss import \ + create_loss_calculator +from diffusion_for_multi_scale_molecular_dynamics.loss.loss_parameters import ( + MSELossParameters, WeightedMSELossParameters) +from src.diffusion_for_multi_scale_molecular_dynamics.utils.tensor_utils import \ broadcast_batch_tensor_to_all_dimensions @@ -93,7 +95,7 @@ def computed_loss( target_normalized_conditional_scores, sigmas, ): - unreduced_loss = loss_calculator.calculate_unreduced_loss( + unreduced_loss = loss_calculator.X.calculate_unreduced_loss( predicted_normalized_scores, target_normalized_conditional_scores, sigmas ) return torch.mean(unreduced_loss) diff --git a/tests/models/score_network/base_test_score_network.py b/tests/models/score_network/base_test_score_network.py new file mode 100644 index 00000000..3d8cc09c --- /dev/null +++ b/tests/models/score_network/base_test_score_network.py @@ -0,0 +1,36 @@ +import pytest +import torch + + +class BaseTestScoreNetwork: + """Base class defining common fixtures for all tests.""" + + @pytest.fixture(scope="class", autouse=True) + def set_seed(self): + """Set the random seed.""" + torch.manual_seed(234233) + + @pytest.fixture() + def score_network(self, *args): + raise NotImplementedError("This fixture must be implemented in the derived class.") + + @pytest.fixture() + def batch_size(self, *args, **kwargs): + return 16 + + @pytest.fixture() + def number_of_atoms(self): + return 8 + + @pytest.fixture() + def spatial_dimension(self): + return 3 + + @pytest.fixture() + def num_atom_types(self): + return 5 + + @pytest.fixture() + def atom_types(self, batch_size, number_of_atoms, num_atom_types): + atom_types = torch.randint(0, num_atom_types + 1, (batch_size, number_of_atoms)) + return atom_types diff --git a/tests/models/test_analytical_score_network.py b/tests/models/score_network/test_analytical_score_network.py similarity index 85% rename from tests/models/test_analytical_score_network.py rename to tests/models/score_network/test_analytical_score_network.py index b2d0af03..a0537d54 100644 --- a/tests/models/test_analytical_score_network.py +++ b/tests/models/score_network/test_analytical_score_network.py @@ -7,7 +7,9 @@ AnalyticalScoreNetwork, AnalyticalScoreNetworkParameters, TargetScoreBasedAnalyticalScoreNetwork) from diffusion_for_multi_scale_molecular_dynamics.namespace import ( - NOISE, NOISY_RELATIVE_COORDINATES, TIME, UNIT_CELL) + AXL, NOISE, NOISY_AXL_COMPOSITION, TIME, UNIT_CELL) +from tests.models.score_network.base_test_score_network import \ + BaseTestScoreNetwork def factorial(n): @@ -18,7 +20,7 @@ def factorial(n): return n * factorial(n - 1) -class TestAnalyticalScoreNetwork: +class TestAnalyticalScoreNetwork(BaseTestScoreNetwork): @pytest.fixture(scope="class", autouse=True) def set_default_type_to_float64(self): torch.set_default_dtype(torch.float64) @@ -27,14 +29,6 @@ def set_default_type_to_float64(self): # to not affect other tests. torch.set_default_dtype(torch.float32) - @pytest.fixture(scope="class", autouse=True) - def set_random_seed(self): - torch.manual_seed(23423423) - - @pytest.fixture - def batch_size(self): - return 4 - @pytest.fixture def kmax(self): # kmax has to be fairly large for the comparison test between the analytical score and the target based @@ -45,6 +39,10 @@ def kmax(self): def spatial_dimension(self, request): return request.param + @pytest.fixture + def num_atom_types(self): + return 1 + @pytest.fixture(params=[1, 2]) def number_of_atoms(self, request): return request.param @@ -53,6 +51,19 @@ def number_of_atoms(self, request): def equilibrium_relative_coordinates(self, number_of_atoms, spatial_dimension): return torch.rand(number_of_atoms, spatial_dimension) + """ + @pytest.fixture + def atom_types(self, batch_size, number_of_atoms, num_atom_types): + return torch.randint( + 0, + num_atom_types, + ( + batch_size, + number_of_atoms, + ), + ) + """ + @pytest.fixture(params=["finite", "zero"]) def variance_parameter(self, request): if request.param == "zero": @@ -63,7 +74,7 @@ def variance_parameter(self, request): return 1.0 / inverse_variance @pytest.fixture() - def batch(self, batch_size, number_of_atoms, spatial_dimension): + def batch(self, batch_size, number_of_atoms, spatial_dimension, atom_types): relative_coordinates = torch.rand( batch_size, number_of_atoms, spatial_dimension ) @@ -71,7 +82,9 @@ def batch(self, batch_size, number_of_atoms, spatial_dimension): noises = torch.rand(batch_size, 1) unit_cell = torch.rand(batch_size, spatial_dimension, spatial_dimension) return { - NOISY_RELATIVE_COORDINATES: relative_coordinates, + NOISY_AXL_COMPOSITION: AXL( + A=atom_types, X=relative_coordinates, L=torch.zeros_like(atom_types) + ), TIME: times, NOISE: noises, UNIT_CELL: unit_cell, @@ -86,6 +99,7 @@ def score_network_parameters( equilibrium_relative_coordinates, variance_parameter, use_permutation_invariance, + num_atom_types ): hyper_params = AnalyticalScoreNetworkParameters( number_of_atoms=number_of_atoms, @@ -94,6 +108,7 @@ def score_network_parameters( equilibrium_relative_coordinates=equilibrium_relative_coordinates, variance_parameter=variance_parameter, use_permutation_invariance=use_permutation_invariance, + num_atom_types=num_atom_types ) return hyper_params @@ -146,7 +161,7 @@ def test_compute_unnormalized_log_probability( score_network, ): sigmas = batch[NOISE] # dimension: [batch_size, 1] - xt = batch[NOISY_RELATIVE_COORDINATES] + xt = batch[NOISY_AXL_COMPOSITION].X computed_log_prob = score_network._compute_unnormalized_log_probability( sigmas, xt, equilibrium_relative_coordinates ) @@ -174,6 +189,11 @@ def test_compute_unnormalized_log_probability( expected_log_prob[batch_idx] += torch.log(sum_on_k) + # Let's give a free pass to any problematic expected values, which are calculated with a fragile + # brute force approach + problem_mask = torch.logical_or(torch.isnan(expected_log_prob), torch.isinf(expected_log_prob)) + expected_log_prob[problem_mask] = computed_log_prob[problem_mask] + torch.testing.assert_close(expected_log_prob, computed_log_prob) @pytest.mark.parametrize( @@ -185,7 +205,7 @@ def test_analytical_score_network( ): normalized_scores = score_network.forward(batch) - assert normalized_scores.shape == ( + assert normalized_scores.X.shape == ( batch_size, number_of_atoms, spatial_dimension, diff --git a/tests/models/score_network/test_force_field_augmented_score_network.py b/tests/models/score_network/test_force_field_augmented_score_network.py index 573fe18e..b8a0e60e 100644 --- a/tests/models/score_network/test_force_field_augmented_score_network.py +++ b/tests/models/score_network/test_force_field_augmented_score_network.py @@ -6,33 +6,29 @@ from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.mlp_score_network import ( MLPScoreNetwork, MLPScoreNetworkParameters) from diffusion_for_multi_scale_molecular_dynamics.namespace import ( - CARTESIAN_FORCES, NOISE, NOISY_RELATIVE_COORDINATES, TIME, UNIT_CELL) + AXL, CARTESIAN_FORCES, NOISE, NOISY_AXL_COMPOSITION, TIME, UNIT_CELL) +from tests.models.score_network.base_test_score_network import \ + BaseTestScoreNetwork @pytest.mark.parametrize("number_of_atoms", [4, 8, 16]) @pytest.mark.parametrize("radial_cutoff", [1.5, 2.0, 2.5]) -class TestForceFieldAugmentedScoreNetwork: - @pytest.fixture(scope="class", autouse=True) - def set_random_seed(self): - torch.manual_seed(345345345) - - @pytest.fixture() - def spatial_dimension(self): - return 3 - +class TestForceFieldAugmentedScoreNetwork(BaseTestScoreNetwork): @pytest.fixture() - def score_network_parameters(self, number_of_atoms, spatial_dimension): + def score_network( + self, number_of_atoms, spatial_dimension, num_atom_types + ): # Generate an arbitrary MLP-based score network. - return MLPScoreNetworkParameters( + score_network_parameters = MLPScoreNetworkParameters( spatial_dimension=spatial_dimension, number_of_atoms=number_of_atoms, - embedding_dimensions_size=12, + num_atom_types=num_atom_types, + noise_embedding_dimensions_size=6, + time_embedding_dimensions_size=6, + atom_type_embedding_dimensions_size=12, n_hidden_dimensions=2, hidden_dimensions_size=16, ) - - @pytest.fixture() - def score_network(self, score_network_parameters): return MLPScoreNetwork(score_network_parameters) @pytest.fixture() @@ -48,10 +44,6 @@ def force_field_augmented_score_network( ) return augmented_score_network - @pytest.fixture() - def batch_size(self): - return 16 - @pytest.fixture def times(self, batch_size): times = torch.rand(batch_size, 1) @@ -94,10 +86,20 @@ def noises(self, batch_size): @pytest.fixture() def batch( - self, relative_coordinates, cartesian_forces, times, noises, basis_vectors + self, + relative_coordinates, + atom_types, + cartesian_forces, + times, + noises, + basis_vectors, ): return { - NOISY_RELATIVE_COORDINATES: relative_coordinates, + NOISY_AXL_COMPOSITION: AXL( + A=atom_types, + X=relative_coordinates, + L=torch.zeros_like(atom_types), # TODO + ), TIME: times, UNIT_CELL: basis_vectors, NOISE: noises, @@ -144,7 +146,8 @@ def test_get_cartesian_pseudo_forces( ) ) cartesian_pseudo_force_contributions = ( - force_field_augmented_score_network._get_cartesian_pseudo_forces_contributions(cartesian_displacements)) + force_field_augmented_score_network._get_cartesian_pseudo_forces_contributions(cartesian_displacements) + ) computed_cartesian_pseudo_forces = ( force_field_augmented_score_network._get_cartesian_pseudo_forces( @@ -180,7 +183,7 @@ def test_augmented_scores( raw_scores = score_network(batch) augmented_scores = force_field_augmented_score_network(batch) - torch.testing.assert_allclose(augmented_scores - raw_scores, forces) + torch.testing.assert_allclose(augmented_scores.X - raw_scores.X, forces) def test_specific_scenario_sanity_check(): @@ -199,10 +202,15 @@ def test_specific_scenario_sanity_check(): # Put two atoms on a straight line relative_coordinates = torch.tensor([[[0.35, 0.5, 0.0], [0.65, 0.5, 0.0]]]) - + atom_types = torch.zeros_like(relative_coordinates[..., 0]) basis_vectors = torch.diag(torch.ones(spatial_dimension)).unsqueeze(0) - batch = {NOISY_RELATIVE_COORDINATES: relative_coordinates, UNIT_CELL: basis_vectors} + batch = { + NOISY_AXL_COMPOSITION: AXL( + A=atom_types, X=relative_coordinates, L=torch.zeros_like(atom_types) + ), + UNIT_CELL: basis_vectors, + } forces = force_field_score_network.get_relative_coordinates_pseudo_force(batch) diff --git a/tests/models/score_network/test_score_network_basic_checks.py b/tests/models/score_network/test_score_network_basic_checks.py new file mode 100644 index 00000000..f64dee77 --- /dev/null +++ b/tests/models/score_network/test_score_network_basic_checks.py @@ -0,0 +1,174 @@ +import pytest +import torch + +from diffusion_for_multi_scale_molecular_dynamics.models.score_networks import ( + ScoreNetwork, ScoreNetworkParameters) +from diffusion_for_multi_scale_molecular_dynamics.namespace import ( + AXL, NOISE, NOISY_AXL_COMPOSITION, TIME, UNIT_CELL) +from tests.models.score_network.base_test_score_network import \ + BaseTestScoreNetwork + + +@pytest.mark.parametrize("spatial_dimension", [2, 3]) +class TestScoreNetworkBasicCheck(BaseTestScoreNetwork): + + @pytest.fixture() + def score_network(self, spatial_dimension, num_atom_types): + score_parameters = ScoreNetworkParameters( + architecture="dummy", + spatial_dimension=spatial_dimension, + num_atom_types=num_atom_types, + ) + + return ScoreNetwork(score_parameters) + + @pytest.fixture() + def good_batch(self, spatial_dimension, num_atom_types, number_of_atoms): + batch_size = 16 + relative_coordinates = torch.rand( + batch_size, number_of_atoms, spatial_dimension + ) + times = torch.rand(batch_size, 1) + noises = torch.rand(batch_size, 1) + unit_cell = torch.rand(batch_size, spatial_dimension, spatial_dimension) + atom_types = torch.randint(0, num_atom_types + 1, (batch_size, number_of_atoms)) + return { + NOISY_AXL_COMPOSITION: AXL( + A=atom_types, X=relative_coordinates, L=torch.zeros_like(atom_types) + ), + TIME: times, + NOISE: noises, + UNIT_CELL: unit_cell, + } + + @pytest.fixture() + def bad_batch(self, good_batch, problem, num_atom_types): + + bad_batch_dict = dict(good_batch) + + match problem: + case "position_name": + bad_batch_dict["bad_position_name"] = bad_batch_dict[ + NOISY_AXL_COMPOSITION + ] + del bad_batch_dict[NOISY_AXL_COMPOSITION] + + case "position_shape": + shape = bad_batch_dict[NOISY_AXL_COMPOSITION].X.shape + bad_batch_dict[NOISY_AXL_COMPOSITION] = AXL( + A=bad_batch_dict[NOISY_AXL_COMPOSITION].A, + X=bad_batch_dict[NOISY_AXL_COMPOSITION].X.reshape( + shape[0], shape[1] // 2, shape[2] * 2 + ), + L=bad_batch_dict[NOISY_AXL_COMPOSITION].L, + ) + + case "position_range1": + bad_positions = bad_batch_dict[NOISY_AXL_COMPOSITION].X + bad_positions[0, 0, 0] = 1.01 + bad_batch_dict[NOISY_AXL_COMPOSITION] = AXL( + A=bad_batch_dict[NOISY_AXL_COMPOSITION].A, + X=bad_positions, + L=bad_batch_dict[NOISY_AXL_COMPOSITION].L, + ) + + case "position_range2": + bad_positions = bad_batch_dict[NOISY_AXL_COMPOSITION].X + bad_positions[1, 0, 0] = -0.01 + bad_batch_dict[NOISY_AXL_COMPOSITION] = AXL( + A=bad_batch_dict[NOISY_AXL_COMPOSITION].A, + X=bad_positions, + L=bad_batch_dict[NOISY_AXL_COMPOSITION].L, + ) + + case "atom_types_shape": + shape = bad_batch_dict[NOISY_AXL_COMPOSITION].A.shape + bad_batch_dict[NOISY_AXL_COMPOSITION] = AXL( + A=bad_batch_dict[NOISY_AXL_COMPOSITION].A.reshape( + shape[0] * 2, shape[1] // 2 + ), + X=bad_batch_dict[NOISY_AXL_COMPOSITION].X, + L=bad_batch_dict[NOISY_AXL_COMPOSITION].L, + ) + + case "atom_types_range1": + bad_types = bad_batch_dict[NOISY_AXL_COMPOSITION].A + bad_types[0, 0] = num_atom_types + 2 + bad_batch_dict[NOISY_AXL_COMPOSITION] = AXL( + A=bad_types, + X=bad_batch_dict[NOISY_AXL_COMPOSITION].X, + L=bad_batch_dict[NOISY_AXL_COMPOSITION].L, + ) + + case "atom_types_range2": + bad_types = bad_batch_dict[NOISY_AXL_COMPOSITION].A + bad_types[1, 0] = -1 + bad_batch_dict[NOISY_AXL_COMPOSITION] = AXL( + A=bad_types, + X=bad_batch_dict[NOISY_AXL_COMPOSITION].X, + L=bad_batch_dict[NOISY_AXL_COMPOSITION].L, + ) + + case "time_name": + bad_batch_dict["bad_time_name"] = bad_batch_dict[TIME] + del bad_batch_dict[TIME] + + case "time_shape": + shape = bad_batch_dict[TIME].shape + bad_batch_dict[TIME] = bad_batch_dict[TIME].reshape( + shape[0] // 2, shape[1] * 2 + ) + + case "noise_name": + bad_batch_dict["bad_noise_name"] = bad_batch_dict[NOISE] + del bad_batch_dict[NOISE] + + case "noise_shape": + shape = bad_batch_dict[NOISE].shape + bad_batch_dict[NOISE] = bad_batch_dict[NOISE].reshape( + shape[0] // 2, shape[1] * 2 + ) + + case "time_range1": + bad_batch_dict[TIME][5, 0] = 2.00 + case "time_range2": + bad_batch_dict[TIME][0, 0] = -0.05 + + case "cell_name": + bad_batch_dict["bad_unit_cell_key"] = bad_batch_dict[UNIT_CELL] + del bad_batch_dict[UNIT_CELL] + + case "cell_shape": + shape = bad_batch_dict[UNIT_CELL].shape + bad_batch_dict[UNIT_CELL] = bad_batch_dict[UNIT_CELL].reshape( + shape[0] // 2, shape[1] * 2, shape[2] + ) + + return bad_batch_dict + + def test_check_batch_good(self, score_network, good_batch): + score_network._check_batch(good_batch) + + @pytest.mark.parametrize( + "problem", + [ + "position_name", + "time_name", + "position_shape", + "atom_types_shape", + "time_shape", + "noise_name", + "noise_shape", + "position_range1", + "position_range2", + "atom_types_range1", + "atom_types_range2", + "time_range1", + "time_range2", + "cell_name", + "cell_shape", + ], + ) + def test_check_batch_bad(self, score_network, bad_batch): + with pytest.raises(AssertionError): + score_network._check_batch(bad_batch) diff --git a/tests/models/score_network/test_score_network_equivariance.py b/tests/models/score_network/test_score_network_equivariance.py new file mode 100644 index 00000000..6c012007 --- /dev/null +++ b/tests/models/score_network/test_score_network_equivariance.py @@ -0,0 +1,526 @@ +import einops +import pytest +import torch +from e3nn import o3 + +from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.diffusion_mace_score_network import ( + DiffusionMACEScoreNetwork, DiffusionMACEScoreNetworkParameters) +from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.egnn_score_network import ( + EGNNScoreNetwork, EGNNScoreNetworkParameters) +from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.mace_score_network import ( + MACEScoreNetwork, MACEScoreNetworkParameters) +from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.score_prediction_head import \ + MaceEquivariantScorePredictionHeadParameters +from diffusion_for_multi_scale_molecular_dynamics.namespace import ( + AXL, CARTESIAN_FORCES, NOISE, NOISY_AXL_COMPOSITION, + NOISY_CARTESIAN_POSITIONS, TIME, UNIT_CELL) +from diffusion_for_multi_scale_molecular_dynamics.utils.basis_transformations import ( + get_positions_from_coordinates, get_reciprocal_basis_vectors, + get_relative_coordinates_from_cartesian_positions, + map_relative_coordinates_to_unit_cell) +from diffusion_for_multi_scale_molecular_dynamics.utils.geometric_utils import \ + get_cubic_point_group_symmetries +from tests.models.score_network.base_test_score_network import \ + BaseTestScoreNetwork + + +class BaseTestScoreEquivariance(BaseTestScoreNetwork): + + @staticmethod + def apply_rotation_to_configuration(batch_rotation_matrices, batch_configuration): + """Apply rotations to configuration. + + Args: + batch_rotation_matrices : Dimension [batch_size, spatial_dimension, spatial_dimension] + batch_configuration : Dimension [batch_size, number_of_atoms, spatial_dimension] + + Returns: + rotated_batch_configuration : Dimension [batch_size, number_of_atoms, spatial_dimension] + """ + return einops.einsum( + batch_rotation_matrices, + batch_configuration, + "batch alpha beta, batch natoms beta -> batch natoms alpha", + ).contiguous() + + @staticmethod + def get_rotated_basis_vectors(batch_rotation_matrices, basis_vectors): + """Get rotated basis vectors. + + Basis vectors are assumed to be in ROW format, + + basis_vectors = [ --- a1 ---] + [---- a2 ---] + [---- a3 ---] + + Args: + batch_rotation_matrices : Dimension [batch_size, spatial_dimension, spatial_dimension] + basis_vectors : Dimension [batch_size, spatial_dimension, spatial_dimension] + + Returns: + rotated_basis_vectors : Dimension [batch_size, spatial_dimension, spatial_dimension] + """ + new_basis_vectors = einops.einsum( + batch_rotation_matrices, + basis_vectors, + "batch alpha beta, batch i beta -> batch i alpha", + ).contiguous() + return new_basis_vectors + + @staticmethod + def create_batch( + relative_coordinates, + cartesian_positions, + atom_types, + basis_vectors, + times, + noises, + forces, + ): + batch = { + NOISY_AXL_COMPOSITION: AXL( + A=atom_types, + X=relative_coordinates, + L=torch.zeros_like(atom_types), # TODO + ), + NOISY_CARTESIAN_POSITIONS: cartesian_positions, + TIME: times, + NOISE: noises, + UNIT_CELL: basis_vectors, + CARTESIAN_FORCES: forces, + } + return batch + + @pytest.fixture(scope="class", autouse=True) + def set_default_type_to_float64(self): + torch.set_default_dtype(torch.float64) + yield + # this returns the default type to float32 at the end of all tests in this class in order + # to not affect other tests. + torch.set_default_dtype(torch.float32) + + @pytest.fixture() + def output(self, batch, score_network): + with torch.no_grad(): + return score_network(batch) + + @pytest.fixture() + def translated_output(self, translated_batch, score_network): + with torch.no_grad(): + return score_network(translated_batch) + + @pytest.fixture() + def rotated_output(self, rotated_batch, score_network): + with torch.no_grad(): + return score_network(rotated_batch) + + @pytest.fixture() + def permuted_output(self, permuted_batch, score_network): + with torch.no_grad(): + return score_network(permuted_batch) + + @pytest.fixture(params=[True, False]) + def are_basis_vectors_rotated(self, request): + # Should the basis vectors be rotated according to the point group operation? + return request.param + + @pytest.fixture(params=[True, False]) + def is_cell_cubic(self, request): + # Should the basis vectors form a cube? + return request.param + + @pytest.fixture(params=[True, False]) + def is_rotations_cubic_point_group(self, request): + # Should the rotations be the symmetries of a cube? + return request.param + + @pytest.fixture() + def batch_size(self, is_rotations_cubic_point_group): + if is_rotations_cubic_point_group: + return len(get_cubic_point_group_symmetries()) + else: + return 16 + + @pytest.fixture() + def basis_vectors(self, batch_size, spatial_dimension, is_cell_cubic): + if is_cell_cubic: + # Cubic unit cells. + basis_vectors = (5.0 + 5.0 * torch.rand(1)) * torch.eye( + spatial_dimension + ).repeat(batch_size, 1, 1) + else: + # orthogonal boxes with dimensions between 5 and 10. + orthogonal_boxes = torch.stack( + [ + torch.diag(5.0 + 5.0 * torch.rand(spatial_dimension)) + for _ in range(batch_size) + ] + ) + # add a bit of noise to make the vectors not quite orthogonal + basis_vectors = orthogonal_boxes + 0.1 * torch.randn( + batch_size, spatial_dimension, spatial_dimension + ) + + return basis_vectors + + @pytest.fixture() + def rotated_basis_vectors( + self, cartesian_rotations, basis_vectors, are_basis_vectors_rotated + ): + # The basis vectors are defined as ROWS. + if are_basis_vectors_rotated: + return self.get_rotated_basis_vectors(cartesian_rotations, basis_vectors) + else: + return basis_vectors + + @pytest.fixture() + def relative_coordinates(self, batch_size, number_of_atoms, spatial_dimension): + relative_coordinates = torch.rand( + batch_size, number_of_atoms, spatial_dimension + ) + return relative_coordinates + + @pytest.fixture() + def cartesian_positions(self, relative_coordinates, basis_vectors): + return get_positions_from_coordinates(relative_coordinates, basis_vectors) + + @pytest.fixture() + def times(self, batch_size): + return torch.rand(batch_size, 1) + + @pytest.fixture() + def noises(self, batch_size): + return 0.5 * torch.rand(batch_size, 1) + + @pytest.fixture() + def forces(self, batch_size, spatial_dimension): + return 0.5 * torch.rand(batch_size, spatial_dimension) + + @pytest.fixture() + def permutations(self, batch_size, number_of_atoms): + return torch.stack([torch.randperm(number_of_atoms) for _ in range(batch_size)]) + + @pytest.fixture() + def cartesian_rotations(self, batch_size, is_rotations_cubic_point_group): + if is_rotations_cubic_point_group: + return get_cubic_point_group_symmetries() + else: + return o3.rand_matrix(batch_size) + + @pytest.fixture() + def cartesian_translations( + self, batch_size, number_of_atoms, spatial_dimension, basis_vectors + ): + batch_relative_coordinates_translations = torch.rand( + batch_size, spatial_dimension + ) + + batch_cartesian_translations = [] + for t, cell in zip(batch_relative_coordinates_translations, basis_vectors): + batch_cartesian_translations.append(t @ cell) + + batch_cartesian_translations = torch.stack(batch_cartesian_translations) + + cartesian_translations = torch.repeat_interleave( + batch_cartesian_translations.unsqueeze(1), number_of_atoms, dim=1 + ) + return cartesian_translations + + @pytest.fixture() + def batch( + self, + relative_coordinates, + cartesian_positions, + atom_types, + basis_vectors, + times, + noises, + forces, + ): + return self.create_batch( + relative_coordinates, + cartesian_positions, + atom_types, + basis_vectors, + times, + noises, + forces, + ) + + @pytest.fixture() + def translated_batch( + self, + cartesian_translations, + relative_coordinates, + cartesian_positions, + atom_types, + basis_vectors, + times, + noises, + forces, + ): + translated_cartesian_positions = cartesian_positions + cartesian_translations + reciprocal_basis_vectors = get_reciprocal_basis_vectors(basis_vectors) + + new_relative_coordinates = map_relative_coordinates_to_unit_cell( + get_relative_coordinates_from_cartesian_positions( + translated_cartesian_positions, reciprocal_basis_vectors + ) + ) + new_cartesian_positions = get_positions_from_coordinates( + new_relative_coordinates, basis_vectors + ) + return self.create_batch( + new_relative_coordinates, + new_cartesian_positions, + atom_types, + basis_vectors, + times, + noises, + forces, + ) + + @pytest.fixture() + def rotated_batch( + self, + rotated_basis_vectors, + cartesian_rotations, + relative_coordinates, + cartesian_positions, + atom_types, + basis_vectors, + times, + noises, + forces, + ): + rotated_cartesian_positions = self.apply_rotation_to_configuration( + cartesian_rotations, cartesian_positions + ) + + rotated_reciprocal_basis_vectors = get_reciprocal_basis_vectors( + rotated_basis_vectors + ) + + rel_coords = get_relative_coordinates_from_cartesian_positions( + rotated_cartesian_positions, rotated_reciprocal_basis_vectors + ) + new_relative_coordinates = map_relative_coordinates_to_unit_cell(rel_coords) + new_cartesian_positions = get_positions_from_coordinates( + new_relative_coordinates, rotated_reciprocal_basis_vectors + ) + return self.create_batch( + new_relative_coordinates, + new_cartesian_positions, + atom_types, + rotated_basis_vectors, + times, + noises, + forces, + ) + + @pytest.fixture() + def permuted_batch( + self, + permutations, + relative_coordinates, + cartesian_positions, + atom_types, + basis_vectors, + times, + noises, + forces, + ): + batch_size = relative_coordinates.shape[0] + + new_cartesian_positions = torch.stack( + [ + cartesian_positions[batch_idx, permutations[batch_idx], :] + for batch_idx in range(batch_size) + ] + ) + + new_relative_coordinates = torch.stack( + [ + relative_coordinates[batch_idx, permutations[batch_idx], :] + for batch_idx in range(batch_size) + ] + ) + + new_atom_types = torch.stack( + [ + atom_types[batch_idx, permutations[batch_idx]] + for batch_idx in range(batch_size) + ] + ) + return self.create_batch( + new_relative_coordinates, + new_cartesian_positions, + new_atom_types, + basis_vectors, + times, + noises, + forces, + ) + + def test_translation_invariance(self, output, translated_output): + torch.testing.assert_close(output, translated_output) + + @pytest.fixture() + def rotated_scores_should_match( + self, is_rotations_cubic_point_group, is_cell_cubic, are_basis_vectors_rotated + ): + # The rotated scores should match the original scores if the basis vectors are rotated. + # If the basis vectors are NOT rotated, only a cubic unit cell (and cubic symmetries) should match. + should_match = are_basis_vectors_rotated or ( + is_cell_cubic and is_rotations_cubic_point_group + ) + return should_match + + @pytest.fixture() + def atom_output_should_be_tested_for_rotational_equivariance(self): + return True + + def test_rotation_equivariance( + self, + output, + rotated_output, + basis_vectors, + rotated_basis_vectors, + cartesian_rotations, + rotated_scores_should_match, + atom_output_should_be_tested_for_rotational_equivariance + ): + + # The score is ~ nabla_x ln P. There must a be a basis change to turn it into a cartesian score of the + # form ~ nabla_r ln P. + reciprocal_basis_vectors = get_reciprocal_basis_vectors(basis_vectors) + cartesian_scores = einops.einsum( + reciprocal_basis_vectors, + output.X, + "batch alpha i, batch natoms i -> batch natoms alpha", + ).contiguous() + + reciprocal_rotated_basis_vectors = get_reciprocal_basis_vectors( + rotated_basis_vectors + ) + rotated_cartesian_scores = einops.einsum( + reciprocal_rotated_basis_vectors, + rotated_output.X, + "batch alpha i, batch natoms i -> batch natoms alpha", + ).contiguous() + + expected_rotated_cartesian_scores = self.apply_rotation_to_configuration( + cartesian_rotations, cartesian_scores + ) + + if rotated_scores_should_match: + torch.testing.assert_close( + expected_rotated_cartesian_scores, rotated_cartesian_scores + ) + torch.testing.assert_close(output.L, rotated_output.L) + + if atom_output_should_be_tested_for_rotational_equivariance: + torch.testing.assert_close(output.A, rotated_output.A) + else: + with pytest.raises(AssertionError): + torch.testing.assert_close( + expected_rotated_cartesian_scores, rotated_cartesian_scores + ) + # TODO: it's not clear what the expectation should be for A and L in this case... + + def test_permutation_equivariance( + self, output, permuted_output, batch_size, permutations + ): + + expected_output_x = torch.stack( + [ + output.X[batch_idx, permutations[batch_idx], :] + for batch_idx in range(batch_size) + ] + ) + + expected_output_a = torch.stack( + [ + output.A[batch_idx, permutations[batch_idx]] + for batch_idx in range(batch_size) + ] + ) + + expected_permuted_output = AXL( + A=expected_output_a, X=expected_output_x, L=output.L + ) + + torch.testing.assert_close(expected_permuted_output, permuted_output) + + +class TestEquivarianceDiffusionMACE(BaseTestScoreEquivariance): + + @pytest.fixture() + def score_network_parameters( + self, number_of_atoms, num_atom_types, spatial_dimension + ): + return DiffusionMACEScoreNetworkParameters( + spatial_dimension=spatial_dimension, + number_of_atoms=number_of_atoms, + num_atom_types=num_atom_types, + r_max=3.0, + num_bessel=4, + num_polynomial_cutoff=3, + hidden_irreps="8x0e + 8x1o", + mlp_irreps="8x0e", + number_of_mlp_layers=1, + correlation=2, + radial_MLP=[8, 8, 8], + ) + + @pytest.fixture() + def score_network(self, score_network_parameters): + return DiffusionMACEScoreNetwork(score_network_parameters) + + +# TODO: This model has not yet been adapted to multiple atom types, and so is not ready for atom_type related tests. +# This test should be updated if the model is adapted to multiple atom types. +class TestEquivarianceMaceWithEquivariantScorePredictionHead(BaseTestScoreEquivariance): + + @pytest.fixture() + def atom_output_should_be_tested_for_rotational_equivariance(self): + return False + + @pytest.fixture() + def score_network_parameters( + self, + spatial_dimension, + number_of_atoms, + num_atom_types, + ): + prediction_head_parameters = MaceEquivariantScorePredictionHeadParameters( + spatial_dimension=spatial_dimension, + number_of_layers=2, + ) + + return MACEScoreNetworkParameters( + spatial_dimension=spatial_dimension, + number_of_atoms=number_of_atoms, + num_atom_types=num_atom_types, + r_max=3.0, + prediction_head_parameters=prediction_head_parameters, + ) + + @pytest.fixture() + def score_network(self, score_network_parameters): + return MACEScoreNetwork(score_network_parameters) + + +class TestEquivarianceEGNN(BaseTestScoreEquivariance): + + @pytest.fixture(params=[("fully_connected", None), ("radial_cutoff", 3.0)]) + def score_network_parameters(self, request, num_atom_types): + edges, radial_cutoff = request.param + return EGNNScoreNetworkParameters( + edges=edges, radial_cutoff=radial_cutoff, num_atom_types=num_atom_types + ) + + @pytest.fixture() + def score_network(self, score_network_parameters): + score_network = EGNNScoreNetwork(score_network_parameters) + return score_network diff --git a/tests/models/score_network/test_score_network.py b/tests/models/score_network/test_score_network_general_tests.py similarity index 51% rename from tests/models/score_network/test_score_network.py rename to tests/models/score_network/test_score_network_general_tests.py index 2465523c..30ddc2b3 100644 --- a/tests/models/score_network/test_score_network.py +++ b/tests/models/score_network/test_score_network_general_tests.py @@ -1,5 +1,3 @@ -import itertools -from copy import deepcopy from dataclasses import asdict, dataclass, fields import einops @@ -15,156 +13,48 @@ MACEScoreNetwork, MACEScoreNetworkParameters) from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.mlp_score_network import ( MLPScoreNetwork, MLPScoreNetworkParameters) -from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.score_network import ( - ScoreNetwork, ScoreNetworkParameters) from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.score_network_factory import \ create_score_network_parameters from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.score_prediction_head import ( MaceEquivariantScorePredictionHeadParameters, MaceMLPScorePredictionHeadParameters) from diffusion_for_multi_scale_molecular_dynamics.namespace import ( - CARTESIAN_FORCES, NOISE, NOISY_RELATIVE_COORDINATES, TIME, UNIT_CELL) -from diffusion_for_multi_scale_molecular_dynamics.utils.basis_transformations import \ - map_relative_coordinates_to_unit_cell + AXL, CARTESIAN_FORCES, NOISE, NOISY_AXL_COMPOSITION, TIME, UNIT_CELL) +from tests.fake_data_utils import generate_random_string +from tests.models.score_network.base_test_score_network import \ + BaseTestScoreNetwork -def assert_parameters_are_the_same(parameters1: dataclass, parameters2: dataclass): - """Compare dataclasses explicitly as a workaround for the potential presence of numpy arrays.""" - assert type(parameters1) is type(parameters2) +class BaseScoreNetworkGeneralTests(BaseTestScoreNetwork): + """Base score network general tests. - for field in fields(parameters1): - value1 = getattr(parameters1, field.name) - value2 = getattr(parameters2, field.name) - - assert type(value1) is type(value2) - - if type(value1) is np.ndarray: - np.testing.assert_array_equal(value1, value2) - else: - assert value1 == value2 - - -@pytest.mark.parametrize("spatial_dimension", [2, 3]) -class TestScoreNetworkCheck: - - @pytest.fixture(scope="class", autouse=True) - def set_random_seed(self): - torch.manual_seed(123) - - @pytest.fixture() - def base_score_network(self, spatial_dimension): - return ScoreNetwork( - ScoreNetworkParameters( - architecture="dummy", spatial_dimension=spatial_dimension - ) - ) - - @pytest.fixture() - def good_batch(self, spatial_dimension): - batch_size = 16 - relative_coordinates = torch.rand(batch_size, 8, spatial_dimension) - times = torch.rand(batch_size, 1) - noises = torch.rand(batch_size, 1) - unit_cell = torch.rand(batch_size, spatial_dimension, spatial_dimension) - return { - NOISY_RELATIVE_COORDINATES: relative_coordinates, - TIME: times, - NOISE: noises, - UNIT_CELL: unit_cell, - } - - @pytest.fixture() - def bad_batch(self, good_batch, problem): - - bad_batch_dict = dict(good_batch) - - match problem: - case "position_name": - bad_batch_dict["bad_position_name"] = bad_batch_dict[ - NOISY_RELATIVE_COORDINATES - ] - del bad_batch_dict[NOISY_RELATIVE_COORDINATES] - - case "position_shape": - shape = bad_batch_dict[NOISY_RELATIVE_COORDINATES].shape - bad_batch_dict[NOISY_RELATIVE_COORDINATES] = bad_batch_dict[ - NOISY_RELATIVE_COORDINATES - ].reshape(shape[0], shape[1] // 2, shape[2] * 2) - - case "position_range1": - bad_batch_dict[NOISY_RELATIVE_COORDINATES][0, 0, 0] = 1.01 - - case "position_range2": - bad_batch_dict[NOISY_RELATIVE_COORDINATES][1, 0, 0] = -0.01 - - case "time_name": - bad_batch_dict["bad_time_name"] = bad_batch_dict[TIME] - del bad_batch_dict[TIME] - - case "time_shape": - shape = bad_batch_dict[TIME].shape - bad_batch_dict[TIME] = bad_batch_dict[TIME].reshape( - shape[0] // 2, shape[1] * 2 - ) - - case "noise_name": - bad_batch_dict["bad_noise_name"] = bad_batch_dict[NOISE] - del bad_batch_dict[NOISE] - - case "noise_shape": - shape = bad_batch_dict[NOISE].shape - bad_batch_dict[NOISE] = bad_batch_dict[NOISE].reshape( - shape[0] // 2, shape[1] * 2 - ) - - case "time_range1": - bad_batch_dict[TIME][5, 0] = 2.00 - case "time_range2": - bad_batch_dict[TIME][0, 0] = -0.05 + Base class to run a battery of tests on a score network. To test a specific score network class, this base class + should be extended by implementing a 'score_network' fixture that instantiates the score network class of interest. + """ - case "cell_name": - bad_batch_dict["bad_unit_cell_key"] = bad_batch_dict[UNIT_CELL] - del bad_batch_dict[UNIT_CELL] + @staticmethod + def assert_parameters_are_the_same(parameters1: dataclass, parameters2: dataclass): + """Compare dataclasses explicitly as a workaround for the potential presence of numpy arrays.""" + assert type(parameters1) is type(parameters2) - case "cell_shape": - shape = bad_batch_dict[UNIT_CELL].shape - bad_batch_dict[UNIT_CELL] = bad_batch_dict[UNIT_CELL].reshape( - shape[0] // 2, shape[1] * 2, shape[2] - ) + for field in fields(parameters1): + value1 = getattr(parameters1, field.name) + value2 = getattr(parameters2, field.name) - return bad_batch_dict + assert type(value1) is type(value2) - def test_check_batch_good(self, base_score_network, good_batch): - base_score_network._check_batch(good_batch) + if type(value1) is np.ndarray: + np.testing.assert_array_equal(value1, value2) + else: + assert value1 == value2 - @pytest.mark.parametrize( - "problem", - [ - "position_name", - "time_name", - "position_shape", - "time_shape", - "noise_name", - "noise_shape", - "position_range1", - "position_range2", - "time_range1", - "time_range2", - "cell_name", - "cell_shape", - ], - ) - def test_check_batch_bad(self, base_score_network, bad_batch): - with pytest.raises(AssertionError): - base_score_network._check_batch(bad_batch) + @pytest.fixture(params=[2, 3, 16]) + def num_atom_types(self, request): + return request.param - -class BaseTestScoreNetwork: - """Base Test Score Network. - - Base class to run a battery of tests on a score network. To test a specific score network class, this base class - should be extended by implementing a 'score_network' fixture that instantiates the score network class of interest. - """ + @pytest.fixture + def unique_elements(self, num_atom_types): + return [generate_random_string(size=3) for _ in range(num_atom_types)] @pytest.fixture() def score_network_parameters(self, *args): @@ -172,24 +62,6 @@ def score_network_parameters(self, *args): "This fixture must be implemented in the derived class." ) - @pytest.fixture() - def score_network(self, *args): - raise NotImplementedError( - "This fixture must be implemented in the derived class." - ) - - @pytest.fixture(scope="class", autouse=True) - def set_random_seed(self): - torch.manual_seed(23423423) - - @pytest.fixture() - def batch_size(self): - return 16 - - @pytest.fixture() - def number_of_atoms(self): - return 8 - @pytest.fixture() def basis_vectors(self, batch_size, spatial_dimension): # orthogonal boxes with dimensions between 5 and 10. @@ -214,6 +86,24 @@ def relative_coordinates( ) return relative_coordinates + @pytest.fixture + def atom_types(self, batch_size, number_of_atoms, num_atom_types): + atom_types = torch.randint(0, num_atom_types + 1, (batch_size, number_of_atoms)) + return atom_types + + @pytest.fixture() + def expected_score_shape( + self, batch_size, number_of_atoms, spatial_dimension, num_atom_types + ): + first_dims = ( + batch_size, + number_of_atoms, + ) + return { + "X": first_dims + (spatial_dimension,), + "A": first_dims + (num_atom_types + 1,), + } + @pytest.fixture def cartesian_forces( self, batch_size, number_of_atoms, spatial_dimension, basis_vectors @@ -230,16 +120,22 @@ def times(self, batch_size): def noises(self, batch_size): return torch.rand(batch_size, 1) - @pytest.fixture() - def expected_score_shape(self, batch_size, number_of_atoms, spatial_dimension): - return batch_size, number_of_atoms, spatial_dimension - @pytest.fixture() def batch( - self, relative_coordinates, cartesian_forces, times, noises, basis_vectors + self, + relative_coordinates, + cartesian_forces, + times, + noises, + basis_vectors, + atom_types, ): return { - NOISY_RELATIVE_COORDINATES: relative_coordinates, + NOISY_AXL_COMPOSITION: AXL( + A=atom_types, + X=relative_coordinates, + L=torch.zeros_like(atom_types), # TODO + ), TIME: times, UNIT_CELL: basis_vectors, NOISE: noises, @@ -247,8 +143,8 @@ def batch( } @pytest.fixture() - def global_parameters_dictionary(self, spatial_dimension): - return dict(spatial_dimension=spatial_dimension, irrelevant=123) + def global_parameters_dictionary(self, spatial_dimension, unique_elements): + return dict(spatial_dimension=spatial_dimension, irrelevant=123, elements=unique_elements) @pytest.fixture() def score_network_dictionary( @@ -262,7 +158,8 @@ def score_network_dictionary( def test_output_shape(self, score_network, batch, expected_score_shape): scores = score_network(batch) - assert scores.shape == expected_score_shape + assert scores.X.shape == expected_score_shape["X"] + assert scores.A.shape == expected_score_shape["A"] def test_create_score_network_parameters( self, @@ -273,22 +170,43 @@ def test_create_score_network_parameters( computed_score_network_parameters = create_score_network_parameters( score_network_dictionary, global_parameters_dictionary ) - assert_parameters_are_the_same( + self.assert_parameters_are_the_same( computed_score_network_parameters, score_network_parameters ) + def test_consistent_output(self, batch, score_network): + # apply twice on the same input, get the same answer? + with torch.no_grad(): + output1 = score_network(batch) + output2 = score_network(batch) + + torch.testing.assert_close(output1, output2) + + def test_time_dependence(self, batch, score_network): + # Different times, different results? + new_time_batch = dict(batch) + new_time_batch[TIME] = torch.rand(batch[TIME].shape) + new_time_batch[NOISE] = torch.rand(batch[NOISE].shape) + with torch.no_grad(): + output1 = score_network(batch) + output2 = score_network(new_time_batch) + + with pytest.raises(AssertionError): + torch.testing.assert_close(output1, output2) + @pytest.mark.parametrize("spatial_dimension", [2, 3]) @pytest.mark.parametrize("n_hidden_dimensions", [1, 2, 3]) @pytest.mark.parametrize("hidden_dimensions_size", [8, 16]) @pytest.mark.parametrize("embedding_dimensions_size", [4, 12]) -class TestMLPScoreNetwork(BaseTestScoreNetwork): +class TestMLPScoreNetwork(BaseScoreNetworkGeneralTests): @pytest.fixture() def score_network_parameters( self, number_of_atoms, spatial_dimension, + num_atom_types, embedding_dimensions_size, n_hidden_dimensions, hidden_dimensions_size, @@ -296,7 +214,10 @@ def score_network_parameters( return MLPScoreNetworkParameters( spatial_dimension=spatial_dimension, number_of_atoms=number_of_atoms, - embedding_dimensions_size=embedding_dimensions_size, + num_atom_types=num_atom_types, + noise_embedding_dimensions_size=embedding_dimensions_size, + time_embedding_dimensions_size=embedding_dimensions_size, + atom_type_embedding_dimensions_size=embedding_dimensions_size, n_hidden_dimensions=n_hidden_dimensions, hidden_dimensions_size=hidden_dimensions_size, ) @@ -306,10 +227,9 @@ def score_network(self, score_network_parameters): return MLPScoreNetwork(score_network_parameters) -@pytest.mark.parametrize("spatial_dimension", [3]) -@pytest.mark.parametrize("n_hidden_dimensions", [1, 2, 3]) -@pytest.mark.parametrize("hidden_dimensions_size", [8, 16]) -class TestMACEScoreNetworkMLPHead(BaseTestScoreNetwork): +@pytest.mark.parametrize("n_hidden_dimensions", [2]) +@pytest.mark.parametrize("hidden_dimensions_size", [8]) +class TestMACEScoreNetworkMLPHead(BaseScoreNetworkGeneralTests): @pytest.fixture() def prediction_head_parameters( @@ -324,11 +244,16 @@ def prediction_head_parameters( @pytest.fixture() def score_network_parameters( - self, number_of_atoms, spatial_dimension, prediction_head_parameters + self, + number_of_atoms, + spatial_dimension, + num_atom_types, + prediction_head_parameters, ): return MACEScoreNetworkParameters( spatial_dimension=spatial_dimension, number_of_atoms=number_of_atoms, + num_atom_types=num_atom_types, r_max=3.0, prediction_head_parameters=prediction_head_parameters, ) @@ -339,7 +264,7 @@ def score_network(self, score_network_parameters): @pytest.mark.parametrize("spatial_dimension", [3]) -class TestMACEScoreNetworkEquivariantHead(BaseTestScoreNetwork): +class TestMACEScoreNetworkEquivariantHead(BaseScoreNetworkGeneralTests): @pytest.fixture() def prediction_head_parameters(self, spatial_dimension): prediction_head_parameters = MaceEquivariantScorePredictionHeadParameters( @@ -349,11 +274,16 @@ def prediction_head_parameters(self, spatial_dimension): @pytest.fixture() def score_network_parameters( - self, number_of_atoms, spatial_dimension, prediction_head_parameters + self, + number_of_atoms, + spatial_dimension, + num_atom_types, + prediction_head_parameters, ): return MACEScoreNetworkParameters( spatial_dimension=spatial_dimension, number_of_atoms=number_of_atoms, + num_atom_types=num_atom_types, r_max=3.0, prediction_head_parameters=prediction_head_parameters, ) @@ -363,13 +293,15 @@ def score_network(self, score_network_parameters): return MACEScoreNetwork(score_network_parameters) -@pytest.mark.parametrize("spatial_dimension", [3]) -class TestDiffusionMACEScoreNetwork(BaseTestScoreNetwork): +class TestDiffusionMACEScoreNetwork(BaseScoreNetworkGeneralTests): @pytest.fixture() - def score_network_parameters(self, number_of_atoms, spatial_dimension): + def score_network_parameters( + self, number_of_atoms, num_atom_types, spatial_dimension + ): return DiffusionMACEScoreNetworkParameters( spatial_dimension=spatial_dimension, number_of_atoms=number_of_atoms, + num_atom_types=num_atom_types, r_max=3.0, num_bessel=4, num_polynomial_cutoff=3, @@ -385,68 +317,26 @@ def score_network(self, score_network_parameters): return DiffusionMACEScoreNetwork(score_network_parameters) -class TestEGNNScoreNetwork(BaseTestScoreNetwork): - - @pytest.fixture(scope="class", autouse=True) - def set_default_type_to_float64(self): - # Set the default type to float64 to make sure the tests are stringent. - torch.set_default_dtype(torch.float64) - yield - # this returns the default type to float32 at the end of all tests in this class in order - # to not affect other tests. - torch.set_default_dtype(torch.float32) - - @pytest.fixture() - def spatial_dimension(self): - return 3 - - @pytest.fixture() - def basis_vectors(self, batch_size, spatial_dimension): - # The basis vectors should form a cube in order to test the equivariance of the current implementation - # of the EGNN model. The octaheral point group only applies in this case! - acell = 5.5 - cubes = torch.stack( - [ - torch.diag(acell * torch.ones(spatial_dimension)) - for _ in range(batch_size) - ] - ) - return cubes +class TestEGNNScoreNetwork(BaseScoreNetworkGeneralTests): @pytest.fixture(params=[("fully_connected", None), ("radial_cutoff", 3.0)]) - def score_network_parameters(self, request): + def score_network_parameters(self, request, num_atom_types): edges, radial_cutoff = request.param - return EGNNScoreNetworkParameters(edges=edges, radial_cutoff=radial_cutoff) + return EGNNScoreNetworkParameters( + edges=edges, radial_cutoff=radial_cutoff, num_atom_types=num_atom_types + ) @pytest.fixture() def score_network(self, score_network_parameters): score_network = EGNNScoreNetwork(score_network_parameters) return score_network - @pytest.fixture() - def octahedral_point_group_symmetries(self): - permutations = [ - torch.diag(torch.ones(3))[[idx]] - for idx in itertools.permutations([0, 1, 2]) - ] - sign_changes = [ - torch.diag(torch.tensor(diag)) - for diag in itertools.product([-1.0, 1.0], repeat=3) - ] - - symmetries = [] - for permutation in permutations: - for sign_change in sign_changes: - symmetries.append(permutation @ sign_change) - - return symmetries - @pytest.mark.parametrize( "edges, radial_cutoff", [("fully_connected", 3.0), ("radial_cutoff", None)] ) - def test_score_network_parameters(self, edges, radial_cutoff): + def test_score_network_parameters(self, edges, radial_cutoff, num_atom_types): score_network_parameters = EGNNScoreNetworkParameters( - edges=edges, radial_cutoff=radial_cutoff + edges=edges, radial_cutoff=radial_cutoff, num_atom_types=num_atom_types ) with pytest.raises(AssertionError): # Check that the code crashes when inconsistent parameters are fed in. @@ -472,7 +362,7 @@ def test_create_block_diagonal_projection_matrices( @pytest.fixture() def flat_relative_coordinates(self, batch): - relative_coordinates = batch[NOISY_RELATIVE_COORDINATES] + relative_coordinates = batch[NOISY_AXL_COMPOSITION].X flat_relative_coordinates = einops.rearrange( relative_coordinates, "batch natom space -> (batch natom) space" ) @@ -503,42 +393,3 @@ def test_get_euclidean_positions( torch.testing.assert_close( expected_euclidean_positions, computed_euclidean_positions ) - - @pytest.fixture() - def global_translations(self, batch_size, number_of_atoms, spatial_dimension): - translations = einops.repeat( - torch.rand(batch_size, spatial_dimension), - "batch spatial_dimension -> batch natoms spatial_dimension", - natoms=number_of_atoms, - ) - return translations - - def test_equivariance( - self, - score_network, - batch, - octahedral_point_group_symmetries, - global_translations, - ): - with torch.no_grad(): - normalized_scores = score_network(batch) - - for point_group_symmetry in octahedral_point_group_symmetries: - op = point_group_symmetry.transpose(1, 0) - modified_batch = deepcopy(batch) - relative_coordinates = modified_batch[NOISY_RELATIVE_COORDINATES] - - op_relative_coordinates = relative_coordinates @ op + global_translations - op_relative_coordinates = map_relative_coordinates_to_unit_cell( - op_relative_coordinates - ) - - modified_batch[NOISY_RELATIVE_COORDINATES] = op_relative_coordinates - with torch.no_grad(): - modified_normalized_scores = score_network(modified_batch) - - expected_modified_normalized_scores = normalized_scores @ op - - torch.testing.assert_close( - expected_modified_normalized_scores, modified_normalized_scores - ) diff --git a/tests/models/test_position_diffusion_lightning_model.py b/tests/models/test_axl_diffusion_lightning_model.py similarity index 72% rename from tests/models/test_position_diffusion_lightning_model.py rename to tests/models/test_axl_diffusion_lightning_model.py index 06314532..7120a8dd 100644 --- a/tests/models/test_position_diffusion_lightning_model.py +++ b/tests/models/test_axl_diffusion_lightning_model.py @@ -1,32 +1,56 @@ +from dataclasses import dataclass + +import numpy as np import pytest import torch from pytorch_lightning import LightningDataModule, Trainer from torch.utils.data import DataLoader, random_split -from diffusion_for_multi_scale_molecular_dynamics.generators.predictor_corrector_position_generator import \ +from diffusion_for_multi_scale_molecular_dynamics.generators.predictor_corrector_axl_generator import \ PredictorCorrectorSamplingParameters +from diffusion_for_multi_scale_molecular_dynamics.loss.loss_parameters import \ + create_loss_parameters from diffusion_for_multi_scale_molecular_dynamics.metrics.sampling_metrics_parameters import \ SamplingMetricsParameters -from diffusion_for_multi_scale_molecular_dynamics.models.loss import \ - create_loss_parameters +from diffusion_for_multi_scale_molecular_dynamics.models.axl_diffusion_lightning_model import ( + AXLDiffusionLightningModel, AXLDiffusionParameters) from diffusion_for_multi_scale_molecular_dynamics.models.optimizer import \ OptimizerParameters -from diffusion_for_multi_scale_molecular_dynamics.models.position_diffusion_lightning_model import ( - PositionDiffusionLightningModel, PositionDiffusionParameters) from diffusion_for_multi_scale_molecular_dynamics.models.scheduler import ( CosineAnnealingLRSchedulerParameters, ReduceLROnPlateauSchedulerParameters) from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.mlp_score_network import \ MLPScoreNetworkParameters from diffusion_for_multi_scale_molecular_dynamics.namespace import ( - CARTESIAN_FORCES, RELATIVE_COORDINATES) -from diffusion_for_multi_scale_molecular_dynamics.samples.diffusion_sampling_parameters import \ + ATOM_TYPES, AXL_COMPOSITION, CARTESIAN_FORCES, RELATIVE_COORDINATES) +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ + NoiseParameters +from diffusion_for_multi_scale_molecular_dynamics.oracle.energy_oracle import ( + EnergyOracle, OracleParameters) +from diffusion_for_multi_scale_molecular_dynamics.oracle.energy_oracle_factory import ( + ENERGY_ORACLE_BY_NAME, ORACLE_PARAMETERS_BY_NAME) +from diffusion_for_multi_scale_molecular_dynamics.sampling.diffusion_sampling_parameters import \ DiffusionSamplingParameters from diffusion_for_multi_scale_molecular_dynamics.score.wrapped_gaussian_score import \ get_sigma_normalized_score_brute_force from diffusion_for_multi_scale_molecular_dynamics.utils.tensor_utils import \ broadcast_batch_tensor_to_all_dimensions -from src.diffusion_for_multi_scale_molecular_dynamics.samplers.variance_sampler import \ - NoiseParameters +from tests.fake_data_utils import generate_random_string + + +@dataclass(kw_only=True) +class FakeOracleParameters(OracleParameters): + name = "test" + + +class FakeEnergyOracle(EnergyOracle): + + def _compute_one_configuration_energy( + self, + cartesian_positions: np.ndarray, + basis_vectors: np.ndarray, + atom_types: np.ndarray, + ) -> float: + return np.random.rand() class FakePositionsDataModule(LightningDataModule): @@ -36,20 +60,29 @@ def __init__( dataset_size: int = 33, number_of_atoms: int = 8, spatial_dimension: int = 2, + num_atom_types: int = 2, ): super().__init__() self.batch_size = batch_size all_relative_coordinates = torch.rand( dataset_size, number_of_atoms, spatial_dimension ) + potential_energies = torch.rand(dataset_size) + all_atom_types = torch.randint( + 0, num_atom_types, (dataset_size, number_of_atoms) + ) box = torch.rand(spatial_dimension) self.data = [ { - RELATIVE_COORDINATES: configuration, + RELATIVE_COORDINATES: coordinate_configuration, + ATOM_TYPES: atom_configuration, "box": box, - CARTESIAN_FORCES: torch.zeros_like(configuration), + CARTESIAN_FORCES: torch.zeros_like(coordinate_configuration), + "potential_energy": potential_energy } - for configuration in all_relative_coordinates + for coordinate_configuration, atom_configuration, potential_energy in zip( + all_relative_coordinates, all_atom_types, potential_energies + ) ] self.train_data, self.val_data, self.test_data = None, None, None @@ -72,7 +105,7 @@ def test_dataloader(self): class TestPositionDiffusionLightningModel: @pytest.fixture(scope="class", autouse=True) def set_random_seed(self): - torch.manual_seed(2345234) + torch.manual_seed(234523) @pytest.fixture() def batch_size(self): @@ -82,6 +115,14 @@ def batch_size(self): def number_of_atoms(self): return 8 + @pytest.fixture() + def num_atom_types(self): + return 4 + + @pytest.fixture + def unique_elements(self, num_atom_types): + return [generate_random_string(size=8) for _ in range(num_atom_types)] + @pytest.fixture() def unit_cell_size(self): return 10.1 @@ -112,7 +153,7 @@ def scheduler_parameters(self, request): @pytest.fixture(params=["mse", "weighted_mse"]) def loss_parameters(self, request): - model_dict = dict(loss=dict(algorithm=request.param)) + model_dict = dict(loss=dict(coordinates_algorithm=request.param)) return create_loss_parameters(model_dictionary=model_dict) @pytest.fixture() @@ -125,13 +166,19 @@ def cell_dimensions(self, unit_cell_size, spatial_dimension): @pytest.fixture() def sampling_parameters( - self, number_of_atoms, spatial_dimension, number_of_samples, cell_dimensions + self, + number_of_atoms, + spatial_dimension, + number_of_samples, + cell_dimensions, + num_atom_types, ): sampling_parameters = PredictorCorrectorSamplingParameters( number_of_atoms=number_of_atoms, spatial_dimension=spatial_dimension, number_of_samples=number_of_samples, cell_dimensions=cell_dimensions, + num_atom_types=num_atom_types, ) return sampling_parameters @@ -139,8 +186,9 @@ def sampling_parameters( def diffusion_sampling_parameters(self, sampling_parameters): noise_parameters = NoiseParameters(total_time_steps=5) metrics_parameters = SamplingMetricsParameters( - structure_factor_max_distance=1.0 - ) + structure_factor_max_distance=1.0, + compute_energies=True, + compute_structure_factor=False) diffusion_sampling_parameters = DiffusionSamplingParameters( sampling_parameters=sampling_parameters, noise_parameters=noise_parameters, @@ -152,6 +200,8 @@ def diffusion_sampling_parameters(self, sampling_parameters): def hyper_params( self, number_of_atoms, + unique_elements, + num_atom_types, spatial_dimension, optimizer_parameters, scheduler_parameters, @@ -161,21 +211,27 @@ def hyper_params( ): score_network_parameters = MLPScoreNetworkParameters( number_of_atoms=number_of_atoms, + num_atom_types=num_atom_types, n_hidden_dimensions=3, - embedding_dimensions_size=8, + noise_embedding_dimensions_size=8, + time_embedding_dimensions_size=8, + atom_type_embedding_dimensions_size=8, hidden_dimensions_size=8, spatial_dimension=spatial_dimension, ) noise_parameters = NoiseParameters(total_time_steps=15) - hyper_params = PositionDiffusionParameters( + oracle_parameters = OracleParameters(name='test', elements=unique_elements) + + hyper_params = AXLDiffusionParameters( score_network_parameters=score_network_parameters, optimizer_parameters=optimizer_parameters, scheduler_parameters=scheduler_parameters, noise_parameters=noise_parameters, loss_parameters=loss_parameters, diffusion_sampling_parameters=diffusion_sampling_parameters, + oracle_parameters=oracle_parameters ) return hyper_params @@ -196,11 +252,14 @@ def noisy_relative_coordinates( return noisy_relative_coordinates @pytest.fixture() - def fake_datamodule(self, batch_size, number_of_atoms, spatial_dimension): + def fake_datamodule( + self, batch_size, number_of_atoms, spatial_dimension, num_atom_types + ): data_module = FakePositionsDataModule( batch_size=batch_size, number_of_atoms=number_of_atoms, spatial_dimension=spatial_dimension, + num_atom_types=num_atom_types, ) return data_module @@ -218,8 +277,14 @@ def sigmas(self, batch_size, number_of_atoms, spatial_dimension): return sigmas @pytest.fixture() - def lightning_model(self, hyper_params): - lightning_model = PositionDiffusionLightningModel(hyper_params) + def lightning_model(self, mocker, hyper_params): + fake_oracle_parameters_by_name = dict(test=FakeOracleParameters) + fake_energy_oracle_by_name = dict(test=FakeEnergyOracle) + + mocker.patch.dict(ORACLE_PARAMETERS_BY_NAME, fake_oracle_parameters_by_name) + mocker.patch.dict(ENERGY_ORACLE_BY_NAME, fake_energy_oracle_by_name) + + lightning_model = AXLDiffusionLightningModel(hyper_params) return lightning_model @pytest.fixture() @@ -270,7 +335,7 @@ def test_get_target_normalized_score( unit_cell_sample, ): computed_target_normalized_scores = ( - lightning_model._get_target_normalized_score( + lightning_model._get_coordinates_target_normalized_score( noisy_relative_coordinates, real_relative_coordinates, sigmas ) ) @@ -291,8 +356,12 @@ def test_generate_sample( self, lightning_model, number_of_samples, number_of_atoms, spatial_dimension ): samples_batch = lightning_model.generate_samples() - assert samples_batch[RELATIVE_COORDINATES].shape == ( + assert samples_batch[AXL_COMPOSITION].X.shape == ( number_of_samples, number_of_atoms, spatial_dimension, ) + assert samples_batch[AXL_COMPOSITION].A.shape == ( + number_of_samples, + number_of_atoms, + ) diff --git a/tests/models/test_diffusion_mace.py b/tests/models/test_diffusion_mace.py deleted file mode 100644 index 4f43f356..00000000 --- a/tests/models/test_diffusion_mace.py +++ /dev/null @@ -1,377 +0,0 @@ -import pytest -import torch -from e3nn import o3 -from mace.modules import gate_dict, interaction_classes - -from diffusion_for_multi_scale_molecular_dynamics.models.diffusion_mace import ( - DiffusionMACE, LinearVectorReadoutBlock, input_to_diffusion_mace) -from diffusion_for_multi_scale_molecular_dynamics.namespace import ( - CARTESIAN_FORCES, NOISE, NOISY_CARTESIAN_POSITIONS, - NOISY_RELATIVE_COORDINATES, TIME, UNIT_CELL) -from diffusion_for_multi_scale_molecular_dynamics.utils.basis_transformations import ( - get_positions_from_coordinates, get_reciprocal_basis_vectors, - get_relative_coordinates_from_cartesian_positions, - map_relative_coordinates_to_unit_cell) - - -def test_linear_vector_readout_block(): - - batch_size = 10 - vector_output_dimension = 3 - irreps_in = o3.Irreps("16x0e + 12x1o + 14x2e") - - vector_readout = LinearVectorReadoutBlock(irreps_in) - - input_features = irreps_in.randn(batch_size, -1) - - output_features = vector_readout(input_features) - - assert output_features.shape == (batch_size, vector_output_dimension) - - -class TestDiffusionMace: - @pytest.fixture(scope="class", autouse=True) - def set_default_type_to_float64(self): - torch.set_default_dtype(torch.float64) - yield - # this returns the default type to float32 at the end of all tests in this class in order - # to not affect other tests. - torch.set_default_dtype(torch.float32) - - @pytest.fixture(scope="class", autouse=True) - def set_seed(self): - """Set the random seed.""" - torch.manual_seed(234233) - - @pytest.fixture(scope="class") - def batch_size(self): - return 4 - - @pytest.fixture(scope="class") - def number_of_atoms(self): - return 8 - - @pytest.fixture(scope="class") - def spatial_dimension(self): - return 3 - - @pytest.fixture(scope="class") - def basis_vectors(self, batch_size, spatial_dimension): - # orthogonal boxes with dimensions between 5 and 10. - orthogonal_boxes = torch.stack( - [ - torch.diag(5.0 + 5.0 * torch.rand(spatial_dimension)) - for _ in range(batch_size) - ] - ) - # add a bit of noise to make the vectors not quite orthogonal - basis_vectors = orthogonal_boxes + 0.1 * torch.randn( - batch_size, spatial_dimension, spatial_dimension - ) - return basis_vectors - - @pytest.fixture(scope="class") - def reciprocal_basis_vectors(self, basis_vectors): - return get_reciprocal_basis_vectors(basis_vectors) - - @pytest.fixture(scope="class") - def relative_coordinates(self, batch_size, number_of_atoms, spatial_dimension): - relative_coordinates = torch.rand( - batch_size, number_of_atoms, spatial_dimension - ) - return relative_coordinates - - @pytest.fixture(scope="class") - def cartesian_positions(self, relative_coordinates, basis_vectors): - return get_positions_from_coordinates(relative_coordinates, basis_vectors) - - @pytest.fixture(scope="class") - def times(self, batch_size): - return torch.rand(batch_size, 1) - - @pytest.fixture(scope="class") - def noises(self, batch_size): - return 0.5 * torch.rand(batch_size, 1) - - @pytest.fixture(scope="class") - def forces(self, batch_size, spatial_dimension): - return 0.5 * torch.rand(batch_size, spatial_dimension) - - @pytest.fixture(scope="class") - def batch( - self, - relative_coordinates, - cartesian_positions, - basis_vectors, - times, - noises, - forces, - ): - batch = { - NOISY_RELATIVE_COORDINATES: relative_coordinates, - NOISY_CARTESIAN_POSITIONS: cartesian_positions, - TIME: times, - NOISE: noises, - UNIT_CELL: basis_vectors, - CARTESIAN_FORCES: forces, - } - return batch - - @pytest.fixture(scope="class") - def cartesian_rotations(self, batch_size): - return o3.rand_matrix(batch_size) - - @pytest.fixture(scope="class") - def permutations(self, batch_size, number_of_atoms): - return torch.stack([torch.randperm(number_of_atoms) for _ in range(batch_size)]) - - @pytest.fixture(scope="class") - def cartesian_translations( - self, batch_size, number_of_atoms, spatial_dimension, basis_vectors - ): - batch_relative_coordinates_translations = torch.rand( - batch_size, spatial_dimension - ) - - batch_cartesian_translations = [] - for t, cell in zip(batch_relative_coordinates_translations, basis_vectors): - batch_cartesian_translations.append(t @ cell) - - batch_cartesian_translations = torch.stack(batch_cartesian_translations) - - cartesian_translations = torch.repeat_interleave( - batch_cartesian_translations.unsqueeze(1), number_of_atoms, dim=1 - ) - return cartesian_translations - - @pytest.fixture() - def r_max(self): - return 3.0 - - @pytest.fixture() - def hyperparameters(self, r_max): - - hps = dict( - r_max=r_max, - num_bessel=8, - num_polynomial_cutoff=5, - num_edge_hidden_layers=0, - edge_hidden_irreps=o3.Irreps("8x0e"), - max_ell=2, - num_elements=1, - atomic_numbers=[14], - interaction_cls=interaction_classes["RealAgnosticResidualInteractionBlock"], - interaction_cls_first=interaction_classes["RealAgnosticInteractionBlock"], - num_interactions=2, - hidden_irreps=o3.Irreps("8x0e + 8x1o + 8x2e"), - mlp_irreps=o3.Irreps("8x0e"), - number_of_mlp_layers=2, - avg_num_neighbors=1, - correlation=2, - gate=gate_dict["silu"], - radial_MLP=[8, 8, 8], - radial_type="bessel", - ) - return hps - - @pytest.fixture() - def diffusion_mace(self, hyperparameters): - diffusion_mace = DiffusionMACE(**hyperparameters) - diffusion_mace.eval() - return diffusion_mace - - @pytest.fixture() - def graph_input(self, batch, r_max): - return input_to_diffusion_mace(batch, radial_cutoff=r_max) - - @pytest.fixture() - def cartesian_scores( - self, - graph_input, - diffusion_mace, - batch_size, - number_of_atoms, - spatial_dimension, - ): - flat_cartesian_scores = diffusion_mace(graph_input) - return flat_cartesian_scores.reshape( - batch_size, number_of_atoms, spatial_dimension - ) - - @pytest.fixture() - def translated_graph_input( - self, - batch, - r_max, - basis_vectors, - reciprocal_basis_vectors, - cartesian_translations, - ): - - translated_batch = dict(batch) - - original_cartesian_positions = translated_batch[NOISY_CARTESIAN_POSITIONS] - translated_cartesian_positions = ( - original_cartesian_positions + cartesian_translations - ) - - rel_coords = get_relative_coordinates_from_cartesian_positions( - translated_cartesian_positions, reciprocal_basis_vectors - ) - new_relative_coordinates = map_relative_coordinates_to_unit_cell(rel_coords) - - new_cartesian_positions = get_positions_from_coordinates( - new_relative_coordinates, basis_vectors - ) - - translated_batch[NOISY_CARTESIAN_POSITIONS] = new_cartesian_positions - translated_batch[NOISY_RELATIVE_COORDINATES] = new_relative_coordinates - - return input_to_diffusion_mace(translated_batch, radial_cutoff=r_max) - - @pytest.fixture() - def translated_cartesian_scores( - self, - diffusion_mace, - batch_size, - number_of_atoms, - spatial_dimension, - basis_vectors, - translated_graph_input, - ): - flat_translated_cartesian_scores = diffusion_mace(translated_graph_input) - return flat_translated_cartesian_scores.reshape( - batch_size, number_of_atoms, spatial_dimension - ) - - @pytest.fixture() - def rotated_graph_input( - self, batch, r_max, basis_vectors, reciprocal_basis_vectors, cartesian_rotations - ): - rotated_batch = dict(batch) - - original_cartesian_positions = rotated_batch[NOISY_CARTESIAN_POSITIONS] - original_basis_vectors = rotated_batch[UNIT_CELL] - - rotated_cartesian_positions = torch.matmul( - original_cartesian_positions, cartesian_rotations.transpose(2, 1) - ) - - rotated_basis_vectors = torch.matmul( - original_basis_vectors, cartesian_rotations.transpose(2, 1) - ) - rotated_reciprocal_basis_vectors = get_reciprocal_basis_vectors( - rotated_basis_vectors - ) - - rel_coords = get_relative_coordinates_from_cartesian_positions( - rotated_cartesian_positions, rotated_reciprocal_basis_vectors - ) - new_relative_coordinates = map_relative_coordinates_to_unit_cell(rel_coords) - new_cartesian_positions = get_positions_from_coordinates( - new_relative_coordinates, rotated_basis_vectors - ) - - rotated_batch[NOISY_CARTESIAN_POSITIONS] = new_cartesian_positions - rotated_batch[NOISY_RELATIVE_COORDINATES] = new_relative_coordinates - rotated_batch[UNIT_CELL] = rotated_basis_vectors - - return input_to_diffusion_mace(rotated_batch, radial_cutoff=r_max) - - @pytest.fixture() - def rotated_cartesian_scores( - self, - diffusion_mace, - batch_size, - number_of_atoms, - spatial_dimension, - rotated_graph_input, - ): - flat_rotated_cartesian_scores = diffusion_mace(rotated_graph_input) - return flat_rotated_cartesian_scores.reshape( - batch_size, number_of_atoms, spatial_dimension - ) - - @pytest.fixture() - def permuted_graph_input(self, batch_size, batch, r_max, permutations): - permuted_batch = dict(batch) - - for position_key in [NOISY_CARTESIAN_POSITIONS, NOISY_RELATIVE_COORDINATES]: - pos = permuted_batch[position_key] - permuted_pos = torch.stack( - [ - pos[batch_idx, permutations[batch_idx], :] - for batch_idx in range(batch_size) - ] - ) - permuted_batch[position_key] = permuted_pos - - return input_to_diffusion_mace(permuted_batch, radial_cutoff=r_max) - - @pytest.fixture() - def permuted_cartesian_scores( - self, - diffusion_mace, - batch_size, - number_of_atoms, - spatial_dimension, - permuted_graph_input, - ): - flat_permuted_cartesian_scores = diffusion_mace(permuted_graph_input) - return flat_permuted_cartesian_scores.reshape( - batch_size, number_of_atoms, spatial_dimension - ) - - def test_translation_invariance( - self, cartesian_scores, translated_cartesian_scores - ): - torch.testing.assert_close(translated_cartesian_scores, cartesian_scores) - - def test_rotation_equivariance( - self, cartesian_scores, rotated_cartesian_scores, cartesian_rotations - ): - vector_irreps = o3.Irreps("1o") - d_matrices = vector_irreps.D_from_matrix(cartesian_rotations) - - expected_rotated_cartesian_scores = torch.matmul( - cartesian_scores, d_matrices.transpose(2, 1) - ) - torch.testing.assert_close( - expected_rotated_cartesian_scores, rotated_cartesian_scores - ) - - def test_permutation_equivariance( - self, cartesian_scores, permuted_cartesian_scores, batch_size, permutations - ): - - expected_permuted_cartesian_scores = torch.stack( - [ - cartesian_scores[batch_idx, permutations[batch_idx], :] - for batch_idx in range(batch_size) - ] - ) - - torch.testing.assert_close( - expected_permuted_cartesian_scores, permuted_cartesian_scores - ) - - def test_time_dependence(self, batch, r_max, diffusion_mace): - - graph_input = input_to_diffusion_mace(batch, radial_cutoff=r_max) - flat_cartesian_scores1 = diffusion_mace(graph_input) - flat_cartesian_scores2 = diffusion_mace(graph_input) - - # apply twice on the same input, get the same answer? - torch.testing.assert_close(flat_cartesian_scores1, flat_cartesian_scores2) - - new_time_batch = dict(batch) - new_time_batch[TIME] = torch.rand(batch[TIME].shape) - new_time_batch[NOISE] = torch.rand(batch[NOISE].shape) - new_graph_input = input_to_diffusion_mace(new_time_batch, radial_cutoff=r_max) - new_flat_cartesian_scores = diffusion_mace(new_graph_input) - - # Different times, different results? - with pytest.raises(AssertionError): - torch.testing.assert_close( - new_flat_cartesian_scores, flat_cartesian_scores1 - ) diff --git a/tests/models/test_egcl.py b/tests/models/test_egcl.py new file mode 100644 index 00000000..642ebe23 --- /dev/null +++ b/tests/models/test_egcl.py @@ -0,0 +1,59 @@ +import pytest +import torch + +from diffusion_for_multi_scale_molecular_dynamics.models.egnn import E_GCL + + +class TestEGCL: + + @pytest.fixture(scope="class") + def spatial_dimension(self): + return 3 + + @pytest.fixture(scope="class") + def node_features_size(self): + return 5 + + @pytest.fixture(scope="class") + def egcl_hyperparameters(self, node_features_size): + hps = dict( + input_size=node_features_size, + message_n_hidden_dimensions=1, + message_hidden_dimensions_size=4, + node_n_hidden_dimensions=1, + node_hidden_dimensions_size=4, + coordinate_n_hidden_dimensions=1, + coordinate_hidden_dimensions_size=4, + output_size=node_features_size) + return hps + + @pytest.fixture() + def egcl(self, egcl_hyperparameters): + model = E_GCL(**egcl_hyperparameters) + model.eval() + return model + + @pytest.fixture(scope="class") + def single_edge(self): + return torch.Tensor([1, 0]).unsqueeze(0).long() + + @pytest.fixture(scope="class") + def fixed_distance(self): + return 0.4 + + @pytest.fixture(scope="class") + def simple_pair_coord(self, fixed_distance, spatial_dimension): + coord = torch.zeros(2, spatial_dimension) + coord[1, 0] = fixed_distance + return coord + + def test_egcl_coord2radial( + self, single_edge, fixed_distance, simple_pair_coord, egcl + ): + computed_distance_squared, computed_displacement = egcl.coord2radial( + single_edge, simple_pair_coord + ) + torch.testing.assert_close(computed_distance_squared.item(), fixed_distance**2) + torch.testing.assert_close( + computed_displacement, simple_pair_coord[1, :].unsqueeze(0) + ) diff --git a/tests/models/test_egnn.py b/tests/models/test_egnn.py deleted file mode 100644 index be93f0d4..00000000 --- a/tests/models/test_egnn.py +++ /dev/null @@ -1,262 +0,0 @@ -import math -from copy import copy - -import pytest -import torch - -from diffusion_for_multi_scale_molecular_dynamics.models.egnn import (E_GCL, - EGNN) - - -class TestEGNN: - @pytest.fixture(scope="class", autouse=True) - def set_default_type_to_float64(self): - """Set the random seed.""" - torch.set_default_dtype(torch.float64) - yield - # this returns the default type to float32 at the end of all tests in this class in order - # to not affect other tests. - torch.set_default_dtype(torch.float32) - - @pytest.fixture(scope="class", autouse=True) - def set_seed(self): - """Set the random seed.""" - torch.manual_seed(234233) - - @pytest.fixture(scope="class") - def batch_size(self): - return 4 - - @pytest.fixture(scope="class") - def number_of_atoms(self): - return 8 - - @pytest.fixture(scope="class") - def spatial_dimension(self): - return 3 - - @pytest.fixture(scope="class") - def relative_coordinates(self, batch_size, number_of_atoms, spatial_dimension): - relative_coordinates = torch.rand( - batch_size, number_of_atoms, spatial_dimension - ) - return relative_coordinates - - @pytest.fixture(scope="class") - def node_features_size(self): - return 5 - - @pytest.fixture(scope="class") - def node_features(self, batch_size, number_of_atoms, node_features_size): - node_features = torch.randn(batch_size, number_of_atoms, node_features_size) - return node_features - - @pytest.fixture(scope="class") - def num_edges(self, number_of_atoms): - return math.floor(number_of_atoms * 1.5) - - @pytest.fixture(scope="class") - def edges(self, batch_size, number_of_atoms, num_edges): - all_edges = [] - for b in range(batch_size): - batch_edges = torch.Tensor( - [(i, j) for i in range(number_of_atoms) for j in range(number_of_atoms)] - ) - # select num_edges randomly - indices = torch.randperm(len(batch_edges)) - shuffled_edges = batch_edges[indices] + b * number_of_atoms - all_edges.append(shuffled_edges[:num_edges]) - return torch.cat(all_edges, dim=0).long() - - @pytest.fixture(scope="class") - def batch( - self, relative_coordinates, node_features, edges, batch_size, number_of_atoms - ): - batch = { - "coord": relative_coordinates.view(batch_size * number_of_atoms, -1), - "node_features": node_features.view(batch_size * number_of_atoms, -1), - "edges": edges, - } - return batch - - @pytest.fixture(scope="class") - def generic_hyperparameters(self, node_features_size): - hps = dict( - input_size=node_features_size, - message_n_hidden_dimensions=1, - message_hidden_dimensions_size=4, - node_n_hidden_dimensions=1, - node_hidden_dimensions_size=4, - coordinate_n_hidden_dimensions=1, - coordinate_hidden_dimensions_size=4, - ) - return hps - - @pytest.fixture() - def egnn_hyperparameters(self, generic_hyperparameters): - hps = copy(generic_hyperparameters) - hps["n_layers"] = 2 - return hps - - @pytest.fixture() - def egcl_hyperparameters(self, generic_hyperparameters, node_features_size): - hps = copy(generic_hyperparameters) - hps["output_size"] = node_features_size - return hps - - @pytest.fixture() - def egcl(self, egcl_hyperparameters): - model = E_GCL(**egcl_hyperparameters) - model.eval() - return model - - @pytest.fixture() - def egnn(self, egnn_hyperparameters): - model = EGNN(**egnn_hyperparameters) - model.eval() - return model - - @pytest.fixture() - def egnn_scores(self, batch, egnn, batch_size, number_of_atoms, spatial_dimension): - egnn_scores = egnn(batch["node_features"], batch["edges"], batch["coord"]) - return egnn_scores.reshape(batch_size, number_of_atoms, spatial_dimension) - - @pytest.fixture() - def egcl_scores(self, batch, egcl, batch_size, number_of_atoms): - egcl_h, egcl_x = egcl(batch["node_features"], batch["edges"], batch["coord"]) - return egcl_h.reshape(batch_size, number_of_atoms, -1), egcl_x.reshape( - batch_size, number_of_atoms, -1 - ) - - @pytest.fixture(scope="class") - def permutations(self, batch_size, number_of_atoms): - return torch.stack([torch.randperm(number_of_atoms) for _ in range(batch_size)]) - - @pytest.fixture(scope="class") - def permuted_coordinates(self, batch_size, number_of_atoms, batch, permutations): - permuted_batch = batch - pos = permuted_batch["coord"].view(batch_size, number_of_atoms, -1) - permuted_pos = torch.stack( - [ - pos[batch_idx, permutations[batch_idx], :] - for batch_idx in range(batch_size) - ] - ) - return permuted_pos.view(batch_size * number_of_atoms, -1) - - @pytest.fixture(scope="class") - def permuted_node_features(self, batch_size, number_of_atoms, batch, permutations): - permuted_batch = batch - - h = permuted_batch["node_features"].view(batch_size, number_of_atoms, -1) - permuted_h = torch.stack( - [ - h[batch_idx, permutations[batch_idx], :] - for batch_idx in range(batch_size) - ] - ) - return permuted_h.view(batch_size * number_of_atoms, -1) - - @pytest.fixture(scope="class") - def permuted_edges(self, batch_size, batch, permutations, number_of_atoms): - edges = batch["edges"] - permuted_edges = edges.clone() - for b in range(batch_size): - for atom in range(number_of_atoms): - new_atom_idx = permutations[b, atom] + b * number_of_atoms - permuted_edges[edges == new_atom_idx] = atom + b * number_of_atoms - return permuted_edges.long() - - @pytest.fixture() - def permuted_batch( - self, permuted_coordinates, permuted_edges, permuted_node_features - ): - permuted_batch = { - "coord": permuted_coordinates, - "node_features": permuted_node_features, - "edges": permuted_edges, - } - return permuted_batch - - @pytest.fixture() - def permuted_egnn_scores( - self, permuted_batch, egnn, batch_size, number_of_atoms, spatial_dimension - ): - egnn_scores = egnn( - permuted_batch["node_features"], - permuted_batch["edges"], - permuted_batch["coord"], - ) - return egnn_scores.reshape(batch_size, number_of_atoms, spatial_dimension) - - @pytest.fixture() - def permuted_egcl_scores(self, permuted_batch, egcl, batch_size, number_of_atoms): - egcl_h, egcl_x = egcl( - permuted_batch["node_features"], - permuted_batch["edges"], - permuted_batch["coord"], - ) - return egcl_h.reshape(batch_size, number_of_atoms, -1), egcl_x.reshape( - batch_size, number_of_atoms, -1 - ) - - def test_egcl_permutation_equivariance( - self, egcl_scores, permuted_egcl_scores, batch_size, permutations - ): - permuted_egcl_h, permuted_egcl_x = permuted_egcl_scores - egcl_h, egcl_x = egcl_scores - - expected_permuted_h = torch.stack( - [ - egcl_h[batch_idx, permutations[batch_idx], :] - for batch_idx in range(batch_size) - ] - ) - - torch.testing.assert_close(expected_permuted_h, permuted_egcl_h) - - expected_permuted_x = torch.stack( - [ - egcl_x[batch_idx, permutations[batch_idx], :] - for batch_idx in range(batch_size) - ] - ) - - torch.testing.assert_close(expected_permuted_x, permuted_egcl_x) - - def test_egnn_permutation_equivariance( - self, egnn_scores, permuted_egnn_scores, batch_size, permutations - ): - expected_permuted_scores = torch.stack( - [ - egnn_scores[batch_idx, permutations[batch_idx], :] - for batch_idx in range(batch_size) - ] - ) - - torch.testing.assert_close(expected_permuted_scores, permuted_egnn_scores) - - @pytest.fixture(scope="class") - def single_edge(self): - return torch.Tensor([1, 0]).unsqueeze(0).long() - - @pytest.fixture(scope="class") - def fixed_distance(self): - return 0.4 - - @pytest.fixture(scope="class") - def simple_pair_coord(self, fixed_distance, spatial_dimension): - coord = torch.zeros(2, spatial_dimension) - coord[1, 0] = fixed_distance - return coord - - def test_egcl_coord2radial( - self, single_edge, fixed_distance, simple_pair_coord, egcl - ): - computed_distance_squared, computed_displacement = egcl.coord2radial( - single_edge, simple_pair_coord - ) - torch.testing.assert_close(computed_distance_squared.item(), fixed_distance**2) - torch.testing.assert_close( - computed_displacement, simple_pair_coord[1, :].unsqueeze(0) - ) diff --git a/tests/models/test_score_fokker_planck_error.py b/tests/models/test_score_fokker_planck_error.py deleted file mode 100644 index 095a7c2a..00000000 --- a/tests/models/test_score_fokker_planck_error.py +++ /dev/null @@ -1,286 +0,0 @@ -from typing import Callable - -import einops -import pytest -import torch - -from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.egnn_score_network import \ - EGNNScoreNetworkParameters -from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.score_network_factory import \ - create_score_network -from diffusion_for_multi_scale_molecular_dynamics.namespace import ( - NOISE, NOISY_RELATIVE_COORDINATES, TIME, UNIT_CELL) -from diffusion_for_multi_scale_molecular_dynamics.samplers.exploding_variance import \ - ExplodingVariance -from src.diffusion_for_multi_scale_molecular_dynamics.models.normalized_score_fokker_planck_error import \ - NormalizedScoreFokkerPlanckError -from src.diffusion_for_multi_scale_molecular_dynamics.samplers.variance_sampler import \ - NoiseParameters - - -def get_finite_difference_time_derivative( - tensor_function: Callable, - relative_coordinates: torch.Tensor, - times: torch.Tensor, - unit_cells: torch.Tensor, - epsilon: float = 1.0e-8, -): - """Compute the finite difference of a tensor function with respect to time.""" - h = epsilon * torch.ones_like(times) - f_hp = tensor_function(relative_coordinates, times + h, unit_cells) - f_hm = tensor_function(relative_coordinates, times - h, unit_cells) - - batch_size, natoms, spatial_dimension = relative_coordinates.shape - denominator = einops.repeat(2 * h, "b 1 -> b n s", n=natoms, s=spatial_dimension) - time_derivative = (f_hp - f_hm) / denominator - return time_derivative - - -def get_finite_difference_gradient( - scalar_function: Callable, - relative_coordinates: torch.Tensor, - times: torch.Tensor, - unit_cells: torch.Tensor, - epsilon: float = 1.0e-6, -): - """Compute the gradient of a scalar function using finite difference.""" - batch_size, natoms, spatial_dimension = relative_coordinates.shape - - x = relative_coordinates - - gradient = torch.zeros_like(relative_coordinates) - for atom_idx in range(natoms): - for space_idx in range(spatial_dimension): - dx = torch.zeros_like(relative_coordinates) - dx[:, atom_idx, space_idx] = epsilon - - f_p = scalar_function(x + dx, times, unit_cells) - f_m = scalar_function(x - dx, times, unit_cells) - - gradient[:, atom_idx, space_idx] = (f_p - f_m) / (2.0 * epsilon) - - return gradient - - -def get_finite_difference_divergence( - tensor_function: Callable, - relative_coordinates: torch.Tensor, - times: torch.Tensor, - unit_cells: torch.Tensor, - epsilon: float = 1.0e-8, -): - """Compute the finite difference divergence of a tensor function.""" - batch_size, natoms, spatial_dimension = relative_coordinates.shape - - x = relative_coordinates - finite_difference_divergence = torch.zeros(batch_size) - - for atom_idx in range(natoms): - for space_idx in range(spatial_dimension): - dx = torch.zeros_like(relative_coordinates) - dx[:, atom_idx, space_idx] = epsilon - vec_hp = tensor_function(x + dx, times, unit_cells) - vec_hm = tensor_function(x - dx, times, unit_cells) - div_contribution = ( - vec_hp[:, atom_idx, space_idx] - vec_hm[:, atom_idx, space_idx] - ) / (2.0 * epsilon) - finite_difference_divergence += div_contribution - - return finite_difference_divergence - - -class TestScoreFokkerPlanckError: - @pytest.fixture(scope="class", autouse=True) - def set_default_type_to_float64(self): - torch.set_default_dtype(torch.float64) - yield - # this returns the default type to float32 at the end of all tests in this class in order - # to not affect other tests. - torch.set_default_dtype(torch.float32) - - @pytest.fixture(scope="class", autouse=True) - def set_random_seed(self): - torch.manual_seed(23423423) - - @pytest.fixture - def batch_size(self): - return 5 - - @pytest.fixture - def spatial_dimension(self): - return 3 - - @pytest.fixture(params=[True, False]) - def inference_mode(self, request): - return request.param - - @pytest.fixture(params=[2, 4]) - def number_of_atoms(self, request): - return request.param - - @pytest.fixture - def relative_coordinates(self, batch_size, number_of_atoms, spatial_dimension): - return torch.rand(batch_size, number_of_atoms, spatial_dimension) - - @pytest.fixture - def times(self, batch_size): - times = torch.rand(batch_size, 1) - return times - - @pytest.fixture - def unit_cells(self, batch_size, spatial_dimension): - return torch.rand(batch_size, spatial_dimension, spatial_dimension) - - @pytest.fixture() - def score_network_parameters(self, number_of_atoms, spatial_dimension): - # Let's test with a "real" model to identify any snag in the diff engine. - score_network_parameters = EGNNScoreNetworkParameters( - spatial_dimension=spatial_dimension, - message_n_hidden_dimensions=2, - node_n_hidden_dimensions=2, - coordinate_n_hidden_dimensions=2, - n_layers=2, - ) - return score_network_parameters - - @pytest.fixture() - def noise_parameters(self): - return NoiseParameters(total_time_steps=10, sigma_min=0.1, sigma_max=0.5) - - @pytest.fixture() - def batch(self, relative_coordinates, times, unit_cells, noise_parameters): - return { - NOISY_RELATIVE_COORDINATES: relative_coordinates, - TIME: times, - NOISE: ExplodingVariance(noise_parameters).get_sigma(times), - UNIT_CELL: unit_cells, - } - - @pytest.fixture() - def sigma_normalized_score_network(self, score_network_parameters, inference_mode): - score_network = create_score_network(score_network_parameters) - if inference_mode: - for parameter in score_network.parameters(): - parameter.requires_grad_(False) - - return score_network - - @pytest.fixture() - def expected_normalized_scores(self, sigma_normalized_score_network, batch): - return sigma_normalized_score_network(batch) - - @pytest.fixture - def normalized_score_fokker_planck_error( - self, sigma_normalized_score_network, noise_parameters - ): - return NormalizedScoreFokkerPlanckError( - sigma_normalized_score_network, noise_parameters - ) - - def test_normalized_scores_function( - self, expected_normalized_scores, normalized_score_fokker_planck_error, batch - ): - computed_normalized_scores = ( - normalized_score_fokker_planck_error._normalized_scores_function( - relative_coordinates=batch[NOISY_RELATIVE_COORDINATES], - times=batch[TIME], - unit_cells=batch[UNIT_CELL], - ) - ) - - torch.testing.assert_allclose( - expected_normalized_scores, computed_normalized_scores - ) - - def test_normalized_scores_square_norm_function( - self, expected_normalized_scores, normalized_score_fokker_planck_error, batch - ): - flat_scores = einops.rearrange( - expected_normalized_scores, "batch natoms space -> batch (natoms space)" - ) - - expected_squared_norms = (flat_scores**2).sum(dim=1) - - computed_squared_norms = normalized_score_fokker_planck_error._normalized_scores_square_norm_function( - relative_coordinates=batch[NOISY_RELATIVE_COORDINATES], - times=batch[TIME], - unit_cells=batch[UNIT_CELL], - ) - - torch.testing.assert_allclose(expected_squared_norms, computed_squared_norms) - - def test_get_dn_dt( - self, - normalized_score_fokker_planck_error, - relative_coordinates, - times, - unit_cells, - ): - finite_difference_dn_dt = get_finite_difference_time_derivative( - normalized_score_fokker_planck_error._normalized_scores_function, - relative_coordinates, - times, - unit_cells, - ) - - computed_dn_dt = normalized_score_fokker_planck_error._get_dn_dt( - relative_coordinates, times, unit_cells - ) - torch.testing.assert_close(computed_dn_dt, finite_difference_dn_dt) - - def test_divergence_function( - self, - normalized_score_fokker_planck_error, - relative_coordinates, - times, - unit_cells, - ): - finite_difference_divergence = get_finite_difference_divergence( - normalized_score_fokker_planck_error._normalized_scores_function, - relative_coordinates, - times, - unit_cells, - ) - - computed_divergence = normalized_score_fokker_planck_error._divergence_function( - relative_coordinates, times, unit_cells - ) - - torch.testing.assert_close(computed_divergence, finite_difference_divergence) - - def test_get_gradient( - self, - normalized_score_fokker_planck_error, - relative_coordinates, - times, - unit_cells, - ): - for callable in [ - normalized_score_fokker_planck_error._divergence_function, - normalized_score_fokker_planck_error._normalized_scores_square_norm_function, - ]: - computed_grads = normalized_score_fokker_planck_error._get_gradient( - callable, relative_coordinates, times, unit_cells - ) - finite_difference_grads = get_finite_difference_gradient( - callable, relative_coordinates, times, unit_cells - ) - - torch.testing.assert_close(computed_grads, finite_difference_grads) - - def test_get_normalized_score_fokker_planck_error( - self, - normalized_score_fokker_planck_error, - relative_coordinates, - times, - unit_cells, - ): - errors1 = normalized_score_fokker_planck_error.get_normalized_score_fokker_planck_error( - relative_coordinates, times, unit_cells - ) - - errors2 = normalized_score_fokker_planck_error.get_normalized_score_fokker_planck_error_by_iterating_over_batch( - relative_coordinates, times, unit_cells - ) - - torch.testing.assert_allclose(errors1, errors2) diff --git a/tests/noise_schedulers/__init__.py b/tests/noise_schedulers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/samplers/test_exploding_variance.py b/tests/noise_schedulers/test_exploding_variance.py similarity index 89% rename from tests/samplers/test_exploding_variance.py rename to tests/noise_schedulers/test_exploding_variance.py index e588e31a..895a88d4 100644 --- a/tests/samplers/test_exploding_variance.py +++ b/tests/noise_schedulers/test_exploding_variance.py @@ -1,9 +1,9 @@ import pytest import torch -from diffusion_for_multi_scale_molecular_dynamics.samplers.exploding_variance import \ - ExplodingVariance -from src.diffusion_for_multi_scale_molecular_dynamics.samplers.variance_sampler import \ +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.exploding_variance import \ + VarianceScheduler +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ NoiseParameters @@ -33,7 +33,7 @@ def times(self): @pytest.fixture() def exploding_variance(self, noise_parameters): - return ExplodingVariance(noise_parameters) + return VarianceScheduler(noise_parameters) @pytest.fixture() def expected_sigmas(self, noise_parameters, times): diff --git a/tests/samplers/test_variance_sampler.py b/tests/noise_schedulers/test_noise_scheduler.py similarity index 57% rename from tests/samplers/test_variance_sampler.py rename to tests/noise_schedulers/test_noise_scheduler.py index bfe4c5ea..b6950f06 100644 --- a/tests/samplers/test_variance_sampler.py +++ b/tests/noise_schedulers/test_noise_scheduler.py @@ -1,14 +1,18 @@ +import einops import pytest import torch -from src.diffusion_for_multi_scale_molecular_dynamics.samplers.variance_sampler import ( - ExplodingVarianceSampler, NoiseParameters) +from src.diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ + NoiseParameters +from src.diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_scheduler import \ + NoiseScheduler @pytest.mark.parametrize("total_time_steps", [3, 10, 17]) @pytest.mark.parametrize("time_delta", [1e-5, 0.1]) @pytest.mark.parametrize("sigma_min", [0.005, 0.1]) @pytest.mark.parametrize("corrector_step_epsilon", [2e-5, 0.1]) +@pytest.mark.parametrize("num_classes", [4]) class TestExplodingVarianceSampler: @pytest.fixture() def noise_parameters( @@ -22,8 +26,10 @@ def noise_parameters( ) @pytest.fixture() - def variance_sampler(self, noise_parameters): - return ExplodingVarianceSampler(noise_parameters=noise_parameters) + def variance_sampler(self, noise_parameters, num_classes): + return NoiseScheduler( + noise_parameters=noise_parameters, num_classes=num_classes + ) @pytest.fixture() def expected_times(self, total_time_steps, time_delta): @@ -56,6 +62,41 @@ def expected_epsilons(self, expected_sigmas, noise_parameters): return torch.tensor(epsilons) + @pytest.fixture() + def expected_betas(self, expected_times, noise_parameters): + betas = [] + for i in range(noise_parameters.total_time_steps): + betas.append(1.0 / (noise_parameters.total_time_steps - i)) + return torch.tensor(betas) + + @pytest.fixture() + def expected_alphas(self, expected_betas): + alphas = [1 - expected_betas[0]] + for beta in expected_betas[1:]: + alphas.append(alphas[-1] * (1 - beta.item())) + return torch.tensor(alphas) + + @pytest.fixture() + def expected_q_matrix(self, expected_betas, num_classes): + expected_qs = [] + for beta in expected_betas: + q = torch.zeros(1, num_classes, num_classes) + for i in range(num_classes): + q[0, i, i] = 1 - beta.item() + q[0, :-1, -1] = beta.item() + q[0, -1, -1] = 1 + expected_qs.append(q) + return torch.concatenate(expected_qs, dim=0) + + @pytest.fixture() + def expected_q_bar_matrix(self, expected_q_matrix): + expected_qbars = [expected_q_matrix[0]] + for qmat in expected_q_matrix[1:]: + expected_qbars.append( + einops.einsum(expected_qbars[-1], qmat, "i j, j k -> i k") + ) + return torch.stack(expected_qbars, dim=0) + @pytest.fixture() def indices(self, time_sampler, shape): return time_sampler.get_random_time_step_indices(shape) @@ -101,9 +142,23 @@ def test_get_random_time_step_indices(self, variance_sampler, total_time_steps): assert torch.all(random_indices >= 0) assert torch.all(random_indices < total_time_steps) + def test_create_beta_array(self, variance_sampler, expected_betas): + assert torch.allclose(variance_sampler._beta_array, expected_betas) + + def test_create_alpha_bar_array(self, variance_sampler, expected_alphas): + assert torch.allclose(variance_sampler._alpha_bar_array, expected_alphas) + + def test_create_q_matrix_array(self, variance_sampler, expected_q_matrix): + assert torch.allclose(variance_sampler._q_matrix_array, expected_q_matrix) + + def test_create_q_bar_matrix_array(self, variance_sampler, expected_q_bar_matrix): + assert torch.allclose( + variance_sampler._q_bar_matrix_array, expected_q_bar_matrix + ) + @pytest.mark.parametrize("batch_size", [1, 10, 100]) def test_get_random_noise_parameter_sample( - self, mocker, variance_sampler, batch_size + self, mocker, variance_sampler, batch_size, num_classes ): random_indices = variance_sampler._get_random_time_step_indices(shape=(1000,)) mocker.patch.object( @@ -121,12 +176,34 @@ def test_get_random_noise_parameter_sample( ) expected_gs = variance_sampler._g_array.take(random_indices) expected_gs_squared = variance_sampler._g_squared_array.take(random_indices) + expected_betas = variance_sampler._beta_array.take(random_indices) + expected_alpha_bars = variance_sampler._alpha_bar_array.take(random_indices) + expected_q_matrices = variance_sampler._q_matrix_array.index_select( + dim=0, index=random_indices + ) + expected_q_bar_matrices = variance_sampler._q_bar_matrix_array.index_select( + dim=0, index=random_indices + ) + expected_q_bar_tm1_matrices = torch.where( + random_indices.view(-1, 1, 1) == 0, + torch.eye(num_classes).unsqueeze(0), # replace t=0 with identity matrix + variance_sampler._q_bar_matrix_array.index_select( + dim=0, index=(random_indices - 1).clip(min=0) + ), + ) torch.testing.assert_close(noise_sample.time, expected_times) torch.testing.assert_close(noise_sample.sigma, expected_sigmas) torch.testing.assert_close(noise_sample.sigma_squared, expected_sigmas_squared) torch.testing.assert_close(noise_sample.g, expected_gs) torch.testing.assert_close(noise_sample.g_squared, expected_gs_squared) + torch.testing.assert_close(noise_sample.beta, expected_betas) + torch.testing.assert_close(noise_sample.alpha_bar, expected_alpha_bars) + torch.testing.assert_close(noise_sample.q_matrix, expected_q_matrices) + torch.testing.assert_close(noise_sample.q_bar_matrix, expected_q_bar_matrices) + torch.testing.assert_close( + noise_sample.q_bar_tm1_matrix, expected_q_bar_tm1_matrices + ) def test_get_all_sampling_parameters(self, variance_sampler): noise, langevin_dynamics = variance_sampler.get_all_sampling_parameters() @@ -144,3 +221,13 @@ def test_get_all_sampling_parameters(self, variance_sampler): torch.testing.assert_close( langevin_dynamics.sqrt_2_epsilon, variance_sampler._sqrt_two_epsilon_array ) + + torch.testing.assert_close(noise.beta, variance_sampler._beta_array) + torch.testing.assert_close(noise.alpha_bar, variance_sampler._alpha_bar_array) + torch.testing.assert_close(noise.q_matrix, variance_sampler._q_matrix_array) + torch.testing.assert_close( + noise.q_bar_matrix, variance_sampler._q_bar_matrix_array + ) + torch.testing.assert_close( + noise.q_bar_tm1_matrix[1:], variance_sampler._q_bar_matrix_array[:-1] + ) diff --git a/tests/noisers/__init__.py b/tests/noisers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/noisers/test_atom_types_noiser.py b/tests/noisers/test_atom_types_noiser.py new file mode 100644 index 00000000..4780a6aa --- /dev/null +++ b/tests/noisers/test_atom_types_noiser.py @@ -0,0 +1,76 @@ +import einops +import pytest +import torch + +from diffusion_for_multi_scale_molecular_dynamics.noisers.atom_types_noiser import \ + AtomTypesNoiser + + +@pytest.mark.parametrize("shape", [(10, 1), (4, 5, 3), (2, 2, 2, 2)]) +class TestNoisyAtomTypesSampler: + + @pytest.fixture(scope="class", autouse=True) + def set_random_seed(self): + torch.manual_seed(23423) + + @pytest.fixture() + def num_classes(self): + return 4 + + @pytest.fixture() + def real_atom_types(self, shape, num_classes): + return torch.randint(0, num_classes, shape).long() + + @pytest.fixture() + def real_atom_types_one_hot(self, real_atom_types, num_classes): + return torch.nn.functional.one_hot(real_atom_types, num_classes=num_classes) + + @pytest.fixture() + def q_bar_matrices(self, shape, num_classes): + return torch.rand(shape + (num_classes, num_classes)) + + @pytest.fixture() + def computed_noisy_atom_types(self, real_atom_types_one_hot, q_bar_matrices): + return AtomTypesNoiser.get_noisy_atom_types_sample( + real_atom_types_one_hot, q_bar_matrices + ) + + @pytest.fixture() + def fake_uniform_noise(self, shape, num_classes): + return torch.rand(shape + (num_classes,)) + + def test_shape(self, computed_noisy_atom_types, shape): + assert computed_noisy_atom_types.shape == shape + + def test_range(self, computed_noisy_atom_types, num_classes): + assert torch.all(computed_noisy_atom_types >= 0) + assert torch.all(computed_noisy_atom_types < num_classes) + + def test_get_noisy_relative_coordinates_sample( + self, mocker, real_atom_types_one_hot, q_bar_matrices, fake_uniform_noise + ): + mocker.patch.object( + AtomTypesNoiser, + "_get_uniform_noise", + return_value=fake_uniform_noise, + ) + computed_samples = AtomTypesNoiser.get_noisy_atom_types_sample( + real_atom_types_one_hot, q_bar_matrices + ) + + flat_q_matrices = q_bar_matrices.flatten(end_dim=-3) + flat_atom_types = real_atom_types_one_hot.flatten(end_dim=-2).float() + flat_computed_samples = computed_samples.flatten() + flat_fake_noise = fake_uniform_noise.flatten(end_dim=-2) + + for qmat, x0, computed_sample, epsilon in zip( + flat_q_matrices, + flat_atom_types, + flat_computed_samples, + flat_fake_noise, + ): + post_q = einops.einsum(x0, qmat, "... j, ... j i -> ... i") + expected_sample = torch.log(post_q) - torch.log(-torch.log(epsilon)) + expected_sample = torch.argmax(expected_sample, dim=-1) + + assert torch.all(computed_sample == expected_sample) diff --git a/tests/samplers/test_noisy_relative_coordinates_sampler.py b/tests/noisers/test_relative_coordinates_noiser.py similarity index 85% rename from tests/samplers/test_noisy_relative_coordinates_sampler.py rename to tests/noisers/test_relative_coordinates_noiser.py index b18b5f76..8f8a3c62 100644 --- a/tests/samplers/test_noisy_relative_coordinates_sampler.py +++ b/tests/noisers/test_relative_coordinates_noiser.py @@ -2,8 +2,8 @@ import pytest import torch -from diffusion_for_multi_scale_molecular_dynamics.samplers.noisy_relative_coordinates_sampler import \ - NoisyRelativeCoordinatesSampler +from diffusion_for_multi_scale_molecular_dynamics.noisers.relative_coordinates_noiser import \ + RelativeCoordinatesNoiser @pytest.mark.parametrize("shape", [(10, 1), (4, 5, 3), (2, 2, 2, 2)]) @@ -23,7 +23,7 @@ def sigmas(self, shape): @pytest.fixture() def computed_noisy_relative_coordinates(self, real_relative_coordinates, sigmas): - return NoisyRelativeCoordinatesSampler.get_noisy_relative_coordinates_sample( + return RelativeCoordinatesNoiser.get_noisy_relative_coordinates_sample( real_relative_coordinates, sigmas ) @@ -43,13 +43,13 @@ def test_get_noisy_relative_coordinates_sample( self, mocker, real_relative_coordinates, sigmas, fake_gaussian_sample ): mocker.patch.object( - NoisyRelativeCoordinatesSampler, + RelativeCoordinatesNoiser, "_get_gaussian_noise", return_value=fake_gaussian_sample, ) computed_samples = ( - NoisyRelativeCoordinatesSampler.get_noisy_relative_coordinates_sample( + RelativeCoordinatesNoiser.get_noisy_relative_coordinates_sample( real_relative_coordinates, sigmas ) ) diff --git a/tests/oracle/test_lammps.py b/tests/oracle/test_lammps.py deleted file mode 100644 index 77947b49..00000000 --- a/tests/oracle/test_lammps.py +++ /dev/null @@ -1,44 +0,0 @@ -import numpy as np -import pytest - -from diffusion_for_multi_scale_molecular_dynamics.oracle.lammps import \ - get_energy_and_forces_from_lammps - - -@pytest.fixture -def high_symmetry_lattice(): - box = np.eye(3) * 4 - return box - - -@pytest.fixture -def high_symmetry_positions(): - positions = np.array([[0, 0, 0], [2, 2, 2]]) - return positions - - -# do not run on github because no lammps -@pytest.mark.not_on_github -def test_high_symmetry(high_symmetry_positions, high_symmetry_lattice): - energy, forces = get_energy_and_forces_from_lammps( - high_symmetry_positions, high_symmetry_lattice, atom_types=np.array([1, 1]) - ) - for x in ["x", "y", "z"]: - assert np.allclose(forces[f"f{x}"], [0, 0]) - assert energy < 0 - - -@pytest.fixture -def low_symmetry_positions(): - positions = np.array([[0.23, 1.2, 2.01], [3.2, 0.9, 3.87]]) - return positions - - -@pytest.mark.not_on_github -def test_low_symmetry(low_symmetry_positions, high_symmetry_lattice): - energy, forces = get_energy_and_forces_from_lammps( - low_symmetry_positions, high_symmetry_lattice, atom_types=np.array([1, 1]) - ) - for x in ["x", "y", "z"]: - assert not np.allclose(forces[f"f{x}"], [0, 0]) - assert energy < 0 diff --git a/tests/oracle/test_lammps_energy_oracle.py b/tests/oracle/test_lammps_energy_oracle.py new file mode 100644 index 00000000..0b0d0666 --- /dev/null +++ b/tests/oracle/test_lammps_energy_oracle.py @@ -0,0 +1,118 @@ +import einops +import numpy as np +import pytest +import torch + +from diffusion_for_multi_scale_molecular_dynamics.data.element_types import \ + ElementTypes +from diffusion_for_multi_scale_molecular_dynamics.namespace import ( + AXL, AXL_COMPOSITION, CARTESIAN_POSITIONS, UNIT_CELL) +from diffusion_for_multi_scale_molecular_dynamics.oracle.lammps_energy_oracle import ( + LammpsEnergyOracle, LammpsOracleParameters) + + +@pytest.mark.not_on_github +class TestLammpsEnergyOracle: + + @pytest.fixture(scope="class", autouse=True) + def set_seed(self): + """Set the random seed.""" + np.random.seed(2311331423) + + @pytest.fixture() + def spatial_dimension(self): + return 3 + + @pytest.fixture(params=[8, 12, 16]) + def num_atoms(self, request): + return request.param + + @pytest.fixture() + def acell(self): + return 5.5 + + @pytest.fixture() + def box(self, spatial_dimension, acell): + return np.diag(spatial_dimension * [acell]) + + @pytest.fixture() + def cartesian_positions(self, num_atoms, spatial_dimension, box): + x = np.random.rand(num_atoms, spatial_dimension) + return einops.einsum(box, x, "d1 d2, natoms d2 -> natoms d1") + + @pytest.fixture(params=[1, 2]) + def number_of_unique_elements(self, request): + return request.param + + @pytest.fixture() + def unique_elements(self, number_of_unique_elements): + match number_of_unique_elements: + case 1: + elements = ['Si'] + case 2: + elements = ['Si', 'Ge'] + case _: + raise NotImplementedError() + + return elements + + @pytest.fixture() + def lammps_oracle_parameters(self, number_of_unique_elements, unique_elements): + match number_of_unique_elements: + case 1: + sw_coeff_filename = 'Si.sw' + case 2: + sw_coeff_filename = 'SiGe.sw' + case _: + raise NotImplementedError() + + return LammpsOracleParameters(sw_coeff_filename=sw_coeff_filename, elements=unique_elements) + + @pytest.fixture() + def element_types(self, unique_elements): + return ElementTypes(unique_elements) + + @pytest.fixture() + def atom_types(self, element_types, num_atoms): + return np.random.choice(element_types.element_ids, num_atoms, replace=True) + + @pytest.fixture() + def batch_size(self): + return 4 + + @pytest.fixture() + def samples(self, batch_size, num_atoms, spatial_dimension, element_types): + + list_acells = 5. + 5.0 * torch.rand(batch_size) + basis_vectors = torch.stack([acell * torch.eye(spatial_dimension) for acell in list_acells]) + + relative_coordinates = torch.rand(batch_size, num_atoms, spatial_dimension) + cartesian_positions = einops.einsum(basis_vectors, relative_coordinates, + "batch d1 d2, batch natoms d2 -> batch natoms d1") + + atom_types = torch.randint(element_types.number_of_atom_types, (batch_size, num_atoms)) + + axl_composition = AXL(X=relative_coordinates, A=atom_types, L=basis_vectors) + + return {UNIT_CELL: basis_vectors, + CARTESIAN_POSITIONS: cartesian_positions, + AXL_COMPOSITION: axl_composition} + + @pytest.fixture() + def oracle(self, element_types, lammps_oracle_parameters): + return LammpsEnergyOracle(lammps_oracle_parameters=lammps_oracle_parameters) + + def test_compute_energy_and_forces(self, oracle, element_types, cartesian_positions, box, atom_types, tmp_path): + + dump_file_path = tmp_path / "dump.yaml" + energy, forces = oracle._compute_energy_and_forces(cartesian_positions, box, atom_types, dump_file_path) + + np.testing.assert_allclose(cartesian_positions, forces[['x', 'y', 'z']].values, rtol=1e-5) + + expected_atoms = [element_types.get_element(id) for id in atom_types] + computed_atoms = forces['element'].to_list() + assert expected_atoms == computed_atoms + + def test_compute_oracle_energies(self, oracle, samples, batch_size): + energies = oracle.compute_oracle_energies(samples) + assert len(energies) == batch_size diff --git a/tests/sampling/__init__.py b/tests/sampling/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/samples_and_metrics/test_sampling.py b/tests/sampling/test_diffusion_sampling.py similarity index 73% rename from tests/samples_and_metrics/test_sampling.py rename to tests/sampling/test_diffusion_sampling.py index c7cdc993..3eefe000 100644 --- a/tests/samples_and_metrics/test_sampling.py +++ b/tests/sampling/test_diffusion_sampling.py @@ -2,17 +2,17 @@ import pytest import torch +from diffusion_for_multi_scale_molecular_dynamics.generators.axl_generator import ( + AXLGenerator, SamplingParameters) from diffusion_for_multi_scale_molecular_dynamics.namespace import ( - CARTESIAN_POSITIONS, RELATIVE_COORDINATES, UNIT_CELL) + AXL, AXL_COMPOSITION, CARTESIAN_POSITIONS, UNIT_CELL) from diffusion_for_multi_scale_molecular_dynamics.utils.basis_transformations import \ get_positions_from_coordinates -from src.diffusion_for_multi_scale_molecular_dynamics.generators.position_generator import ( - PositionGenerator, SamplingParameters) -from src.diffusion_for_multi_scale_molecular_dynamics.samples.sampling import \ +from src.diffusion_for_multi_scale_molecular_dynamics.sampling.diffusion_sampling import \ create_batch_of_samples -class DummyGenerator(PositionGenerator): +class DummyGenerator(AXLGenerator): def __init__(self, relative_coordinates): self._relative_coordinates = relative_coordinates self._counter = 0 @@ -24,9 +24,14 @@ def sample( self, number_of_samples: int, device: torch.device, unit_cell: torch.Tensor ) -> torch.Tensor: self._counter += number_of_samples - return self._relative_coordinates[ - self._counter - number_of_samples: self._counter + rel_coordinates = self._relative_coordinates[ + self._counter - number_of_samples:self._counter ] + return AXL( + A=torch.zeros_like(rel_coordinates[..., 0]).long(), + X=rel_coordinates, + L=torch.zeros_like(rel_coordinates), + ) @pytest.fixture @@ -49,6 +54,11 @@ def spatial_dimensions(): return 3 +@pytest.fixture +def num_atom_types(): + return 4 + + @pytest.fixture def relative_coordinates(number_of_samples, number_of_atoms, spatial_dimensions): return torch.rand(number_of_samples, number_of_atoms, spatial_dimensions) @@ -66,7 +76,11 @@ def generator(relative_coordinates): @pytest.fixture def sampling_parameters( - spatial_dimensions, number_of_atoms, number_of_samples, cell_dimensions + spatial_dimensions, + number_of_atoms, + number_of_samples, + cell_dimensions, + num_atom_types, ): return SamplingParameters( algorithm="dummy", @@ -75,6 +89,7 @@ def sampling_parameters( number_of_samples=number_of_samples, sample_batchsize=2, cell_dimensions=cell_dimensions, + num_atom_types=num_atom_types, ) @@ -94,7 +109,7 @@ def test_create_batch_of_samples( ) torch.testing.assert_allclose( - computed_samples[RELATIVE_COORDINATES], relative_coordinates + computed_samples[AXL_COMPOSITION].X, relative_coordinates ) torch.testing.assert_allclose(computed_samples[UNIT_CELL], expected_basis_vectors) torch.testing.assert_allclose( diff --git a/tests/test_sample_diffusion.py b/tests/test_sample_diffusion.py index d6f8d0ab..5f3560db 100644 --- a/tests/test_sample_diffusion.py +++ b/tests/test_sample_diffusion.py @@ -5,19 +5,19 @@ import yaml from diffusion_for_multi_scale_molecular_dynamics import sample_diffusion -from diffusion_for_multi_scale_molecular_dynamics.generators.predictor_corrector_position_generator import \ +from diffusion_for_multi_scale_molecular_dynamics.generators.predictor_corrector_axl_generator import \ PredictorCorrectorSamplingParameters -from diffusion_for_multi_scale_molecular_dynamics.models.loss import \ +from diffusion_for_multi_scale_molecular_dynamics.loss.loss_parameters import \ MSELossParameters +from diffusion_for_multi_scale_molecular_dynamics.models.axl_diffusion_lightning_model import ( + AXLDiffusionLightningModel, AXLDiffusionParameters) from diffusion_for_multi_scale_molecular_dynamics.models.optimizer import \ OptimizerParameters -from diffusion_for_multi_scale_molecular_dynamics.models.position_diffusion_lightning_model import ( - PositionDiffusionLightningModel, PositionDiffusionParameters) from diffusion_for_multi_scale_molecular_dynamics.models.score_networks.mlp_score_network import \ MLPScoreNetworkParameters from diffusion_for_multi_scale_molecular_dynamics.namespace import \ - RELATIVE_COORDINATES -from diffusion_for_multi_scale_molecular_dynamics.samplers.variance_sampler import \ + AXL_COMPOSITION +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ NoiseParameters @@ -31,6 +31,11 @@ def number_of_atoms(): return 8 +@pytest.fixture() +def num_atom_types(): + return 3 + + @pytest.fixture() def number_of_samples(): return 12 @@ -58,6 +63,7 @@ def sampling_parameters( number_of_samples, cell_dimensions, record_samples, + num_atom_types, ): return PredictorCorrectorSamplingParameters( number_of_corrector_steps=1, @@ -66,19 +72,23 @@ def sampling_parameters( number_of_samples=number_of_samples, cell_dimensions=cell_dimensions, record_samples=record_samples, + num_atom_types=num_atom_types, ) @pytest.fixture() -def sigma_normalized_score_network(number_of_atoms, noise_parameters): +def axl_network(number_of_atoms, noise_parameters, num_atom_types): score_network_parameters = MLPScoreNetworkParameters( number_of_atoms=number_of_atoms, - embedding_dimensions_size=8, + num_atom_types=num_atom_types, + noise_embedding_dimensions_size=8, + time_embedding_dimensions_size=8, + atom_type_embedding_dimensions_size=8, n_hidden_dimensions=2, hidden_dimensions_size=16, ) - diffusion_params = PositionDiffusionParameters( + diffusion_params = AXLDiffusionParameters( score_network_parameters=score_network_parameters, loss_parameters=MSELossParameters(), optimizer_parameters=OptimizerParameters(name="adam", learning_rate=1e-3), @@ -87,8 +97,8 @@ def sigma_normalized_score_network(number_of_atoms, noise_parameters): diffusion_sampling_parameters=None, ) - model = PositionDiffusionLightningModel(diffusion_params) - return model.sigma_normalized_score_network + model = AXLDiffusionLightningModel(diffusion_params) + return model.axl_network @pytest.fixture() @@ -136,7 +146,7 @@ def args(config_path, checkpoint_path, output_path): def test_sample_diffusion( mocker, args, - sigma_normalized_score_network, + axl_network, output_path, number_of_samples, number_of_atoms, @@ -144,18 +154,22 @@ def test_sample_diffusion( record_samples, ): mocker.patch( - "diffusion_for_multi_scale_molecular_dynamics.sample_diffusion.get_sigma_normalized_score_network", - return_value=sigma_normalized_score_network, + "diffusion_for_multi_scale_molecular_dynamics.sample_diffusion.get_axl_network", + return_value=axl_network, ) sample_diffusion.main(args) assert (output_path / "samples.pt").exists() samples = torch.load(output_path / "samples.pt") - assert samples[RELATIVE_COORDINATES].shape == ( + assert samples[AXL_COMPOSITION].X.shape == ( number_of_samples, number_of_atoms, spatial_dimension, ) + assert samples[AXL_COMPOSITION].A.shape == ( + number_of_samples, + number_of_atoms, + ) assert (output_path / "trajectories.pt").exists() == record_samples diff --git a/tests/test_train_diffusion.py b/tests/test_train_diffusion.py index 16f10d23..1858f913 100644 --- a/tests/test_train_diffusion.py +++ b/tests/test_train_diffusion.py @@ -8,7 +8,7 @@ import glob import os import re -from typing import Union +from typing import List, Union import numpy as np import pytest @@ -57,14 +57,20 @@ def get_prediction_head_parameters(name: str): def get_score_network( - architecture: str, head_name: Union[str, None], number_of_atoms: int + architecture: str, + head_name: Union[str, None], + number_of_atoms: int, + num_atom_types: int, ): if architecture == "mlp": assert head_name is None, "There are no head options for a MLP score network." score_network = dict( architecture="mlp", number_of_atoms=number_of_atoms, - embedding_dimensions_size=8, + num_atom_types=num_atom_types, + noise_embedding_dimensions_size=8, + time_embedding_dimensions_size=8, + atom_type_embedding_dimensions_size=8, n_hidden_dimensions=2, hidden_dimensions_size=16, ) @@ -77,6 +83,7 @@ def get_score_network( number_of_atoms=number_of_atoms, radial_MLP=[4, 4, 4], prediction_head_parameters=get_prediction_head_parameters(head_name), + num_atom_types=num_atom_types, ) elif architecture == "diffusion_mace": @@ -90,10 +97,11 @@ def get_score_network( number_of_mlp_layers=1, number_of_atoms=number_of_atoms, radial_MLP=[4, 4, 4], + num_atom_types=num_atom_types, ) elif architecture == "egnn": - score_network = dict(architecture="egnn") + score_network = dict(architecture="egnn", num_atom_types=num_atom_types) else: raise NotImplementedError("This score network is not implemented") return score_network @@ -101,6 +109,8 @@ def get_score_network( def get_config( number_of_atoms: int, + num_atom_types: int, + unique_elements: List[str], max_epoch: int, architecture: str, head_name: Union[str, None], @@ -109,8 +119,10 @@ def get_config( data_config = dict(batch_size=4, num_workers=0, max_atom=number_of_atoms) model_config = dict( - score_network=get_score_network(architecture, head_name, number_of_atoms), - loss={"algorithm": "mse"}, + score_network=get_score_network( + architecture, head_name, number_of_atoms, num_atom_types + ), + loss={"coordinates_algorithm": "mse"}, noise={"total_time_steps": 10}, ) @@ -121,6 +133,7 @@ def get_config( algorithm=sampling_algorithm, spatial_dimension=3, number_of_atoms=number_of_atoms, + num_atom_types=num_atom_types, number_of_samples=4, record_samples=True, cell_dimensions=[10.0, 10.0, 10.0], @@ -147,6 +160,7 @@ def get_config( exp_name="smoke_test", seed=9999, spatial_dimension=3, + elements=unique_elements, data=data_config, model=model_config, optimizer=optimizer_config, @@ -172,16 +186,35 @@ def get_config( ], ) class TestTrainDiffusion(TestDiffusionDataBase): + + @pytest.fixture(autouse=True) + def skip_mps_accelerator(self, accelerator): + if accelerator == 'mps': + pytest.skip("Skipping MPS accelerator: it is incompatible with KeOps and leads to segfaults") + @pytest.fixture() def max_epoch(self): return 5 + @pytest.fixture() + def num_atom_types(self): + return 3 + @pytest.fixture() def config( - self, number_of_atoms, max_epoch, architecture, head_name, sampling_algorithm + self, + number_of_atoms, + num_atom_types, + unique_elements, + max_epoch, + architecture, + head_name, + sampling_algorithm, ): return get_config( number_of_atoms, + num_atom_types=num_atom_types, + unique_elements=unique_elements, max_epoch=max_epoch, architecture=architecture, head_name=head_name, diff --git a/tests/utils/test_basis_transformations.py b/tests/utils/test_basis_transformations.py index e14bc7a8..363fba15 100644 --- a/tests/utils/test_basis_transformations.py +++ b/tests/utils/test_basis_transformations.py @@ -1,10 +1,11 @@ import pytest import torch +from diffusion_for_multi_scale_molecular_dynamics.namespace import AXL from diffusion_for_multi_scale_molecular_dynamics.utils.basis_transformations import ( get_positions_from_coordinates, get_reciprocal_basis_vectors, get_relative_coordinates_from_cartesian_positions, - map_relative_coordinates_to_unit_cell) + map_axl_composition_to_unit_cell, map_relative_coordinates_to_unit_cell) @pytest.fixture @@ -22,6 +23,11 @@ def relative_coordinates(batch_size, number_of_atoms): return torch.rand(batch_size, number_of_atoms, 3) +@pytest.fixture +def num_atom_types(): + return 5 + + def test_get_reciprocal_basis_vectors(basis_vectors): reciprocal_basis_vectors = get_reciprocal_basis_vectors(basis_vectors) assert reciprocal_basis_vectors.shape == basis_vectors.shape @@ -74,7 +80,7 @@ def test_remainder_failure(): @pytest.mark.parametrize("shape", [(10,), (10, 20), (3, 4, 5)]) def test_map_relative_coordinates_to_unit_cell_hard(shape): - relative_coordinates = 1e-8 * (torch.rand((10,)) - 0.5) + relative_coordinates = 1e-8 * (torch.rand(shape) - 0.5) computed_relative_coordinates = map_relative_coordinates_to_unit_cell( relative_coordinates ) @@ -95,7 +101,32 @@ def test_map_relative_coordinates_to_unit_cell_hard(shape): @pytest.mark.parametrize("shape", [(100, 8, 16)]) def test_map_relative_coordinates_to_unit_cell_easy(shape): # Very unlikely to hit the edge cases. - relative_coordinates = 10.0 * (torch.rand((10,)) - 0.5) + relative_coordinates = 10.0 * (torch.rand(shape) - 0.5) expected_values = torch.remainder(relative_coordinates, 1.0) computed_values = map_relative_coordinates_to_unit_cell(relative_coordinates) torch.testing.assert_close(computed_values, expected_values) + + +@pytest.mark.parametrize("shape", [(10,), (10, 20), (3, 4, 5)]) +def test_map_axl_to_unit_cell_hard(shape, num_atom_types): + atom_types = torch.randint(0, num_atom_types + 1, shape) + relative_coordinates = 1e-8 * (torch.rand(shape) - 0.5) + axl_composition = AXL(A=atom_types, X=relative_coordinates, L=torch.rand(shape)) + + computed_axl_composition = map_axl_composition_to_unit_cell( + axl_composition, device=torch.device("cpu") + ) + + positive_relative_coordinates_mask = relative_coordinates >= 0.0 + assert torch.all( + relative_coordinates[positive_relative_coordinates_mask] + == computed_axl_composition.X[positive_relative_coordinates_mask] + ) + torch.testing.assert_close( + computed_axl_composition.X[~positive_relative_coordinates_mask], + torch.zeros_like( + computed_axl_composition.X[~positive_relative_coordinates_mask] + ), + ) + assert torch.all(computed_axl_composition.A == axl_composition.A) + assert torch.all(computed_axl_composition.L == axl_composition.L) diff --git a/tests/utils/test_d3pm_utils.py b/tests/utils/test_d3pm_utils.py new file mode 100644 index 00000000..8cd4e7af --- /dev/null +++ b/tests/utils/test_d3pm_utils.py @@ -0,0 +1,372 @@ +from copy import copy + +import pytest +import torch + +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_parameters import \ + NoiseParameters +from diffusion_for_multi_scale_molecular_dynamics.noise_schedulers.noise_scheduler import \ + NoiseScheduler +from diffusion_for_multi_scale_molecular_dynamics.utils.d3pm_utils import ( + class_index_to_onehot, compute_q_at_given_a0, compute_q_at_given_atm1, + get_probability_at_previous_time_step, get_probability_from_logits) +from diffusion_for_multi_scale_molecular_dynamics.utils.tensor_utils import \ + broadcast_batch_matrix_tensor_to_all_dimensions + + +@pytest.fixture(scope="module", autouse=True) +def set_random_seed(): + torch.manual_seed(2345234) + + +@pytest.fixture() +def final_shape(batch_size, number_of_dimensions): + shape = torch.randint(low=1, high=5, size=(number_of_dimensions,)) + shape[0] = batch_size + return tuple(shape.numpy()) + + +@pytest.fixture() +def batch_values(final_shape, num_classes): + return torch.randint(0, num_classes, final_shape) + + +@pytest.fixture() +def q_t(final_shape, num_classes): + return torch.randn(final_shape + (num_classes, num_classes)) + + +@pytest.fixture() +def one_hot_x(batch_values, num_classes): + return torch.nn.functional.one_hot(batch_values.long(), num_classes) + + +@pytest.mark.parametrize("batch_size", [4, 8]) +@pytest.mark.parametrize("number_of_dimensions", [4, 8]) +@pytest.mark.parametrize("num_classes", [1, 2, 3]) +def test_class_index_to_onehot(batch_size, batch_values, final_shape, num_classes): + computed_onehot_encoded = class_index_to_onehot(batch_values, num_classes) + + expected_encoding = torch.zeros(final_shape + (num_classes,)) + for i in range(num_classes): + expected_encoding[..., i] += torch.where(batch_values == i, 1, 0) + assert torch.all(expected_encoding == computed_onehot_encoded) + + +@pytest.mark.parametrize("batch_size", [4, 8]) +@pytest.mark.parametrize("number_of_dimensions", [4, 8]) +@pytest.mark.parametrize("num_classes", [1, 2, 3]) +def test_compute_q_xt_bar_xo(q_t, one_hot_x, num_classes): + computed_q_xtxo = compute_q_at_given_a0(one_hot_x, q_t) + expected_q_xtxo = torch.zeros_like(one_hot_x.float()) + for i in range(num_classes): + for j in range(num_classes): + expected_q_xtxo[..., i] += one_hot_x[..., j].float() * q_t[..., j, i] + torch.testing.assert_allclose(computed_q_xtxo, expected_q_xtxo) + + +@pytest.mark.parametrize("batch_size", [4, 8]) +@pytest.mark.parametrize("number_of_dimensions", [4, 8]) +@pytest.mark.parametrize("num_classes", [1, 2, 3]) +def test_compute_q_xt_bar_xtm1(q_t, one_hot_x, num_classes): + computed_q_xtxtm1 = compute_q_at_given_atm1(one_hot_x, q_t) + expected_q_xtxtm1 = torch.zeros_like(one_hot_x.float()) + for i in range(num_classes): + for j in range(num_classes): + expected_q_xtxtm1[..., i] += one_hot_x[..., j].float() * q_t[..., j, i] + torch.testing.assert_allclose(computed_q_xtxtm1, expected_q_xtxtm1) + + +@pytest.fixture +def batch_size(): + return 4 + + +@pytest.fixture +def number_of_atoms(): + return 8 + + +@pytest.fixture +def num_atom_types(): + return 5 + + +@pytest.fixture +def num_classes(num_atom_types): + return num_atom_types + 1 + + +@pytest.fixture +def predicted_logits(batch_size, number_of_atoms, num_classes): + logits = 10 * (torch.randn(batch_size, number_of_atoms, num_classes) - 0.5) + logits[:, :, -1] = -torch.inf # force the model to never predict MASK + return logits + + +@pytest.fixture +def predicted_p_a0_given_at(predicted_logits): + return torch.nn.functional.softmax(predicted_logits, dim=-1) + + +@pytest.fixture +def one_hot_at(batch_size, number_of_atoms, num_atom_types, num_classes): + # at CAN be MASK. + one_hot_indices = torch.randint( + 0, + num_classes, + ( + batch_size, + number_of_atoms, + ), + ) + one_hots = class_index_to_onehot(one_hot_indices, num_classes=num_classes) + return one_hots + + +@pytest.fixture +def q_matrices(batch_size, number_of_atoms, num_classes): + random_q_matrices = torch.rand(batch_size, num_classes, num_classes) + final_shape = (batch_size, number_of_atoms) + broadcast_q_matrices = broadcast_batch_matrix_tensor_to_all_dimensions( + random_q_matrices, final_shape=final_shape + ) + return broadcast_q_matrices + + +@pytest.fixture +def q_bar_matrices(batch_size, number_of_atoms, num_classes): + random_q_bar_matrices = torch.rand(batch_size, num_classes, num_classes) + final_shape = (batch_size, number_of_atoms) + broadcast_q_bar_matrices = broadcast_batch_matrix_tensor_to_all_dimensions( + random_q_bar_matrices, final_shape=final_shape + ) + return broadcast_q_bar_matrices + + +@pytest.fixture +def q_bar_tm1_matrices(batch_size, number_of_atoms, num_classes): + random_q_bar_tm1_matrices = torch.rand(batch_size, num_classes, num_classes) + final_shape = (batch_size, number_of_atoms) + broadcast_q_bar_tm1_matrices = broadcast_batch_matrix_tensor_to_all_dimensions( + random_q_bar_tm1_matrices, final_shape=final_shape + ) + return broadcast_q_bar_tm1_matrices + + +@pytest.fixture +def loss_eps(): + return 1.0e-12 + + +@pytest.fixture +def expected_p_atm1_given_at_from_logits( + predicted_p_a0_given_at, + one_hot_at, + q_matrices, + q_bar_matrices, + q_bar_tm1_matrices, +): + batch_size, natoms, num_classes = predicted_p_a0_given_at.shape + + denominator = torch.zeros(batch_size, natoms) + numerator1 = torch.zeros(batch_size, natoms, num_classes) + numerator2 = torch.zeros(batch_size, natoms, num_classes) + + for i in range(num_classes): + for j in range(num_classes): + denominator[:, :] += ( + predicted_p_a0_given_at[:, :, i] + * q_bar_matrices[:, :, i, j] + * one_hot_at[:, :, j] + ) + numerator1[:, :, i] += ( + predicted_p_a0_given_at[:, :, j] * q_bar_tm1_matrices[:, :, j, i] + ) + numerator2[:, :, i] += q_matrices[:, :, i, j] * one_hot_at[:, :, j] + + numerator = numerator1 * numerator2 + + expected_p = torch.zeros(batch_size, natoms, num_classes) + for i in range(num_classes): + expected_p[:, :, i] = numerator[:, :, i] / denominator[:, :] + + # Note that the expected_p_atm1_given_at is not really a probability (and thus does not sum to 1) because + # the Q matrices are random. + return expected_p + + +@pytest.fixture +def one_hot_a0(batch_size, number_of_atoms, num_atom_types, num_classes): + # a0 CANNOT be MASK. + one_hot_indices = torch.randint( + 0, + num_atom_types, + ( + batch_size, + number_of_atoms, + ), + ) + one_hots = class_index_to_onehot(one_hot_indices, num_classes=num_classes) + return one_hots + + +@pytest.fixture +def expected_p_atm1_given_at_from_onehot( + one_hot_a0, one_hot_at, q_matrices, q_bar_matrices, q_bar_tm1_matrices +): + batch_size, natoms, num_classes = one_hot_a0.shape + + denominator = torch.zeros(batch_size, natoms) + numerator1 = torch.zeros(batch_size, natoms, num_classes) + numerator2 = torch.zeros(batch_size, natoms, num_classes) + + for i in range(num_classes): + for j in range(num_classes): + denominator[:, :] += ( + one_hot_a0[:, :, i] * q_bar_matrices[:, :, i, j] * one_hot_at[:, :, j] + ) + numerator1[:, :, i] += one_hot_a0[:, :, j] * q_bar_tm1_matrices[:, :, j, i] + numerator2[:, :, i] += q_matrices[:, :, i, j] * one_hot_at[:, :, j] + + numerator = numerator1 * numerator2 + + expected_q = torch.zeros(batch_size, natoms, num_classes) + for i in range(num_classes): + expected_q[:, :, i] = numerator[:, :, i] / denominator[:, :] + + return expected_q + + +def test_get_probability_at_previous_time_step_from_logits( + predicted_logits, + one_hot_at, + q_matrices, + q_bar_matrices, + q_bar_tm1_matrices, + loss_eps, + expected_p_atm1_given_at_from_logits, +): + computed_p_atm1_given_at = get_probability_at_previous_time_step( + predicted_logits, + one_hot_at, + q_matrices, + q_bar_matrices, + q_bar_tm1_matrices, + small_epsilon=loss_eps, + probability_at_zeroth_timestep_are_logits=True, + ) + + assert torch.allclose( + computed_p_atm1_given_at, expected_p_atm1_given_at_from_logits + ) + + +def test_get_probability_at_previous_time_step_from_one_hot_probabilities( + one_hot_a0, + one_hot_at, + q_matrices, + q_bar_matrices, + q_bar_tm1_matrices, + loss_eps, + expected_p_atm1_given_at_from_onehot, +): + computed_q_atm1_given_at_and_a0 = get_probability_at_previous_time_step( + one_hot_a0, + one_hot_at, + q_matrices, + q_bar_matrices, + q_bar_tm1_matrices, + small_epsilon=loss_eps, + probability_at_zeroth_timestep_are_logits=False, + ) + + assert torch.allclose( + computed_q_atm1_given_at_and_a0, expected_p_atm1_given_at_from_onehot + ) + + +@pytest.mark.parametrize("total_time_steps", [2, 5, 10]) +def test_prob_a0_given_a1_is_never_mask(number_of_atoms, num_classes, total_time_steps, loss_eps): + noise_parameters = NoiseParameters(total_time_steps=total_time_steps) + noise_scheduler = NoiseScheduler(noise_parameters=noise_parameters, num_classes=num_classes) + + logits = torch.rand(1, number_of_atoms, num_classes) + logits[..., -1] = -torch.inf + + atom_shape = (1, number_of_atoms) + q_matrices = broadcast_batch_matrix_tensor_to_all_dimensions( + batch_values=noise_scheduler._q_matrix_array[0].unsqueeze(0), final_shape=atom_shape + ) + + q_bar_matrices = broadcast_batch_matrix_tensor_to_all_dimensions( + batch_values=noise_scheduler._q_bar_matrix_array[0].unsqueeze(0), final_shape=atom_shape + ) + + q_bar_tm1_matrices = broadcast_batch_matrix_tensor_to_all_dimensions( + batch_values=noise_scheduler._q_bar_tm1_matrix_array[0].unsqueeze(0), final_shape=atom_shape + ) + + a1 = torch.randint(0, num_classes, (1, number_of_atoms)) + a1_onehot = class_index_to_onehot(a1, num_classes) + + p_a0_given_a1 = get_probability_at_previous_time_step(logits, + a1_onehot, + q_matrices, + q_bar_matrices, + q_bar_tm1_matrices, + small_epsilon=loss_eps, + probability_at_zeroth_timestep_are_logits=True) + + mask_probability = p_a0_given_a1[..., -1] + torch.testing.assert_allclose(mask_probability, torch.zeros_like(mask_probability)) + + total_probability = p_a0_given_a1.sum(dim=-1) + torch.testing.assert_allclose(total_probability, torch.ones_like(total_probability)) + + +@pytest.fixture() +def logits(batch_size, num_atom_types, num_classes): + return torch.rand(batch_size, num_atom_types, num_classes) + + +@pytest.mark.parametrize("lowest_probability_value", [1e-12, 1e-8, 1e-3]) +def test_get_probability_from_logits_general(logits, lowest_probability_value): + probabilities = get_probability_from_logits(logits, lowest_probability_value) + + approximate_probabilities = torch.nn.functional.softmax(logits, dim=-1) + + torch.testing.assert_close(probabilities, approximate_probabilities) + + computed_sums = probabilities.sum(dim=-1) + torch.testing.assert_close(computed_sums, torch.ones_like(computed_sums)) + + +@pytest.mark.parametrize("lowest_probability_value", [1e-12, 1e-8, 1e-3]) +def test_get_probability_from_logits_some_zero_probabilities(logits, lowest_probability_value): + + mask = torch.randint(0, 2, logits.shape).to(torch.bool) + mask[:, :, 0] = False # make sure no mask is all True. + + edge_case_logits = copy(logits) + edge_case_logits[mask] = -torch.inf + + computed_probabilities = get_probability_from_logits(edge_case_logits, lowest_probability_value) + + computed_sums = computed_probabilities.sum(dim=-1) + torch.testing.assert_close(computed_sums, torch.ones_like(computed_sums)) + + assert torch.all(computed_probabilities[mask] > 0.1 * lowest_probability_value) + + +@pytest.mark.parametrize("lowest_probability_value", [1e-12, 1e-8, 1e-3]) +def test_get_probability_from_logits_pathological(logits, lowest_probability_value): + + mask = torch.randint(0, 2, logits.shape).to(torch.bool) + mask[0, 0, :] = True # All bad logits + + bad_logits = copy(logits) + bad_logits[mask] = -torch.inf + + with pytest.raises(AssertionError): + _ = get_probability_from_logits(bad_logits, lowest_probability_value) diff --git a/tests/utils/test_sample_trajectory.py b/tests/utils/test_sample_trajectory.py index e8699dee..a3c45164 100644 --- a/tests/utils/test_sample_trajectory.py +++ b/tests/utils/test_sample_trajectory.py @@ -1,11 +1,11 @@ from copy import deepcopy -import einops import pytest import torch +from diffusion_for_multi_scale_molecular_dynamics.namespace import AXL from diffusion_for_multi_scale_molecular_dynamics.utils.sample_trajectory import \ - PredictorCorrectorSampleTrajectory + SampleTrajectory @pytest.fixture(autouse=True, scope="module") @@ -38,6 +38,11 @@ def spatial_dimension(): return 3 +@pytest.fixture(scope="module") +def num_classes(): + return 5 + + @pytest.fixture(scope="module") def basis_vectors(batch_size): # orthogonal boxes with dimensions between 5 and 10. @@ -50,22 +55,40 @@ def basis_vectors(batch_size): @pytest.fixture(scope="module") -def list_i_indices(number_of_predictor_steps): +def list_time_indices(number_of_predictor_steps): return torch.arange(number_of_predictor_steps - 1, -1, -1) @pytest.fixture(scope="module") -def list_sigmas(number_of_predictor_steps): - return torch.rand(number_of_predictor_steps) +def predictor_model_outputs( + number_of_predictor_steps, + batch_size, + number_of_atoms, + spatial_dimension, + num_classes, +): + list_scores = [ + AXL( + A=torch.rand(batch_size, number_of_atoms, num_classes), + X=torch.rand(batch_size, number_of_atoms, spatial_dimension), + L=torch.zeros( + batch_size, number_of_atoms, spatial_dimension * (spatial_dimension - 1) + ), # TODO placeholder + ) + for _ in range(number_of_predictor_steps) + ] + return list_scores @pytest.fixture(scope="module") -def list_times(number_of_predictor_steps): - return torch.rand(number_of_predictor_steps) +def list_x_i(number_of_predictor_steps, batch_size, number_of_atoms, spatial_dimension): + return torch.rand( + number_of_predictor_steps, batch_size, number_of_atoms, spatial_dimension + ) @pytest.fixture(scope="module") -def predictor_scores( +def list_x_im1( number_of_predictor_steps, batch_size, number_of_atoms, spatial_dimension ): return torch.rand( @@ -74,31 +97,61 @@ def predictor_scores( @pytest.fixture(scope="module") -def list_x_i(number_of_predictor_steps, batch_size, number_of_atoms, spatial_dimension): - return torch.rand( - number_of_predictor_steps, batch_size, number_of_atoms, spatial_dimension +def list_atom_types_i( + number_of_predictor_steps, batch_size, number_of_atoms, num_classes +): + return torch.randint( + 0, num_classes, (number_of_predictor_steps, batch_size, number_of_atoms) ) @pytest.fixture(scope="module") -def list_x_im1( - number_of_predictor_steps, batch_size, number_of_atoms, spatial_dimension +def list_atom_types_im1( + number_of_predictor_steps, batch_size, number_of_atoms, num_classes ): - return torch.rand( - number_of_predictor_steps, batch_size, number_of_atoms, spatial_dimension + return torch.randint( + 0, num_classes, (number_of_predictor_steps, batch_size, number_of_atoms) ) @pytest.fixture(scope="module") -def corrector_scores( +def list_axl_i(list_x_i, list_atom_types_i): + list_axl = [ + AXL(A=atom_types_i, X=x_i, L=torch.zeros_like(x_i)) + for atom_types_i, x_i in zip(list_atom_types_i, list_x_i) + ] + return list_axl + + +@pytest.fixture(scope="module") +def list_axl_im1(list_x_im1, list_atom_types_im1): + list_axl = [ + AXL(A=atom_types_im1, X=x_im1, L=torch.zeros_like(x_im1)) + for atom_types_im1, x_im1 in zip(list_atom_types_im1, list_x_im1) + ] + return list_axl + + +@pytest.fixture(scope="module") +def corrector_model_outputs( number_of_predictor_steps, number_of_corrector_steps, batch_size, number_of_atoms, spatial_dimension, + num_classes, ): - number_of_scores = number_of_predictor_steps * number_of_corrector_steps - return torch.rand(number_of_scores, batch_size, number_of_atoms, spatial_dimension) + list_scores = [ + AXL( + A=torch.rand(batch_size, number_of_atoms, num_classes), + X=torch.rand(batch_size, number_of_atoms, spatial_dimension), + L=torch.zeros( + batch_size, number_of_atoms, spatial_dimension * (spatial_dimension - 1) + ), # TODO placeholder + ) + for _ in range(number_of_predictor_steps * number_of_corrector_steps) + ] + return list_scores @pytest.fixture(scope="module") @@ -113,6 +166,29 @@ def list_x_i_corr( return torch.rand(number_of_scores, batch_size, number_of_atoms, spatial_dimension) +@pytest.fixture(scope="module") +def list_atom_types_i_corr( + number_of_predictor_steps, + number_of_corrector_steps, + batch_size, + number_of_atoms, + num_classes, +): + number_of_scores = number_of_predictor_steps * number_of_corrector_steps + return torch.randint( + 0, num_classes, (number_of_scores, batch_size, number_of_atoms) + ) + + +@pytest.fixture(scope="module") +def list_axl_i_corr(list_x_i_corr, list_atom_types_i_corr): + list_axl = [ + AXL(A=atom_types_i_corr, X=x_i_corr, L=torch.zeros_like(x_i_corr)) + for atom_types_i_corr, x_i_corr in zip(list_atom_types_i_corr, list_x_i_corr) + ] + return list_axl + + @pytest.fixture(scope="module") def list_corrected_x_i( number_of_predictor_steps, @@ -126,132 +202,132 @@ def list_corrected_x_i( @pytest.fixture(scope="module") -def sample_trajectory( +def list_corrected_atom_types_i( + number_of_predictor_steps, number_of_corrector_steps, - list_i_indices, - list_times, - list_sigmas, - basis_vectors, - list_x_i, - list_x_im1, - predictor_scores, - list_x_i_corr, - list_corrected_x_i, - corrector_scores, + batch_size, + number_of_atoms, + num_classes, ): - sample_trajectory = PredictorCorrectorSampleTrajectory() - sample_trajectory.record_unit_cell(basis_vectors) + number_of_scores = number_of_predictor_steps * number_of_corrector_steps + return torch.randint( + 0, num_classes, (number_of_scores, batch_size, number_of_atoms) + ) + +@pytest.fixture(scope="module") +def list_corrected_axl_i(list_corrected_x_i, list_corrected_atom_types_i): + list_axl = [ + AXL( + A=corrected_atom_types_i, X=corrected_x_i, L=torch.zeros_like(corrected_x_i) + ) + for corrected_atom_types_i, corrected_x_i in zip( + list_corrected_atom_types_i, list_corrected_x_i + ) + ] + return list_axl + + +@pytest.fixture(scope="module") +def sample_trajectory( + number_of_corrector_steps, + list_time_indices, + basis_vectors, + list_axl_i, + list_axl_im1, + predictor_model_outputs, + list_axl_i_corr, + list_corrected_axl_i, + corrector_model_outputs, +): + sample_trajectory_recorder = SampleTrajectory() total_corrector_index = 0 - for i_index, time, sigma, x_i, x_im1, scores in zip( - list_i_indices, list_times, list_sigmas, list_x_i, list_x_im1, predictor_scores + for time_step_index, axl_i, axl_im1, model_predictions_i in zip( + list_time_indices, + list_axl_i, + list_axl_im1, + predictor_model_outputs, ): - sample_trajectory.record_predictor_step( - i_index=i_index, time=time, sigma=sigma, x_i=x_i, x_im1=x_im1, scores=scores - ) + entry = dict(time_step_index=time_step_index, + composition_i=axl_i, + composition_im1=axl_im1, + model_predictions_i=model_predictions_i) + sample_trajectory_recorder.record(key="predictor_step", entry=entry) for _ in range(number_of_corrector_steps): - x_i = list_x_i_corr[total_corrector_index] - corrected_x_i = list_corrected_x_i[total_corrector_index] - scores = corrector_scores[total_corrector_index] - sample_trajectory.record_corrector_step( - i_index=i_index, - time=time, - sigma=sigma, - x_i=x_i, - corrected_x_i=corrected_x_i, - scores=scores, - ) - total_corrector_index += 1 + axl_i = list_axl_i_corr[total_corrector_index] + corrected_axl_i = list_corrected_axl_i[total_corrector_index] + model_predictions_i = corrector_model_outputs[total_corrector_index] + entry = dict(time_step_index=time_step_index, + composition_i=axl_i, + corrected_composition_i=corrected_axl_i, + model_predictions_i=model_predictions_i) + sample_trajectory_recorder.record(key="corrector_step", entry=entry) - return sample_trajectory + total_corrector_index += 1 + return sample_trajectory_recorder -def test_sample_trajectory_unit_cell(sample_trajectory, basis_vectors): - torch.testing.assert_close(sample_trajectory.data["unit_cell"], basis_vectors) +@pytest.fixture(scope="module") +def pickle_data(sample_trajectory, tmp_path_factory): + path_to_pickle = tmp_path_factory.mktemp("sample_trajectory") / "test.pkl" + sample_trajectory.write_to_pickle(path_to_pickle) + data = torch.load(path_to_pickle) + return data -def test_record_predictor( - sample_trajectory, list_times, list_sigmas, list_x_i, list_x_im1, predictor_scores -): - torch.testing.assert_close( - torch.tensor(sample_trajectory.data["predictor_time"]), list_times - ) - torch.testing.assert_close( - torch.tensor(sample_trajectory.data["predictor_sigma"]), list_sigmas - ) - torch.testing.assert_close( - torch.stack(sample_trajectory.data["predictor_x_i"], dim=0), list_x_i - ) - torch.testing.assert_close( - torch.stack(sample_trajectory.data["predictor_x_im1"], dim=0), list_x_im1 - ) - torch.testing.assert_close( - torch.stack(sample_trajectory.data["predictor_scores"], dim=0), predictor_scores - ) +def test_predictor_step(number_of_predictor_steps, + pickle_data, + list_time_indices, + list_axl_i, + list_axl_im1, + predictor_model_outputs): + assert "predictor_step" in pickle_data + predictor_step_data = pickle_data["predictor_step"] -def test_record_corrector( - sample_trajectory, - number_of_corrector_steps, - list_times, - list_sigmas, - list_x_i_corr, - list_corrected_x_i, - corrector_scores, -): + assert len(predictor_step_data) == number_of_predictor_steps - torch.testing.assert_close( - torch.tensor(sample_trajectory.data["corrector_time"]), - torch.repeat_interleave(list_times, number_of_corrector_steps), - ) - torch.testing.assert_close( - torch.tensor(sample_trajectory.data["corrector_sigma"]), - torch.repeat_interleave(list_sigmas, number_of_corrector_steps), - ) - torch.testing.assert_close( - torch.stack(sample_trajectory.data["corrector_x_i"], dim=0), list_x_i_corr - ) - torch.testing.assert_close( - torch.stack(sample_trajectory.data["corrector_corrected_x_i"], dim=0), - list_corrected_x_i, - ) - torch.testing.assert_close( - torch.stack(sample_trajectory.data["corrector_scores"], dim=0), corrector_scores - ) + for step_idx in range(number_of_predictor_steps): + entry = predictor_step_data[step_idx] + assert entry['time_step_index'] == list_time_indices[step_idx] + torch.testing.assert_close(entry['composition_i'], list_axl_i[step_idx]) + torch.testing.assert_close(entry['composition_im1'], list_axl_im1[step_idx]) + torch.testing.assert_close(entry['model_predictions_i'], predictor_model_outputs[step_idx]) -def test_standardize_data_and_write_pickle( - sample_trajectory, - basis_vectors, - list_times, - list_sigmas, - list_x_i, - predictor_scores, - tmp_path, +def test_corrector_step( + number_of_predictor_steps, + number_of_corrector_steps, + pickle_data, + list_time_indices, + list_axl_i_corr, + list_corrected_axl_i, + corrector_model_outputs, ): - pickle_path = str(tmp_path / "test_pickle_path.pkl") - sample_trajectory.write_to_pickle(pickle_path) - with open(pickle_path, "rb") as fd: - standardized_data = torch.load(fd) + assert "corrector_step" in pickle_data + corrector_step_data = pickle_data["corrector_step"] - reordered_scores = einops.rearrange(predictor_scores, "t b n d -> b t n d") - reordered_relative_coordinates = einops.rearrange(list_x_i, "t b n d -> b t n d") + assert len(corrector_step_data) == number_of_predictor_steps * number_of_corrector_steps - torch.testing.assert_close(standardized_data["unit_cell"], basis_vectors) - torch.testing.assert_close(standardized_data["time"], list_times) - torch.testing.assert_close(standardized_data["sigma"], list_sigmas) - torch.testing.assert_close( - standardized_data["relative_coordinates"], reordered_relative_coordinates - ) - torch.testing.assert_close(standardized_data["normalized_scores"], reordered_scores) + global_step_idx = 0 + for predictor_step_idx in range(number_of_predictor_steps): + expected_time_index = list_time_indices[predictor_step_idx] + + for corrector_step_idx in range(number_of_corrector_steps): + entry = corrector_step_data[global_step_idx] + assert entry['time_step_index'] == expected_time_index + torch.testing.assert_close(entry['composition_i'], list_axl_i_corr[global_step_idx]) + torch.testing.assert_close(entry['corrected_composition_i'], list_corrected_axl_i[global_step_idx]) + torch.testing.assert_close(entry['model_predictions_i'], corrector_model_outputs[global_step_idx]) + global_step_idx += 1 -def test_reset(sample_trajectory, tmp_path): +def test_reset(sample_trajectory): # We don't want to affect other tests! copied_sample_trajectory = deepcopy(sample_trajectory) - assert len(copied_sample_trajectory.data.keys()) != 0 + assert len(copied_sample_trajectory._internal_data.keys()) != 0 copied_sample_trajectory.reset() - assert len(copied_sample_trajectory.data.keys()) == 0 + assert len(copied_sample_trajectory._internal_data.keys()) == 0 diff --git a/tests/utils/test_tensor_utils.py b/tests/utils/test_tensor_utils.py index 4d2d5253..a854cf5c 100644 --- a/tests/utils/test_tensor_utils.py +++ b/tests/utils/test_tensor_utils.py @@ -1,8 +1,9 @@ import pytest import torch -from diffusion_for_multi_scale_molecular_dynamics.utils.tensor_utils import \ - broadcast_batch_tensor_to_all_dimensions +from diffusion_for_multi_scale_molecular_dynamics.utils.tensor_utils import ( + broadcast_batch_matrix_tensor_to_all_dimensions, + broadcast_batch_tensor_to_all_dimensions) @pytest.fixture(scope="module", autouse=True) @@ -15,6 +16,11 @@ def batch_values(batch_size): return torch.rand(batch_size) +@pytest.fixture() +def batch_matrix_values(batch_size, num_classes): + return torch.rand(batch_size, num_classes, num_classes) + + @pytest.fixture() def final_shape(batch_size, number_of_dimensions): shape = torch.randint(low=1, high=5, size=(number_of_dimensions,)) @@ -36,3 +42,20 @@ def test_broadcast_batch_tensor_to_all_dimensions( for expected_value, computed_values in zip(batch_values, value_arrays): expected_values = torch.ones_like(computed_values) * expected_value torch.testing.assert_close(expected_values, computed_values) + + +@pytest.mark.parametrize("batch_size", [4, 8]) +@pytest.mark.parametrize("number_of_dimensions", [1, 2, 3]) +@pytest.mark.parametrize("num_classes", [1, 2, 4]) +def test_broadcast_batch_matrix_tensor_to_all_dimensions( + batch_size, batch_matrix_values, final_shape, num_classes +): + broadcast_values = broadcast_batch_matrix_tensor_to_all_dimensions( + batch_matrix_values, final_shape + ) + + value_arrays = broadcast_values.reshape(batch_size, -1, num_classes, num_classes) + + for expected_value, computed_values in zip(batch_matrix_values, value_arrays): + expected_values = torch.ones_like(computed_values) * expected_value + torch.testing.assert_close(expected_values, computed_values)