From c3fc3a3b9bba82e9459a5d047b0381aef2980afc Mon Sep 17 00:00:00 2001 From: Blazej Banaszewski Date: Mon, 5 Aug 2024 08:46:26 +0000 Subject: [PATCH] enable batched featurisation + add batch_size parameter --- minimol/model.py | 19 +++++++++++-------- notebooks/downstream_adaptation.ipynb | 20 +++++++------------- 2 files changed, 18 insertions(+), 21 deletions(-) diff --git a/minimol/model.py b/minimol/model.py index 018a0ff..f3df179 100644 --- a/minimol/model.py +++ b/minimol/model.py @@ -4,6 +4,7 @@ from omegaconf import OmegaConf from typing import Union import pkg_resources +from contextlib import redirect_stdout, redirect_stderr from torch_geometric.nn import global_max_pool @@ -23,12 +24,12 @@ class Minimol: - def __init__(self): + def __init__(self, batch_size: int = 100): + self.batch_size = batch_size # handle the paths state_dict_path = pkg_resources.resource_filename('minimol.ckpts.minimol_v1', 'state_dict.pth') config_path = pkg_resources.resource_filename('minimol.ckpts.minimol_v1', 'config.yaml') base_shape_path = pkg_resources.resource_filename('minimol.ckpts.minimol_v1', 'base_shape.yaml') - # Load the config cfg = self.load_config(os.path.basename(config_path)) cfg = OmegaConf.to_container(cfg, resolve=True) @@ -58,6 +59,7 @@ def __init__(self): self.predictor = Fingerprinter(predictor, 'gnn:15') self.predictor.setup() + def load_config(self, config_name): hydra.initialize('ckpts/minimol_v1/', version_base=None) cfg = hydra.compose(config_name=config_name) @@ -65,15 +67,16 @@ def load_config(self, config_name): def __call__(self, smiles: Union[str,list]) -> torch.Tensor: smiles = [smiles] if not isinstance(smiles, list) else smiles - - input_features, _ = self.datamodule._featurize_molecules(smiles) - input_features = self.to_fp32(input_features) - batch_size = min(100, len(input_features)) + batch_size = min(self.batch_size, len(smiles)) results = [] - for i in tqdm(range(0, len(input_features), batch_size)): - batch = Batch.from_data_list(input_features[i:(i + batch_size)]) + for i in tqdm(range(0, len(smiles), batch_size)): + with open(os.devnull, 'w') as fnull, redirect_stdout(fnull), redirect_stderr(fnull): # suppress output + input_features, _ = self.datamodule._featurize_molecules(smiles[i:(i + batch_size)]) + input_features = self.to_fp32(input_features) + + batch = Batch.from_data_list(input_features) batch = {"features": batch, "batch_indices": batch.batch} node_features = self.predictor.get_fingerprints_for_batch(batch) fingerprint_graph = global_max_pool(node_features, batch['batch_indices']) diff --git a/notebooks/downstream_adaptation.ipynb b/notebooks/downstream_adaptation.ipynb index 8ea5431..63c33d6 100644 --- a/notebooks/downstream_adaptation.ipynb +++ b/notebooks/downstream_adaptation.ipynb @@ -42,7 +42,7 @@ "Found local copy...\n", "generating training, validation splits...\n", "generating training, validation splits...\n", - "100%|██████████| 461/461 [00:00<00:00, 3529.71it/s]\n" + "100%|██████████| 461/461 [00:00<00:00, 3648.38it/s]\n" ] } ], @@ -112,33 +112,27 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from minimol import Minimol\n", "\n", - "featuriser = Minimol()" + "featuriser = Minimol(batch_size=50)" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "featurizing_smiles, batch=1: 100%|██████████| 58/58 [00:04<00:00, 11.97it/s]\n", - "Casting to FP32: 100%|██████████| 58/58 [00:00<00:00, 11466.87it/s]\n", - "100%|██████████| 1/1 [00:00<00:00, 5.32it/s]\n", - "featurizing_smiles, batch=3: 100%|██████████| 39/39 [00:00<00:00, 416.25it/s]\n", - "Casting to FP32: 100%|██████████| 117/117 [00:00<00:00, 20868.96it/s]\n", - "100%|██████████| 2/2 [00:00<00:00, 8.08it/s]\n", - "featurizing_smiles, batch=13: 100%|██████████| 31/31 [00:00<00:00, 104.34it/s]\n", - "Casting to FP32: 100%|██████████| 403/403 [00:00<00:00, 15718.65it/s]\n", - "100%|██████████| 5/5 [00:00<00:00, 6.81it/s]\n" + "100%|██████████| 12/12 [00:25<00:00, 2.14s/it]\n", + "100%|██████████| 24/24 [00:01<00:00, 14.06it/s]\n", + "100%|██████████| 81/81 [00:05<00:00, 13.51it/s]\n" ] } ],