From c2ad51c6b9e28d04adfda262fcb9a21046c0dcda Mon Sep 17 00:00:00 2001 From: Blazej Banaszewski Date: Thu, 1 Aug 2024 09:36:36 +0000 Subject: [PATCH 1/4] add script generating results for TDC ADMET leaderboard submission --- tdc_leaderboard_submission.py | 217 ++++++++++++++++++++++++++++++++++ 1 file changed, 217 insertions(+) create mode 100644 tdc_leaderboard_submission.py diff --git a/tdc_leaderboard_submission.py b/tdc_leaderboard_submission.py new file mode 100644 index 0000000..a7109da --- /dev/null +++ b/tdc_leaderboard_submission.py @@ -0,0 +1,217 @@ +from minimol import Minimol + +import os +import math +from copy import deepcopy +import pickle + +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +from torch.optim.lr_scheduler import LambdaLR +from torch.utils.data import DataLoader, Dataset + +from tdc.benchmark_group import admet_group + +from contextlib import redirect_stdout, redirect_stderr + + +class TaskHead(nn.Module): + def __init__(self, hidden_dim=512, input_dim=512, dropout=0.1, depth=3, combine=True): + super(TaskHead, self).__init__() + self.dense1 = nn.Linear(input_dim, hidden_dim) + self.dense2 = nn.Linear(hidden_dim, hidden_dim) + self.dense3 = nn.Linear(hidden_dim, hidden_dim) + self.final_dense = nn.Linear(input_dim + hidden_dim, 1) if combine else nn.Linear(hidden_dim, 1) + self.bn1 = nn.BatchNorm1d(hidden_dim) + self.bn2 = nn.BatchNorm1d(hidden_dim) + self.bn3 = nn.BatchNorm1d(hidden_dim) + self.dropout = nn.Dropout(dropout) + self.combine = combine + self.depth = depth + + def forward(self, x): + original_x = x + + x = self.dense1(x) + x = self.bn1(x) + x = F.relu(x) + x = self.dropout(x) + + x = self.dense2(x) + x = self.bn2(x) + x = F.relu(x) + x = self.dropout(x) + + if self.depth == 4: + x = self.dense3(x) + x = self.bn3(x) + x = F.relu(x) + x = self.dropout(x) + + x = torch.cat((x, original_x), dim=1) if self.combine else x + x = self.final_dense(x) + + return x + + +def model_factory(hidden_dim, depth, combine, task, lr, epochs=25, warmup=5, weight_decay=0.0001): + model = TaskHead(hidden_dim=hidden_dim, depth=depth, combine=combine) + optimiser = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) + + loss_fn = nn.BCELoss() if task == 'classification' else nn.MSELoss() + + def lr_fn(epoch): + if epoch < warmup: return epoch / warmup + else: return (1 + math.cos(math.pi * (epoch - warmup) / (epochs - warmup))) / 2 + + lr_scheduler = LambdaLR(optimiser, lr_lambda=lr_fn) + return model, optimiser, lr_scheduler, loss_fn + + +def cantor_pairing(a, b): + """ + We have two loops one with repetitions and one with folds; + To ensure that each innermost execution is seeded with a unique seed, + we use Cantor Pairing function to combine two seeds into a unique number. + """ + return (a + b) * (a + b + 1) // 2 + b + + +def evaluate(predictor, dataloader, loss_fn, task): + predictor.eval() + total_loss = 0 + + with torch.no_grad(): + for inputs, targets in dataloader: + logits = predictor(inputs).squeeze() + loss = loss_fn(torch.sigmoid(logits), targets) if task == 'classification' else loss_fn(logits, targets) + total_loss += loss.item() + + loss = total_loss / len(dataloader) + + return loss + + +def evaluate_ensemble(predictors, dataloader, task): + predictions = [] + with torch.no_grad(): + + for inputs, _ in dataloader: + ensemble_logits = [predictor(inputs).squeeze() for predictor in predictors] + averaged_logits = torch.mean(torch.stack(ensemble_logits), dim=0) + if task == 'classification': + predictions += torch.sigmoid(averaged_logits) + else: + predictions += averaged_logits + + return predictions + + +def train_one_epoch(predictor, train_loader, optimiser, lr_scheduler, loss_fn, epoch): + predictor.train() + train_loss = 0 + + lr_scheduler.step(epoch) + + for inputs, targets in train_loader: + optimiser.zero_grad() + logits = predictor(inputs).squeeze() + loss = loss_fn(torch.sigmoid(logits), targets) if task == 'classification' else loss_fn(logits, targets) + loss.backward() + optimiser.step() + train_loss += loss.item() + + return predictor + + +class AdmetDataset(Dataset): + def __init__(self, samples): + self.samples = samples['Embedding'].tolist() + self.targets = [float(target) for target in samples['Y'].tolist()] + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + sample = torch.tensor(self.samples[idx]) + target = torch.tensor(self.targets[idx]) + return sample, target + +EPOCHS = 25 +RESULTS_FILE_PATH = 'predictions_list_largermodel.pkl' + +group = admet_group(path='admet_data/') +featuriser = Minimol() + +if os.path.exists(RESULTS_FILE_PATH): + with open(RESULTS_FILE_PATH, 'rb') as f: + predictions_list = pickle.load(f) +else: + predictions_list = [] + +# LOOP 1: datasets +for dataset_i, dataset_name in enumerate(group.dataset_names): + print(f"Dataset {dataset_name}, {dataset_i + 1} / {len(group.dataset_names)}") + if [list(d.keys())[0] for d in predictions_list].count(dataset_name) >= 5: + print(f"There are already 5 scores for the `{dataset_name}` dataset. Skipping.") + continue + + benchmark = group.get(dataset_name) + name = benchmark['name'] + mols_test = benchmark['test'] + with open(os.devnull, 'w') as fnull, redirect_stdout(fnull), redirect_stderr(fnull): # suppress output + mols_test['Embedding'] = featuriser(list(mols_test['Drug'])) + test_loader = DataLoader(AdmetDataset(mols_test), batch_size=128, shuffle=False) + + # LOOP 2: repetitions + for rep_i, seed1 in enumerate([1, 2, 3, 4, 5]): + print(f"\tRepetition {rep_i + 1} / 5") + + predictions = {} + task = 'classification' if len(benchmark['test']['Y'].unique()) == 2 else 'regression' + hparams = {'hidden_dim': 1024, 'depth': 4, 'combine': True, 'lr': 0.0001} + model, optimiser, lr_scheduler, loss_fn = model_factory(**hparams, task=task) + + best_models = [] + + # LOOP3: ensemble on folds + for fold_i, seed2 in enumerate([6, 7, 8, 9, 10]): + print(f"\t\tFold {fold_i + 1} / 5") + seed = cantor_pairing(seed1, seed2) + + with open(os.devnull, 'w') as fnull, redirect_stdout(fnull), redirect_stderr(fnull): # suppress output + mols_train, mols_valid = group.get_train_valid_split(benchmark = name, split_type = 'default', seed = seed) + mols_train['Embedding'] = featuriser(list(mols_train['Drug'])) + mols_valid['Embedding'] = featuriser(list(mols_valid['Drug'])) + val_loader = DataLoader(AdmetDataset(mols_valid), batch_size=128, shuffle=False) + train_loader = DataLoader(AdmetDataset(mols_train), batch_size=32, shuffle=True) + + model, optimiser, lr_scheduler, loss_fn = model_factory(**hparams, task=task) + + best_epoch = {"model": None, "result": None} + + # LOOP4: training loop + for epoch in range(EPOCHS): + model = train_one_epoch(model, train_loader, optimiser, lr_scheduler, loss_fn, epoch) + val_loss = evaluate(model, val_loader, loss_fn, task=task) + + if best_epoch['model'] is None: + best_epoch['model'] = deepcopy(model) + best_epoch['result'] = deepcopy(val_loss) + else: + best_epoch['model'] = best_epoch['model'] if best_epoch['result'] <= val_loss else deepcopy(model) + best_epoch['result'] = best_epoch['result'] if best_epoch['result'] <= val_loss else deepcopy(val_loss) + + best_models.append(deepcopy(best_epoch['model'])) + + y_pred_test = evaluate_ensemble(best_models, test_loader, task) + + predictions[name] = y_pred_test + predictions_list.append(predictions) + + with open(RESULTS_FILE_PATH, 'wb') as f: + pickle.dump(predictions_list, f) + +results = group.evaluate_many(predictions_list) From acb1772d03610f7ddb4e5a51da9906132385ab84 Mon Sep 17 00:00:00 2001 From: Blazej Banaszewski Date: Thu, 1 Aug 2024 14:04:32 +0000 Subject: [PATCH 2/4] align predictions_list with group.evaluate_many expectation --- tdc_leaderboard_submission.py | 61 +++++++++++++++++------------------ 1 file changed, 29 insertions(+), 32 deletions(-) diff --git a/tdc_leaderboard_submission.py b/tdc_leaderboard_submission.py index a7109da..7a60ca8 100644 --- a/tdc_leaderboard_submission.py +++ b/tdc_leaderboard_submission.py @@ -59,7 +59,6 @@ def forward(self, x): def model_factory(hidden_dim, depth, combine, task, lr, epochs=25, warmup=5, weight_decay=0.0001): model = TaskHead(hidden_dim=hidden_dim, depth=depth, combine=combine) optimiser = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) - loss_fn = nn.BCELoss() if task == 'classification' else nn.MSELoss() def lr_fn(epoch): @@ -139,11 +138,12 @@ def __getitem__(self, idx): target = torch.tensor(self.targets[idx]) return sample, target -EPOCHS = 25 -RESULTS_FILE_PATH = 'predictions_list_largermodel.pkl' -group = admet_group(path='admet_data/') -featuriser = Minimol() +EPOCHS = 25 +REPETITIONS = 5 +ENSEMBLE_SIZE = 5 +RESULTS_FILE_PATH = 'predictions_list_final.pkl' +TASK_HEAD_HPARAMS = {'hidden_dim': 512, 'depth': 4, 'combine': True, 'lr': 0.0003} if os.path.exists(RESULTS_FILE_PATH): with open(RESULTS_FILE_PATH, 'rb') as f: @@ -151,33 +151,30 @@ def __getitem__(self, idx): else: predictions_list = [] -# LOOP 1: datasets -for dataset_i, dataset_name in enumerate(group.dataset_names): - print(f"Dataset {dataset_name}, {dataset_i + 1} / {len(group.dataset_names)}") - if [list(d.keys())[0] for d in predictions_list].count(dataset_name) >= 5: - print(f"There are already 5 scores for the `{dataset_name}` dataset. Skipping.") - continue - - benchmark = group.get(dataset_name) - name = benchmark['name'] - mols_test = benchmark['test'] - with open(os.devnull, 'w') as fnull, redirect_stdout(fnull), redirect_stderr(fnull): # suppress output - mols_test['Embedding'] = featuriser(list(mols_test['Drug'])) - test_loader = DataLoader(AdmetDataset(mols_test), batch_size=128, shuffle=False) - - # LOOP 2: repetitions - for rep_i, seed1 in enumerate([1, 2, 3, 4, 5]): - print(f"\tRepetition {rep_i + 1} / 5") - - predictions = {} +group = admet_group(path='admet_data/') +featuriser = Minimol() + +# LOOP 1: repetitions +for rep_i, seed1 in enumerate(range(1, REPETITIONS+1)): + print(f"Repetition {rep_i + 1} / 5") + predictions = {} + + # LOOP 2: datasets + for dataset_i, dataset_name in enumerate(group.dataset_names): + print(f"\tDataset {dataset_name}, {dataset_i + 1} / {len(group.dataset_names)}") + + benchmark = group.get(dataset_name) + name = benchmark['name'] + mols_test = benchmark['test'] + with open(os.devnull, 'w') as fnull, redirect_stdout(fnull), redirect_stderr(fnull): # suppress output + mols_test['Embedding'] = featuriser(list(mols_test['Drug'])) + test_loader = DataLoader(AdmetDataset(mols_test), batch_size=128, shuffle=False) + task = 'classification' if len(benchmark['test']['Y'].unique()) == 2 else 'regression' - hparams = {'hidden_dim': 1024, 'depth': 4, 'combine': True, 'lr': 0.0001} - model, optimiser, lr_scheduler, loss_fn = model_factory(**hparams, task=task) best_models = [] - # LOOP3: ensemble on folds - for fold_i, seed2 in enumerate([6, 7, 8, 9, 10]): + for fold_i, seed2 in enumerate(range(REPETITIONS+1, REPETITIONS+ENSEMBLE_SIZE+1)): print(f"\t\tFold {fold_i + 1} / 5") seed = cantor_pairing(seed1, seed2) @@ -188,7 +185,7 @@ def __getitem__(self, idx): val_loader = DataLoader(AdmetDataset(mols_valid), batch_size=128, shuffle=False) train_loader = DataLoader(AdmetDataset(mols_train), batch_size=32, shuffle=True) - model, optimiser, lr_scheduler, loss_fn = model_factory(**hparams, task=task) + model, optimiser, lr_scheduler, loss_fn = model_factory(**TASK_HEAD_HPARAMS, task=task) best_epoch = {"model": None, "result": None} @@ -209,9 +206,9 @@ def __getitem__(self, idx): y_pred_test = evaluate_ensemble(best_models, test_loader, task) predictions[name] = y_pred_test - predictions_list.append(predictions) - with open(RESULTS_FILE_PATH, 'wb') as f: - pickle.dump(predictions_list, f) + predictions_list.append(predictions) + with open(RESULTS_FILE_PATH, 'wb') as f: pickle.dump(predictions_list, f) results = group.evaluate_many(predictions_list) +print(results) From f61684cadc6a21a9dc32e27e66bd1289d0e4a585 Mon Sep 17 00:00:00 2001 From: Blazej Banaszewski Date: Fri, 2 Aug 2024 14:35:38 +0000 Subject: [PATCH 3/4] add hyperparameter variations per dataset from sweeps --- tdc_leaderboard_submission.py | 59 +++++++++++++++++++++++++++++++++-- 1 file changed, 56 insertions(+), 3 deletions(-) diff --git a/tdc_leaderboard_submission.py b/tdc_leaderboard_submission.py index 7a60ca8..04ad369 100644 --- a/tdc_leaderboard_submission.py +++ b/tdc_leaderboard_submission.py @@ -142,8 +142,59 @@ def __getitem__(self, idx): EPOCHS = 25 REPETITIONS = 5 ENSEMBLE_SIZE = 5 -RESULTS_FILE_PATH = 'predictions_list_final.pkl' -TASK_HEAD_HPARAMS = {'hidden_dim': 512, 'depth': 4, 'combine': True, 'lr': 0.0003} +RESULTS_FILE_PATH = 'results_best_val.pkl' +DEFAULT_HEAD_HPARAMS = {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0003} +MODE = 'best_val' +SWEEP_RESULTS = { + 'best_val': { + 'caco2_wang': {'hidden_dim': 2048, 'depth': 4, 'combine': True, 'lr': 0.0005}, + 'hia_hou': {'hidden_dim': 2048, 'depth': 4, 'combine': True, 'lr': 0.0003}, + 'pgp_broccatelli': {'hidden_dim': 512, 'depth': 4, 'combine': True, 'lr': 0.0003}, + 'bioavailability_ma': {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0003}, + 'lipophilicity_astrazeneca': {'hidden_dim': 2048, 'depth': 4, 'combine': True, 'lr': 0.0005}, + 'solubility_aqsoldb': {'hidden_dim': 1024, 'depth': 4, 'combine': True, 'lr': 0.0005}, + 'bbb_martins': {'hidden_dim': 2048, 'depth': 3, 'combine': True, 'lr': 0.0001}, + 'ppbr_az': {'hidden_dim': 2048, 'depth': 4, 'combine': True, 'lr': 0.0003}, + 'vdss_lombardo': {'hidden_dim': 1024, 'depth': 4, 'combine': True, 'lr': 0.0001}, + 'cyp2d6_veith': {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0001}, + 'cyp3a4_veith': {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0001}, + 'cyp2c9_veith': {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0001}, + 'cyp2d6_substrate_carbonmangels': {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0001}, + 'cyp3a4_substrate_carbonmangels': {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0001}, + 'cyp2c9_substrate_carbonmangels': {'hidden_dim': 1024, 'depth': 3, 'combine': True, 'lr': 0.0005}, + 'half_life_obach': {'hidden_dim': 1024, 'depth': 3, 'combine': True, 'lr': 0.0003}, + 'clearance_microsome_az': {'hidden_dim': 1024, 'depth': 4, 'combine': True, 'lr': 0.0005}, + 'clearance_hepatocyte_az': {'hidden_dim': 2048, 'depth': 4, 'combine': True, 'lr': 0.0005}, + 'herg': {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0003}, + 'ames': {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0001}, + 'dili': {'hidden_dim': 512, 'depth': 4, 'combine': True, 'lr': 0.0005}, + 'ld50_zhu': {'hidden_dim': 1024, 'depth': 4, 'combine': True, 'lr': 0.0001}, + }, + 'best_test': { + 'caco2_wang': {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0003}, + 'hia_hou': {'hidden_dim': 512, 'depth': 4, 'combine': True, 'lr': 0.0001}, + 'pgp_broccatelli': {'hidden_dim': 1024, 'depth': 3, 'combine': True, 'lr': 0.0001}, + 'bioavailability_ma': {'hidden_dim': 512, 'depth': 4, 'combine': True, 'lr': 0.0001}, + 'lipophilicity_astrazeneca': {'hidden_dim': 2048, 'depth': 4, 'combine': True, 'lr': 0.0005}, + 'solubility_aqsoldb': {'hidden_dim': 1024, 'depth': 4, 'combine': True, 'lr': 0.0001}, + 'bbb_martins': {'hidden_dim': 1024, 'depth': 3, 'combine': True, 'lr': 0.0001}, + 'ppbr_az': {'hidden_dim': 2048, 'depth': 4, 'combine': True, 'lr': 0.0003}, + 'vdss_lombardo': {'hidden_dim': 512, 'depth': 4, 'combine': True, 'lr': 0.0003}, + 'cyp2d6_veith': {'hidden_dim': 2048, 'depth': 4, 'combine': True, 'lr': 0.0005}, + 'cyp3a4_veith': {'hidden_dim': 512, 'depth': 4, 'combine': True, 'lr': 0.0005}, + 'cyp2c9_veith': {'hidden_dim': 2048, 'depth': 4, 'combine': True, 'lr': 0.0003}, + 'cyp2d6_substrate_carbonmangels': {'hidden_dim': 2048, 'depth': 3, 'combine': True, 'lr': 0.0003}, + 'cyp3a4_substrate_carbonmangels': {'hidden_dim': 1024, 'depth': 4, 'combine': True, 'lr': 0.0001}, + 'cyp2c9_substrate_carbonmangels': {'hidden_dim': 1024, 'depth': 3, 'combine': True, 'lr': 0.0001}, + 'half_life_obach': {'hidden_dim': 1024, 'depth': 4, 'combine': True, 'lr': 0.0005}, + 'clearance_microsome_az': {'hidden_dim': 2048, 'depth': 4, 'combine': True, 'lr': 0.0005}, + 'clearance_hepatocyte_az': {'hidden_dim': 2048, 'depth': 4, 'combine': True, 'lr': 0.0003}, + 'herg': {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0001}, + 'ames': {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0001}, + 'dili': {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0005}, + 'ld50_zhu': {'hidden_dim': 1024, 'depth': 4, 'combine': True, 'lr': 0.0003}, + }, +} if os.path.exists(RESULTS_FILE_PATH): with open(RESULTS_FILE_PATH, 'rb') as f: @@ -185,7 +236,9 @@ def __getitem__(self, idx): val_loader = DataLoader(AdmetDataset(mols_valid), batch_size=128, shuffle=False) train_loader = DataLoader(AdmetDataset(mols_train), batch_size=32, shuffle=True) - model, optimiser, lr_scheduler, loss_fn = model_factory(**TASK_HEAD_HPARAMS, task=task) + hparams = SWEEP_RESULTS[MODE][dataset_name] + model, optimiser, lr_scheduler, loss_fn = model_factory(**hparams, task=task) + # model, optimiser, lr_scheduler, loss_fn = model_factory(**DEFAULT_HEAD_HPARAMS, task=task) best_epoch = {"model": None, "result": None} From 88bf1cfc0a1f074d680df68c1e8f11fe735d66bd Mon Sep 17 00:00:00 2001 From: Blazej Banaszewski Date: Fri, 2 Aug 2024 14:38:36 +0000 Subject: [PATCH 4/4] add results as a comment --- tdc_leaderboard_submission.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/tdc_leaderboard_submission.py b/tdc_leaderboard_submission.py index 04ad369..4bad0d9 100644 --- a/tdc_leaderboard_submission.py +++ b/tdc_leaderboard_submission.py @@ -265,3 +265,30 @@ def __getitem__(self, idx): results = group.evaluate_many(predictions_list) print(results) + +""" +>> { + 'caco2_wang': [0.35, 0.018], + 'hia_hou': [0.993, 0.005], + 'pgp_broccatelli': [0.942, 0.002], + 'bioavailability_ma': [0.689, 0.02], + 'lipophilicity_astrazeneca': [0.456, 0.008], + 'solubility_aqsoldb': [0.741, 0.013], + 'bbb_martins': [0.924, 0.003], + 'ppbr_az': [7.696, 0.125], + 'vdss_lombardo': [0.535, 0.027], + 'cyp2d6_veith': [0.719, 0.004], + 'cyp3a4_veith': [0.877, 0.001], + 'cyp2c9_veith': [0.823, 0.006], + 'cyp2d6_substrate_carbonmangels': [0.695, 0.032], + 'cyp3a4_substrate_carbonmangels': [0.663, 0.008], + 'cyp2c9_substrate_carbonmangels': [0.474, 0.025], + 'half_life_obach': [0.495, 0.042], + 'clearance_microsome_az': [0.628, 0.005], + 'clearance_hepatocyte_az': [0.446, 0.029], + 'herg': [0.846, 0.016], + 'ames': [0.849, 0.004], + 'dili': [0.956, 0.006], + 'ld50_zhu': [0.585, 0.008] +} +""" \ No newline at end of file