Skip to content

Commit

Permalink
missing fnct in model_loader that should be removed in the future
Browse files Browse the repository at this point in the history
  • Loading branch information
sblackburn-mila committed Mar 28, 2024
1 parent 5dfffac commit e8b91ea
Showing 1 changed file with 25 additions and 1 deletion.
26 changes: 25 additions & 1 deletion crystal_diffusion/models/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
ValidOptimizerNames)
from crystal_diffusion.models.position_diffusion_lightning_model import (
PositionDiffusionLightningModel, PositionDiffusionParameters)
from crystal_diffusion.models.score_network import MLPScoreNetworkParameters
from crystal_diffusion.models.score_network import (MLPScoreNetwork,
MLPScoreNetworkParameters)
from crystal_diffusion.samplers.variance_sampler import NoiseParameters

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -45,3 +46,26 @@ def load_diffusion_model(hyper_params: Dict[AnyStr, Any]) -> PositionDiffusionLi
logger.info('model info:\n' + str(model) + '\n')

return model


def load_model(hyper_params): # pragma: no cover
"""Instantiate a model.
Args:
hyper_params (dict): hyper parameters from the config file
Returns:
model (obj): A neural network model object.
"""
architecture = hyper_params['architecture']
# __TODO__ fix architecture list
if architecture == 'simple_mlp':
model_class = MLPScoreNetwork
else:
raise ValueError('architecture {} not supported'.format(architecture))
logger.info('selected architecture: {}'.format(architecture))

model = model_class(hyper_params)
logger.info('model info:\n' + str(model) + '\n')

return model

0 comments on commit e8b91ea

Please sign in to comment.