From e8b91eaba957b82af4330d9a822b791ce19f9bd3 Mon Sep 17 00:00:00 2001 From: Simon Blackburn Date: Thu, 28 Mar 2024 15:10:19 -0400 Subject: [PATCH] missing fnct in model_loader that should be removed in the future --- crystal_diffusion/models/model_loader.py | 26 +++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/crystal_diffusion/models/model_loader.py b/crystal_diffusion/models/model_loader.py index 4a13bcb0..e4d1c1c3 100644 --- a/crystal_diffusion/models/model_loader.py +++ b/crystal_diffusion/models/model_loader.py @@ -6,7 +6,8 @@ ValidOptimizerNames) from crystal_diffusion.models.position_diffusion_lightning_model import ( PositionDiffusionLightningModel, PositionDiffusionParameters) -from crystal_diffusion.models.score_network import MLPScoreNetworkParameters +from crystal_diffusion.models.score_network import (MLPScoreNetwork, + MLPScoreNetworkParameters) from crystal_diffusion.samplers.variance_sampler import NoiseParameters logger = logging.getLogger(__name__) @@ -45,3 +46,26 @@ def load_diffusion_model(hyper_params: Dict[AnyStr, Any]) -> PositionDiffusionLi logger.info('model info:\n' + str(model) + '\n') return model + + +def load_model(hyper_params): # pragma: no cover + """Instantiate a model. + + Args: + hyper_params (dict): hyper parameters from the config file + + Returns: + model (obj): A neural network model object. + """ + architecture = hyper_params['architecture'] + # __TODO__ fix architecture list + if architecture == 'simple_mlp': + model_class = MLPScoreNetwork + else: + raise ValueError('architecture {} not supported'.format(architecture)) + logger.info('selected architecture: {}'.format(architecture)) + + model = model_class(hyper_params) + logger.info('model info:\n' + str(model) + '\n') + + return model