Skip to content

Commit

Permalink
Merge pull request #13 from mila-iqia/diffusion_training
Browse files Browse the repository at this point in the history
Diffusion training
  • Loading branch information
sblackburn86 authored Apr 4, 2024
2 parents b2b94d4 + c908069 commit c8eab3d
Show file tree
Hide file tree
Showing 8 changed files with 368 additions and 28 deletions.
11 changes: 5 additions & 6 deletions crystal_diffusion/data/diffusion/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

from crystal_diffusion.data.diffusion.data_preprocess import \
LammpsProcessorForDiffusion
from crystal_diffusion.utils.hp_utils import check_and_log_hp

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -49,7 +48,7 @@ def __init__(
lot of disk space. Defaults to None.
"""
super().__init__()
check_and_log_hp(["batch_size", "num_workers"], hyper_params) # validate the hyperparameters
# check_and_log_hp(["batch_size", "num_workers"], hyper_params) # validate the hyperparameters
# TODO add the padding parameters for number of atoms
self.lammps_run_dir = lammps_run_dir
self.processed_dataset_dir = processed_dataset_dir
Expand All @@ -68,7 +67,7 @@ def dataset_transform(x: Dict[typing.AnyStr, typing.Any], spatial_dim: int = 3)
Args:
x: raw columns from the processed data files. Should contain natom, box, type, position and
reduced_position.
relative_positions.
spatial_dim (optional): number of spatial dimensions. Defaults to 3.
Returns:
Expand All @@ -78,7 +77,7 @@ def dataset_transform(x: Dict[typing.AnyStr, typing.Any], spatial_dim: int = 3)
transformed_x['natom'] = torch.as_tensor(x['natom']).long() # resulting tensor size: (batchsize, )
bsize = transformed_x['natom'].size(0)
transformed_x['box'] = torch.as_tensor(x['box']) # size: (batchsize, spatial dimension)
for pos in ['position', 'reduced_position']:
for pos in ['position', 'relative_positions']:
transformed_x[pos] = torch.as_tensor(x[pos]).view(bsize, -1, spatial_dim)
transformed_x['type'] = torch.as_tensor(x['type']).long() # size: (batchsize, max atom)

Expand All @@ -89,7 +88,7 @@ def pad_samples(x: Dict[typing.AnyStr, typing.Any], max_atom: int, spatial_dim:
"""Pad a sample for batching.
Args:
x: initial sample from the dataset. Should contain natom, position, reduced_position and type.
x: initial sample from the dataset. Should contain natom, position, relative_positions and type.
max_atom: maximum number of atoms to pad to
spatial_dim (optional): number of spatial dimensions. Defaults to 3.
Expand All @@ -100,7 +99,7 @@ def pad_samples(x: Dict[typing.AnyStr, typing.Any], max_atom: int, spatial_dim:
if natom > max_atom:
raise ValueError(f"Hyper-parameter max_atom is smaller than an example in the dataset with {natom} atoms.")
x['type'] = F.pad(torch.as_tensor(x['type']).long(), (0, max_atom - natom), 'constant', -1)
for pos in ['position', 'reduced_position']:
for pos in ['position', 'relative_positions']:
x[pos] = F.pad(torch.as_tensor(x[pos]).float(), (0, spatial_dim * (max_atom - natom)), 'constant',
torch.nan)
return x
Expand Down
18 changes: 9 additions & 9 deletions crystal_diffusion/data/diffusion/data_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,30 +55,30 @@ def prepare_data(self, raw_data_dir: str, mode: str = 'train') -> List[str]:
return list_files

@staticmethod
def _convert_coords_to_reduced(row: pd.Series) -> List[float]:
"""Convert a dataframe row to reduced coordinates.
def _convert_coords_to_relative(row: pd.Series) -> List[float]:
"""Convert a dataframe row to relative coordinates.
Args:
row: entry in the dataframe. Should contain box, x, y and z
Returns:
x, y and z in reduced coordinates
x, y and z in relative (reduced) coordinates
"""
x_lim, y_lim, z_lim = row['box']
coord_red = [coord for triple in zip(row['x'], row['y'], row['z']) for coord in
(triple[0] / x_lim, triple[1] / y_lim, triple[2] / z_lim)]
return coord_red

def get_x_reduced(self, df: pd.DataFrame) -> pd.DataFrame:
"""Add a column with reduced x,y, z coordinates.
def get_x_relative(self, df: pd.DataFrame) -> pd.DataFrame:
"""Add a column with relative x,y, z coordinates.
Args:
df: dataframe with atomic positions. Should contain box, x, y and z.
Returns:
dataframe with added column of reduced positions [x1, y1, z1, x2, y2, ...]
dataframe with added column of relative positions [x1, y1, z1, x2, y2, ...]
"""
df['reduced_position'] = df.apply(lambda x: self._convert_coords_to_reduced(x), axis=1)
df['relative_positions'] = df.apply(lambda x: self._convert_coords_to_relative(x), axis=1)
return df

def parse_lammps_run(self, run_dir: str) -> Optional[pd.DataFrame]:
Expand Down Expand Up @@ -114,11 +114,11 @@ def parse_lammps_run(self, run_dir: str) -> Optional[pd.DataFrame]:
# TODO consider filtering out samples with large forces and MD steps that are too similar
# TODO large force and similar are to be defined
df = df[['type', 'x', 'y', 'z', 'box']]
df = self.get_x_reduced(df) # add reduced coordinates
df = self.get_x_relative(df) # add relative coordinates
df['natom'] = df['type'].apply(lambda x: len(x)) # count number of atoms in a structure
# naive implementation: a list of list which is converted into a 2d array by torch later
# but a list of list is not ok with the writing on files with parquet
df['position'] = df.apply(lambda x: [j for i in ['x', 'y', 'z'] for j in x[i]], axis=1) # position as 3d array
# position is natom * 3 array
# TODO unit test to check the order after reshape
return df[['natom', 'box', 'type', 'position', 'reduced_position']]
return df[['natom', 'box', 'type', 'position', 'relative_positions']]
45 changes: 44 additions & 1 deletion crystal_diffusion/models/model_loader.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,53 @@
"""Functions to instantiate a model based on the provided hyperparameters."""
import logging
from typing import Any, AnyStr, Dict

from crystal_diffusion.models.score_network import MLPScoreNetwork
from crystal_diffusion.models.optimizer import (OptimizerParameters,
ValidOptimizerNames)
from crystal_diffusion.models.position_diffusion_lightning_model import (
PositionDiffusionLightningModel, PositionDiffusionParameters)
from crystal_diffusion.models.score_network import (MLPScoreNetwork,
MLPScoreNetworkParameters)
from crystal_diffusion.samplers.variance_sampler import NoiseParameters

logger = logging.getLogger(__name__)


def load_diffusion_model(hyper_params: Dict[AnyStr, Any]) -> PositionDiffusionLightningModel:
"""Load a position diffusion model from the hyperparameters.
Args:
hyper_params: dictionary of hyperparameters loaded from a config file
Returns:
Diffusion model randomly initialized
"""
score_network_parameters = MLPScoreNetworkParameters(
number_of_atoms=hyper_params['data']['max_atom'],
**hyper_params['model']['score_network']
)
score_network_parameters.spatial_dimension = hyper_params.get('spatial_dimension', 3)

hyper_params['optimizer']['name'] = ValidOptimizerNames(hyper_params['optimizer']['name'])

optimizer_parameters = OptimizerParameters(
**hyper_params['optimizer']
)

noise_parameters = NoiseParameters(**hyper_params['model']['noise'])

diffusion_params = PositionDiffusionParameters(
score_network_parameters=score_network_parameters,
optimizer_parameters=optimizer_parameters,
noise_parameters=noise_parameters,
)

model = PositionDiffusionLightningModel(diffusion_params)
logger.info('model info:\n' + str(model) + '\n')

return model


def load_model(hyper_params): # pragma: no cover
"""Instantiate a model.
Expand Down
Loading

0 comments on commit c8eab3d

Please sign in to comment.