diff --git a/minimol/model.py b/minimol/model.py index f3df179..8be7f6b 100644 --- a/minimol/model.py +++ b/minimol/model.py @@ -8,6 +8,7 @@ from torch_geometric.nn import global_max_pool +from graphium.trainer.predictor import PredictorModule from graphium.finetuning.fingerprinting import Fingerprinter from graphium.config._loader import ( load_accelerator, @@ -24,12 +25,17 @@ class Minimol: - def __init__(self, batch_size: int = 100): + def __init__(self, batch_size: int = 100, base_path: str = None): self.batch_size = batch_size # handle the paths - state_dict_path = pkg_resources.resource_filename('minimol.ckpts.minimol_v1', 'state_dict.pth') - config_path = pkg_resources.resource_filename('minimol.ckpts.minimol_v1', 'config.yaml') - base_shape_path = pkg_resources.resource_filename('minimol.ckpts.minimol_v1', 'base_shape.yaml') + if base_path is None: + state_dict_path = pkg_resources.resource_filename('minimol.ckpts.minimol_v1', 'state_dict.pth') + config_path = pkg_resources.resource_filename('minimol.ckpts.minimol_v1', 'config.yaml') + base_shape_path = pkg_resources.resource_filename('minimol.ckpts.minimol_v1', 'base_shape.yaml') + else: + ckpt_path = os.path.join(base_path, 'model.ckpt') + config_path = os.path.join(base_path, 'config.yaml') + base_shape_path = os.path.join(base_path, 'base_shape.yaml') # Load the config cfg = self.load_config(os.path.basename(config_path)) cfg = OmegaConf.to_container(cfg, resolve=True) @@ -42,21 +48,26 @@ def __init__(self, batch_size: int = 100): # Load the model model_class, model_kwargs = load_architecture(cfg, in_dims=self.datamodule.in_dims) metrics = load_metrics(self.cfg) - predictor = load_predictor( - config=self.cfg, - model_class=model_class, - model_kwargs=model_kwargs, - metrics=metrics, - task_levels=self.datamodule.get_task_levels(), - accelerator_type=accelerator_type, - featurization=self.datamodule.featurization, - task_norms=self.datamodule.task_norms, - replicas=1, - gradient_acc=1, - global_bs=self.datamodule.batch_size_training, - ) - predictor.load_state_dict(torch.load(state_dict_path), strict=False) - self.predictor = Fingerprinter(predictor, 'gnn:15') + + if base_path is None: + predictor = load_predictor( + config=self.cfg, + model_class=model_class, + model_kwargs=model_kwargs, + metrics=metrics, + task_levels=self.datamodule.get_task_levels(), + accelerator_type=accelerator_type, + featurization=self.datamodule.featurization, + task_norms=self.datamodule.task_norms, + replicas=1, + gradient_acc=1, + global_bs=self.datamodule.batch_size_training, + ) + predictor.load_state_dict(torch.load(state_dict_path), strict=False) + else: + predictor = PredictorModule.load_pretrained_model(name_or_path=ckpt_path, device=accelerator_type) + + self.predictor = Fingerprinter(predictor, 'graph_output_nn-graph:0') self.predictor.setup() @@ -78,8 +89,9 @@ def __call__(self, smiles: Union[str,list]) -> torch.Tensor: batch = Batch.from_data_list(input_features) batch = {"features": batch, "batch_indices": batch.batch} - node_features = self.predictor.get_fingerprints_for_batch(batch) - fingerprint_graph = global_max_pool(node_features, batch['batch_indices']) + fingerprint_graph = self.predictor.get_fingerprints_for_batch(batch) + # node_features = self.predictor.get_fingerprints_for_batch(batch) + # fingerprint_graph = global_max_pool(node_features, batch['batch_indices']) num_molecules = min(batch_size, fingerprint_graph.shape[0]) results += [fingerprint_graph[i] for i in range(num_molecules)] diff --git a/tdc_leaderboard_submission.py b/tdc_leaderboard_submission.py index 865cab4..a9cec15 100644 --- a/tdc_leaderboard_submission.py +++ b/tdc_leaderboard_submission.py @@ -142,7 +142,7 @@ def __getitem__(self, idx): EPOCHS = 25 REPETITIONS = 5 ENSEMBLE_SIZE = 5 -RESULTS_FILE_PATH = 'results.pkl' +RESULTS_FILE_PATH = 'results_gpu.pkl' SWEEP_RESULTS = { 'caco2_wang': {'hidden_dim': 2048, 'depth': 4, 'combine': True, 'lr': 0.0005}, 'hia_hou': {'hidden_dim': 2048, 'depth': 4, 'combine': True, 'lr': 0.0003}, @@ -175,7 +175,7 @@ def __getitem__(self, idx): predictions_list = [] group = admet_group(path='admet_data/') -featuriser = Minimol() +featuriser = Minimol(base_path='/nethome/blazejb/minimol/minimol/ckpts/minimol_v1_gpu') # LOOP 1: repetitions for rep_i, seed1 in enumerate(range(1, REPETITIONS+1)):