diff --git a/tdc_leaderboard_submission.py b/tdc_leaderboard_submission.py index 7a60ca8..04ad369 100644 --- a/tdc_leaderboard_submission.py +++ b/tdc_leaderboard_submission.py @@ -142,8 +142,59 @@ def __getitem__(self, idx): EPOCHS = 25 REPETITIONS = 5 ENSEMBLE_SIZE = 5 -RESULTS_FILE_PATH = 'predictions_list_final.pkl' -TASK_HEAD_HPARAMS = {'hidden_dim': 512, 'depth': 4, 'combine': True, 'lr': 0.0003} +RESULTS_FILE_PATH = 'results_best_val.pkl' +DEFAULT_HEAD_HPARAMS = {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0003} +MODE = 'best_val' +SWEEP_RESULTS = { + 'best_val': { + 'caco2_wang': {'hidden_dim': 2048, 'depth': 4, 'combine': True, 'lr': 0.0005}, + 'hia_hou': {'hidden_dim': 2048, 'depth': 4, 'combine': True, 'lr': 0.0003}, + 'pgp_broccatelli': {'hidden_dim': 512, 'depth': 4, 'combine': True, 'lr': 0.0003}, + 'bioavailability_ma': {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0003}, + 'lipophilicity_astrazeneca': {'hidden_dim': 2048, 'depth': 4, 'combine': True, 'lr': 0.0005}, + 'solubility_aqsoldb': {'hidden_dim': 1024, 'depth': 4, 'combine': True, 'lr': 0.0005}, + 'bbb_martins': {'hidden_dim': 2048, 'depth': 3, 'combine': True, 'lr': 0.0001}, + 'ppbr_az': {'hidden_dim': 2048, 'depth': 4, 'combine': True, 'lr': 0.0003}, + 'vdss_lombardo': {'hidden_dim': 1024, 'depth': 4, 'combine': True, 'lr': 0.0001}, + 'cyp2d6_veith': {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0001}, + 'cyp3a4_veith': {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0001}, + 'cyp2c9_veith': {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0001}, + 'cyp2d6_substrate_carbonmangels': {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0001}, + 'cyp3a4_substrate_carbonmangels': {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0001}, + 'cyp2c9_substrate_carbonmangels': {'hidden_dim': 1024, 'depth': 3, 'combine': True, 'lr': 0.0005}, + 'half_life_obach': {'hidden_dim': 1024, 'depth': 3, 'combine': True, 'lr': 0.0003}, + 'clearance_microsome_az': {'hidden_dim': 1024, 'depth': 4, 'combine': True, 'lr': 0.0005}, + 'clearance_hepatocyte_az': {'hidden_dim': 2048, 'depth': 4, 'combine': True, 'lr': 0.0005}, + 'herg': {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0003}, + 'ames': {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0001}, + 'dili': {'hidden_dim': 512, 'depth': 4, 'combine': True, 'lr': 0.0005}, + 'ld50_zhu': {'hidden_dim': 1024, 'depth': 4, 'combine': True, 'lr': 0.0001}, + }, + 'best_test': { + 'caco2_wang': {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0003}, + 'hia_hou': {'hidden_dim': 512, 'depth': 4, 'combine': True, 'lr': 0.0001}, + 'pgp_broccatelli': {'hidden_dim': 1024, 'depth': 3, 'combine': True, 'lr': 0.0001}, + 'bioavailability_ma': {'hidden_dim': 512, 'depth': 4, 'combine': True, 'lr': 0.0001}, + 'lipophilicity_astrazeneca': {'hidden_dim': 2048, 'depth': 4, 'combine': True, 'lr': 0.0005}, + 'solubility_aqsoldb': {'hidden_dim': 1024, 'depth': 4, 'combine': True, 'lr': 0.0001}, + 'bbb_martins': {'hidden_dim': 1024, 'depth': 3, 'combine': True, 'lr': 0.0001}, + 'ppbr_az': {'hidden_dim': 2048, 'depth': 4, 'combine': True, 'lr': 0.0003}, + 'vdss_lombardo': {'hidden_dim': 512, 'depth': 4, 'combine': True, 'lr': 0.0003}, + 'cyp2d6_veith': {'hidden_dim': 2048, 'depth': 4, 'combine': True, 'lr': 0.0005}, + 'cyp3a4_veith': {'hidden_dim': 512, 'depth': 4, 'combine': True, 'lr': 0.0005}, + 'cyp2c9_veith': {'hidden_dim': 2048, 'depth': 4, 'combine': True, 'lr': 0.0003}, + 'cyp2d6_substrate_carbonmangels': {'hidden_dim': 2048, 'depth': 3, 'combine': True, 'lr': 0.0003}, + 'cyp3a4_substrate_carbonmangels': {'hidden_dim': 1024, 'depth': 4, 'combine': True, 'lr': 0.0001}, + 'cyp2c9_substrate_carbonmangels': {'hidden_dim': 1024, 'depth': 3, 'combine': True, 'lr': 0.0001}, + 'half_life_obach': {'hidden_dim': 1024, 'depth': 4, 'combine': True, 'lr': 0.0005}, + 'clearance_microsome_az': {'hidden_dim': 2048, 'depth': 4, 'combine': True, 'lr': 0.0005}, + 'clearance_hepatocyte_az': {'hidden_dim': 2048, 'depth': 4, 'combine': True, 'lr': 0.0003}, + 'herg': {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0001}, + 'ames': {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0001}, + 'dili': {'hidden_dim': 512, 'depth': 3, 'combine': True, 'lr': 0.0005}, + 'ld50_zhu': {'hidden_dim': 1024, 'depth': 4, 'combine': True, 'lr': 0.0003}, + }, +} if os.path.exists(RESULTS_FILE_PATH): with open(RESULTS_FILE_PATH, 'rb') as f: @@ -185,7 +236,9 @@ def __getitem__(self, idx): val_loader = DataLoader(AdmetDataset(mols_valid), batch_size=128, shuffle=False) train_loader = DataLoader(AdmetDataset(mols_train), batch_size=32, shuffle=True) - model, optimiser, lr_scheduler, loss_fn = model_factory(**TASK_HEAD_HPARAMS, task=task) + hparams = SWEEP_RESULTS[MODE][dataset_name] + model, optimiser, lr_scheduler, loss_fn = model_factory(**hparams, task=task) + # model, optimiser, lr_scheduler, loss_fn = model_factory(**DEFAULT_HEAD_HPARAMS, task=task) best_epoch = {"model": None, "result": None}