diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..30bfc9a --- /dev/null +++ b/Dockerfile @@ -0,0 +1,51 @@ +FROM nvcr.io/nvidia/pytorch:22.01-py3 as base + +#create a new new user +RUN useradd -ms /bin/bash shengchaol + +# #change to this user +# USER shengchaol + +#set working directory +WORKDIR /home/shengchaol + +RUN chmod -R 777 /home/shengchaol +RUN chmod -R 777 /usr/bin +RUN chmod -R 777 /bin +RUN chmod -R 777 /usr/local +RUN chmod -R 777 /opt/conda + +RUN conda install -y python=3.7 + +RUN conda install -y -c rdkit rdkit=2020.09.1.0 +RUN conda install -y -c conda-forge -c pytorch pytorch=1.9.1 + +RUN conda install -y -c pyg -c conda-forge pyg + +RUN pip install requests +RUN pip install tqdm +RUN pip install matplotlib +RUN pip install spacy + +# for SciBert +RUN conda install -y boto3 +RUN pip install transformers + +# for MoleculeNet +RUN pip install ogb + +# install pysmilesutils +RUN python -m pip install git+https://github.com/MolecularAI/pysmilesutils.git + +RUN pip install deepspeed + +# install Megatron +RUN cd /tmp && git clone https://github.com/MolecularAI/MolBART.git --branch megatron-molbart-with-zinc && cd /tmp/MolBART/megatron_molbart/Megatron-LM-v1.1.5-3D_parallelism && pip install . + +# install apex +RUN cd /tmp && git clone https://github.com/chao1224/apex.git +RUN cd /tmp/apex/ && pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ + + +#expose port for Jupyter +EXPOSE 8888 \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..251e78b --- /dev/null +++ b/LICENSE @@ -0,0 +1,64 @@ +NVIDIA Source Code License for MoleculeSTM + +1. Definitions + +“Licensor” means any person or entity that distributes its Work. + +“Software” means the original work of authorship made available under this License. + +“Work” means the Software and any additions to or derivative works of the Software that are made available under +this License. + +The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under +U.S. copyright law; provided, however, that for the purposes of this License, derivative works shall not include +works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work. + +Works, including the Software, are “made available” under this License by including in or with the Work either +(a) a copyright notice referencing the applicability of this License to the Work, or (b) a copy of this License. + +2. License Grant + +2.1 Copyright Grant. Subject to the terms and conditions of this License, each Licensor grants to you a perpetual, +worldwide, non-exclusive, royalty-free, copyright license to reproduce, prepare derivative works of, publicly +display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form. + +3. Limitations + +3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this License, (b) you +include a complete copy of this License with your distribution, and (c) you retain without modification any +copyright, patent, trademark, or attribution notices that are present in the Work. + +3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and +distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use +limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works +that are subject to Your Terms. Notwithstanding Your Terms, this License (including the redistribution +requirements in Section 3.1) will continue to apply to the Work itself. + +3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use +non-commercially. Notwithstanding the foregoing, NVIDIA and its affiliates may use the Work and any derivative +works commercially. As used herein, “non-commercially” means for research or evaluation purposes only. + +3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, +cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then +your rights under this License from such Licensor (including the grant in Section 2.1) will terminate immediately. + +3.5 Trademarks. This License does not grant any rights to use any Licensor’s or its affiliates’ names, logos, +or trademarks, except as necessary to reproduce the notices described in this License. + +3.6 Termination. If you violate any term of this License, then your rights under this License (including the +grant in Section 2.1) will terminate immediately. + +4. Disclaimer of Warranty. + +THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING +WARRANTIES OR CONDITIONS OF M ERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU +BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE. + +5. Limitation of Liability. + +EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING +NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, +INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR +INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR +DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER COMM ERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN +ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. \ No newline at end of file diff --git a/MoleculeSTM/__init__.py b/MoleculeSTM/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/MoleculeSTM/backup/downstream_language_edit_step_00_check_reconstruction.py b/MoleculeSTM/backup/downstream_language_edit_step_00_check_reconstruction.py new file mode 100644 index 0000000..3f10668 --- /dev/null +++ b/MoleculeSTM/backup/downstream_language_edit_step_00_check_reconstruction.py @@ -0,0 +1,106 @@ +import argparse +import os +import numpy as np +from rdkit import Chem +from rdkit.Chem import Descriptors + +import torch +from torch.utils.data import DataLoader as torch_DataLoader + +from MoleculeSTM.utils import freeze_network +from MoleculeSTM.datasets import ZINC15_Datasets_Only_SMILES, PubChem_Datasets_Only_SMILES +from MoleculeSTM.models.mega_molbart.mega_mol_bart import MegaMolBART + +props = [ + "qed", "MolWt", "MolLogP", "TPSA", + "HeavyAtomCount", "NumAromaticRings", "NumHAcceptors", "NumHDonors", "NumRotatableBonds" +] +props = [ + "MolWt", "MolLogP" +] +prop_pred = [(n, func) for n, func in Descriptors.descList if n.split("_")[-1] in props] + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--device", type=int, default=0) + parser.add_argument("--verbose", type=int, default=1) + parser.add_argument("--dataspace_path", type=str, default="../../Datasets") + parser.add_argument("--dataset", type=str, default="ZINC15") + parser.add_argument("--molecule_type", type=str, default="MegaMolBART", choices=["MegaMolBART", "Graph"]) + + ########## for MoleculeSTM ########## + parser.add_argument("--CLIP_input_model_dir", type=str, default="../../pretrained_model") + parser.add_argument("--SSL_emb_dim", type=int, default=256) + + ########## for generation ########## + parser.add_argument("--generation_model_dir", type=str, default="../../Datasets/pretrained_MegaMolBART/checkpoints") + + ########## for optimization ########## + parser.add_argument("--batch_size", type=int, default=64) + parser.add_argument("--num_workers", type=int, default=8) + + args = parser.parse_args() + print(args) + + # This is loading from the pretarined_MegaMolBART + MegaMolBART_wrapper = MegaMolBART(input_dir=args.generation_model_dir, output_dir=None) + molecule_model_generation = MegaMolBART_wrapper.model + print("Loading from pretrained MegaMolBART ({}).".format(args.generation_model_dir)) + molecule_dim_generation = 256 + + device = torch.device("cuda:" + str(args.device)) \ + if torch.cuda.is_available() else torch.device("cpu") + molecule_model_generation = molecule_model_generation.to(device) + + np.random.seed(args.seed) + torch.random.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + device = torch.device("cuda:" + str(args.device)) \ + if torch.cuda.is_available() else torch.device("cpu") + + freeze_network(molecule_model_generation) + molecule_model_generation.eval() + + if args.molecule_type == "MegaMolBART": + if args.dataset == "ZINC15": + dataset_root = os.path.join(args.dataspace_path, "ZINC15_data") + dataset = ZINC15_Datasets_Only_SMILES(dataset_root) + elif "PubChem" in args.dataset: + dataset_root = os.path.join(args.dataspace_path, "PubChem_data") + dataset = PubChem_Datasets_Only_SMILES(dataset_root) + else: + raise Exception + dataloader_class = torch_DataLoader + else: + raise Exception + + dataloader = dataloader_class(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) + + for batch_idx, batch in enumerate(dataloader): + SMILES_list = batch + print("SMILES_list", SMILES_list) + + for original_SMILES in SMILES_list: + mol = Chem.MolFromSmiles(original_SMILES) + for name, func in prop_pred: + value = func(mol) + print("{}: {}".format(name, value)) + canon_original_SMILES = Chem.MolToSmiles(mol) + + latent_code_init, pad_mask_init = MegaMolBART_wrapper.smileslist2embedding_model_given(molecule_model_generation, [original_SMILES]) # [pad, B, d], [pad, B] + print("latent_code:\t", latent_code_init[0, :, :5]) + + latent_code_init, pad_mask_init = MegaMolBART_wrapper.smileslist2embedding_model_given(molecule_model_generation, [canon_original_SMILES]) # [pad, B, d], [pad, B] + print("latent_code:\t", latent_code_init[0, :, :5]) + + generated_SMILES = MegaMolBART_wrapper.inverse_transform([latent_code_init], pad_mask_init.bool().cuda(), k=1, sanitize=True) + print("original SMILES: \t", original_SMILES) + print("original SMILES (canon): \t", canon_original_SMILES) + print("reconstructured SMILES: \t", generated_SMILES[0]) + print() + + if batch_idx >= 9: + break diff --git a/MoleculeSTM/backup/downstream_language_edit_step_01_molecule_representation_align.py b/MoleculeSTM/backup/downstream_language_edit_step_01_molecule_representation_align.py new file mode 100644 index 0000000..bea3fc9 --- /dev/null +++ b/MoleculeSTM/backup/downstream_language_edit_step_01_molecule_representation_align.py @@ -0,0 +1,229 @@ +import argparse +import os +import numpy as np +from tqdm import tqdm +import time + +import torch +import torch.nn as nn +from torch import optim +import torch.nn.functional as F +from torch.utils.data import DataLoader as torch_DataLoader + +from MoleculeSTM.utils import get_molecule_repr_MoleculeSTM +from MoleculeSTM.downstream_language_edit_utils import load_molecule_models +from MoleculeSTM.utils import freeze_network +from MoleculeSTM.datasets import PubChem_Datasets_Only_SMILES + + +def cycle_index(num, shift): + arr = torch.arange(num) + shift + arr[-shift:] = torch.arange(shift) + return arr + + +def do_CL(X, Y, args): + if args.normalize: + X = F.normalize(X, dim=-1) + Y = F.normalize(Y, dim=-1) + + if args.SSL_loss == 'EBM_NCE': + criterion = nn.BCEWithLogitsLoss() + neg_Y = torch.cat([Y[cycle_index(len(Y), i + 1)] for i in range(args.CL_neg_samples)], dim=0) + neg_X = X.repeat((args.CL_neg_samples, 1)) + + pred_pos = torch.sum(X * Y, dim=1) / args.T + pred_neg = torch.sum(neg_X * neg_Y, dim=1) / args.T + + loss_pos = criterion(pred_pos, torch.ones(len(pred_pos)).to(pred_pos.device)) + loss_neg = criterion(pred_neg, torch.zeros(len(pred_neg)).to(pred_neg.device)) + CL_loss = (loss_pos + args.CL_neg_samples * loss_neg) / (1 + args.CL_neg_samples) + + CL_acc = (torch.sum(pred_pos > 0).float() + torch.sum(pred_neg < 0).float()) / \ + (len(pred_pos) + len(pred_neg)) + CL_acc = CL_acc.detach().cpu().item() + + elif args.SSL_loss == 'InfoNCE': + criterion = nn.CrossEntropyLoss() + B = X.size()[0] + logits = torch.mm(X, Y.transpose(1, 0)) # B*B + logits = torch.div(logits, args.T) + labels = torch.arange(B).long().to(logits.device) # B*1 + + CL_loss = criterion(logits, labels) + pred = logits.argmax(dim=1, keepdim=False) + CL_acc = pred.eq(labels).sum().detach().cpu().item() * 1. / B + + else: + raise Exception + + return CL_loss, CL_acc + + +def mean_pooling(token_embeddings, attention_mask): + input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() # [pad, B, d] + sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 0) # [B, d] + sum_mask = torch.clamp(input_mask_expanded.sum(0), min=1e-9) # [B, d] + return sum_embeddings / sum_mask + + +def get_molecule_repr_generation(molecule_data, molecule_model, molecule_type="MegaMolBART", MegaMolBART_wrapper=None): + if molecule_type == "MegaMolBART": + embedding, pad_mask = MegaMolBART_wrapper.smileslist2embedding_model_given(molecule_model, molecule_data) # [pad, B, d], [pad, B] + # molecule_repr = embedding[0, :, :] # [B, d] + # next we will take the mean pooling instead of the CLS token. + molecule_repr = mean_pooling(embedding, pad_mask) + else: + molecule_repr, _ = molecule_model(molecule_data) + return molecule_repr + + +def save_model(save_best, epoch=None): + if args.output_model_dir is not None: + if save_best: + global optimal_loss + print("save model with loss: {:.5f}".format(optimal_loss)) + model_file = "model.pth" + + elif epoch is None: + model_file = "model_final.pth" + + else: + model_file = "model_{}.pth".format(epoch) + + saved_file_path = os.path.join(args.output_model_dir, "generation2MoleculeSTM_{}".format(model_file)) + torch.save(generation2MoleculeSTM.state_dict(), saved_file_path) + + saved_file_path = os.path.join(args.output_model_dir, "MoleculeSTM2generation_{}".format(model_file)) + torch.save(MoleculeSTM2generation.state_dict(), saved_file_path) + return + + +def train(epoch): + if args.verbose: + L = tqdm(dataloader) + else: + L = dataloader + + start_time = time.time() + accum_loss, accum_acc = 0, 0 + for batch in L: + SMILES_list = batch + + molecule_repr_generation = get_molecule_repr_generation( + SMILES_list, molecule_model=molecule_model_generation, + molecule_type="MegaMolBART", MegaMolBART_wrapper=MegaMolBART_wrapper + ) + molecule_repr_generation2MoleculeSTM = generation2MoleculeSTM(molecule_repr_generation) + + molecule_repr_MoleculeSTM = get_molecule_repr_MoleculeSTM( + SMILES_list, molecule_model=molecule_model_MoleculeSTM, mol2latent=mol2latent_MoleculeSTM, + molecule_type="MegaMolBART", MegaMolBART_wrapper=MegaMolBART_wrapper + ) + molecule_repr_MoleculeSTM2generation = MoleculeSTM2generation(molecule_repr_MoleculeSTM) + + loss_01, acc_01 = do_CL(molecule_repr_generation, molecule_repr_MoleculeSTM2generation, args) + loss_02, acc_02 = do_CL(molecule_repr_MoleculeSTM, molecule_repr_generation2MoleculeSTM, args) + loss = (loss_01 + loss_02) / 2 + acc = (acc_01 + acc_02) / 2 + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + accum_loss += loss.item() + accum_acc += acc + + accum_loss /= len(L) + accum_acc /= len(L) + + global optimal_loss + temp_loss = accum_loss + if temp_loss < optimal_loss: + optimal_loss = temp_loss + save_model(save_best=True, epoch=epoch) + print("CL Loss: {:.5f}\tCL Acc: {:.5f}Time: {:.5f}".format(accum_loss, accum_acc, time.time() - start_time)) + return + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--device", type=int, default=0) + parser.add_argument("--verbose", type=int, default=1) + parser.add_argument("--dataspace_path", type=str, default="../../Datasets") + parser.add_argument("--dataset", type=str, default="PubChem") + parser.add_argument("--molecule_type", type=str, default="MegaMolBART", choices=["MegaMolBART", "Graph"]) + parser.add_argument("--output_model_dir", type=str, default=None) + + ########## for MoleculeSTM ########## + parser.add_argument("--MoleculeSTM_model_dir", type=str, default="../../pretrained_model_Raw") + parser.add_argument("--SSL_emb_dim", type=int, default=256) + + ########## for generation ########## + parser.add_argument("--generation_model_dir", type=str, default="../../Datasets/pretrained_MegaMolBART/checkpoints") + + ########## for optimization ########## + parser.add_argument("--batch_size", type=int, default=256) + parser.add_argument("--num_workers", type=int, default=8) + parser.add_argument("--epochs", type=int, default=100) + parser.add_argument("--decay", type=float, default=0) + parser.add_argument("--generation_lr", type=float, default=1e-4) + parser.add_argument("--MoleculeSTM_lr", type=float, default=1e-4) + parser.add_argument("--T", type=float, default=0.1) + parser.add_argument("--SSL_loss", type=str, default="EBM_NCE", choices=["EBM_NCE", "InfoNCE"]) + parser.add_argument("--CL_neg_samples", type=int, default=1) + parser.add_argument('--normalize', dest='normalize', action='store_true') + parser.add_argument('--no_normalize', dest='normalize', action='store_false') + parser.set_defaults(normalize=True) + + args = parser.parse_args() + print(args) + + MegaMolBART_wrapper, molecule_model_generation, molecule_dim_generation, \ + molecule_model_MoleculeSTM, mol2latent_MoleculeSTM, molecule_dim_MoleculeSTM = load_molecule_models(args) + device = torch.device("cuda:" + str(args.device)) \ + if torch.cuda.is_available() else torch.device("cpu") + molecule_model_generation = molecule_model_generation.to(device) + molecule_model_MoleculeSTM = molecule_model_MoleculeSTM.to(device) + mol2latent_MoleculeSTM = mol2latent_MoleculeSTM.to(device) + + np.random.seed(args.seed) + torch.random.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + device = torch.device("cuda:" + str(args.device)) \ + if torch.cuda.is_available() else torch.device("cpu") + + freeze_network(molecule_model_generation) + freeze_network(mol2latent_MoleculeSTM) + freeze_network(molecule_model_MoleculeSTM) + molecule_model_generation.eval() + mol2latent_MoleculeSTM.eval() + molecule_model_MoleculeSTM.eval() + + if args.molecule_type == "MegaMolBART": + if "PubChem" in args.dataset: + dataset_root = os.path.join(args.dataspace_path, "PubChem_data") + else: + raise Exception + dataset = PubChem_Datasets_Only_SMILES(dataset_root) + dataloader_class = torch_DataLoader + else: + raise Exception + + dataloader = dataloader_class(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) + + generation2MoleculeSTM = nn.Linear(molecule_dim_generation, molecule_dim_MoleculeSTM).to(device) + MoleculeSTM2generation = nn.Linear(molecule_dim_MoleculeSTM, molecule_dim_generation).to(device) + + model_param_group = [ + {"params": generation2MoleculeSTM.parameters(), "lr": args.generation_lr}, + {"params": MoleculeSTM2generation.parameters(), "lr": args.MoleculeSTM_lr}, + ] + optimizer = optim.Adam(model_param_group, weight_decay=args.decay) + optimal_loss = 1e10 + + for e in range(1, args.epochs+1): + print("Epoch {}".format(e)) + train(e) diff --git a/MoleculeSTM/backup/downstream_language_edit_step_02_latent_optimization.py b/MoleculeSTM/backup/downstream_language_edit_step_02_latent_optimization.py new file mode 100644 index 0000000..5bf5129 --- /dev/null +++ b/MoleculeSTM/backup/downstream_language_edit_step_02_latent_optimization.py @@ -0,0 +1,161 @@ +import argparse +import math +import numpy as np +from rdkit import Chem, RDLogger + +import torch +from torch import optim +import torch.nn.functional as F +from tqdm import tqdm +from downstream_language_edit_utils import load_language_molecule_and_edit_models, clip_loss_for_edit, evaluate_SMILES_list +from MoleculeSTM.utils import prepare_text_tokens + + +def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05): + lr_ramp = min(1, (1 - t) / rampdown) + lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi) + lr_ramp = lr_ramp * min(1, t / rampup) + return initial_lr * lr_ramp + + +def mean_pooling(token_embeddings, attention_mask): + input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() # [pad, B, d] + sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 0) # [B, d] + sum_mask = torch.clamp(input_mask_expanded.sum(0), min=1e-9) # [B, d] + return sum_embeddings / sum_mask + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--device", type=int, default=0) + parser.add_argument("--verbose", type=int, default=1) + + ########## for editing ########## + parser.add_argument("--description", type=str) + parser.add_argument("--input_model_dir", type=str) + parser.add_argument("--mode", type=str, default="edit", choices=["edit", "free_generation"]) + parser.add_argument("--input_SMILES", type=str, default=None) + parser.add_argument("--l2_lambda", type=float, default=0.008) + + ########## for ? ########## + parser.add_argument("--dataspace_path", type=str, default="../../Datasets") + parser.add_argument("--SSL_emb_dim", type=int, default=256) + parser.add_argument("--max_seq_len", type=int, default=512) + + ########## for MoleculeSTM ########## + parser.add_argument("--MoleculeSTM_model_dir", type=str, default="../../pretrained_model_Raw") + + ########## for generation ########## + parser.add_argument("--generation_model_dir", type=str, default="../../Datasets/pretrained_MegaMolBART/checkpoints") + + ########## for MoleculeSTM and generation projection ########## + parser.add_argument("--language_edit_model_dir", type=str, default="edit_temp/EBM_NCE") + + ########## for editing ########## + parser.add_argument("--lr_rampup", type=float, default=0.05) + parser.add_argument("--lr", type=float, default=0.1) + parser.add_argument("--epochs", type=int, default=100) + args = parser.parse_args() + + print(args) + + text_model, text_tokenizer, text_dim, molecule_model, MegaMolBART_wrapper, molecule_dim, \ + text2latent, mol2latent, generation2MoleculeSTM, MoleculeSTM2generation = load_language_molecule_and_edit_models(args) + device = torch.device("cuda:" + str(args.device)) \ + if torch.cuda.is_available() else torch.device("cpu") + text_model = text_model.to(device) + molecule_model = molecule_model.to(device) + text2latent = text2latent.to(device) + mol2latent = mol2latent.to(device) + generation2MoleculeSTM.to(device) + MoleculeSTM2generation.to(device) + text_model.eval() + molecule_model.eval() + text2latent.eval() + mol2latent.eval() + generation2MoleculeSTM.eval() + MoleculeSTM2generation.eval() + + np.random.seed(args.seed) + torch.random.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + device = torch.device("cuda:" + str(args.device)) \ + if torch.cuda.is_available() else torch.device("cpu") + + description_list = [args.description] + text_tokens_ids, text_masks = prepare_text_tokens( + device=device, description=description_list, tokenizer=text_tokenizer, max_seq_len=args.max_seq_len) + text_output = text_model(input_ids=text_tokens_ids, attention_mask=text_masks) + text_repr = text_output["pooler_output"] + text_repr = text2latent(text_repr) + + record_SMILES_list = [] + + if args.mode == "edit": + SMILES_list = [args.input_SMILES] + latent_code_init, pad_mask_init = MegaMolBART_wrapper.smileslist2embedding(SMILES_list) # [pad, B, d], [pad, B] + molecule_repr_generation_init = mean_pooling(latent_code_init, pad_mask_init) # [B, d] + # record_SMILES_list.append(args.input_SMILES) + else: + padding_dim = 10 + latent_code_init = torch.randn(padding_dim, 1, molecule_dim).to(device) + pad_mask_init = torch.zeros(padding_dim, 1).bool().to(device) + print("latent_code_init", latent_code_init.size()) + print("pad_mask_init", pad_mask_init.size()) + + generated_mols = MegaMolBART_wrapper.inverse_transform( + [latent_code_init], pad_mask_init.bool().cuda(), k=1, sanitize=True) + print("initial SMILES", generated_mols[0]) + record_SMILES_list.append(generated_mols[0]) + + l2_lambda_list = [ + 1, 0.1, 0.01, 0.001, 0.0001, + 3, 0.3, 0.03, 0.003, 0.0003, + 5, 0.5, 0.05, 0.005, 0.0005, + 8, 0.8, 0.08, 0.008, 0.0008, + ] + l2_lambda_list = [ + 0.1, + ] + + for l2_lambda in l2_lambda_list: + result_SMILES_list = [record_SMILES_list[0]] + print("with lambda {} ......".format(l2_lambda)) + latent = latent_code_init.detach().clone() + latent.requires_grad = True + optimizer = optim.Adam([latent], lr=args.lr) + + if args.verbose: + L = tqdm(range(args.epochs)) + else: + L = range(args.epochs) + for i in L: + t = i / args.epochs + lr = get_lr(t, args.lr) + optimizer.param_groups[0]["lr"] = lr + + molecule_repr_generation = mean_pooling(latent, pad_mask_init) # [B, d] + # molecule_repr_MoleculeSTM = generation2MoleculeSTM(molecule_repr_generation) + + clip_loss_ = clip_loss_for_edit(molecule_repr_generation, mol2latent, text_repr) + l2_loss_ = args.l2_lambda * ((latent_code_init - latent) ** 2).sum() + + loss = clip_loss_ + l2_loss_ + print(clip_loss_.item(), l2_loss_.item()) + + optimizer.zero_grad() + loss.backward(retain_graph=True) + optimizer.step() + print("clip loss: {:.5f}\tL2 loss: {:.5f}".format(clip_loss_.item(), args.l2_lambda * l2_loss_)) + + generated_mols = MegaMolBART_wrapper.inverse_transform( + [latent], pad_mask_init.bool().cuda(), k=1, sanitize=True) + # print("generated_mols",generated_mols[0]) + # Chem.SanitizeMol(generated_mols[0]) + print("final SMILES", generated_mols[0]) + result_SMILES_list.append(generated_mols[0]) + + evaluate_SMILES_list(result_SMILES_list) + print() diff --git a/MoleculeSTM/backup/downstream_language_edit_utils.py b/MoleculeSTM/backup/downstream_language_edit_utils.py new file mode 100644 index 0000000..34c5237 --- /dev/null +++ b/MoleculeSTM/backup/downstream_language_edit_utils.py @@ -0,0 +1,130 @@ +import os +import copy +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import AutoModel, AutoTokenizer +from MoleculeSTM.models.mega_molbart.mega_mol_bart import MegaMolBART +from rdkit import Chem, RDLogger +from rdkit.Chem import AllChem, Descriptors +lg = RDLogger.logger() +lg.setLevel(RDLogger.CRITICAL) + + +def load_molecule_models(args): + """ + This function returns the two encoders, one for molecule generative model and one for CLIP. + TODO: now we adopt MegaMolBART for both. Will make this more flexible in the future. + """ + # This is loading from the pretarined_MegaMolBART + MegaMolBART_wrapper = MegaMolBART(input_dir=args.generation_model_dir, output_dir=None) + molecule_model_generation = copy.deepcopy(MegaMolBART_wrapper.model) + print("Loading from pretrained MegaMolBART ({}).".format(args.generation_model_dir)) + molecule_dim_generation = 256 + + input_model_path = os.path.join(args.MoleculeSTM_model_dir, "molecule_model.pth") + molecule_model_MoleculeSTM = MegaMolBART_wrapper.model + state_dict = torch.load(input_model_path, map_location='cpu') + print("Loading from {}...".format(input_model_path)) + molecule_model_MoleculeSTM.load_state_dict(state_dict) + molecule_dim_MoleculeSTM = args.SSL_emb_dim + + mol2latent_MoleculeSTM = nn.Linear(256, molecule_dim_MoleculeSTM) + input_model_path = os.path.join(args.MoleculeSTM_model_dir, "mol2latent_model.pth") + print("Loading from {}...".format(input_model_path)) + state_dict = torch.load(input_model_path, map_location='cpu') + mol2latent_MoleculeSTM.load_state_dict(state_dict) + return MegaMolBART_wrapper, molecule_model_generation, molecule_dim_generation, \ + molecule_model_MoleculeSTM, mol2latent_MoleculeSTM, molecule_dim_MoleculeSTM + + +def load_language_molecule_and_edit_models(args): + pretrained_SciBERT_folder = os.path.join(args.dataspace_path, 'pretrained_SciBERT') + text_tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased', cache_dir=pretrained_SciBERT_folder) + # TODO: check https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py#L1501 + text_model = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased', cache_dir=pretrained_SciBERT_folder) + text_dim = 768 + + input_model_path = os.path.join(args.MoleculeSTM_model_dir, "text_model.pth") + print("Loading from {}...".format(input_model_path)) + state_dict = torch.load(input_model_path, map_location='cpu') + text_model.load_state_dict(state_dict) + + """ + input_model_path = os.path.join(args.MoleculeSTM_model_dir, "molecule_model.pth") + print("Loading from {}...".format(input_model_path)) + MegaMolBART_wrapper = MegaMolBART(input_dir=None, output_dir=None) + molecule_model = MegaMolBART_wrapper.model + state_dict = torch.load(input_model_path, map_location='cpu') + molecule_model.load_state_dict(state_dict) + """ + # This is loading from the pretarined_MegaMolBART + MegaMolBART_wrapper = MegaMolBART(input_dir=args.generation_model_dir, output_dir=None) + molecule_model = MegaMolBART_wrapper.model + print("Loading from pretrained MegaMolBART ({}).".format(args.generation_model_dir)) + molecule_dim_generation = 256 + molecule_dim_MoleculeSTM = 256 + + text2latent = nn.Linear(text_dim, args.SSL_emb_dim) + input_model_path = os.path.join(args.MoleculeSTM_model_dir, "text2latent_model.pth") + print("Loading from {}...".format(input_model_path)) + state_dict = torch.load(input_model_path, map_location='cpu') + text2latent.load_state_dict(state_dict) + + mol2latent = nn.Linear(molecule_dim_generation, args.SSL_emb_dim) + input_model_path = os.path.join(args.MoleculeSTM_model_dir, "mol2latent_model.pth") + print("Loading from {}...".format(input_model_path)) + state_dict = torch.load(input_model_path, map_location='cpu') + mol2latent.load_state_dict(state_dict) + + generation2MoleculeSTM = nn.Linear(molecule_dim_generation, molecule_dim_MoleculeSTM) + input_model_path = os.path.join(args.language_edit_model_dir, "generation2MoleculeSTM_model.pth") + print("Loading from {}...".format(input_model_path)) + state_dict = torch.load(input_model_path, map_location='cpu') + generation2MoleculeSTM.load_state_dict(state_dict) + + MoleculeSTM2generation = nn.Linear(molecule_dim_MoleculeSTM, molecule_dim_generation) + input_model_path = os.path.join(args.language_edit_model_dir, "MoleculeSTM2generation_model.pth") + print("Loading from {}...".format(input_model_path)) + state_dict = torch.load(input_model_path, map_location='cpu') + MoleculeSTM2generation.load_state_dict(state_dict) + + return text_model, text_tokenizer, text_dim, molecule_model, MegaMolBART_wrapper, molecule_dim_generation, text2latent, mol2latent, generation2MoleculeSTM, MoleculeSTM2generation + + +def clip_loss_for_edit(molecule_repr, mol2latent, text_repr): + # molecule_repr = F.normalize(molecule_repr, dim=-1) + # molecule_repr = mol2latent(molecule_repr) + molecule_repr = F.normalize(molecule_repr, dim=-1) + + text_repr = F.normalize(text_repr, dim=-1) + + similarity = -torch.mm(molecule_repr, text_repr.transpose(0, 1))[0] + return similarity + + +def evaluate_SMILES_list(SMILES_list): + print("SMILES_list:") + print(SMILES_list) + mol_list = [] + for SMILES in SMILES_list: + mol = Chem.MolFromSmiles(SMILES) + # Chem.SanitizeMol(mol) + # print(SMILES, mol) + if mol is None: + continue + mol_list.append(mol) + print("mol_list", len(mol_list)) + + print() + props = ["MolWt", "MolLogP", "TPSA", "qed"] + props = ["MolLogP"] + prop_pred = [(n, func) for n, func in Descriptors.descList if n.split("_")[-1] in props] + for name, func in prop_pred: + print("evaluating with {}".format(name)) + for SMILES, mol in zip(SMILES_list, mol_list): + value = func(mol) + print("====={} & {:.5f}".format(SMILES, value)) + print() + + return \ No newline at end of file diff --git a/MoleculeSTM/bart_vocab.txt b/MoleculeSTM/bart_vocab.txt new file mode 100644 index 0000000..4842bc2 --- /dev/null +++ b/MoleculeSTM/bart_vocab.txt @@ -0,0 +1,523 @@ + +? +^ +& + + +LogD_change_(-0.1, 0.1] +LogD_change_(0.1, 0.3] +LogD_change_(0.3, 0.5] +LogD_change_(0.5, 0.7] +LogD_change_(0.7, 0.9] +LogD_change_(0.9, 1.1] +LogD_change_(1.1, 1.3] +LogD_change_(1.3, 1.5] +LogD_change_(1.5, 1.7] +LogD_change_(1.7, 1.9] +LogD_change_(1.9, 2.1] +LogD_change_(2.1, 2.3] +LogD_change_(2.3, 2.5] +LogD_change_(2.5, 2.7] +LogD_change_(2.7, 2.9] +LogD_change_(2.9, 3.1] +LogD_change_(3.1, 3.3] +LogD_change_(3.3, 3.5] +LogD_change_(3.5, 3.7] +LogD_change_(3.7, 3.9] +LogD_change_(3.9, 4.1] +LogD_change_(4.1, 4.3] +LogD_change_(4.3, 4.5] +LogD_change_(4.5, 4.7] +LogD_change_(4.7, 4.9] +LogD_change_(4.9, 5.1] +LogD_change_(5.1, 5.3] +LogD_change_(5.3, 5.5] +LogD_change_(5.5, 5.7] +LogD_change_(5.7, 5.9] +LogD_change_(5.9, inf] +LogD_change_(-0.3, -0.1] +LogD_change_(-0.5, -0.3] +LogD_change_(-0.7, -0.5] +LogD_change_(-0.9, -0.7] +LogD_change_(-1.1, -0.9] +LogD_change_(-1.3, -1.1] +LogD_change_(-1.5, -1.3] +LogD_change_(-1.7, -1.5] +LogD_change_(-1.9, -1.7] +LogD_change_(-2.1, -1.9] +LogD_change_(-2.3, -2.1] +LogD_change_(-2.5, -2.3] +LogD_change_(-2.7, -2.5] +LogD_change_(-2.9, -2.7] +LogD_change_(-3.1, -2.9] +LogD_change_(-3.3, -3.1] +LogD_change_(-3.5, -3.3] +LogD_change_(-3.7, -3.5] +LogD_change_(-3.9, -3.7] +LogD_change_(-4.1, -3.9] +LogD_change_(-4.3, -4.1] +LogD_change_(-4.5, -4.3] +LogD_change_(-4.7, -4.5] +LogD_change_(-4.9, -4.7] +LogD_change_(-5.1, -4.9] +LogD_change_(-5.3, -5.1] +LogD_change_(-5.5, -5.3] +LogD_change_(-5.7, -5.5] +LogD_change_(-inf, -5.7] +Solubility_low->high +Solubility_high->low +Solubility_no_change +Clint_low->high +Clint_high->low +Clint_no_change + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +C +c +1 +( +- +2 +s +N += +) +n +Br +3 +O +[nH] +[C@H] +/ +\ +[C@@H] +[C@] +o +4 +[C@@] +5 +6 +7 +[n+] +. +[Br-] +S +I +F +Cl +[O-] +# +P +[Na+] +[N+] +[Cl-] +8 +[Si] +[N-] +[18F] +9 +[I-] +B +[S+] +[2H] +[P+] +[125I] +%10 +%11 +[n-] +[B-] +[O] +[o+] +[N@@+] +[N@+] +[PH] +[se] +[Se] +[s+] +[Li+] +[P@] +[P@@] +[3H] +[K+] +[OH-] +[te+] +[se+] +[Te] +[S@+] +[BH2-] +[11CH3] +[11C] +[Cl+3] +[N@@] +[S-] +[C+] +[P@+] +[C-] +[Zn+2] +[Ca+2] +[Mg+2] +[SeH] +[BH-] +[SH2] +[TeH2] +[SiH4] +[N@] +[14CH2] +[Ag] +[S@@+] +[I+] +[MgH2] +[125IH] +[Se+] +[As] +[SiH2] +[Ra] +[IH2] +[P-] +[Na] +[NH-] +[Cs] +[Zn] +[Li] +[As+] +[te] +[131I] +[Ag+] +p +[Al-3] +[Rb+] +[13CH] +[11CH] +[11CH2] +[14c] +[14C@@] +[127I] +[Mg] +[14C] +[123I] +[124I] +[F-] +[14CH3] +[135I] +%12 +[NH+] +[76Br] +[32PH] +[35S] +b +[73Se] +[11C@@H] +[Se-] +[c+] +[14C@H] +[18OH] +[SH] +[S@@] +[Cs+] +[He] +[O+] +[nH+] +[NH2+] +[32P] +[Zn+] +[BH3-] +[I+3] +[Si-] +[SH+] +[19F] +[Br+2] +[I+2] +[Al+3] +[123I-] +[131I-] +[127Xe] +[133Xe] +[89Sr+2] +[N] +[82Rb+] +[75Se] +[Rb] +[81Kr] +[18F-] +[13NH3] +[K] +[Cl+2] +[Zn-2] +[SeH2] +[AsH3] +[Kr] +[Xe] +[S@] +[cH-] +[NH4+] +[Al] +[Si@] +[15n] +[SH-] +[SiH] +[11c] +[OH] +[c-] +[18FH] +[123IH] +[13c] +[13cH] +[14cH] +[Cl+] +[Sr+2] +[CH-] +[Bi] +[B] +[Ba+2] +[Bi+3] +[SiH-] +[b-] +[H+] +[13C] +[OH+] +[14CH] +[15nH] +[CaH2] +[LiH] +[C] +[Ca] +[42K+] +[123Te] +[S-2] +[223Ra+2] +[S] +[Ra+2] +[22Na+] +[125I-] +[85Sr+2] +[PH2] +[SrH2] +[15OH2] +[47Ca+2] +[85SrH2] +[45Ca+2] +[B@@-] +[B@-] +[17F] +[PH2+] +[11C-] +[Mg+] +[NaH] +[P@@+] +[SiH3-] +[P-3] +[KH] +[Be+2] +[NH3+] +[Ag-4] +[18O] +[14C@@H] +[CH] +[CH2] +[O-2] +[124I-] +[As-] +[Ba] +[223Ra] +[82Rb] +[76BrH] +[AsH] +[131Cs] \ No newline at end of file diff --git a/MoleculeSTM/cuchemcommon/__init__.py b/MoleculeSTM/cuchemcommon/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/MoleculeSTM/cuchemcommon/context.py b/MoleculeSTM/cuchemcommon/context.py new file mode 100644 index 0000000..f78eb19 --- /dev/null +++ b/MoleculeSTM/cuchemcommon/context.py @@ -0,0 +1,53 @@ +import logging +import os +from configparser import RawConfigParser +from io import StringIO + +from MoleculeSTM.cuchemcommon.utils.singleton import Singleton + +logger = logging.getLogger(__name__) + +CONFIG_FILE = '.env' + + +class Context(metaclass=Singleton): + + def __init__(self): + + self.dask_client = None + self.compute_type = 'gpu' + self.is_benchmark = False + self.benchmark_file = None + self.cache_directory = None + self.n_molecule = None + self.batch_size = 10000 + + self.config = {} + if os.path.exists(CONFIG_FILE): + logger.info('Reading properties from %s...', CONFIG_FILE) + self.config = self._load_properties_file(CONFIG_FILE) + else: + logger.warn('Could not locate %s', CONFIG_FILE) + + def _load_properties_file(self, properties_file): + """ + Reads a properties file using ConfigParser. + + :param propertiesFile/configFile: + """ + config_file = open(properties_file, 'r') + config_content = StringIO('[root]\n' + config_file.read()) + config = RawConfigParser() + config.read_file(config_content) + + return config._sections['root'] + + def get_config(self, config_name, default=None): + """ + Returns values from local configuration. + """ + try: + return self.config[config_name] + except KeyError: + logger.warn('%s not found, returing default.', config_name) + return default diff --git a/MoleculeSTM/cuchemcommon/data/__init__.py b/MoleculeSTM/cuchemcommon/data/__init__.py new file mode 100644 index 0000000..3a07d30 --- /dev/null +++ b/MoleculeSTM/cuchemcommon/data/__init__.py @@ -0,0 +1,45 @@ +from typing import List + + +class ClusterWfDAO(object): + """ + Base class for all DAO for fetching data for Clustering Workflows + """ + + def meta_df(self): + """ + Returns df with dtype set for structure without any column filter. + """ + return NotImplemented + + def fetch_molecular_embedding(self, n_molecules: int, cache_directory: str = None): + """ + Fetch molecular properties from database/cache into a dask array. + """ + return NotImplemented + + def fetch_molecular_embedding_by_id(self, molecule_id: List): + """ + Fetch molecular properties from database for the given id. Id depends on + the backend databse. For chemble DB it should be molregid. + """ + return NotImplemented + + def fetch_id_from_smile(self, new_molecules: List): + """ + Fetch molecular details for a list of molecules. The values in the list + of molecules depends on database/service used. For e.g. it could be + ChemblId or molreg_id for Chemble database. + """ + return NotImplemented + + +class GenerativeWfDao(object): + + def fetch_id_from_chembl(self, id: List): + """ + Fetch molecular details for a list of molecules. The values in the list + of molecules depends on database/service used. For e.g. it could be + ChemblId or molreg_id for Chemble database. + """ + return NotImplemented diff --git a/MoleculeSTM/cuchemcommon/data/__pycache__/__init__.cpython-37.pyc b/MoleculeSTM/cuchemcommon/data/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000..06d3caa Binary files /dev/null and b/MoleculeSTM/cuchemcommon/data/__pycache__/__init__.cpython-37.pyc differ diff --git a/MoleculeSTM/cuchemcommon/data/cluster_wf.py b/MoleculeSTM/cuchemcommon/data/cluster_wf.py new file mode 100644 index 0000000..6462d5f --- /dev/null +++ b/MoleculeSTM/cuchemcommon/data/cluster_wf.py @@ -0,0 +1,61 @@ +import logging +import math +import os +from typing import List + +import cudf +import dask +import dask_cudf +from cuchemcommon.context import Context +from cuchemcommon.data.helper.chembldata import BATCH_SIZE, ChEmblData +from cuchemcommon.utils.singleton import Singleton + +from . import ClusterWfDAO + +logger = logging.getLogger(__name__) + +FINGER_PRINT_FILES = 'filter_*.h5' + + +class ChemblClusterWfDao(ClusterWfDAO, metaclass=Singleton): + + def __init__(self, fp_type): + self.chem_data = ChEmblData(fp_type) + + def meta_df(self): + chem_data = ChEmblData() + return chem_data._meta_df() + + def fetch_molecular_embedding(self, + n_molecules: int, + cache_directory: str = None): + context = Context() + if cache_directory: + hdf_path = os.path.join(cache_directory, FINGER_PRINT_FILES) + logger.info('Reading %d rows from %s...', n_molecules, hdf_path) + mol_df = dask.dataframe.read_hdf(hdf_path, 'fingerprints') + + if n_molecules > 0: + npartitions = math.ceil(n_molecules / BATCH_SIZE) + mol_df = mol_df.head(n_molecules, compute=False, npartitions=npartitions) + else: + logger.info('Reading molecules from database...') + mol_df = self.chem_data.fetch_mol_embedding(num_recs=n_molecules, + batch_size=context.batch_size) + + return mol_df + + def fetch_molecular_embedding_by_id(self, molecule_id: List): + context = Context() + meta = self.chem_data._meta_df() + fp_df = self.chem_data._fetch_mol_embedding(molregnos=molecule_id, + batch_size=context.batch_size) \ + .astype(meta.dtypes) + + fp_df = cudf.from_pandas(fp_df) + fp_df = dask_cudf.from_cudf(fp_df, npartitions=1).reset_index() + return fp_df + + def fetch_id_from_chembl(self, new_molecules: List): + logger.debug('Fetch ChEMBL ID using molregno...') + return self.chem_data.fetch_id_from_chembl(new_molecules) diff --git a/MoleculeSTM/cuchemcommon/data/generative_wf.py b/MoleculeSTM/cuchemcommon/data/generative_wf.py new file mode 100644 index 0000000..9e16a2d --- /dev/null +++ b/MoleculeSTM/cuchemcommon/data/generative_wf.py @@ -0,0 +1,19 @@ +import logging +from typing import List + +from cuchemcommon.data.helper.chembldata import ChEmblData +from cuchemcommon.utils.singleton import Singleton + +from . import GenerativeWfDao + +logger = logging.getLogger(__name__) + + +class ChemblGenerativeWfDao(GenerativeWfDao, metaclass=Singleton): + + def __init__(self, fp_type): + self.chem_data = ChEmblData(fp_type) + + def fetch_id_from_chembl(self, id: List): + logger.debug('Fetch ChEMBL ID using molregno...') + return self.chem_data.fetch_id_from_chembl(id) diff --git a/MoleculeSTM/cuchemcommon/data/helper/__init__.py b/MoleculeSTM/cuchemcommon/data/helper/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/MoleculeSTM/cuchemcommon/data/helper/chembldata.py b/MoleculeSTM/cuchemcommon/data/helper/chembldata.py new file mode 100644 index 0000000..7b0d272 --- /dev/null +++ b/MoleculeSTM/cuchemcommon/data/helper/chembldata.py @@ -0,0 +1,320 @@ +import os +import warnings +import pandas +import sqlite3 +import logging + +from typing import List +from dask import delayed, dataframe + +from contextlib import closing +from cuchemcommon.utils.singleton import Singleton +from cuchemcommon.context import Context + +warnings.filterwarnings("ignore", message=r"deprecated", category=FutureWarning) +logger = logging.getLogger(__name__) + +BATCH_SIZE = 100000 +ADDITIONAL_FEILD = ['canonical_smiles', 'transformed_smiles'] +IMP_PROPS = [ + 'alogp', + 'aromatic_rings', + 'full_mwt', + 'psa', + 'rtb'] +IMP_PROPS_TYPE = [pandas.Series([], dtype='float64'), + pandas.Series([], dtype='int64'), + pandas.Series([], dtype='float64'), + pandas.Series([], dtype='float64'), + pandas.Series([], dtype='int64')] +ADDITIONAL_FEILD_TYPE = [pandas.Series([], dtype='object'), + pandas.Series([], dtype='object')] + +SQL_MOLECULAR_PROP = """ +SELECT md.molregno as molregno, md.chembl_id, cp.*, cs.* +FROM compound_properties cp, + compound_structures cs, + molecule_dictionary md +WHERE cp.molregno = md.molregno + AND md.molregno = cs.molregno + AND md.molregno in (%s) +""" + + +# DEPRECATED. Please add code to DAO classes. +class ChEmblData(object, metaclass=Singleton): + + def __init__(self, fp_type): + + context = Context() + db_file = context.get_config('data_mount_path', default='/data') + db_file = os.path.join(db_file, 'db/chembl_27.db') + + if not os.path.exists(db_file): + logger.error('%s not found', db_file) + raise Exception('{} not found'.format(db_file)) + + self.fp_type = fp_type + self.chembl_db = 'file:%s?mode=ro' % db_file + + logger.info('ChEMBL database: %s...' % self.chembl_db) + + def fetch_props_by_molregno(self, molregnos): + """ + Returns compound properties and structure filtered by ChEMBL IDs along + with a list of columns. + """ + with closing(sqlite3.connect(self.chembl_db, uri=True)) as con, con, \ + closing(con.cursor()) as cur: + select_stmt = SQL_MOLECULAR_PROP % " ,".join(list(map(str, molregnos))) + cur.execute(select_stmt) + + cols = list(map(lambda x: x[0], cur.description)) + return cols, cur.fetchall() + + def fetch_props_by_chemble(self, chemble_ids): + """ + Returns compound properties and structure filtered by ChEMBL IDs along + with a list of columns. + """ + sql_stml = """ + SELECT md.molregno as molregno, md.chembl_id, cp.*, cs.* + FROM compound_properties cp, + compound_structures cs, + molecule_dictionary md + WHERE cp.molregno = md.molregno + AND md.molregno = cs.molregno + AND md.chembl_id in (%s) + """ + with closing(sqlite3.connect(self.chembl_db, uri=True)) as con, con, \ + closing(con.cursor()) as cur: + select_stmt = sql_stml % "'%s'" % "','".join([x.strip().upper() for x in chemble_ids]) + cur.execute(select_stmt) + + cols = list(map(lambda x: x[0], cur.description)) + return cols, cur.fetchall() + + def fetch_molregno_by_chemblId(self, chemblIds): + logger.debug('Fetch ChEMBL ID using molregno...') + with closing(sqlite3.connect(self.chembl_db, uri=True)) as con, con, \ + closing(con.cursor()) as cur: + select_stmt = ''' + SELECT md.molregno as molregno + FROM compound_properties cp, + compound_structures cs, + molecule_dictionary md + WHERE cp.molregno = md.molregno + AND md.molregno = cs.molregno + AND md.chembl_id in (%s) + ''' % "'%s'" % "','".join(chemblIds) + cur.execute(select_stmt) + return cur.fetchall() + + def fetch_id_from_chembl(self, new_molecules: List): + logger.debug('Fetch ChEMBL ID using molregno...') + + with closing(sqlite3.connect(self.chembl_db, uri=True)) as con, con, \ + closing(con.cursor()) as cur: + select_stmt = ''' + SELECT cs.molregno as molregno, md.chembl_id as chembl_id, + cs.canonical_smiles as smiles + FROM compound_structures cs, + molecule_dictionary md + WHERE md.molregno = cs.molregno + AND md.chembl_id in (%s) + ''' % "'%s'" % "','".join([x.strip().upper() for x in new_molecules]) + cur.execute(select_stmt) + + return cur.fetchall() + + def fetch_chemblId_by_molregno(self, molregnos): + logger.debug('Fetch ChEMBL ID using molregno...') + with closing(sqlite3.connect(self.chembl_db, uri=True)) as con, con, \ + closing(con.cursor()) as cur: + select_stmt = ''' + SELECT md.chembl_id as chembl_id + FROM molecule_dictionary md + WHERE md.molregno in (%s) + ''' % ", ".join(list(map(str, molregnos))) + cur.execute(select_stmt) + return cur.fetchall() + + def fetch_approved_drugs(self): + """Fetch approved drugs with phase >=3 as dataframe + + Args: + chembl_db_path (string): path to chembl sqlite database + Returns: + pd.DataFrame: dataframe containing SMILES strings and molecule index + """ + logger.debug('Fetching ChEMBL approved drugs...') + with closing(sqlite3.connect(self.chembl_db, uri=True)) as con, con, \ + closing(con.cursor()) as cur: + select_stmt = """SELECT + di.molregno, + cs.canonical_smiles, + di.max_phase_for_ind + FROM + drug_indication AS di + LEFT JOIN compound_structures AS cs ON di.molregno = cs.molregno + WHERE + di.max_phase_for_ind >= 3 + AND cs.canonical_smiles IS NOT NULL;""" + cur.execute(select_stmt) + return cur.fetchall() + + def fetch_random_samples(self, num_samples, max_len): + """Fetch random samples from ChEMBL as dataframe + + Args: + num_samples (int): number of samples to select + chembl_db_path (string): path to chembl sqlite database + Returns: + pd.DataFrame: dataframe containing SMILES strings and molecule index + """ + logger.debug('Fetching ChEMBL random samples...') + with closing(sqlite3.connect(self.chembl_db, uri=True)) as con, con, \ + closing(con.cursor()) as cur: + select_stmt = """SELECT + cs.molregno, + cs.canonical_smiles, + LENGTH(cs.canonical_smiles) as len + FROM + compound_structures AS cs + WHERE + cs.canonical_smiles IS NOT NULL + AND + len <= """ + f'{max_len}' + """ + ORDER BY RANDOM() + LIMIT """ + f'{num_samples};' + + cur.execute(select_stmt) + return cur.fetchall() + + def fetch_molecule_cnt(self): + logger.debug('Finding number of molecules...') + with closing(sqlite3.connect(self.chembl_db, uri=True)) as con, con, \ + closing(con.cursor()) as cur: + select_stmt = ''' + SELECT count(*) + FROM compound_properties cp, + molecule_dictionary md, + compound_structures cs + WHERE cp.molregno = md.molregno + AND md.molregno = cs.molregno + ''' + cur.execute(select_stmt) + + return cur.fetchone()[0] + + def _meta_df(self, **transformation_kwargs): + transformation = self.fp_type(**transformation_kwargs) + + prop_meta = {'id': pandas.Series([], dtype='int64')} + prop_meta.update(dict(zip(IMP_PROPS + ADDITIONAL_FEILD, + IMP_PROPS_TYPE + ADDITIONAL_FEILD_TYPE))) + prop_meta.update({i: pandas.Series([], dtype='float32') for i in range(len(transformation))}) + + return pandas.DataFrame(prop_meta) + + def _fetch_mol_embedding(self, + start=0, + batch_size=BATCH_SIZE, + molregnos=None, + **transformation_kwargs): + """ + Returns compound properties and structure for the first N number of + records in a dataframe. + """ + + logger.info('Fetching %d records starting %d...' % (batch_size, start)) + + imp_cols = ['cp.' + col for col in IMP_PROPS] + + if molregnos is None: + select_stmt = ''' + SELECT md.molregno, %s, cs.canonical_smiles + FROM compound_properties cp, + molecule_dictionary md, + compound_structures cs + WHERE cp.molregno = md.molregno + AND md.molregno = cs.molregno + LIMIT %d, %d + ''' % (', '.join(imp_cols), start, batch_size) + else: + select_stmt = ''' + SELECT md.molregno, %s, cs.canonical_smiles + FROM compound_properties cp, + molecule_dictionary md, + compound_structures cs + WHERE cp.molregno = md.molregno + AND md.molregno = cs.molregno + AND md.molregno in (%s) + LIMIT %d, %d + ''' % (', '.join(imp_cols), " ,".join(list(map(str, molregnos))), start, batch_size) + + df = pandas.read_sql(select_stmt, + sqlite3.connect(self.chembl_db, uri=True)) + + # Smiles -> Smiles transformation and filtering + # TODO: Discuss internally to find use or refactor this code to remove + # model specific filtering + df['transformed_smiles'] = df['canonical_smiles'] + # if smiles_transforms is not None: + # if len(smiles_transforms) > 0: + # for xf in smiles_transforms: + # df['transformed_smiles'] = df['transformed_smiles'].map(xf.transform) + # df.dropna(subset=['transformed_smiles'], axis=0, inplace=True) + + # Conversion to fingerprints or embeddings + # transformed_smiles = df['transformed_smiles'] + transformation = self.fp_type(**transformation_kwargs) + cache_data = transformation.transform(df) + return_df = pandas.DataFrame(cache_data) + + return_df = pandas.DataFrame( + return_df, + columns=pandas.RangeIndex(start=0, + stop=len(transformation))).astype('float32') + + return_df = df.merge(return_df, left_index=True, right_index=True) + return_df.rename(columns={'molregno': 'id'}, inplace=True) + return return_df + + def fetch_mol_embedding(self, + num_recs=None, + batch_size=BATCH_SIZE, + molregnos=None, + **transformation_kwargs): + """ + Returns compound properties and structure for the first N number of + records in a dataframe. + """ + logger.debug('Fetching properties for all molecules...') + + if num_recs is None or num_recs < 0: + num_recs = self.fetch_molecule_cnt() + + logger.info('num_recs %d', num_recs) + logger.info('batch_size %d', batch_size) + meta_df = self._meta_df(**transformation_kwargs) + + dls = [] + for start in range(0, num_recs, batch_size): + bsize = min(num_recs - start, batch_size) + dl_data = delayed(self._fetch_mol_embedding)(start=start, + batch_size=bsize, + molregnos=molregnos, + **transformation_kwargs) + dls.append(dl_data) + + return dataframe.from_delayed(dls, meta=meta_df) + + def save_fingerprints(self, hdf_path='data/filter_*.h5', num_recs=None, batch_size=5000): + """ + Generates fingerprints for all ChEMBL ID's in the database + """ + logger.debug('Fetching molecules from database for fingerprints...') + + mol_df = self.fetch_mol_embedding(num_recs=num_recs, batch_size=batch_size) + mol_df.to_hdf(hdf_path, 'fingerprints') diff --git a/MoleculeSTM/cuchemcommon/fingerprint.py b/MoleculeSTM/cuchemcommon/fingerprint.py new file mode 100644 index 0000000..55f2471 --- /dev/null +++ b/MoleculeSTM/cuchemcommon/fingerprint.py @@ -0,0 +1,95 @@ +import logging +import os +from abc import ABC +from enum import Enum + +import numpy as np +import pandas as pd +from cddd.inference import InferenceModel +from cuchem.utils.data_peddler import download_cddd_models +from rdkit import Chem +from rdkit.Chem import AllChem + +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' +logger = logging.getLogger(__name__) + + +def calc_morgan_fingerprints(dataframe, smiles_col='canonical_smiles'): + """Calculate Morgan fingerprints on SMILES strings + + Args: + dataframe (pd.DataFrame): dataframe containing a SMILES column for calculation + + Returns: + pd.DataFrame: new dataframe containing fingerprints + """ + mf = MorganFingerprint() + fp = mf.transform(dataframe, col_name=smiles_col) + fp = pd.DataFrame(fp) + fp.index = dataframe.index + return fp + + +class TransformationDefaults(Enum): + MorganFingerprint = {'radius': 2, 'nBits': 512} + Embeddings = {} + + +class BaseTransformation(ABC): + def __init__(self, **kwargs): + self.name = None + self.kwargs = None + self.func = None + + def transform(self, data): + return NotImplemented + + def transform_many(self, data): + return list(map(self.transform, data)) + + def __len__(self): + return NotImplemented + + +class MorganFingerprint(BaseTransformation): + + def __init__(self, **kwargs): + self.name = __class__.__name__.split('.')[-1] + self.kwargs = TransformationDefaults[self.name].value + self.kwargs.update(kwargs) + self.func = AllChem.GetMorganFingerprintAsBitVect + + def transform(self, data, col_name='transformed_smiles'): + data = data[col_name] + fp_array = [] + for mol in data: + m = Chem.MolFromSmiles(mol) + fp = self.func(m, **self.kwargs) + fp_array.append(list(fp.ToBitString())) + fp_array = np.asarray(fp_array) + return fp_array + + def __len__(self): + return self.kwargs['nBits'] + + +class Embeddings(BaseTransformation): + + def __init__(self, use_gpu=True, cpu_threads=5, model_dir=None, **kwargs): + self.name = __class__.__name__.split('.')[-1] + self.kwargs = TransformationDefaults[self.name].value + self.kwargs.update(kwargs) + model_dir = download_cddd_models() + self.func = InferenceModel(model_dir, use_gpu=use_gpu, cpu_threads=cpu_threads) + + def transform(self, data): + data = data['transformed_smiles'] + return self.func.seq_to_emb(data).squeeze() + + def inverse_transform(self, embeddings): + "Embedding array -- individual compound embeddings are in rows" + embeddings = np.asarray(embeddings) + return self.func.emb_to_seq(embeddings) + + def __len__(self): + return self.func.hparams.emb_size diff --git a/MoleculeSTM/cuchemcommon/smiles.py b/MoleculeSTM/cuchemcommon/smiles.py new file mode 100644 index 0000000..5034fa1 --- /dev/null +++ b/MoleculeSTM/cuchemcommon/smiles.py @@ -0,0 +1,38 @@ +# import os +# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' + +# import logging +# from abc import ABC +# from rdkit.Chem.SaltRemover import SaltRemover +# from cddd.preprocessing import remove_salt_stereo, filter_smiles + +# logger = logging.getLogger(__name__) + + +# class BaseTransformation(ABC): +# def __init__(self): +# pass + +# def transform(self, data): +# return NotImplemented + +# def transform_many(self, data): +# return list(map(self.transform, data)) +# #return [self.filter(x) for x in data] + + +# class RemoveSalt(BaseTransformation): +# def __init__(self, remover=SaltRemover()): +# self.name = __class__.__name__.split('.')[-1] +# self.remover = remover + +# def transform(self, data): +# return remove_salt_stereo(data, self.remover) + + +# class PreprocessSmiles(BaseTransformation): +# def __init__(self): +# self.name = __class__.__name__.split('.')[-1] + +# def transform(self, data): +# return filter_smiles(data) diff --git a/MoleculeSTM/cuchemcommon/utils/__init__.py b/MoleculeSTM/cuchemcommon/utils/__init__.py new file mode 100644 index 0000000..0de2d94 --- /dev/null +++ b/MoleculeSTM/cuchemcommon/utils/__init__.py @@ -0,0 +1 @@ +from cuchemcommon.utils.singleton import Singleton \ No newline at end of file diff --git a/MoleculeSTM/cuchemcommon/utils/logger.py b/MoleculeSTM/cuchemcommon/utils/logger.py new file mode 100644 index 0000000..7f9e669 --- /dev/null +++ b/MoleculeSTM/cuchemcommon/utils/logger.py @@ -0,0 +1,106 @@ +#!/opt/conda/envs/rapids/bin/python3 +# +# Copyright (c) 2020, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +from datetime import datetime + +from cuchemcommon.context import Context + +from .sysinfo import get_machine_config, print_machine_config + +BENCHMARK_FILE = '/data/benchmark.csv' + +logger = logging.getLogger(__name__) + + +def initialize_logfile(benchmark_file=BENCHMARK_FILE): + """Initialize benchmark file with header if needed""" + + config = get_machine_config() + config_message = print_machine_config(config) + + if not os.path.exists(benchmark_file): + with open(benchmark_file, 'w') as fh: + fh.write(f'# {config_message}\n') + fh.write('date,benchmark_type,step,time(hh:mm:ss.ms),n_molecules,n_workers,metric_name,metric_value\n') + return benchmark_file + + +class MetricsLogger(object): + + def __init__(self, + task_name, + n_molecules): + + self.task_name = task_name + self.n_molecules = n_molecules + self.start_time = None + self.metric_name = None + self.metric_value = None + + self.metric_func = None + self.metric_func_args = None + self.metric_func_kwargs = {} + + def __enter__(self): + self.start_time = datetime.now() + + return self + + def __exit__(self, type, value, traceback): + context = Context() + + runtime = datetime.now() - self.start_time + logger.info('### Runtime {} time (hh:mm:ss.ms) {}'.format(self.task_name, runtime)) + n_workers = len(context.dask_client.cluster.workers) + + if self.metric_func and context.is_benchmark: + self.metric_value = self.metric_func(*self.metric_func_args, + **self.metric_func_kwargs) + + if self.metric_value is None: + self.metric_name = '' + self.metric_value = '' + else: + logger.info('Calculated {} is {}'.format(self.metric_name, self.metric_value)) + + log_results(self.start_time, context.compute_type, self.task_name, + runtime, + n_molecules=self.n_molecules, + n_workers=n_workers, + metric_name=self.metric_name, + metric_value=self.metric_value, + benchmark_file=context.benchmark_file) + + +def log_results(date, + benchmark_type, + step, + time, + n_molecules, + n_workers, + metric_name='', + metric_value='', + benchmark_file=BENCHMARK_FILE): + """Log benchmark results to a file""" + + out_list = [date, benchmark_type, step, time, n_molecules, n_workers, metric_name, metric_value] + out_fmt = ','.join(['{}'] * len(out_list)) + '\n' + + with open(benchmark_file, 'a') as fh: + out_string = out_fmt.format(*out_list) + fh.write(out_string) diff --git a/MoleculeSTM/cuchemcommon/utils/singleton.py b/MoleculeSTM/cuchemcommon/utils/singleton.py new file mode 100644 index 0000000..fc28938 --- /dev/null +++ b/MoleculeSTM/cuchemcommon/utils/singleton.py @@ -0,0 +1,26 @@ +# singleton.py + +import logging + +""" +Metaclass for singletons. +""" + +logger = logging.getLogger(__name__) + + +class Singleton(type): + """ + Ensures single instance of a class. + + Example Usage: + class MySingleton(metaclass=Singleton) + pass + """ + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, cls).__call__( + *args, **kwargs) + return cls._instances[cls] diff --git a/MoleculeSTM/cuchemcommon/utils/sysinfo.py b/MoleculeSTM/cuchemcommon/utils/sysinfo.py new file mode 100644 index 0000000..1077c5b --- /dev/null +++ b/MoleculeSTM/cuchemcommon/utils/sysinfo.py @@ -0,0 +1,68 @@ +#!/opt/conda/envs/rapids/bin/python3 +# +# Copyright (c) 2020, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import Counter + +import psutil +import pynvml as nv + + +def get_machine_config(): + """Get machine config for CPU and GPU(s)""" + + # CPU config + physical_cores = psutil.cpu_count(logical=False) + logical_cores = psutil.cpu_count(logical=True) + + cpufreq = psutil.cpu_freq() + cpufreq_max = cpufreq.max # Mhz + cpufreq_min = cpufreq.min + cpufreq_cur = cpufreq.current + + svmem = psutil.virtual_memory() + mem_total = svmem.total / (1024.0 ** 3) # GB + mem_avail = svmem.available / (1024.0 ** 3) + + # GPU config + nv.nvmlInit() + driver_version = nv.nvmlSystemGetDriverVersion() + deviceCount = nv.nvmlDeviceGetCount() + gpu_devices, gpu_mems = [], [] + for i in range(deviceCount): + handle = nv.nvmlDeviceGetHandleByIndex(i) + gpu_devices.append(nv.nvmlDeviceGetName(handle).decode("utf-8")) + gpu_mem = nv.nvmlDeviceGetMemoryInfo(handle).total / (1024.0 ** 3) + gpu_mems.append(gpu_mem) + + return {'cpu': {'physical_cores': physical_cores, 'logical_cores': logical_cores, + 'min_freq_MHz': cpufreq_min, 'max_freq_MHz': cpufreq_max, 'cur_freq_MHz': cpufreq_cur, + 'total_mem_GB': mem_total, 'avail_mem_GB': mem_avail}, + 'gpu': {'devices': gpu_devices, 'mem_GB': gpu_mems}} + + +def print_machine_config(config): + """Printable version of config""" + cpu_cores = config['cpu']['physical_cores'] + cpu_freq = int(round(config['cpu']['max_freq_MHz'], 0)) + ram = int(round(config['cpu']['total_mem_GB'], 0)) + cpu_config_message = f'{cpu_freq} MHz CPU with {cpu_cores} cores, {ram} GB RAM' + + gpu_devices = Counter([(x, int(round(y, 0))) for x, y in zip(config['gpu']['devices'], config['gpu']['mem_GB'])]) + gpu_config_message = '' + for (handle, mem), count in gpu_devices.items(): + gpu_config_message += f'{count} x {handle} GPU(s)' + + return ', '.join([cpu_config_message, gpu_config_message]) diff --git a/MoleculeSTM/cuchemcommon/workflow.py b/MoleculeSTM/cuchemcommon/workflow.py new file mode 100644 index 0000000..5fe283d --- /dev/null +++ b/MoleculeSTM/cuchemcommon/workflow.py @@ -0,0 +1,201 @@ +import logging +# import torch +from functools import singledispatch +from typing import List + +import numpy as np +from MoleculeSTM.cuchemcommon.data import GenerativeWfDao +from rdkit.Chem import PandasTools, CanonSmiles + +logger = logging.getLogger(__name__) + + +@singledispatch +def add_jitter(embedding, radius, cnt, shape): + return NotImplemented + + +@add_jitter.register(np.ndarray) +def _(embedding, radius, cnt, shape): + + distorteds = [] + for i in range(cnt): + noise = np.random.normal(0, radius, embedding.shape) + distorted = noise + embedding + distorteds.append(distorted) + + return distorteds + + +class BaseGenerativeWorkflow: + + def __init__(self, dao: GenerativeWfDao = None) -> None: + self.dao = dao + self.min_jitter_radius = None + + def get_iteration(self): + NotImplemented + + def smiles_to_embedding(self, + smiles: str, + padding: int): + NotImplemented + + def embedding_to_smiles(self, + embedding: float, + dim: int, + pad_mask): + NotImplemented + + def interpolate_smiles(self, + smiles: List, + num_points: int = 10, + scaled_radius=None, + force_unique=False): + NotImplemented + + def find_similars_smiles_list(self, + smiles: str, + num_requested: int = 10, + scaled_radius=None, + force_unique=False): + NotImplemented + + def find_similars_smiles(self, + smiles: str, + num_requested: int = 10, + scaled_radius=None, + force_unique=False): + NotImplemented + + def _compute_radius(self, scaled_radius): + if scaled_radius: + return float(scaled_radius * self.min_jitter_radius) + else: + return self.min_jitter_radius + + def addjitter(self, + embedding, + radius=None, + cnt=1, + shape=None): + radius = radius if radius else self.radius_scale + return add_jitter(embedding, radius, cnt, shape) + + def compute_unique_smiles(self, + interp_df, + embedding_funct, + scaled_radius=None): + """ + Identify duplicate SMILES and distorts the embedding. The input df + must have columns 'SMILES' and 'Generated' at 0th and 1st position. + 'Generated' colunm must contain boolean to classify SMILES into input + SMILES(False) and generated SMILES(True). + + This function does not make any assumptions about order of embeddings. + Instead it simply orders the df by SMILES to identify the duplicates. + """ + + distance = self._compute_radius(scaled_radius) + embeddings = interp_df['embeddings'] + embeddings_dim = interp_df['embeddings_dim'] + for index, row in interp_df.iterrows(): + smile_string = row['SMILES'] + try: + canonical_smile = CanonSmiles(smile_string) + except: + # If a SMILES cannot be canonicalized, just use the original + canonical_smile = smile_string + + row['SMILES'] = canonical_smile + + for i in range(5): + smiles = interp_df['SMILES'].sort_values() + duplicates = set() + for idx in range(0, smiles.shape[0] - 1): + if smiles.iat[idx] == smiles.iat[idx + 1]: + duplicates.add(smiles.index[idx]) + duplicates.add(smiles.index[idx + 1]) + + if len(duplicates) > 0: + for dup_idx in duplicates: + if interp_df.iat[dup_idx, 3]: + # add jitter to generated molecules only + distored = self.addjitter(embeddings[dup_idx], + distance, + cnt=1, + shape=embeddings_dim[dup_idx]) + embeddings[dup_idx] = distored[0] + interp_df['SMILES'] = embedding_funct(embeddings.to_list()) + interp_df['embeddings'] = embeddings + else: + break + + # Ensure all generated molecules are valid. + for i in range(5): + PandasTools.AddMoleculeColumnToFrame(interp_df, 'SMILES') + invalid_mol_df = interp_df[interp_df['ROMol'].isnull()] + + if not invalid_mol_df.empty: + invalid_index = invalid_mol_df.index.to_list() + for idx in invalid_index: + embeddings[idx] = self.addjitter(embeddings[idx], + distance, + cnt=1, + shape=embeddings_dim[idx])[0] + interp_df['SMILES'] = embedding_funct(embeddings.to_list()) + interp_df['embeddings'] = embeddings + else: + break + + # Cleanup + if 'ROMol' in interp_df.columns: + interp_df = interp_df.drop('ROMol', axis=1) + + return interp_df + + def interpolate_by_id(self, + ids: List, + id_type: str = 'chembleid', + num_points=10, + force_unique=False, + scaled_radius: int = 1): + smiles = None + + if not self.min_jitter_radius: + raise Exception('Property `radius_scale` must be defined in model class.') + + if id_type.lower() == 'chembleid': + smiles = [row[2] for row in self.dao.fetch_id_from_chembl(ids)] + if len(smiles) != len(ids): + raise Exception('One of the ids is invalid %s', ids) + else: + raise Exception('id type %s not supported' % id_type) + + return self.interpolate_smiles(smiles, + num_points=num_points, + scaled_radius=scaled_radius, + force_unique=force_unique) + + def find_similars_smiles_by_id(self, + chemble_id: str, + id_type: str = 'chembleid', + num_requested=10, + force_unique=False, + scaled_radius: int = 1): + smiles = None + + if not self.min_jitter_radius: + raise Exception('Property `radius_scale` must be defined in model class.') + + if id_type.lower() == 'chembleid': + smiles = [row[2] for row in self.dao.fetch_id_from_chembl(chemble_id)] + if len(smiles) != len(chemble_id): + raise Exception('One of the ids is invalid %s' + chemble_id) + else: + raise Exception('id type %s not supported' % id_type) + + return self.find_similars_smiles(smiles[0], + num_requested=num_requested, + scaled_radius=scaled_radius, + force_unique=force_unique) diff --git a/MoleculeSTM/datasets/DrugBankGraph.py b/MoleculeSTM/datasets/DrugBankGraph.py new file mode 100644 index 0000000..c82f00e --- /dev/null +++ b/MoleculeSTM/datasets/DrugBankGraph.py @@ -0,0 +1,235 @@ +import os +from itertools import chain, repeat +import pandas as pd +import torch +from torch_geometric.data import InMemoryDataset, Data +from MoleculeSTM.datasets.utils import mol_to_graph_data_obj_simple +from rdkit.Chem import AllChem + + +class DrugBank_Datasets_Graph_retrieval(InMemoryDataset): + def __init__( + self, root, train_mode, neg_sample_size, processed_dir_prefix, template="raw/SMILES_description_{}.txt", + transform=None, pre_transform=None, pre_filter=None, empty=False + ): + self.root = root + self.transform = transform + self.pre_filter = pre_filter + self.pre_transform = pre_transform + self.processed_dir_prefix = processed_dir_prefix + self.template = template + self.train_mode = train_mode + self.smiles_text_file_name = "SMILES.csv" + + super(DrugBank_Datasets_Graph_retrieval, self).__init__(root, transform, pre_transform, pre_filter) + + if not empty: + self.data, self.slices = torch.load(self.processed_paths[0]) + print('Data: {}'.format(self.data)) + + df = pd.read_csv(os.path.join(self.processed_dir, self.smiles_text_file_name)) + print(df.columns) + self.text_list = df["text"].tolist() + + # sampling + self.neg_sample_size = neg_sample_size + negative_sampled_index_file = os.path.join(self.root, "index", template.format(train_mode)) + print("Loading negative samples from {}".format(negative_sampled_index_file)) + f = open(negative_sampled_index_file, 'r') + neg_index_list = [] + for line in f.readlines(): + line = line.strip().split(",") + line = [int(x) for x in line] + neg_index_list.append(line) + self.neg_index_list = neg_index_list + + return + + def get_graph(self, index): + data = Data() + for key in self.data.keys: + item, slices = self.data[key], self.slices[key] + s = list(repeat(slice(None), item.dim())) + s[data.__cat_dim__(key, item)] = slice(slices[index], slices[index + 1]) + data[key] = item[s] + return data + + def get(self, index): + text = self.text_list[index] + data = self.get_graph(index) + neg_index_list = self.neg_index_list[index][:self.neg_sample_size] + neg_text = [self.text_list[idx] for idx in neg_index_list] + neg_index_list = self.neg_index_list[index][:self.neg_sample_size] + neg_data = [self.get_graph(idx) for idx in neg_index_list] + return text, data, neg_text, neg_data + + @property + def raw_file_names(self): + file_name_list = os.listdir(self.raw_dir) + return file_name_list + + @property + def processed_dir(self): + return os.path.join(self.root, 'processed', '{}_{}'.format(self.processed_dir_prefix, self.train_mode)) + + @property + def processed_file_names(self): + return 'geometric_data_processed.pt' + + def download(self): + return + + def process(self): + data_list, SMILES_list, text_list = [], [], [] + SMILES2description_file = os.path.join(self.root, 'raw', self.template.format(self.train_mode)) + f = open(SMILES2description_file, 'r') + + for line_id, line in enumerate(f.readlines()): + line = line.strip().split("\t", 1) + SMILES = line[0] + text = line[1] + + rdkit_mol = AllChem.MolFromSmiles(SMILES) + data = mol_to_graph_data_obj_simple(rdkit_mol) + data.id = torch.tensor([line_id]) + + data_list.append(data) + SMILES_list.append(SMILES) + text_list.append(text) + + if self.pre_filter is not None: + data_list = [data for data in data_list if self.pre_filter(data)] + + if self.pre_transform is not None: + data_list = [self.pre_transform(data) for data in data_list] + + df = pd.DataFrame( + {"text": text_list, "smiles": SMILES_list}, + ) + saver_path = os.path.join(self.processed_dir, self.smiles_text_file_name) + print("saving to {}".format(saver_path)) + df.to_csv(saver_path, index=False) + + data, slices = self.collate(data_list) + torch.save((data, slices), self.processed_paths[0]) + print("saving to {}".format(self.processed_paths[0])) + print() + return + + def __len__(self): + return len(self.text_list) + + +class DrugBank_Datasets_Graph_ATC(InMemoryDataset): + def __init__( + self, root, file_name, processed_dir_prefix, neg_sample_size, prompt_template="{}.", + transform=None, pre_transform=None, pre_filter=None, empty=False + ): + self.root = root + self.transform = transform + self.pre_filter = pre_filter + self.pre_transform = pre_transform + self.file_name = file_name + self.processed_dir_prefix = processed_dir_prefix + self.smiles_text_file_name = "SMILES.csv" + self.prompt_template = prompt_template + + super(DrugBank_Datasets_Graph_ATC, self).__init__(root, transform, pre_transform, pre_filter) + + if not empty: + self.data, self.slices = torch.load(self.processed_paths[0]) + print('Data: {}'.format(self.data)) + + df = pd.read_csv(os.path.join(self.processed_dir, self.smiles_text_file_name)) + self.SMILES_list = df["smiles"].tolist() + self.ATC_code_list = df["ATC_code"].tolist() + ATC_label_list = df["ATC_label"].tolist() # This is for raw TAC label + self.ATC_label_list = [self.prompt_template.format(x) for x in ATC_label_list] + + self.neg_sample_size = neg_sample_size + negative_sampled_index_file = os.path.join(self.root, "index", file_name) + print("Loading negative samples from {}".format(negative_sampled_index_file)) + f = open(negative_sampled_index_file, 'r') + neg_index_list = [] + for line in f.readlines(): + line = line.strip().split(",") + line = [int(x) for x in line] + neg_index_list.append(line) + self.neg_index_list = neg_index_list + + assert len(self.SMILES_list) == len(self.neg_index_list) == len(self.ATC_code_list) == len(self.ATC_label_list) + return + + def get_graph(self, index): + data = Data() + for key in self.data.keys: + item, slices = self.data[key], self.slices[key] + s = list(repeat(slice(None), item.dim())) + s[data.__cat_dim__(key, item)] = slice(slices[index], slices[index + 1]) + data[key] = item[s] + return data + + def get(self, index): + text = self.ATC_label_list[index] + data = self.get_graph(index) + neg_index_list = self.neg_index_list[index][:self.neg_sample_size] + neg_text = [self.ATC_label_list[idx] for idx in neg_index_list] + neg_index_list = self.neg_index_list[index][:self.neg_sample_size] + neg_data = [self.get_graph(idx) for idx in neg_index_list] + return text, data, neg_text, neg_data + + @property + def raw_file_names(self): + file_name_list = os.listdir(self.raw_dir) + return file_name_list + + @property + def processed_dir(self): + return os.path.join(self.root, "processed", "molecule_{}".format(self.processed_dir_prefix)) + + @property + def processed_file_names(self): + return 'geometric_data_processed.pt' + + def download(self): + return + + def process(self): + SMILES2ATC_txt_file = os.path.join(self.root, "raw", self.file_name) + + f = open(SMILES2ATC_txt_file, 'r') + data_list, SMILES_list, ATC_code_list, ATC_label_list = [], [], [], [] + for line_idx, line in enumerate(f.readlines()): + line = line.strip().split("\t") + SMILES = line[0] + ATC_code = line[1] + ATC_label = line[2] + rdkit_mol = AllChem.MolFromSmiles(SMILES) + data = mol_to_graph_data_obj_simple(rdkit_mol) + data.id = torch.tensor([line_idx]) + + data_list.append(data) + SMILES_list.append(SMILES) + ATC_code_list.append(ATC_code) + ATC_label_list.append(ATC_label) + + if self.pre_filter is not None: + data_list = [data for data in data_list if self.pre_filter(data)] + + if self.pre_transform is not None: + data_list = [self.pre_transform(data) for data in data_list] + + df = pd.DataFrame( + {"smiles": SMILES_list, "ATC_code": ATC_code_list, "ATC_label": ATC_label_list}, + ) + saver_path = os.path.join(self.processed_dir, self.smiles_text_file_name) + print("saving to {}".format(saver_path)) + df.to_csv(saver_path, index=False) + + data, slices = self.collate(data_list) + torch.save((data, slices), self.processed_paths[0]) + print("saving to {}".format(self.processed_paths[0])) + return + + def __len__(self): + return len(self.SMILES_list) diff --git a/MoleculeSTM/datasets/DrugBankSMILES.py b/MoleculeSTM/datasets/DrugBankSMILES.py new file mode 100644 index 0000000..1307920 --- /dev/null +++ b/MoleculeSTM/datasets/DrugBankSMILES.py @@ -0,0 +1,94 @@ +import os +from torch.utils.data import Dataset + + +class DrugBank_Datasets_SMILES_retrieval(Dataset): + def __init__(self, root, train_mode, neg_sample_size, template="SMILES_description_{}.txt"): + self.root = root + + self.SMILES_list, self.text_list = [], [] + SMILES2description_file = os.path.join(self.root, "raw", template.format(train_mode)) + f = open(SMILES2description_file, 'r') + for line in f.readlines(): + line = line.strip().split("\t", 1) + SMILES = line[0] + text = line[1] + self.SMILES_list.append(SMILES) + self.text_list.append(text) + + self.neg_sample_size = neg_sample_size + negative_sampled_index_file = os.path.join(self.root, "index", template.format(train_mode)) + print("Loading negative samples from {}".format(negative_sampled_index_file)) + f = open(negative_sampled_index_file, 'r') + neg_index_list = [] + for line in f.readlines(): + line = line.strip().split(",") + line = [int(x) for x in line] + neg_index_list.append(line) + self.neg_index_list = neg_index_list + return + + def __getitem__(self, index): + description = self.text_list[index] + SMILES = self.SMILES_list[index] + + neg_index_list = self.neg_index_list[index][:self.neg_sample_size] + neg_description = [self.text_list[idx] for idx in neg_index_list] + + neg_index_list = self.neg_index_list[index][:self.neg_sample_size] + neg_SMILES = [self.SMILES_list[idx] for idx in neg_index_list] + + return description, SMILES, neg_description, neg_SMILES + + def __len__(self): + return len(self.SMILES_list) + + +class DrugBank_Datasets_SMILES_ATC(Dataset): + def __init__(self, root, file_name, neg_sample_size, prompt_template="{}."): + self.root = root + self.neg_sample_size = neg_sample_size + self.prompt_template = prompt_template + + SMILES2ATC_txt_file = os.path.join(self.root, 'raw', file_name) + + f = open(SMILES2ATC_txt_file, 'r') + SMILES_list, ATC_code_list, ATC_label_list = [], [], [] + for line in f.readlines(): + line = line.strip().split("\t") + SMILES_list.append(line[0]) + ATC_code_list.append(line[1]) + ATC_label_list.append(prompt_template.format(line[2])) + + self.SMILES_list = SMILES_list + self.ATC_code_list = ATC_code_list + self.ATC_label_list = ATC_label_list + + self.neg_sample_size = neg_sample_size + negative_sampled_index_file = os.path.join(self.root, "index", file_name) + print("Loading negative samples from {}".format(negative_sampled_index_file)) + f = open(negative_sampled_index_file, 'r') + neg_index_list = [] + for line in f.readlines(): + line = line.strip().split(",") + line = [int(x) for x in line] + neg_index_list.append(line) + self.neg_index_list = neg_index_list + + assert len(self.SMILES_list) == len(self.neg_index_list) == len(ATC_code_list) == len(ATC_label_list) + return + + def __getitem__(self, index): + text = self.ATC_label_list[index] + SMILES = self.SMILES_list[index] + + neg_index_list = self.neg_index_list[index][:self.neg_sample_size] + neg_text = [self.ATC_label_list[idx] for idx in neg_index_list] + + neg_index_list = self.neg_index_list[index][:self.neg_sample_size] + neg_SMILES = [self.SMILES_list[idx] for idx in neg_index_list] + + return text, SMILES, neg_text, neg_SMILES + + def __len__(self): + return len(self.SMILES_list) \ No newline at end of file diff --git a/MoleculeSTM/datasets/MoleculeNetGraph.py b/MoleculeSTM/datasets/MoleculeNetGraph.py new file mode 100644 index 0000000..4392598 --- /dev/null +++ b/MoleculeSTM/datasets/MoleculeNetGraph.py @@ -0,0 +1,584 @@ +import os +import pickle +from itertools import chain, repeat + +import networkx as nx +import numpy as np +import pandas as pd +import torch +from ogb.utils.features import atom_to_feature_vector, bond_to_feature_vector +from rdkit import Chem +from rdkit.Chem import AllChem, Descriptors +from rdkit.Chem.rdMolDescriptors import GetMorganFingerprintAsBitVect +from torch.utils import data +from torch_geometric.data import (Data, InMemoryDataset, download_url, extract_zip) + + +def mol_to_graph_data_obj_simple(mol): + """ used in MoleculeNetGraphDataset() class + Converts rdkit mol objects to graph data object in pytorch geometric + NB: Uses simplified atom and bond features, and represent as indices + :param mol: rdkit mol object + :return: graph data object with the attributes: x, edge_index, edge_attr """ + + # atoms + # num_atom_features = 2 # atom type, chirality tag + atom_features_list = [] + for atom in mol.GetAtoms(): + atom_feature = atom_to_feature_vector(atom) + atom_features_list.append(atom_feature) + x = torch.tensor(np.array(atom_features_list), dtype=torch.long) + + # bonds + if len(mol.GetBonds()) <= 0: # mol has no bonds + num_bond_features = 3 # bond type & direction + edge_index = torch.empty((2, 0), dtype=torch.long) + edge_attr = torch.empty((0, num_bond_features), dtype=torch.long) + else: # mol has bonds + edges_list = [] + edge_features_list = [] + for bond in mol.GetBonds(): + i = bond.GetBeginAtomIdx() + j = bond.GetEndAtomIdx() + edge_feature = bond_to_feature_vector(bond) + + edges_list.append((i, j)) + edge_features_list.append(edge_feature) + edges_list.append((j, i)) + edge_features_list.append(edge_feature) + + # data.edge_index: Graph connectivity in COO format with shape [2, num_edges] + edge_index = torch.tensor(np.array(edges_list).T, dtype=torch.long) + + # data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features] + edge_attr = torch.tensor(np.array(edge_features_list), dtype=torch.long) + + data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr) + + return data + + +def graph_data_obj_to_nx_simple(data): + """ torch geometric -> networkx + NB: possible issues with recapitulating relative + stereochemistry since the edges in the nx object are unordered. + :param data: pytorch geometric Data object + :return: networkx object """ + G = nx.Graph() + + # atoms + atom_features = data.x.cpu().numpy() + num_atoms = atom_features.shape[0] + for i in range(num_atoms): + temp_feature = atom_features[i] + G.add_node( + i, + x0=temp_feature[0], + x1=temp_feature[1], + x2=temp_feature[2], + x3=temp_feature[3], + x4=temp_feature[4], + x5=temp_feature[5], + x6=temp_feature[6], + x7=temp_feature[7], + x8=temp_feature[8]) + pass + + # bonds + edge_index = data.edge_index.cpu().numpy() + edge_attr = data.edge_attr.cpu().numpy() + num_bonds = edge_index.shape[1] + for j in range(0, num_bonds, 2): + begin_idx = int(edge_index[0, j]) + end_idx = int(edge_index[1, j]) + temp_feature= edge_attr[j] + if not G.has_edge(begin_idx, end_idx): + G.add_edge(begin_idx, end_idx, + e0=temp_feature[0], + e1=temp_feature[1], + e2=temp_feature[2]) + + return G + + +def nx_to_graph_data_obj_simple(G): + """ vice versa of graph_data_obj_to_nx_simple() + Assume node indices are numbered from 0 to num_nodes - 1. + NB: Uses simplified atom and bond features, and represent as indices. + NB: possible issues with recapitulating relative stereochemistry + since the edges in the nx object are unordered. """ + + # atoms + # num_atom_features = 2 # atom type, chirality tag + atom_features_list = [] + for _, node in G.nodes(data=True): + atom_feature = [node['x0'], node['x1'], node['x2'], node['x3'], node['x4'], node['x5'], node['x6'], node['x7'], node['x8']] + atom_features_list.append(atom_feature) + x = torch.tensor(np.array(atom_features_list), dtype=torch.long) + + # bonds + num_bond_features = 3 # bond type, bond direction + if len(G.edges()) > 0: # mol has bonds + edges_list = [] + edge_features_list = [] + for i, j, edge in G.edges(data=True): + edge_feature = [edge['e0'], edge['e1'], edge['e2']] + edges_list.append((i, j)) + edge_features_list.append(edge_feature) + edges_list.append((j, i)) + edge_features_list.append(edge_feature) + + # data.edge_index: Graph connectivity in COO format with shape [2, num_edges] + edge_index = torch.tensor(np.array(edges_list).T, dtype=torch.long) + + # data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features] + edge_attr = torch.tensor(np.array(edge_features_list), dtype=torch.long) + else: # mol has no bonds + edge_index = torch.empty((2, 0), dtype=torch.long) + edge_attr = torch.empty((0, num_bond_features), dtype=torch.long) + + data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr) + + return data + + +def create_standardized_mol_id(smiles): + """ smiles -> inchi """ + + if check_smiles_validity(smiles): + # remove stereochemistry + smiles = AllChem.MolToSmiles(AllChem.MolFromSmiles(smiles), + isomericSmiles=False) + mol = AllChem.MolFromSmiles(smiles) + if mol is not None: + # to catch weird issue with O=C1O[al]2oc(=O)c3ccc(cn3)c3ccccc3c3cccc(c3)\ + # c3ccccc3c3cc(C(F)(F)F)c(cc3o2)-c2ccccc2-c2cccc(c2)-c2ccccc2-c2cccnc21 + if '.' in smiles: # if multiple species, pick largest molecule + mol_species_list = split_rdkit_mol_obj(mol) + largest_mol = get_largest_mol(mol_species_list) + inchi = AllChem.MolToInchi(largest_mol) + else: + inchi = AllChem.MolToInchi(mol) + return inchi + return + + +class MoleculeNetGraphDataset(InMemoryDataset): + def __init__(self, root, dataset='zinc250k', transform=None, + pre_transform=None, pre_filter=None, empty=False): + + self.root = root + self.dataset = dataset + self.transform = transform + self.pre_filter = pre_filter + self.pre_transform = pre_transform + + super(MoleculeNetGraphDataset, self).__init__(root, transform, pre_transform, pre_filter) + + if not empty: + self.data, self.slices = torch.load(self.processed_paths[0]) + print('Dataset: {}\nData: {}'.format(self.dataset, self.data)) + + def get(self, idx): + data = Data() + for key in self.data.keys: + item, slices = self.data[key], self.slices[key] + s = list(repeat(slice(None), item.dim())) + s[data.__cat_dim__(key, item)] = slice(slices[idx], slices[idx + 1]) + data[key] = item[s] + return data + + @property + def raw_file_names(self): + if self.dataset == 'davis': + file_name_list = ['davis'] + elif self.dataset == 'kiba': + file_name_list = ['kiba'] + else: + file_name_list = os.listdir(self.raw_dir) + return file_name_list + + @property + def processed_file_names(self): + return 'geometric_data_processed.pt' + + def download(self): + return + + def process(self): + + def shared_extractor(smiles_list, rdkit_mol_objs, labels): + data_list, data_smiles_list, data_label_list = [], [], [] + if labels.ndim == 1: + labels = np.expand_dims(labels, axis=1) + for i in range(len(smiles_list)): + print(i) + rdkit_mol = rdkit_mol_objs[i] + if rdkit_mol is None: + continue + data = mol_to_graph_data_obj_simple(rdkit_mol) + data.id = torch.tensor([i]) + data.y = torch.tensor(labels[i]) + data_list.append(data) + data_smiles_list.append(smiles_list[i]) + data_label_list.append(labels[i]) + return data_list, data_smiles_list, data_label_list + + if self.dataset == 'tox21': + smiles_list, rdkit_mol_objs, labels = \ + _load_tox21_dataset(self.raw_paths[0]) + data_list, data_smiles_list, data_label_list = shared_extractor( + smiles_list, rdkit_mol_objs, labels) + + elif self.dataset == 'hiv': + smiles_list, rdkit_mol_objs, labels = \ + _load_hiv_dataset(self.raw_paths[0]) + data_list, data_smiles_list, data_label_list = shared_extractor( + smiles_list, rdkit_mol_objs, labels) + + elif self.dataset == 'bace': + smiles_list, rdkit_mol_objs, folds, labels = \ + _load_bace_dataset(self.raw_paths[0]) + data_list, data_smiles_list, data_label_list = shared_extractor( + smiles_list, rdkit_mol_objs, labels) + + elif self.dataset == 'bbbp': + smiles_list, rdkit_mol_objs, labels = \ + _load_bbbp_dataset(self.raw_paths[0]) + data_list, data_smiles_list, data_label_list = shared_extractor( + smiles_list, rdkit_mol_objs, labels) + + elif self.dataset == 'clintox': + smiles_list, rdkit_mol_objs, labels = \ + _load_clintox_dataset(self.raw_paths[0]) + data_list, data_smiles_list, data_label_list = shared_extractor( + smiles_list, rdkit_mol_objs, labels) + + elif self.dataset == 'esol': + smiles_list, rdkit_mol_objs, labels = \ + _load_esol_dataset(self.raw_paths[0]) + data_list, data_smiles_list, data_label_list = shared_extractor( + smiles_list, rdkit_mol_objs, labels) + + elif self.dataset == 'freesolv': + smiles_list, rdkit_mol_objs, labels = \ + _load_freesolv_dataset(self.raw_paths[0]) + data_list, data_smiles_list, data_label_list = shared_extractor( + smiles_list, rdkit_mol_objs, labels) + + elif self.dataset == 'lipophilicity': + smiles_list, rdkit_mol_objs, labels = \ + _load_lipophilicity_dataset(self.raw_paths[0]) + data_list, data_smiles_list, data_label_list = shared_extractor( + smiles_list, rdkit_mol_objs, labels) + + elif self.dataset == 'malaria': + smiles_list, rdkit_mol_objs, labels = \ + _load_malaria_dataset(self.raw_paths[0]) + data_list, data_smiles_list, data_label_list = shared_extractor( + smiles_list, rdkit_mol_objs, labels) + + elif self.dataset == 'cep': + smiles_list, rdkit_mol_objs, labels = \ + _load_cep_dataset(self.raw_paths[0]) + data_list, data_smiles_list, data_label_list = shared_extractor( + smiles_list, rdkit_mol_objs, labels) + + elif self.dataset == 'muv': + smiles_list, rdkit_mol_objs, labels = \ + _load_muv_dataset(self.raw_paths[0]) + data_list, data_smiles_list, data_label_list = shared_extractor( + smiles_list, rdkit_mol_objs, labels) + + elif self.dataset == 'pcba': + smiles_list, rdkit_mol_objs, labels = \ + _load_pcba_dataset(self.raw_paths[0]) + data_list, data_smiles_list, data_label_list = shared_extractor( + smiles_list, rdkit_mol_objs, labels) + + elif self.dataset == 'sider': + smiles_list, rdkit_mol_objs, labels = \ + _load_sider_dataset(self.raw_paths[0]) + data_list, data_smiles_list, data_label_list = shared_extractor( + smiles_list, rdkit_mol_objs, labels) + + elif self.dataset == 'toxcast': + smiles_list, rdkit_mol_objs, labels = \ + _load_toxcast_dataset(self.raw_paths[0]) + data_list, data_smiles_list, data_label_list = shared_extractor( + smiles_list, rdkit_mol_objs, labels) + + else: + raise ValueError('Dataset {} not included.'.format(self.dataset)) + + if self.pre_filter is not None: + data_list = [data for data in data_list if self.pre_filter(data)] + + if self.pre_transform is not None: + data_list = [self.pre_transform(data) for data in data_list] + + data_smiles_series = pd.Series(data_smiles_list) + saver_path = os.path.join(self.processed_dir, 'smiles.csv') + data_smiles_series.to_csv(saver_path, index=False, header=False) + + data_label_array = np.array(data_label_list) + saver_path = os.path.join(self.processed_dir, 'labels') + np.savez_compressed(saver_path, labels=data_label_array) + + data, slices = self.collate(data_list) + torch.save((data, slices), self.processed_paths[0]) + + return + + +def _load_tox21_dataset(input_path): + input_df = pd.read_csv(input_path, sep=',') + smiles_list = input_df['smiles'] + rdkit_mol_objs_list = [AllChem.MolFromSmiles(s) for s in smiles_list] + tasks = ['NR-AR', 'NR-AR-LBD', 'NR-AhR', 'NR-Aromatase', 'NR-ER', 'NR-ER-LBD', + 'NR-PPAR-gamma', 'SR-ARE', 'SR-ATAD5', 'SR-HSE', 'SR-MMP', 'SR-p53'] + labels = input_df[tasks] + # convert 0 to -1 + labels = labels.replace(0, -1) + # convert nan to 0 + labels = labels.fillna(0) + assert len(smiles_list) == len(rdkit_mol_objs_list) + assert len(smiles_list) == len(labels) + return smiles_list, rdkit_mol_objs_list, labels.values + + +def _load_hiv_dataset(input_path): + input_df = pd.read_csv(input_path, sep=',') + smiles_list = input_df['smiles'] + rdkit_mol_objs_list = [AllChem.MolFromSmiles(s) for s in smiles_list] + labels = input_df['HIV_active'] + # convert 0 to -1 + labels = labels.replace(0, -1) + # there are no nans + assert len(smiles_list) == len(rdkit_mol_objs_list) + assert len(smiles_list) == len(labels) + return smiles_list, rdkit_mol_objs_list, labels.values + + +def _load_bace_dataset(input_path): + input_df = pd.read_csv(input_path, sep=',') + smiles_list = input_df['mol'] + rdkit_mol_objs_list = [AllChem.MolFromSmiles(s) for s in smiles_list] + labels = input_df['Class'] + # convert 0 to -1 + labels = labels.replace(0, -1) + # there are no nans + folds = input_df['Model'] + folds = folds.replace('Train', 0) # 0 -> train + folds = folds.replace('Valid', 1) # 1 -> valid + folds = folds.replace('Test', 2) # 2 -> test + assert len(smiles_list) == len(rdkit_mol_objs_list) + assert len(smiles_list) == len(labels) + assert len(smiles_list) == len(folds) + return smiles_list, rdkit_mol_objs_list, folds.values, labels.values + + +def _load_bbbp_dataset(input_path): + input_df = pd.read_csv(input_path, sep=',') + smiles_list = input_df['smiles'] + rdkit_mol_objs_list = [AllChem.MolFromSmiles(s) for s in smiles_list] + + preprocessed_rdkit_mol_objs_list = [m if m is not None else None + for m in rdkit_mol_objs_list] + preprocessed_smiles_list = [AllChem.MolToSmiles(m) if m is not None else None + for m in preprocessed_rdkit_mol_objs_list] + labels = input_df['p_np'] + # convert 0 to -1 + labels = labels.replace(0, -1) + # there are no nans + assert len(smiles_list) == len(preprocessed_rdkit_mol_objs_list) + assert len(smiles_list) == len(preprocessed_smiles_list) + assert len(smiles_list) == len(labels) + return preprocessed_smiles_list, \ + preprocessed_rdkit_mol_objs_list, labels.values + + +def _load_clintox_dataset(input_path): + input_df = pd.read_csv(input_path, sep=',') + smiles_list = input_df['smiles'] + rdkit_mol_objs_list = [AllChem.MolFromSmiles(s) for s in smiles_list] + + preprocessed_rdkit_mol_objs_list = [m if m is not None else None + for m in rdkit_mol_objs_list] + preprocessed_smiles_list = [AllChem.MolToSmiles(m) if m is not None else None + for m in preprocessed_rdkit_mol_objs_list] + tasks = ['FDA_APPROVED', 'CT_TOX'] + labels = input_df[tasks] + # convert 0 to -1 + labels = labels.replace(0, -1) + # there are no nans + assert len(smiles_list) == len(preprocessed_rdkit_mol_objs_list) + assert len(smiles_list) == len(preprocessed_smiles_list) + assert len(smiles_list) == len(labels) + return preprocessed_smiles_list, \ + preprocessed_rdkit_mol_objs_list, labels.values + + +def _load_esol_dataset(input_path): + # NB: some examples have multiple species + input_df = pd.read_csv(input_path, sep=',') + smiles_list = input_df['smiles'] + rdkit_mol_objs_list = [AllChem.MolFromSmiles(s) for s in smiles_list] + labels = input_df['measured log solubility in mols per litre'] + assert len(smiles_list) == len(rdkit_mol_objs_list) + assert len(smiles_list) == len(labels) + return smiles_list, rdkit_mol_objs_list, labels.values + + +def _load_freesolv_dataset(input_path): + + input_df = pd.read_csv(input_path, sep=',') + smiles_list = input_df['smiles'] + rdkit_mol_objs_list = [AllChem.MolFromSmiles(s) for s in smiles_list] + labels = input_df['expt'] + assert len(smiles_list) == len(rdkit_mol_objs_list) + assert len(smiles_list) == len(labels) + return smiles_list, rdkit_mol_objs_list, labels.values + + +def _load_lipophilicity_dataset(input_path): + + input_df = pd.read_csv(input_path, sep=',') + smiles_list = input_df['smiles'] + rdkit_mol_objs_list = [AllChem.MolFromSmiles(s) for s in smiles_list] + labels = input_df['exp'] + assert len(smiles_list) == len(rdkit_mol_objs_list) + assert len(smiles_list) == len(labels) + return smiles_list, rdkit_mol_objs_list, labels.values + + +def _load_malaria_dataset(input_path): + + input_df = pd.read_csv(input_path, sep=',') + smiles_list = input_df['smiles'] + rdkit_mol_objs_list = [AllChem.MolFromSmiles(s) for s in smiles_list] + labels = input_df['activity'] + assert len(smiles_list) == len(rdkit_mol_objs_list) + assert len(smiles_list) == len(labels) + return smiles_list, rdkit_mol_objs_list, labels.values + + +def _load_cep_dataset(input_path): + + input_df = pd.read_csv(input_path, sep=',') + smiles_list = input_df['smiles'] + rdkit_mol_objs_list = [AllChem.MolFromSmiles(s) for s in smiles_list] + labels = input_df['PCE'] + assert len(smiles_list) == len(rdkit_mol_objs_list) + assert len(smiles_list) == len(labels) + return smiles_list, rdkit_mol_objs_list, labels.values + + +def _load_muv_dataset(input_path): + + input_df = pd.read_csv(input_path, sep=',') + smiles_list = input_df['smiles'] + rdkit_mol_objs_list = [AllChem.MolFromSmiles(s) for s in smiles_list] + tasks = ['MUV-466', 'MUV-548', 'MUV-600', 'MUV-644', 'MUV-652', 'MUV-689', + 'MUV-692', 'MUV-712', 'MUV-713', 'MUV-733', 'MUV-737', 'MUV-810', + 'MUV-832', 'MUV-846', 'MUV-852', 'MUV-858', 'MUV-859'] + labels = input_df[tasks] + # convert 0 to -1 + labels = labels.replace(0, -1) + # convert nan to 0 + labels = labels.fillna(0) + assert len(smiles_list) == len(rdkit_mol_objs_list) + assert len(smiles_list) == len(labels) + return smiles_list, rdkit_mol_objs_list, labels.values + + +def _load_sider_dataset(input_path): + + input_df = pd.read_csv(input_path, sep=',') + smiles_list = input_df['smiles'] + rdkit_mol_objs_list = [AllChem.MolFromSmiles(s) for s in smiles_list] + tasks = ['Hepatobiliary disorders', + 'Metabolism and nutrition disorders', 'Product issues', 'Eye disorders', + 'Investigations', 'Musculoskeletal and connective tissue disorders', + 'Gastrointestinal disorders', 'Social circumstances', + 'Immune system disorders', 'Reproductive system and breast disorders', + 'Neoplasms benign, malignant and unspecified (incl cysts and polyps)', + 'General disorders and administration site conditions', + 'Endocrine disorders', 'Surgical and medical procedures', + 'Vascular disorders', 'Blood and lymphatic system disorders', + 'Skin and subcutaneous tissue disorders', + 'Congenital, familial and genetic disorders', + 'Infections and infestations', + 'Respiratory, thoracic and mediastinal disorders', + 'Psychiatric disorders', 'Renal and urinary disorders', + 'Pregnancy, puerperium and perinatal conditions', + 'Ear and labyrinth disorders', 'Cardiac disorders', + 'Nervous system disorders', + 'Injury, poisoning and procedural complications'] + labels = input_df[tasks] + # convert 0 to -1 + labels = labels.replace(0, -1) + assert len(smiles_list) == len(rdkit_mol_objs_list) + assert len(smiles_list) == len(labels) + return smiles_list, rdkit_mol_objs_list, labels.values + + +def _load_toxcast_dataset(input_path): + + # NB: some examples have multiple species, some example smiles are invalid + input_df = pd.read_csv(input_path, sep=',') + smiles_list = input_df['smiles'] + rdkit_mol_objs_list = [AllChem.MolFromSmiles(s) for s in smiles_list] + # Some smiles could not be successfully converted + # to rdkit mol object so them to None + preprocessed_rdkit_mol_objs_list = [m if m is not None else None + for m in rdkit_mol_objs_list] + preprocessed_smiles_list = [AllChem.MolToSmiles(m) if m is not None else None + for m in preprocessed_rdkit_mol_objs_list] + tasks = list(input_df.columns)[1:] + labels = input_df[tasks] + # convert 0 to -1 + labels = labels.replace(0, -1) + # convert nan to 0 + labels = labels.fillna(0) + assert len(smiles_list) == len(preprocessed_rdkit_mol_objs_list) + assert len(smiles_list) == len(preprocessed_smiles_list) + assert len(smiles_list) == len(labels) + return preprocessed_smiles_list, \ + preprocessed_rdkit_mol_objs_list, labels.values + + +def check_smiles_validity(smiles): + try: + m = Chem.MolFromSmiles(smiles) + if m: + return True + else: + return False + except: + return False + + +def split_rdkit_mol_obj(mol): + """ + Split rdkit mol object containing multiple species or one species into a + list of mol objects or a list containing a single object respectively """ + + smiles = AllChem.MolToSmiles(mol, isomericSmiles=True) + smiles_list = smiles.split('.') + mol_species_list = [] + for s in smiles_list: + if check_smiles_validity(s): + mol_species_list.append(AllChem.MolFromSmiles(s)) + return mol_species_list + + +def get_largest_mol(mol_list): + """ + Given a list of rdkit mol objects, returns mol object containing the + largest num of atoms. If multiple containing largest num of atoms, + picks the first one """ + + num_atoms_list = [len(m.GetAtoms()) for m in mol_list] + largest_mol_idx = num_atoms_list.index(max(num_atoms_list)) + return mol_list[largest_mol_idx] diff --git a/MoleculeSTM/datasets/MoleculeNetSMILES.py b/MoleculeSTM/datasets/MoleculeNetSMILES.py new file mode 100644 index 0000000..4bfe005 --- /dev/null +++ b/MoleculeSTM/datasets/MoleculeNetSMILES.py @@ -0,0 +1,36 @@ +import os +import numpy as np +from rdkit import Chem +from torch.utils.data import Dataset + + +class MoleculeNetSMILESDataset(Dataset): + def __init__(self, root): + ''' + This needs to be called after calling the MoleculeNetGraphDataset. + ''' + self.root = root + SMILES_file = os.path.join(root, "processed", "smiles.csv") + + self.SMILES_list = [] + with open(SMILES_file, 'r') as f: + lines = f.readlines() + for line in lines: + SMILES = line.strip() + mol = Chem.MolFromSmiles(SMILES) + canon_SMILES = Chem.MolToSmiles(mol) + self.SMILES_list.append(canon_SMILES) + + labels_file = os.path.join(root, "processed", "labels.npz") + self.labels_data = np.load(labels_file)['labels'] + + print(len(self.SMILES_list), '\t', self.labels_data.shape) + return + + def __getitem__(self, index): + SMILES = self.SMILES_list[index] + labels = self.labels_data[index] + return SMILES, labels + + def __len__(self): + return len(self.SMILES_list) diff --git a/MoleculeSTM/datasets/PubChemSTM.py b/MoleculeSTM/datasets/PubChemSTM.py new file mode 100644 index 0000000..676c677 --- /dev/null +++ b/MoleculeSTM/datasets/PubChemSTM.py @@ -0,0 +1,275 @@ +import os +from itertools import repeat +import pandas as pd +import json +from tqdm import tqdm + +import torch +from torch.utils.data import Dataset +from torch_geometric.data import Data, InMemoryDataset + +from rdkit import Chem +from rdkit import RDLogger +RDLogger.DisableLog('rdApp.*') + +from MoleculeSTM.datasets.utils import mol_to_graph_data_obj_simple + + +class PubChemSTM_Datasets_Only_SMILES(Dataset): + def __init__(self, root, subset_size=None): + self.root = root + + CID2SMILES_file = os.path.join(self.root, "raw/CID2SMILES.csv") + + df = pd.read_csv(CID2SMILES_file) + SMILES_list = df["SMILES"].tolist() + SMILES_list = sorted(set(SMILES_list)) + + self.SMILES_list = SMILES_list + if subset_size is not None: + self.SMILES_list = self.SMILES_list[:subset_size] + return + + def __getitem__(self, index): + SMILES = self.SMILES_list[index] + return SMILES + + def __len__(self): + return len(self.SMILES_list) + + +class PubChemSTM_Datasets_SMILES(Dataset): + def __init__(self, root): + self.root = root + + CID2text_file = os.path.join(self.root, "raw/CID2text.json") + CID2SMILES_file = os.path.join(self.root, "raw/CID2SMILES.csv") + self.load_CID2SMILES(CID2text_file, CID2SMILES_file) + + self.text_list = [] + missing_count = 0 + for CID, value_list in self.CID2text_data.items(): + if CID not in self.CID2SMILES: + print("CID {} missing".format(CID)) + missing_count += 1 + continue + for value in value_list: + self.text_list.append([CID, value]) + print("missing", missing_count) + print("len of text_list: {}".format(len(self.text_list))) + return + + def load_CID2SMILES(self, CID2text_file, CID2SMILES_file): + with open(CID2text_file, "r") as f: + self.CID2text_data = json.load(f) + print("len of CID2text: {}".format(len(self.CID2text_data.keys()))) + + df = pd.read_csv(CID2SMILES_file) + CID_list, SMILES_list = df["CID"].tolist(), df["SMILES"].tolist() + self.CID2SMILES = {} + for CID, SMILES in zip(CID_list, SMILES_list): + CID = str(CID) + self.CID2SMILES[CID] = SMILES + print("len of CID2SMILES: {}".format(len(self.CID2SMILES.keys()))) + return + + def __getitem__(self, index): + CID, text = self.text_list[index] + SMILES = self.CID2SMILES[CID] + return text, SMILES + + def __len__(self): + return len(self.text_list) + + +class PubChemSTM_SubDatasets_SMILES(PubChemSTM_Datasets_SMILES): + def __init__(self, root, size): + self.root = root + + CID2text_file = os.path.join(self.root, "raw/CID2text.json") + CID2SMILES_file = os.path.join(self.root, "raw/CID2SMILES.csv") + self.load_CID2SMILES(CID2text_file, CID2SMILES_file) + + self.text_list = [] + for CID, value_list in self.CID2text_data.items(): + if CID not in self.CID2SMILES: + print("CID {} missing".format(CID)) + continue + for value in value_list: + self.text_list.append([CID, value]) + if len(self.text_list) >= size: + break + print("len of text_list: {}".format(len(self.text_list))) + return + + +class PubChemSTM_Datasets_Graph(InMemoryDataset): + def __init__(self, root, transform=None, pre_transform=None, pre_filter=None): + self.root = root + self.transform = transform + self.pre_transform = pre_transform + self.pre_filter = pre_filter + # only for `process` function + self.SDF_file_path = os.path.join(self.root, "raw/molecules.sdf") + self.CID2text_file = os.path.join(self.root, "raw/CID2text.json") + # `process` result file + self.CID_text_file_path = os.path.join(self.root, "processed/CID_text_list.csv") + + super(PubChemSTM_Datasets_Graph, self).__init__(root, transform, pre_transform, pre_filter) + + self.load_Graph_CID_and_text() + return + + @property + def processed_file_names(self): + return 'geometric_data_processed.pt' + + def process(self): + suppl = Chem.SDMolSupplier(self.SDF_file_path) + + CID2graph = {} + for mol in tqdm(suppl): + CID = mol.GetProp("PUBCHEM_COMPOUND_CID") + CID = int(CID) + graph = mol_to_graph_data_obj_simple(mol) + CID2graph[CID] = graph + print("CID2graph", len(CID2graph)) + + with open(self.CID2text_file, "r") as f: + CID2text_data = json.load(f) + print("CID2data", len(CID2text_data)) + + CID_list, graph_list, text_list = [], [], [] + for CID, value_list in CID2text_data.items(): + CID = int(CID) + if CID not in CID2graph: + print("CID {} missing".format(CID)) + continue + graph = CID2graph[CID] + for value in value_list: + text_list.append(value) + CID_list.append(CID) + graph_list.append(graph) + + CID_text_df = pd.DataFrame({"CID": CID_list, "text": text_list}) + CID_text_df.to_csv(self.CID_text_file_path, index=None) + + if self.pre_filter is not None: + graph_list = [graph for graph in graph_list if self.pre_filter(graph)] + + if self.pre_transform is not None: + graph_list = [self.pre_transform(graph) for graph in graph_list] + + graphs, slices = self.collate(graph_list) + torch.save((graphs, slices), self.processed_paths[0]) + return + + def load_Graph_CID_and_text(self): + self.graphs, self.slices = torch.load(self.processed_paths[0]) + + CID_text_df = pd.read_csv(self.CID_text_file_path) + self.CID_list = CID_text_df["CID"].tolist() + self.text_list = CID_text_df["text"].tolist() + return + + def get(self, idx): + text = self.text_list[idx] + + data = Data() + for key in self.graphs.keys: + item, slices = self.graphs[key], self.slices[key] + s = list(repeat(slice(None), item.dim())) + s[data.__cat_dim__(key, item)] = slice(slices[idx], slices[idx + 1]) + data[key] = item[s] + return text, data + + def __len__(self): + return len(self.text_list) + + +class PubChemSTM_SubDatasets_Graph(PubChemSTM_Datasets_Graph): + def __init__(self, root, size, transform=None, pre_transform=None, pre_filter=None): + self.root = root + self.size = size + self.transform = transform + self.pre_transform = pre_transform + self.pre_filter = pre_filter + self.size = size + # only for `process` function + self.SDF_file_path = os.path.join(self.root, "raw/molecules.sdf") + self.CID2text_file = os.path.join(self.root, "raw/CID2text.json") + # `process` result file + self.CID_text_file_path = os.path.join(self.root, "processed/CID_text_list.csv") + + super(PubChemSTM_Datasets_Graph, self).__init__(root, transform, pre_transform, pre_filter) + + self.load_Graph_CID_and_text() + return + + def __len__(self): + return self.size + + +class PubChemSTM_Datasets_SMILES_and_Graph(InMemoryDataset): + def __init__(self, root, subset_size=None, transform=None, pre_transform=None, pre_filter=None): + self.root = root + + # only for `process` function + self.SDF_file_path = os.path.join(self.root, "raw/molecules.sdf") + # `process` result file + self.SMILES_file_path = os.path.join(self.root, "processed_molecule_only/SMILES.csv") + + super(PubChemSTM_Datasets_SMILES_and_Graph, self).__init__(root, transform, pre_transform, pre_filter) + + self.graphs, self.slices = torch.load(self.processed_paths[0]) + + CID_text_df = pd.read_csv(self.SMILES_file_path) + self.SMILES_list = CID_text_df["smiles"].tolist() + if subset_size is not None: + self.SMILES_list = self.SMILES_list[:subset_size] + return + + @property + def processed_dir(self): + return os.path.join(self.root, 'processed_molecule_only') + + @property + def processed_file_names(self): + return 'geometric_data_processed.pt' + + def process(self): + suppl = Chem.SDMolSupplier(self.SDF_file_path) + + SMILES_list, graph_list = [], [] + for mol in tqdm(suppl): + SMILES = Chem.MolToSmiles(mol) + SMILES_list.append(SMILES) + graph = mol_to_graph_data_obj_simple(mol) + graph_list.append(graph) + + SMILES_df = pd.DataFrame({"smiles": SMILES_list}) + SMILES_df.to_csv(self.SMILES_file_path, index=None) + + if self.pre_filter is not None: + graph_list = [graph for graph in graph_list if self.pre_filter(graph)] + + if self.pre_transform is not None: + graph_list = [self.pre_transform(graph) for graph in graph_list] + + graphs, slices = self.collate(graph_list) + torch.save((graphs, slices), self.processed_paths[0]) + return + + def get(self, idx): + SMILES = self.SMILES_list[idx] + + data = Data() + for key in self.graphs.keys: + item, slices = self.graphs[key], self.slices[key] + s = list(repeat(slice(None), item.dim())) + s[data.__cat_dim__(key, item)] = slice(slices[idx], slices[idx + 1]) + data[key] = item[s] + return SMILES, data + + def __len__(self): + return len(self.SMILES_list) diff --git a/MoleculeSTM/datasets/PubChemSTM_raw.py b/MoleculeSTM/datasets/PubChemSTM_raw.py new file mode 100644 index 0000000..a610c28 --- /dev/null +++ b/MoleculeSTM/datasets/PubChemSTM_raw.py @@ -0,0 +1,172 @@ +import os +from itertools import repeat +import pandas as pd +import json +from tqdm import tqdm + +import torch +from torch_geometric.data import Data, InMemoryDataset + +from rdkit import Chem +from rdkit import RDLogger +RDLogger.DisableLog('rdApp.*') + +from MoleculeSTM.datasets.utils import mol_to_graph_data_obj_simple + +from MoleculeSTM.datasets import PubChemSTM_Datasets_SMILES + + +class PubChemSTM_Datasets_Raw_SMILES(PubChemSTM_Datasets_SMILES): + def __init__(self, root): + self.root = root + + CID2text_file = os.path.join(self.root, "raw/CID2text_raw.json") + # Both PubChemSTM and PubChemSTM_Raw share the same CID2SMILES file. + CID2SMILES_file = os.path.join(self.root, "raw/CID2SMILES.csv") + self.load_CID2SMILES(CID2text_file, CID2SMILES_file) + + self.text_list = [] + missing_count = 0 + for CID, value_list in self.CID2text_data.items(): + if CID not in self.CID2SMILES: + print("CID {} missing".format(CID)) + missing_count += 1 + continue + for value in value_list: + self.text_list.append([CID, value]) + print("missing", missing_count) + print("len of text_list: {}".format(len(self.text_list))) + + return + + +class PubChemSTM_SubDatasets_Raw_SMILES(PubChemSTM_Datasets_Raw_SMILES): + def __init__(self, root, size): + self.root = root + + CID2text_file = os.path.join(self.root, "raw/CID2text_raw.json") + CID2SMILES_file = os.path.join(self.root, "raw/CID2SMILES.csv") + self.load_CID2SMILES(CID2text_file, CID2SMILES_file) + + self.text_list = [] + for CID, value_list in self.CID2text_data.items(): + if CID not in self.CID2SMILES: + print("CID {} missing".format(CID)) + continue + for value in value_list: + self.text_list.append([CID, value]) + if len(self.text_list) >= size: + break + print("len of text_list: {}".format(len(self.text_list))) + return + + +class PubChemSTM_Datasets_Raw_Graph(InMemoryDataset): + def __init__(self, root, transform=None, pre_transform=None, pre_filter=None): + self.root = root + self.transform = transform + self.pre_transform = pre_transform + self.pre_filter = pre_filter + # only for `process` function + self.SDF_file_path = os.path.join(self.root, "raw/molecules.sdf") + self.CID2text_file = os.path.join(self.root, "raw/CID2text_raw.json") + # `process` result file + self.CID_text_file_path = os.path.join(self.root, "processed_raw/CID_text_list.csv") + + super(PubChemSTM_Datasets_Raw_Graph, self).__init__(root, transform, pre_transform, pre_filter) + + self.load_Graph_CID_and_text() + return + + @property + def processed_dir(self) -> str: + return os.path.join(self.root, 'processed_raw') + + @property + def processed_file_names(self): + return 'geometric_data_processed.pt' + + def process(self): + suppl = Chem.SDMolSupplier(self.SDF_file_path) + + CID2graph = {} + for mol in tqdm(suppl): + CID = mol.GetProp("PUBCHEM_COMPOUND_CID") + CID = int(CID) + graph = mol_to_graph_data_obj_simple(mol) + CID2graph[CID] = graph + print("CID2graph", len(CID2graph)) + + with open(self.CID2text_file, "r") as f: + CID2text_data = json.load(f) + print("CID2data", len(CID2text_data)) + + CID_list, graph_list, text_list = [], [], [] + for CID, value_list in CID2text_data.items(): + CID = int(CID) + if CID not in CID2graph: + print("CID {} missing".format(CID)) + continue + graph = CID2graph[CID] + for value in value_list: + text_list.append(value) + CID_list.append(CID) + graph_list.append(graph) + + CID_text_df = pd.DataFrame({"CID": CID_list, "text": text_list}) + CID_text_df.to_csv(self.CID_text_file_path, index=None) + + if self.pre_filter is not None: + graph_list = [graph for graph in graph_list if self.pre_filter(graph)] + + if self.pre_transform is not None: + graph_list = [self.pre_transform(graph) for graph in graph_list] + + graphs, slices = self.collate(graph_list) + torch.save((graphs, slices), self.processed_paths[0]) + return + + def load_Graph_CID_and_text(self): + self.graphs, self.slices = torch.load(self.processed_paths[0]) + + CID_text_df = pd.read_csv(self.CID_text_file_path) + self.CID_list = CID_text_df["CID"].tolist() + self.text_list = CID_text_df["text"].tolist() + return + + def get(self, idx): + text = self.text_list[idx] + + data = Data() + for key in self.graphs.keys: + item, slices = self.graphs[key], self.slices[key] + s = list(repeat(slice(None), item.dim())) + s[data.__cat_dim__(key, item)] = slice(slices[idx], slices[idx + 1]) + data[key] = item[s] + return text, data + + def __len__(self): + return len(self.text_list) + + +class PubChemSTM_SubDatasets_Raw_Graph(PubChemSTM_Datasets_Raw_Graph): + def __init__(self, root, size, transform=None, pre_transform=None, pre_filter=None): + self.root = root + self.size = size + self.transform = transform + self.pre_transform = pre_transform + self.pre_filter = pre_filter + self.size = size + # only for `process` function + self.SDF_file_path = os.path.join(self.root, "raw/molecules.sdf") + self.CID2text_file = os.path.join(self.root, "raw/CID2text_raw.json") + # `process` result file + self.CID_text_file_path = os.path.join(self.root, "processed_raw/CID_text_list.csv") + + super(PubChemSTM_SubDatasets_Raw_Graph, self).__init__(root, transform, pre_transform, pre_filter) + + self.load_Graph_CID_and_text() + return + + def __len__(self): + return self.size diff --git a/MoleculeSTM/datasets/ZINC250K_Graph.py b/MoleculeSTM/datasets/ZINC250K_Graph.py new file mode 100644 index 0000000..df23f93 --- /dev/null +++ b/MoleculeSTM/datasets/ZINC250K_Graph.py @@ -0,0 +1,67 @@ +import os +import pandas as pd +from tqdm import tqdm +from rdkit import Chem +from itertools import repeat + +import torch +from torch_geometric.data import Data, InMemoryDataset + +from MoleculeSTM.datasets.utils import mol_to_graph_data_obj_simple + + +class ZINC250K_Dataset_Graph(InMemoryDataset): + def __init__(self, root, subset_size=None, transform=None, pre_transform=None, pre_filter=None): + self.root = root + + self.SMILES_file = os.path.join(self.root, "raw/250k_rndm_zinc_drugs_clean_3.csv") + df = pd.read_csv(self.SMILES_file) + SMILES_list = df['smiles'].tolist() + self.SMILES_list = [x.strip() for x in SMILES_list] + + super(ZINC250K_Dataset_Graph, self).__init__(root, transform, pre_transform, pre_filter) + + self.graphs, self.slices = torch.load(self.processed_paths[0]) + + if subset_size is not None: + self.SMILES_list = self.SMILES_list[:subset_size] + return + + @property + def processed_dir(self): + return os.path.join(self.root, 'processed_molecule_only') + + @property + def processed_file_names(self): + return 'geometric_data_processed.pt' + + def process(self): + graph_list = [] + for SMILES in tqdm(self.SMILES_list): + RDKit_mol = Chem.MolFromSmiles(SMILES) + graph = mol_to_graph_data_obj_simple(RDKit_mol) + graph_list.append(graph) + + if self.pre_filter is not None: + graph_list = [graph for graph in graph_list if self.pre_filter(graph)] + + if self.pre_transform is not None: + graph_list = [self.pre_transform(graph) for graph in graph_list] + + graphs, slices = self.collate(graph_list) + torch.save((graphs, slices), self.processed_paths[0]) + return + + def get(self, idx): + SMILES = self.SMILES_list[idx] + + data = Data() + for key in self.graphs.keys: + item, slices = self.graphs[key], self.slices[key] + s = list(repeat(slice(None), item.dim())) + s[data.__cat_dim__(key, item)] = slice(slices[idx], slices[idx + 1]) + data[key] = item[s] + return SMILES, data + + def __len__(self): + return len(self.SMILES_list) \ No newline at end of file diff --git a/MoleculeSTM/datasets/ZINC250K_SMILES.py b/MoleculeSTM/datasets/ZINC250K_SMILES.py new file mode 100644 index 0000000..7dfaf2d --- /dev/null +++ b/MoleculeSTM/datasets/ZINC250K_SMILES.py @@ -0,0 +1,30 @@ +from torch.utils.data import Dataset +import os +import pandas as pd + + +class ZINC250K_Dataset_SMILES(Dataset): + def __init__(self, root, subset_size=None): + self.root = root + + SMILES_file = os.path.join(self.root, "raw/250k_rndm_zinc_drugs_clean_3.csv") + df = pd.read_csv(SMILES_file) + SMILES_list = df['smiles'].tolist() # Already canonical SMILES + self.SMILES_list = [x.strip() for x in SMILES_list] + + new_SMILES_file = os.path.join(self.root, "raw/smiles.csv") + if not os.path.exists(new_SMILES_file): + data_smiles_series = pd.Series(self.SMILES_list) + print("saving to {}".format(new_SMILES_file)) + data_smiles_series.to_csv(new_SMILES_file, index=False, header=False) + + if subset_size is not None: + self.SMILES_list = self.SMILES_list[:subset_size] + return + + def __getitem__(self, index): + SMILES = self.SMILES_list[index] + return SMILES + + def __len__(self): + return len(self.SMILES_list) diff --git a/MoleculeSTM/datasets/__init__.py b/MoleculeSTM/datasets/__init__.py new file mode 100644 index 0000000..0ae00ea --- /dev/null +++ b/MoleculeSTM/datasets/__init__.py @@ -0,0 +1,8 @@ +from MoleculeSTM.datasets.PubChemSTM import PubChemSTM_Datasets_SMILES, PubChemSTM_SubDatasets_SMILES, PubChemSTM_Datasets_Graph, PubChemSTM_SubDatasets_Graph, PubChemSTM_Datasets_Only_SMILES, PubChemSTM_Datasets_SMILES_and_Graph +from MoleculeSTM.datasets.PubChemSTM_raw import PubChemSTM_Datasets_Raw_SMILES, PubChemSTM_SubDatasets_Raw_SMILES, PubChemSTM_Datasets_Raw_Graph, PubChemSTM_SubDatasets_Raw_Graph +from MoleculeSTM.datasets.MoleculeNetGraph import MoleculeNetGraphDataset +from MoleculeSTM.datasets.MoleculeNetSMILES import MoleculeNetSMILESDataset +from MoleculeSTM.datasets.DrugBankSMILES import DrugBank_Datasets_SMILES_retrieval, DrugBank_Datasets_SMILES_ATC +from MoleculeSTM.datasets.DrugBankGraph import DrugBank_Datasets_Graph_retrieval, DrugBank_Datasets_Graph_ATC +from MoleculeSTM.datasets.ZINC250K_SMILES import ZINC250K_Dataset_SMILES +from MoleculeSTM.datasets.ZINC250K_Graph import ZINC250K_Dataset_Graph \ No newline at end of file diff --git a/MoleculeSTM/datasets/utils.py b/MoleculeSTM/datasets/utils.py new file mode 100644 index 0000000..38446aa --- /dev/null +++ b/MoleculeSTM/datasets/utils.py @@ -0,0 +1,182 @@ + +import networkx as nx +import numpy as np +import torch +from rdkit import Chem +from torch_geometric.data import Data + + +allowable_features = { + 'possible_atomic_num_list': list(range(1, 119)), + 'possible_formal_charge_list': [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5], + 'possible_chirality_list': [ + Chem.rdchem.ChiralType.CHI_UNSPECIFIED, + Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW, + Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW, + Chem.rdchem.ChiralType.CHI_OTHER + ], + 'possible_hybridization_list': [ + Chem.rdchem.HybridizationType.S, + Chem.rdchem.HybridizationType.SP, + Chem.rdchem.HybridizationType.SP2, + Chem.rdchem.HybridizationType.SP3, + Chem.rdchem.HybridizationType.SP3D, + Chem.rdchem.HybridizationType.SP3D2, + Chem.rdchem.HybridizationType.UNSPECIFIED + ], + 'possible_numH_list': [0, 1, 2, 3, 4, 5, 6, 7, 8], + 'possible_implicit_valence_list': [0, 1, 2, 3, 4, 5, 6], + 'possible_degree_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + 'possible_bonds': [ + Chem.rdchem.BondType.SINGLE, + Chem.rdchem.BondType.DOUBLE, + Chem.rdchem.BondType.TRIPLE, + Chem.rdchem.BondType.AROMATIC + ], + 'possible_bond_dirs': [ # only for double bond stereo information + Chem.rdchem.BondDir.NONE, + Chem.rdchem.BondDir.ENDUPRIGHT, + Chem.rdchem.BondDir.ENDDOWNRIGHT + ] +} + + +def mol_to_graph_data_obj_simple(mol): + # atoms + # num_atom_features = 2 # atom type, chirality tag + atom_features_list = [] + for atom in mol.GetAtoms(): + atomic_num = atom.GetAtomicNum() + chiral_tag = atom.GetChiralTag() + if atomic_num == 0: + atomic_num = 118 # Only for one extreme case + atom_feature = [allowable_features['possible_atomic_num_list'].index(atomic_num)] + \ + [allowable_features['possible_chirality_list'].index(chiral_tag)] + atom_features_list.append(atom_feature) + x = torch.tensor(np.array(atom_features_list), dtype=torch.long) + + # bonds + if len(mol.GetBonds()) <= 0: # mol has no bonds + num_bond_features = 2 # bond type & direction + edge_index = torch.empty((2, 0), dtype=torch.long) + edge_attr = torch.empty((0, num_bond_features), dtype=torch.long) + else: # mol has bonds + edges_list = [] + edge_features_list = [] + for bond in mol.GetBonds(): + i = bond.GetBeginAtomIdx() + j = bond.GetEndAtomIdx() + bond_type = bond.GetBondType() + bond_dir = bond.GetBondDir() + if bond_dir not in allowable_features['possible_bond_dirs']: + bond_dir = 0 + edge_feature = [allowable_features['possible_bonds'].index(bond_type)] + \ + [allowable_features['possible_bond_dirs'].index(bond_dir)] + edges_list.append((i, j)) + edge_features_list.append(edge_feature) + edges_list.append((j, i)) + edge_features_list.append(edge_feature) + + # data.edge_index: Graph connectivity in COO format with shape [2, num_edges] + edge_index = torch.tensor(np.array(edges_list).T, dtype=torch.long) + + # data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features] + edge_attr = torch.tensor(np.array(edge_features_list), dtype=torch.long) + + data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr) + + return data + + +def graph_data_obj_to_mol_simple(data_x, data_edge_index, data_edge_attr): + mol = Chem.RWMol() + + # atoms + atom_features = data_x.cpu().numpy() + num_atoms = atom_features.shape[0] + for i in range(num_atoms): + atomic_num_idx, chirality_tag_idx = atom_features[i] + atomic_num = allowable_features['possible_atomic_num_list'][atomic_num_idx] + chirality_tag = allowable_features['possible_chirality_list'][chirality_tag_idx] + atom = Chem.Atom(atomic_num) + atom.SetChiralTag(chirality_tag) + mol.AddAtom(atom) + + # bonds + edge_index = data_edge_index.cpu().numpy() + edge_attr = data_edge_attr.cpu().numpy() + num_bonds = edge_index.shape[1] + for j in range(0, num_bonds, 2): + begin_idx = int(edge_index[0, j]) + end_idx = int(edge_index[1, j]) + bond_type_idx, bond_dir_idx = edge_attr[j] + bond_type = allowable_features['possible_bonds'][bond_type_idx] + bond_dir = allowable_features['possible_bond_dirs'][bond_dir_idx] + mol.AddBond(begin_idx, end_idx, bond_type) + # set bond direction + new_bond = mol.GetBondBetweenAtoms(begin_idx, end_idx) + new_bond.SetBondDir(bond_dir) + return mol + + +def graph_data_obj_to_nx_simple(data): + G = nx.Graph() + + # atoms + atom_features = data.x.cpu().numpy() + num_atoms = atom_features.shape[0] + for i in range(num_atoms): + atomic_num_idx, chirality_tag_idx = atom_features[i] + G.add_node(i, atom_num_idx=atomic_num_idx, + chirality_tag_idx=chirality_tag_idx) + pass + + # bonds + edge_index = data.edge_index.cpu().numpy() + edge_attr = data.edge_attr.cpu().numpy() + num_bonds = edge_index.shape[1] + for j in range(0, num_bonds, 2): + begin_idx = int(edge_index[0, j]) + end_idx = int(edge_index[1, j]) + bond_type_idx, bond_dir_idx = edge_attr[j] + if not G.has_edge(begin_idx, end_idx): + G.add_edge(begin_idx, end_idx, + bond_type_idx=bond_type_idx, + bond_dir_idx=bond_dir_idx) + + return G + + +def nx_to_graph_data_obj_simple(G): + # atoms + # num_atom_features = 2 # atom type, chirality tag + atom_features_list = [] + for _, node in G.nodes(data=True): + atom_feature = [node['atom_num_idx'], node['chirality_tag_idx']] + atom_features_list.append(atom_feature) + x = torch.tensor(np.array(atom_features_list), dtype=torch.long) + + # bonds + num_bond_features = 2 # bond type, bond direction + if len(G.edges()) > 0: # mol has bonds + edges_list = [] + edge_features_list = [] + for i, j, edge in G.edges(data=True): + edge_feature = [edge['bond_type_idx'], edge['bond_dir_idx']] + edges_list.append((i, j)) + edge_features_list.append(edge_feature) + edges_list.append((j, i)) + edge_features_list.append(edge_feature) + + # data.edge_index: Graph connectivity in COO format with shape [2, num_edges] + edge_index = torch.tensor(np.array(edges_list).T, dtype=torch.long) + + # data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features] + edge_attr = torch.tensor(np.array(edge_features_list), dtype=torch.long) + else: # mol has no bonds + edge_index = torch.empty((2, 0), dtype=torch.long) + edge_attr = torch.empty((0, num_bond_features), dtype=torch.long) + + data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr) + + return data diff --git a/MoleculeSTM/downstream_molecule_edit_utils.py b/MoleculeSTM/downstream_molecule_edit_utils.py new file mode 100644 index 0000000..c3a7a7c --- /dev/null +++ b/MoleculeSTM/downstream_molecule_edit_utils.py @@ -0,0 +1,502 @@ +import os +import copy +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import AutoModel, AutoTokenizer +from MoleculeSTM.models.mega_molbart.mega_mol_bart import MegaMolBART +from MoleculeSTM.models import GNN, GNN_graphpred, MLP +from rdkit import Chem, RDLogger +from rdkit.Chem import AllChem, Descriptors +from rdkit import DataStructs +lg = RDLogger.logger() +lg.setLevel(RDLogger.CRITICAL) + + +def get_SMILES_list(args): + if args.input_SMILES is not None: + SMILES_list = [args.input_SMILES] + else: + SMILES_list = [] + f = open(args.input_SMILES_file, 'r') + lines = f.readlines() + for line in lines: + SMILES = line.strip() + if len(SMILES) > 0: + SMILES_list.append(SMILES) + return SMILES_list + + +description_dict = { + 101: "This molecule is soluble in water.", + 102: "This molecule is insoluble in water.", + 103: "This molecule is like a drug.", + 104: "This molecule is not like a drug.", + 105: "This molecule has high permeability.", + 106: "This molecule has low permeability.", + 107: "This molecule has more hydrogen bond acceptors.", + 108: "This molecule has more hydrogen bond donors.", + 109: "This molecule has high bioavailability.", + 110: "This molecule has low toxicity.", + 111: "This molecule is metabolically stable.", + + 201: "This molecule is soluble in water and has more hydrogen bond acceptors.", + 202: "This molecule is insoluble in water and has more hydrogen bond acceptors.", + 203: "This molecule is soluble in water and has more hydrogen bond donors.", + 204: "This molecule is insoluble in water and has more hydrogen bond donors.", + 205: "This molecule is soluble in water and has high permeability.", + 206: "This molecule is soluble in water and has low permeability.", + + 301: "This molecule looks like Penicillin.", + 302: "This molecule looks like Aspirin.", + 303: "This molecule looks like Caffeine.", + 304: "This molecule looks like Cholesterol.", + 305: "This molecule looks like Dopamine.", + 306: "This molecule looks like Cysteine.", + 307: "This molecule looks like Glutathione.", + + 401: "This molecule is tested positive in an assay that are inhibitors and substrates of an enzyme protein. It uses molecular oxygen inserting one oxygen atom into a substrate, and reducing the second into a water molecule.", + 402: "This molecule is tested positive in an assay for Anthrax Lethal, which acts as a protease that cleaves the N-terminal of most dual specificity mitogen-activated protein kinase kinases.", + 403: "This molecule is tested positive in an assay for Activators of ClpP, which cleaves peptides in various proteins in a process that requires ATP hydrolysis and has a limited peptidase activity in the absence of ATP-binding subunits.", + 404: "This molecule is tested positive in an assay for activators involved in the transport of proteins between the endosomes and the trans Golgi network.", + 405: "This molecule is an inhibitor of a protein that prevents the establishment of the cellular antiviral state by inhibiting ubiquitination that triggers antiviral transduction signal and inhibits post-transcriptional processing of cellular pre-mRNA.", + 406: "This molecule is tested positive in the high throughput screening assay to identify inhibitors of the SARS coronavirus 3C-like Protease, which cleaves the C-terminus of replicase polyprotein at 11 sites.", +} + + +def get_description_list(args): + if args.input_description is not None: + description_list = [args.input_description] + elif args.input_description_id is None: + raise ValueError + else: + print("Use {} descrition.".format(args.input_description_id)) + description_list = [description_dict[args.input_description_id]] + print("description_list", description_list) + return description_list + + +# https://pubchem.ncbi.nlm.nih.gov/compound/5904 +# Penicillin_SMILES = "CC1(C(N2C(S1)C(C2=O)NC(=O)CC3=CC=CC=C3)C(=O)O)C" +Penicillin_SMILES = "CC1(C)SC2C(NC(=O)Cc3ccccc3)C(=O)N2C1C(=O)O" + +# https://pubchem.ncbi.nlm.nih.gov/compound/2244 +# Aspirin_SMILES = "CC(=O)OC1=CC=CC=C1C(=O)O" +Aspirin_SMILES = "CC(=O)Oc1ccccc1C(=O)O" + +# https://pubchem.ncbi.nlm.nih.gov/compound/2519 +# Caffeine_SMILES = "CN1C=NC2=C1C(=O)N(C(=O)N2C)C" +Caffeine_SMILES = "Cn1c(=O)c2c(ncn2C)n(C)c1=O" + +# https://pubchem.ncbi.nlm.nih.gov/compound/5997 +# Cholesterol_SMILES = "CC(C)CCCC(C)C1CCC2C1(CCC3C2CC=C4C3(CCC(C4)O)C)C" +Cholesterol_SMILES = "CC(C)CCCC(C)C1CCC2C3CC=C4CC(O)CCC4(C)C3CCC12C" + +# https://pubchem.ncbi.nlm.nih.gov/compound/681 +# Dopamine_SMILES = "C1=CC(=C(C=C1CCN)O)O" +Dopamine_SMILES = "NCCc1ccc(O)c(O)c1" + +# https://pubchem.ncbi.nlm.nih.gov/compound/5862 +# Cysteine_SMILES = "C(C(C(=O)O)N)S" +Cysteine_SMILES = "NC(CS)C(=O)O" + +# https://pubchem.ncbi.nlm.nih.gov/compound/124886 +# Glutathione_SMILES = "C(CC(=O)NC(CS)C(=O)NCC(=O)O)C(C(=O)O)N" +Glutathione_SMILES = "NC(CCC(=O)NC(CS)C(=O)NCC(=O)O)C(=O)O" + + +def load_molecule_models(args): + """ + This function returns the two encoders, one for molecule generative model and one for CLIP. + """ + if args.MoleculeSTM_molecule_type == "SMILES": + # This is loading from the pretarined_MegaMolBART + MegaMolBART_wrapper = MegaMolBART(vocab_path=args.vocab_path, input_dir=args.MegaMolBART_generation_model_dir, output_dir=None) + molecule_model_generation = copy.deepcopy(MegaMolBART_wrapper.model) + print("Loading from pretrained MegaMolBART ({}).".format(args.MegaMolBART_generation_model_dir)) + molecule_dim_generation = 256 + + input_model_path = os.path.join(args.MoleculeSTM_model_dir, "molecule_model.pth") + molecule_model_MoleculeSTM = MegaMolBART_wrapper.model + state_dict = torch.load(input_model_path, map_location='cpu') + print("Loading from {}...".format(input_model_path)) + molecule_model_MoleculeSTM.load_state_dict(state_dict) + molecule_dim_MoleculeSTM = args.SSL_emb_dim + + mol2latent_MoleculeSTM = nn.Linear(256, molecule_dim_MoleculeSTM) + input_model_path = os.path.join(args.MoleculeSTM_model_dir, "mol2latent_model.pth") + print("Loading from {}...".format(input_model_path)) + state_dict = torch.load(input_model_path, map_location='cpu') + mol2latent_MoleculeSTM.load_state_dict(state_dict) + + else: + # This is loading from the pretarined_MegaMolBART + MegaMolBART_wrapper = MegaMolBART(vocab_path=args.vocab_path, input_dir=args.MegaMolBART_generation_model_dir, output_dir=None) + molecule_model_generation = copy.deepcopy(MegaMolBART_wrapper.model) + print("Loading from pretrained MegaMolBART ({}).".format(args.MegaMolBART_generation_model_dir)) + molecule_dim_generation = 256 + + # This is loading GNN from the pretrained_GNN + molecule_node_model = GNN(num_layer=args.num_layer, emb_dim=args.gnn_emb_dim, JK=args.JK, drop_ratio=args.dropout_ratio, gnn_type=args.gnn_type) + molecule_model_MoleculeSTM = GNN_graphpred(num_layer=args.num_layer, emb_dim=args.gnn_emb_dim, JK=args.JK, graph_pooling=args.graph_pooling, num_tasks=1, molecule_node_model=molecule_node_model) + print("Start from pretrained model (MoleculeSTM) in {}.".format(args.MoleculeSTM_model_dir)) + input_model_path = os.path.join(args.MoleculeSTM_model_dir, "molecule_model.pth") + state_dict = torch.load(input_model_path, map_location='cpu') + molecule_model_MoleculeSTM.load_state_dict(state_dict) + molecule_dim_MoleculeSTM = args.SSL_emb_dim + + mol2latent_MoleculeSTM = nn.Linear(300, molecule_dim_MoleculeSTM) + input_model_path = os.path.join(args.MoleculeSTM_model_dir, "mol2latent_model.pth") + print("Loading from {}...".format(input_model_path)) + state_dict = torch.load(input_model_path, map_location='cpu') + mol2latent_MoleculeSTM.load_state_dict(state_dict) + + return MegaMolBART_wrapper, molecule_model_generation, molecule_dim_generation, \ + molecule_model_MoleculeSTM, mol2latent_MoleculeSTM, molecule_dim_MoleculeSTM + + +def load_language_molecule_and_edit_models(args): + pretrained_SciBERT_folder = os.path.join(args.dataspace_path, 'pretrained_SciBERT') + text_tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased', cache_dir=pretrained_SciBERT_folder) + text_model = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased', cache_dir=pretrained_SciBERT_folder) + text_dim = 768 + + input_model_path = os.path.join(args.MoleculeSTM_model_dir, "text_model.pth") + print("Loading from {}...".format(input_model_path)) + state_dict = torch.load(input_model_path, map_location='cpu') + text_model.load_state_dict(state_dict) + + """ + input_model_path = os.path.join(args.MoleculeSTM_model_dir, "molecule_model.pth") + print("Loading from {}...".format(input_model_path)) + MegaMolBART_wrapper = MegaMolBART(input_dir=None, output_dir=None) + molecule_model = MegaMolBART_wrapper.model + state_dict = torch.load(input_model_path, map_location='cpu') + molecule_model.load_state_dict(state_dict) + """ + # This is loading from the pretarined_MegaMolBART + MegaMolBART_wrapper = MegaMolBART(vocab_path=args.vocab_path, input_dir=args.MegaMolBART_generation_model_dir, output_dir=None) + molecule_model = MegaMolBART_wrapper.model + print("Loading from pretrained MegaMolBART ({}).".format(args.MegaMolBART_generation_model_dir)) + molecule_dim_generation = 256 + if args.MoleculeSTM_molecule_type == "SMILES": # For MegaMolBART + molecule_dim_MoleculeSTM = 256 + else: # For GIN + molecule_dim_MoleculeSTM = 300 + + text2latent = nn.Linear(text_dim, args.SSL_emb_dim) + input_model_path = os.path.join(args.MoleculeSTM_model_dir, "text2latent_model.pth") + print("Loading from {}...".format(input_model_path)) + state_dict = torch.load(input_model_path, map_location='cpu') + text2latent.load_state_dict(state_dict) + + mol2latent = nn.Linear(molecule_dim_MoleculeSTM, args.SSL_emb_dim) + input_model_path = os.path.join(args.MoleculeSTM_model_dir, "mol2latent_model.pth") + print("Loading from {}...".format(input_model_path)) + state_dict = torch.load(input_model_path, map_location='cpu') + mol2latent.load_state_dict(state_dict) + + # generation2MoleculeSTM = nn.Linear(molecule_dim_generation, args.SSL_emb_dim) + generation2MoleculeSTM = MLP(molecule_dim_generation, [args.SSL_emb_dim, args.SSL_emb_dim]) + input_model_path = os.path.join(args.language_edit_model_dir, "generation2foundation_model.pth") + print("Loading from {}...".format(input_model_path)) + state_dict = torch.load(input_model_path, map_location='cpu') + generation2MoleculeSTM.load_state_dict(state_dict) + + # MoleculeSTM2generation = nn.Linear(args.SSL_emb_dim, molecule_dim_generation) + MoleculeSTM2generation = MLP(args.SSL_emb_dim, [molecule_dim_generation, molecule_dim_generation]) + input_model_path = os.path.join(args.language_edit_model_dir, "foundation2generation_model.pth") + print("Loading from {}...".format(input_model_path)) + state_dict = torch.load(input_model_path, map_location='cpu') + MoleculeSTM2generation.load_state_dict(state_dict) + + return text_model, text_tokenizer, text_dim, molecule_model, MegaMolBART_wrapper, molecule_dim_generation, text2latent, mol2latent, generation2MoleculeSTM, MoleculeSTM2generation + + +def clip_loss_for_edit(molecule_repr, text_repr): + molecule_repr = F.normalize(molecule_repr, dim=-1) + text_repr = F.normalize(text_repr, dim=-1) + + similarity = -torch.mm(molecule_repr, text_repr.transpose(0, 1))[0] + return similarity + + +def get_molecule_similarity(mol_a, mol_b): + fp_a = AllChem.GetMorganFingerprintAsBitVect(mol_a, 2, nBits=1024) + fp_b = AllChem.GetMorganFingerprintAsBitVect(mol_b, 2, nBits=1024) + sim = DataStructs.TanimotoSimilarity(fp_a, fp_b) + return sim + + +def evaluate_SMILES_list(SMILES_list, description): + print("SMILES_list:", SMILES_list) + mol_list = [] + for SMILES in SMILES_list: + mol = Chem.MolFromSmiles(SMILES) + # Chem.SanitizeMol(mol) + # print(SMILES, mol) + if mol is None: + continue + mol_list.append(mol) + print("valid mol list:", len(mol_list)) + + if len(mol_list) < 3: + return [False] + + if "soluble" in description and "insoluble" not in description: + props = ["MolLogP"] + prop_pred = [(n, func) for n, func in Descriptors.descList if n.split("_")[-1] in props] + value_list = [] + for name, func in prop_pred: + for SMILES, mol in zip(SMILES_list, mol_list): + value = func(mol) + value_list.append(value) + print("{} & {:.5f}".format(SMILES, value)) + if value_list[0] > value_list[2]: + answer = [True] + else: + answer = [False] + + elif "insoluble" in description: + props = ["MolLogP"] + prop_pred = [(n, func) for n, func in Descriptors.descList if n.split("_")[-1] in props] + value_list = [] + for name, func in prop_pred: + for SMILES, mol in zip(SMILES_list, mol_list): + value = func(mol) + value_list.append(value) + print("{} & {:.5f}".format(SMILES, value)) + if value_list[0] < value_list[2]: + answer = [True] + else: + answer = [False] + + elif description in ["This molecule is more like a drug.", "This molecule is like a drug."]: + props = ["qed"] + prop_pred = [(n, func) for n, func in Descriptors.descList if n.split("_")[-1] in props] + value_list = [] + for name, func in prop_pred: + for SMILES, mol in zip(SMILES_list, mol_list): + value = func(mol) + value_list.append(value) + print("{} & {:.5f}".format(SMILES, value)) + if value_list[0] < value_list[2]: + answer = [True] + else: + answer = [False] + + elif description in ["This molecule is less like a drug.", "This molecule is not like a drug."]: + props = ["qed"] + prop_pred = [(n, func) for n, func in Descriptors.descList if n.split("_")[-1] in props] + value_list = [] + for name, func in prop_pred: + for SMILES, mol in zip(SMILES_list, mol_list): + value = func(mol) + value_list.append(value) + print("{} & {:.5f}".format(SMILES, value)) + if value_list[0] > value_list[2]: + answer = [True] + else: + answer = [False] + + elif description in ["This molecule has higher permeability.", "This molecule has high permeability."]: + props = ["TPSA"] + prop_pred = [(n, func) for n, func in Descriptors.descList if n.split("_")[-1] in props] + value_list = [] + for name, func in prop_pred: + for SMILES, mol in zip(SMILES_list, mol_list): + value = func(mol) + value_list.append(value) + print("{} & {:.5f}".format(SMILES, value)) + if value_list[0] > value_list[2]: + answer = [True] + else: + answer = [False] + + elif description in ["This molecule has lower permeability.", "This molecule has low permeability."]: + props = ["TPSA"] + prop_pred = [(n, func) for n, func in Descriptors.descList if n.split("_")[-1] in props] + value_list = [] + for name, func in prop_pred: + for SMILES, mol in zip(SMILES_list, mol_list): + value = func(mol) + value_list.append(value) + print("{} & {:.5f}".format(SMILES, value)) + if value_list[0] < value_list[2]: + answer = [True] + else: + answer = [False] + + elif description in ["This molecule has higher molecular weight.", "This molecule has high molecular weight."]: + props = ["MolWt"] + prop_pred = [(n, func) for n, func in Descriptors.descList if n.split("_")[-1] in props] + value_list = [] + for name, func in prop_pred: + for SMILES, mol in zip(SMILES_list, mol_list): + value = func(mol) + value_list.append(value) + print("{} & {:.5f}".format(SMILES, value)) + if value_list[0] < value_list[2]: + answer = [True] + else: + answer = [False] + + elif description in ["This molecule has lower molecular weight.", "This molecule has low molecular weight."]: + props = ["MolWt"] + prop_pred = [(n, func) for n, func in Descriptors.descList if n.split("_")[-1] in props] + value_list = [] + for name, func in prop_pred: + for SMILES, mol in zip(SMILES_list, mol_list): + value = func(mol) + value_list.append(value) + print("{} & {:.5f}".format(SMILES, value)) + if value_list[0] > value_list[2]: + answer = [True] + else: + answer = [False] + + elif description in ["This molecule has more hydrogen bond acceptors."]: + props = ["NumHAcceptors"] + prop_pred = [(n, func) for n, func in Descriptors.descList if n.split("_")[-1] in props] + value_list = [] + for name, func in prop_pred: + for SMILES, mol in zip(SMILES_list, mol_list): + value = func(mol) + value_list.append(value) + print("{} & {:.5f}".format(SMILES, value)) + if value_list[0] < value_list[2]: + answer = [True] + else: + answer = [False] + + elif description in ["This molecule has more hydrogen bond donors."]: + props = ["NumHDonors"] + prop_pred = [(n, func) for n, func in Descriptors.descList if n.split("_")[-1] in props] + value_list = [] + for name, func in prop_pred: + for SMILES, mol in zip(SMILES_list, mol_list): + value = func(mol) + value_list.append(value) + print("{} & {:.5f}".format(SMILES, value)) + if value_list[0] < value_list[2]: + answer = [True] + else: + answer = [False] + + elif "penicillin" in description or "Penicillin" in description: + target_mol = Chem.MolFromSmiles(Penicillin_SMILES) + original_SMILES = SMILES_list[0] + original_mol = mol_list[0] + original_similarity = get_molecule_similarity(target_mol, original_mol) + print("similarity between penicillin and original molecules\n{} & {:.5f}".format(original_SMILES, original_similarity)) + + edited_SMILES = SMILES_list[2] + edited_mol = mol_list[2] + edited_similarity = get_molecule_similarity(target_mol, edited_mol) + print("similarity between penicillin and edited molecules\n{} & {:.5f}".format(edited_SMILES, edited_similarity)) + if edited_similarity > original_similarity: + answer = [True] + else: + answer = [False] + + elif "aspirin" in description or "Aspirin" in description: + target_mol = Chem.MolFromSmiles(Aspirin_SMILES) + original_SMILES = SMILES_list[0] + original_mol = mol_list[0] + original_similarity = get_molecule_similarity(target_mol, original_mol) + print("similarity between aspirin and original molecules\n{} & {:.5f}".format(original_SMILES, original_similarity)) + + edited_SMILES = SMILES_list[2] + edited_mol = mol_list[2] + edited_similarity = get_molecule_similarity(target_mol, edited_mol) + print("similarity between aspirin and edited molecules\n{} & {:.5f}".format(edited_SMILES, edited_similarity)) + if edited_similarity > original_similarity: # check original_similarity >< 0.8 + answer = [True] + else: + answer = [False] + + elif "caffeine" in description or "Caffeine" in description: + target_mol = Chem.MolFromSmiles(Caffeine_SMILES) + original_SMILES = SMILES_list[0] + original_mol = mol_list[0] + original_similarity = get_molecule_similarity(target_mol, original_mol) + print("similarity between caffeine and original molecules\n{} & {:.5f}".format(original_SMILES, original_similarity)) + + edited_SMILES = SMILES_list[2] + edited_mol = mol_list[2] + edited_similarity = get_molecule_similarity(target_mol, edited_mol) + print("similarity between caffeine and edited molecules\n{} & {:.5f}".format(edited_SMILES, edited_similarity)) + if edited_similarity > original_similarity: + answer = [True] + else: + answer = [False] + + elif "cholesterol" in description or "Cholesterol" in description: + target_mol = Chem.MolFromSmiles(Cholesterol_SMILES) + original_SMILES = SMILES_list[0] + original_mol = mol_list[0] + original_similarity = get_molecule_similarity(target_mol, original_mol) + print("similarity between cholesterol and original molecules\n{} & {:.5f}".format(original_SMILES, original_similarity)) + + edited_SMILES = SMILES_list[2] + edited_mol = mol_list[2] + edited_similarity = get_molecule_similarity(target_mol, edited_mol) + print("similarity between cholesterol and edited molecules\n{} & {:.5f}".format(edited_SMILES, edited_similarity)) + if edited_similarity > original_similarity: # check original_similarity >< 0.8 + answer = [True] + else: + answer = [False] + + elif "dopamine" in description or "Dopamine" in description: + target_mol = Chem.MolFromSmiles(Dopamine_SMILES) + original_SMILES = SMILES_list[0] + original_mol = mol_list[0] + original_similarity = get_molecule_similarity(target_mol, original_mol) + print("similarity between dopamine and original molecules\n{} & {:.5f}".format(original_SMILES, original_similarity)) + + edited_SMILES = SMILES_list[2] + edited_mol = mol_list[2] + edited_similarity = get_molecule_similarity(target_mol, edited_mol) + print("similarity between dopamine and edited molecules\n{} & {:.5f}".format(edited_SMILES, edited_similarity)) + if edited_similarity > original_similarity: + answer = [True] + else: + answer = [False] + + elif "cysteine" in description or "Cysteine" in description: + target_mol = Chem.MolFromSmiles(Cysteine_SMILES) + original_SMILES = SMILES_list[0] + original_mol = mol_list[0] + original_similarity = get_molecule_similarity(target_mol, original_mol) + print("similarity between cysteine and original molecules\n{} & {:.5f}".format(original_SMILES, original_similarity)) + + edited_SMILES = SMILES_list[2] + edited_mol = mol_list[2] + edited_similarity = get_molecule_similarity(target_mol, edited_mol) + print("similarity between cysteine and edited molecules\n{} & {:.5f}".format(edited_SMILES, edited_similarity)) + if edited_similarity > original_similarity: # check original_similarity >< 0.8 + answer = [True] + else: + answer = [False] + + elif "glutathione" in description or "Glutathione" in description: + target_mol = Chem.MolFromSmiles(Glutathione_SMILES) + original_SMILES = SMILES_list[0] + original_mol = mol_list[0] + original_similarity = get_molecule_similarity(target_mol, original_mol) + print("similarity between glutathione and original molecules\n{} & {:.5f}".format(original_SMILES, original_similarity)) + + edited_SMILES = SMILES_list[2] + edited_mol = mol_list[2] + edited_similarity = get_molecule_similarity(target_mol, edited_mol) + print("similarity between glutathione and edited molecules\n{} & {:.5f}".format(edited_SMILES, edited_similarity)) + if edited_similarity > original_similarity: # check original_similarity >< 0.8 + answer = [True] + else: + answer = [False] + + else: + print("Not implemented.") + answer = [False] + + return answer \ No newline at end of file diff --git a/MoleculeSTM/models/GA/ZINC_first_1000.smi b/MoleculeSTM/models/GA/ZINC_first_1000.smi new file mode 100644 index 0000000..3d4698d --- /dev/null +++ b/MoleculeSTM/models/GA/ZINC_first_1000.smi @@ -0,0 +1,1000 @@ +CC(C)(C)c1ccc2occ(CC(=O)Nc3ccccc3F)c2c1 +C[C@@H]1CC(Nc2cncc(-c3nncn3C)c2)C[C@@H](C)C1 +N#Cc1ccc(-c2ccc(O[C@@H](C(=O)N3CCCC3)c3ccccc3)cc2)cc1 +CCOC(=O)[C@@H]1CCCN(C(=O)c2nc(-c3ccc(C)cc3)n3c2CCCCC3)C1 +N#CC1=C(SCC(=O)Nc2cccc(Cl)c2)N=C([O-])[C@H](C#N)C12CCCCC2 +CC[NH+](CC)[C@](C)(CC)[C@H](O)c1cscc1Br +COc1ccc(C(=O)N(C)[C@@H](C)C/C(N)=N/O)cc1O +O=C(Nc1nc[nH]n1)c1cccnc1Nc1cccc(F)c1 +Cc1c(/C=N/c2cc(Br)ccn2)c(O)n2c(nc3ccccc32)c1C#N +C[C@@H]1CN(C(=O)c2cc(Br)cn2C)CC[C@H]1[NH3+] +CCOc1ccc(OCC)c([C@H]2C(C#N)=C(N)N(c3ccccc3C(F)(F)F)C3=C2C(=O)CCC3)c1 +Cc1ccc2nc(S[C@H](C)C(=O)NC3CCC(C)CC3)n(C)c(=O)c2c1 +O=C(N1CCc2c(F)ccc(F)c2C1)C1(O)Cc2ccccc2C1 +Cc1ccccc1C(=O)N1CCC2(CC1)C[C@H](c1ccccc1)C(=O)N2C +CCCc1cc(NC(=O)CN2C(=O)NC3(CCC(C)CC3)C2=O)n(C)n1 +CC(C)Cc1nc(SCC(=O)NC[C@@H]2CCCO2)c2c(=O)n(C)c(=O)n(C)c2n1 +Cc1ccc(CNC(=O)c2ccccc2NC(=O)[C@@H]2CC(=O)N(c3ccc(C)cc3)C2)cc1 +CCCCC(=O)NC(=S)Nc1ccccc1C(=O)N1CCOCC1 +Cc1c(NC(=O)CSc2nc3sc4c(c3c(=O)[nH]2)CCCC4)c(=O)n(-c2ccccc2)n1C +CC(C)[C@@H](Oc1cccc(Cl)c1)C(=O)N1CCC(n2cccn2)CC1 +CCN(CC)C(=O)C[C@@H](C)[NH2+][C@H](C)c1cccc(F)c1F +Cc1nc2c(c(Nc3ncc(C)s3)n1)CCN(C(=O)CCc1ccccc1)C2 +O=C(NCCNC(=O)N1C[C@H]2CC=CC[C@@H]2C1)c1cccnc1 +O=c1n(CCO)c2ccccc2n1CCO +COC(=O)Cc1csc(NC(=O)Cc2coc3cc(C)ccc23)n1 +Cc1ccc(N2CC[C@@H](NS(=O)(=O)c3ccccc3C)C2=O)cc1C +CC[C@H](C)C[C@@H](C)NC(=O)N1CCN(CC(=O)NC2CC2)CC1 +CC(=O)Nc1c2n(c3ccccc13)C[C@](C)(C(=O)NC1CCCCC1)N(C1CCCCC1)C2=O +N#Cc1ccncc1NC[C@@H]1C[C@@]12CCc1ccccc12 +Cc1cccn2c(=O)c(C(=O)NC[C@H]3CCO[C@@H]3C(C)C)cnc12 +CNC(=O)c1ccc(/C=C/C(=O)Nc2c(C)cc(C)nc2Cl)cc1 +CC1=C(CNC(=O)c2cc(-c3ccccc3)nc3c2CNN3C(C)C)CN=N1 +C[C@@H](NC(=O)COC(=O)/C=C/c1ccc(Cl)cc1)c1ccccc1 +CCc1ccc(N(Cc2ccc(C)s2)C(=O)c2ccc(=O)n(C)n2)cc1 +CCOC(=O)c1nnc2ccccc2c1N1CC[C@@H]([NH+](CC)CC)C1 +Cc1ccc(C#N)cc1S(=O)(=O)NCc1ccnc(OC(C)(C)C)c1 +O=C(O[C@H]1CCOC1)C1(c2ccc(Cl)c(Cl)c2)CCC1 +CCC[NH2+][C@@H]1COC[C@H]1C(=O)NCc1cscc1C +O=C(NCc1nccc2ccccc12)c1ccc[nH]c1=O +CC(=O)c1ccc(S(=O)(=O)N2CCCC[C@H]2C)cc1 +O=[N+]([O-])c1c(Nc2cccc3ncccc23)ncnc1N1CCN(c2cccc(Cl)c2)CC1 +O=C(CCCO)Nc1ccc(F)cc1F +NC(=O)CCOc1ccc(NC(=O)C[C@H]2CCc3ccccc32)cc1 +COc1cc(C)ccc1OCC(=O)Nc1nnc(C)s1 +CC(=O)c1c(O)cccc1COc1ccccc1 +CCn1cc(S(=O)(=O)N2CCCCC[C@@H]2c2cc(-c3ccc(F)cc3)no2)cn1 +COC(=O)[C@](NC(=O)c1cccc(Cl)c1)(Nc1ccc(Br)c[nH+]1)C(F)(F)F +Cc1[nH]c2ccc(C(=O)Nc3cc(C(C)(C)C)nn3-c3ncccn3)cc2c1C +Cc1noc(C)c1C[C@H](C)C(=O)N[C@@H](C)C1CCCCC1 +CCn1cc(C(=O)N[C@H]2CC(=O)N(C)C2)c(C(C)C)n1 +COc1cccc(-c2cncc3ccccc23)c1C(=O)N(C(C)C)C(C)C +COc1ccc([C@@H](C)NC(=O)Cc2cccc3ccccc23)cc1 +O=C1C[C@H](c2nc(-c3cccnc3)no2)CN1c1cccc(Cl)c1 +C[C@H]1CCCN(c2ccc(C(=O)Nc3ccc(N4CCOCC4)cc3)cc2[N+](=O)[O-])C1 +C=CCN(C(=O)C/C=C/c1ccc(C)cc1)[C@@H]1CCS(=O)(=O)C1 +O=C(CSc1nnc(-c2cccc([N+](=O)[O-])c2)o1)Nc1nncs1 +CN(CCc1ccc(F)cc1)c1cc(Br)cc(F)c1C(N)=O +COc1ccccc1NC(=O)CSc1ccc(-c2ccccc2OC)nn1 +Cc1occc1C(=O)/C(C#N)=C\c1cccc(C(F)(F)F)c1 +COc1ccc2c(c1)N(C(=O)CCSc1ccccn1)C[C@@H](C)O2 +CC[C@@H](NC(=O)[C@H](C)n1cccn1)c1ccc(C)c(F)c1 +CCC[C@]1(C(=O)N[C@@H]2CONC2=O)CC[NH2+]C1 +O=C(c1cc2cc([N+](=O)[O-])ccc2oc1=O)N1CCN(Cc2ccccc2)CC1 +CCn1c(CC2CC[NH2+]CC2)nn(CCO)c1=O +C=CCn1c(S[C@H](C)C(=O)N2CCC(C)CC2)nnc1-c1ccc(Cl)cc1 +CCO[C@H]1C(=O)O[C@H]([C@@H](O)CO)C1=O +Cc1ccc(-c2nnc(C[NH+](CCO)[C@H]3CCc4ccccc43)o2)cc1 +Cc1cc(-n2c(C)cc(C[NH2+][C@H](C)c3ccc(F)c(F)c3)c2C)no1 +C[C@@H](NC(=O)Nc1ccn(-c2ncccc2Cl)n1)[C@@H]1CCCO1 +COc1cc(S(=O)(=O)N2CCN=C2Cc2ccccc2)ccc1Cl +COc1ccc(OC)c(/C=C2\Oc3cc(OC(=O)c4ccncc4)cc(C)c3C2=O)c1 +COc1ccc([C@@H](NC(=O)Nc2cc(C)ccc2Cl)C2CCOCC2)cc1 +C[C@H](Cc1cccs1)N(C)C[C@@H]1CCCC[C@@H]1[NH3+] +C[C@H]([NH3+])c1nc2cc(C(F)(F)F)ccc2n1C +COc1cccc(CN2CCc3nnc(CCc4ccccc4)n3CC2)c1 +O=C(N[C@H]1CCS(=O)(=O)C1)C1CC[NH2+]CC1 +COCC[C@H](C)C(=O)N[C@@H](C)COC +Cc1cc(N(C)C)ccc1NC(=O)c1ccc(CN2CC[NH+](C)CC2)cc1 +C[C@H](CNC(=O)[C@H]1CC[NH2+][C@@H]1C)C[NH+]1CCCC1 +CN(C)c1ccc([C@H](CNC(=O)C(=O)Nc2ccccc2C#N)N2CC[NH+](C)CC2)cc1 +CCOc1ncnc(S(=O)(=O)CC)c1N +CC[C@@H](NC(=O)N(C)Cc1ccc(-c2ccccc2)cc1)c1ccncc1 +O=C(Nc1ccc(-c2nc3ccccc3o2)cc1)[C@H]1CCCN1S(=O)(=O)c1ccc(F)cc1 +CC[C@@H](C)CNc1nc2ccc(Cl)cc2s1 +Cc1cc(C)c2nc(N3CCN(C(=O)[C@@H]4CCCCN4S(C)(=O)=O)CC3)sc2c1 +CCc1nnc(-c2cc3ccccc3n2CC(=O)NC(C)(C)C)o1 +CCc1ccc(NC(=O)c2nn(-c3ccc(CC)cc3)ccc2=O)cc1 +Cc1ccc(C(=O)N[C@H]2CCC[NH2+][C@H]2C)cc1F +C[C@H](OC(=O)c1nc(C2CC2)n2ccccc12)c1cnc2ccccc2c1 +CCCCOc1ccccc1C[C@@H]([NH3+])C(=O)[O-] +CCC[C@@H]1CN(C(=O)C(=O)Nc2ccc(C)nc2Cl)CCO1 +C[C@H]1C(=O)N(c2ccc3c(c2)CCC3)CCN1C(=O)c1ccc(Cl)c(Cl)c1 +COC(=O)C1(NC(=O)[C@H]2C[C@H]2c2c(F)cccc2F)CCSCC1 +N#CC1(NC(=O)COc2cccc(Cl)c2)CCCC1 +COC1CC[NH+](CCNc2nccn(C)c2=O)CC1 +C=CCN(Cc1cccc([N+](=O)[O-])c1)C(=O)Nc1cc(OC)ccc1Cl +Cc1cc(Cl)ccc1OCC(=O)N/N=C/c1ccccn1 +O=C1NC(=S)NC(=O)C1=CNc1ccc([N+](=O)[O-])cc1O +Cc1c(C(=O)N2CCOCC2)oc2c1-c1nn(CC(=O)NCc3ccco3)cc1CC2 +CCc1ccc(CNC(=O)c2ccc(-c3nccnc3N3CCCCC3)cc2)cc1 +COc1ccc([C@H]2C[C@@H](C(F)(F)F)n3nc(C(=O)NC4CCCCC4)cc3N2)cc1OC +CCCc1cc(C(=O)NNC(=O)c2cccc(Br)c2)[nH]n1 +O=[N+]([O-])c1c(Nc2ccc(F)c(F)c2)ncnc1Oc1cccc2cccnc12 +CC(C)(C)n1ncc2c1CCC[C@H]2NC(=O)CSc1nc2ccccc2o1 +Cc1ccc([C@@H](C)NC(=O)N[C@@H](CCO)c2cccs2)cc1 +CCN(CC)S(=O)(=O)c1ccc2nc(-c3ccncc3)cc(C(=O)[O-])c2c1 +CCCN(CC)c1cc[nH+]c(C(=O)[O-])c1 +Cc1ccccc1N1C(=O)C[C@H]([NH+](C2CCCCC2)C2CCCCC2)C1=O +CS(=O)(=O)[C@H]1O[C@H]1c1ccc(Cl)cc1Cl +C[C@H](CSc1ccc(C(=O)N(C)C)cn1)C(=O)[O-] +CCOC(=O)[C@H]1C=C(C#N)O[C@@H](c2ccc(C)cc2)C1 +CCC[NH2+]C1CCC(O)(Cc2nc(C)cs2)CC1 +O=C1c2ccccc2N[C@H](CSC2=NC=NC3=NC=N[C@@H]32)N1c1ccc(Cl)cc1 +CCc1nc2ccccc2c(C(=O)NCc2ccc(OC)c(C(=O)OC)c2)c1C +Cc1nn(-c2nncc(-c3ccc(Cl)cc3)n2)c2c1[C@H](c1ccccc1)CC(=O)N2 +CC(C)NC(=O)Nc1cccc(C(=O)N(C)Cc2nnc(C3CC3)n2C)c1 +Cc1nnccc1C(=O)N[C@H](C)c1ccc(Cl)s1 +O=C(NCCCc1nc(-c2ccc(Br)o2)no1)C1CC1 +CCOc1cccc(NC(=O)CCc2ccc(N)cc2)c1 +COc1ccc(/C=C2/SC(=O)N(CC(=O)Nc3ccc(F)cc3)C2=O)cc1OC +CC(C)N(C)C(=O)[C@H]1CSCN1C(=O)/C=C/SCc1ccco1 +CCC1(CC)[C@@H](NC(=O)Nc2ccc(C(=O)NC)cc2)[C@H](C)[C@@H]1OC +Cc1ccc2ncc(C(=O)Nc3ncccc3OCc3ccncc3)n2c1 +N#Cc1ccnc(N2CCC([NH2+]C[C@@H]3CCCO3)CC2)c1 +O=C(Cn1nnn(-c2cccs2)c1=O)OC1CCCCC1 +COc1cccc(COc2ccc(OC)cc2CCl)c1 +CC[C@H](NC(=O)NCc1c(C)noc1C)c1ccc(OC)cc1 +CCc1c(C(=O)Nc2ccc3c(c2)NC(=O)CS3)[nH]c(C)c1C(C)=O +CSc1ccc(/C=c2\sc3ncnn3c2=O)cc1 +COc1ccc(C(=O)/C=C(\C)Nc2ccc(F)cc2F)cc1 +O=C(C=C1CCSCC1)N[C@@H]1CCC[C@H]1Cc1ccccc1 +C[C@@H](Sc1nc(/C=C/c2cccs2)n[nH]1)C(=O)N1CCOCC1 +Cc1cccc(C(=O)N[C@@H](C(=O)N2CCC[C@@H](C)C2)C(C)C)c1 +COC(=O)[C@@H](C)Sc1nnc(Nc2cccc(Br)c2)s1 +Cc1cnc(C(=O)Nc2ccc(N(C)C3CC[NH+](C)CC3)cc2)cn1 +CC(=O)Nc1nc2ccc(NC(=O)NCc3ccccc3)cc2s1 +C=CCN1C(=O)/C(=C/c2ccccc2F)S/C1=N\S(=O)(=O)c1cccs1 +CC[C@H]1CC(=O)N(Cc2ccccc2C#CCCO)C1 +COc1ccc(F)cc1C(=O)Nc1nccs1 +CC[C@@H](CC(=O)NC1(C(=O)OC)CCSCC1)c1ccccc1 +Cc1ccc(C)c(-n2c(SCCCCCO)nc3ccccc3c2=O)c1 +CC[C@@H](C)[C@@H]([NH3+])c1ccc(Cl)s1 +COc1cccc([C@@H](C)[NH2+]CCOc2ncccc2Cl)c1 +CC[C@@H]1CCCCN1C(=O)NC1CCN(C(=O)OC(C)(C)C)CC1 +CCOc1cccc(NC(=O)NCc2ccc(N3CCSCC3)cc2)c1 +Cc1cccc(NC(=O)CN2CCN(c3ccc4c(c3)OCCO4)C2=O)n1 +C=C(C)C(=O)N[C@H](C)c1nc2ccccc2n1CCC(=O)N1CCCCCC1 +CCOC[C@H]1CC[NH+](Cc2ccc(-c3nc4ccccc4s3)o2)C1 +CCOC(=O)[C@]1(Cc2cccc(Cl)c2)CCCN(C(=O)c2ccnn2C)C1 +Cc1ccc([N+](=O)[O-])cc1NC(=O)C(=O)N1CC[C@H]([NH+]2CCCC2)C1 +CCOCCCNC(=O)N[C@@H]1CCC[C@@H](CC)C1 +O=C(Cc1cccc(F)c1F)Nc1cccc(Br)n1 +COc1ccccc1NC(=O)[C@@H]1CCCN(C(=O)Nc2cccs2)C1 +C[C@H]1CCC[C@](C#N)([C@]2(O)CCCCC2(C)C)C1 +O=C(NCc1ccc([N+]2=CCCC2)cc1)NC1(c2ccc(Cl)cc2)CC1 +CCCC(=O)N[C@@H]1CCC[NH+](Cc2ncccc2C)C1 +O=C(NCc1cccs1)C1(c2cccc(Cl)c2)CCC1 +C[C@H]1CCC[C@H](NC(=O)[C@@H](C)Sc2ncn[nH]2)[C@@H]1C +COc1ccc([C@@H]([NH2+]Cc2ccc(Cl)nc2)c2ccc(F)cc2)cc1 +COc1cc(NC(=O)[C@H](C)Sc2ccccc2Cl)cc(OC)c1 +CCN1CCC(=NNC(=O)c2ccccc2)CC1 +CCCOc1ccc(Br)cc1C[NH+]1CCC([C@@H](C)O)CC1 +Cc1cc2n(C[C@H](O)CO[C@H](c3ccccc3)c3ccccc3C)c(=O)c3ccccc3n2n1 +CCc1ncc(CN(C)C(=O)Nc2c(C)ccc([N+](=O)[O-])c2C)s1 +c1ccc2nc(NCCCc3nc4ccccc4[nH]3)cnc2c1 +Cc1c([C@H](C)[NH2+]Cc2cccn2C)cnn1C +CC(=O)N[C@@H](C(=O)NC1COC1)C(C)C +O=C(Nc1ccccc1F)c1cc2ccccc2c2cccnc12 +Cc1ccccc1N1C(=O)/C(=C/c2cccn2-c2cccc([N+](=O)[O-])c2)C([O-])=NC1=S +COCCN1C[C@H](C(=O)N(Cc2cccc(Cl)c2)C(C)C)CC1=O +COc1ccc(NC(=O)N2CCN(C(=O)Cc3csc4ccccc34)CC2)cc1OC +C#CCN(C[C@H]1CCCO1)C(=O)N[C@@H](C)c1cccc([N+](=O)[O-])c1 +Cc1cccc(C2=CCN(C(=O)Nc3ccc(C(N)=O)c(Cl)c3)CC2)c1 +COCCCN1C(=O)c2ccc(C(=O)Nc3nc(-c4ccc(C)cc4)cs3)cc2C1=O +O=S(=O)(Nc1ccc(N2CCCS2(=O)=O)cc1)c1ccc(F)c(Cl)c1 +O=C1/C(=C/c2ccccc2)Oc2c1ccc1c2CN(Cc2cccs2)CO1 +CC[C@@H](C)[C@@H](NC(=O)c1cccc(F)c1)C(=O)N=c1[nH]c2ccccc2[nH]1 +Cc1c(F)cc(N)cc1S(=O)(=O)N[C@@H](C)C1CC1 +Cc1ccc(Cn2ncc3c(N)ncnc32)cc1 +CCOC(=O)C(C)(C)c1nc(-c2ccccc2)no1 +CCOC(=O)c1sc(/C=C/c2nc3c(s2)CCC3)nc1C +C[C@@H]1CC[C@@H]([NH2+]C2CCC(NS(C)(=O)=O)CC2)c2ccccc21 +CN(C)S(=O)(=O)c1ccc(C(=O)N(C(=O)N2CCCCC2)c2ccccc2)cc1 +CC[NH2+][C@@H](CC)c1ccccc1OCc1cccc(F)c1 +C=CCOC(=O)C1=C(C)N=C2S[C@H](C)C(=O)N2[C@H]1c1ccc(F)cc1 +CCC[NH2+][C@]1(C(=O)OCC)CC[C@H](n2cc(Cl)c(C)n2)C1 +Cc1cc(NC(=O)CSc2nnc3c4ccccc4n(C)c3n2)ccc1Br +CCOC(=O)c1sc(NC(=O)c2ccc(-n3c(C)nc4ccccc4c3=O)cc2)cc1C +CC#CCC(=O)C1([NH+](CC)CC)CCCC1 +O=C(NCCCc1ccccc1)C1CCN(C(=O)[C@@H]2CC(=O)N(c3ccccc3)C2)CC1 +O=Cc1ccc(OCc2ccn(-c3cccc(F)c3)n2)cc1 +Clc1ccccc1Cn1ccnc1 +Cc1[nH]c(=O)c(C(=O)N2CCN(c3ccccc3)C(=O)C2)c(C)c1C +C[C@@H](NC(=O)NC[C@H]1CCCN(c2ncccn2)C1)C1CCOCC1 +O=C(NCCS(=O)(=O)c1ccccc1)N1CCC[C@@H]2CCC[C@@H]21 +Cc1ccc(-c2cc(NC(=O)C(C)C)c(=O)n(CC(=O)Nc3cccc(C)c3)n2)cc1 +Cc1ccc(S(=O)(=O)N2CCN(C(=O)[C@H]3CCCC[C@@H]3C(=O)[O-])CC2)cc1C +CCc1ccc(-c2nc(C(=O)N3CCO[C@H](CC)C3)cs2)cc1 +Cc1cc(C)cc(OCC(C)(C)C[NH2+]C2CC2)c1 +CCOC(=O)[C@H](F)[C@@]1(O)CCC[NH+](C(C)C)CC1 +CCn1cc(/C=C/C(=O)c2ccc3ccccc3c2)cn1 +CCC(CC)C(=O)Nc1cnn(-c2ccccc2F)c1 +O=C1O[C@H](C(=O)Nc2ccnc3ccnn23)Cc2ccccc21 +COCCn1nnnc1[C@@H](C(C)C)N1CCSCC1 +O=C(NCCNS(=O)(=O)c1cccc(Cl)c1F)c1cccnc1 +Cc1ccc(CNC(=O)NCc2nnc3n2CCC3)s1 +C/C=C/C[C@]1(C(=O)[O-])CCN(C(=O)OC(C)(C)C)C1 +O=C(N[C@@H](CO)c1ccco1)c1cc(Cl)ccc1OC1CCCC1 +Cc1cc(N(C)C)ccc1NC(=O)c1c[nH]c2nccc(Cl)c12 +CCOC(=O)c1cccc(NC(=O)c2cn[nH]c2C)c1 +CCC(CC)[S@](=O)CCC(=O)[O-] +COCc1ccc(C(=O)N(C)Cc2ccc(Cl)s2)cc1 +O=C(CCc1nc2ccccc2c(=O)[nH]1)Nc1cc(Cl)c(Cl)cc1Cl +CC[C@@H](C)NC(=O)c1ccc2c(c1)CCCN2S(C)(=O)=O +COc1ccc(C)cc1-n1nnnc1SCC(=O)Nc1cc(C)cc(C)c1 +C[C@H](NC(=O)Cc1ccc[nH]1)C(=O)N1CCCC[C@H]1C +CCOC(=O)[C@H]1CCCN(C(=O)c2cc(C(C)C)n(C)n2)C1 +C[C@H]([NH2+]Cc1nc(Cc2ccccc2)no1)[C@@H](C)n1cccn1 +c1cc(CN2CC[NH+](Cc3ccc4c(c3)OCCO4)CC2)no1 +CCCCNC(=O)CCc1c(C)nc2c3ccccc3nn2c1C +O=C(C[C@@H]1C[NH2+]CCO1)N[C@H]1C=CS(=O)(=O)C1 +C[C@@H]1CCO[C@@H]1C(=O)N1CC[C@H](C(N)=O)c2ccccc21 +CCC[NH2+][C@H](Cc1nn(C)c2ccccc12)c1ncc[nH]1 +CC(C)c1noc(-c2cc[nH+]c(N3CCN(C(=O)[C@H]4C[C@H]4C)CC3)c2)n1 +O=C(NCc1cccnc1)NCc1ccnc(OCC(F)F)c1 +C[C@@H](NC(=O)N1CCCCCCC1)[C@@H]1CCCO1 +COc1ccc([C@H]2C(C(=O)NCc3ccccc3)=C(C)Nc3ncnn32)cc1OC +Cc1cc(=O)n2c(n1)SC[C@@H]2CC(=O)NCCC(C)C +CC(C)C1=CN=N[C@H]1[C@H]1CCC[NH+](C[C@@H](C)Cc2ccc3c(c2)OCO3)C1 +COc1cc(CNC(=O)c2occc2Br)ccn1 +Cc1cccc(COc2ccc(Br)cc2/C=C2\SC(=S)NC2=O)c1 +O=C(N[C@@H]1CCO[C@@H]1c1ccc(Cl)c(F)c1)[C@H]1Cc2ccccc2O1 +Cc1ccc([C@]2([NH3+])CC[C@@H]2C)cc1 +O=C(CNC(=O)c1ccco1)OCC(=O)c1ccc2ccccc2c1 +O=C(c1cc2c(F)cccc2[nH]1)N(C[C@@H]1CCCO1)c1ccncc1 +Cc1cc(C)c(NC(=O)c2cc3ccccc3n2C)c(C)c1 +CSc1ccccc1NC(=O)N[C@@H](CO)c1ccc(Cl)cc1 +Cc1cc(Br)ccc1CNC(=O)C1CC=CC1 +COc1ccc(CN/C(C)=C2/C(=O)N(c3ccc(OC)cc3)N=C2C)cc1 +O=C(COC(=O)C1(c2ccccc2F)CCCC1)N1CCOCC1 +CC1(C)C(=O)NCC[NH+]1Cc1ccc(OCC(F)F)cc1 +Cc1ccc([N+](=O)[O-])cc1C(=O)Nc1ccc(C(=O)NC(C)C)cc1 +C[C@@H]([C@@H](O)c1ccc2ncnn2c1)[N+](=O)[O-] +C[C@H](Oc1ccc(Cl)c(Cl)c1)C(=O)NC[C@H]1CCC[C@@H]1O +COC[C@H]1CCC[NH+](Cc2cc(C)n(Cc3ccco3)c2C)C1 +Cc1ccc([C@H](C)NC(=O)CN(C)C(=O)OC(C)(C)C)cc1F +CC(=O)NCCC(=O)N1CCC[C@@H](C)C1 +COc1ccc(N2CCn3c2nn(CC(N)=O)c(=O)c3=O)cc1 +ClC(Cl)(Cl)c1nonc1C(Cl)(Cl)Cl +CCc1sc(C(=O)N2CCN([C@@H](C(N)=O)c3ccccc3)CC2)cc1C +C[NH+](C)Cc1cc(NC(=O)CCCC(=O)N2CCCCCC2)[nH]n1 +C[C@@H](Nc1ccc(COC(C)(C)C)cc1)c1ccc(C#N)cc1 +NC(=O)c1ccc(NC(=O)c2cccn(Cc3ccc(F)cc3)c2=O)cc1 +Cc1cc(C(=O)N2CC[C@H](C)C[C@H]2C)c2c(C)nn(C)c2n1 +Cc1c(C(=O)N2CCCC2)oc2c1-c1nn(CC(=O)N3C[C@@H](C)C[C@@H](C)C3)cc1CC2 +Cc1cc(C)cc(NC(=O)CC(C)C)c1 +COc1cc(OC)cc([C@H]2CC[NH+](CCC(F)(F)F)C2)c1 +CCn1ccnc(N2CCCC[C@@H](N3CC[NH+](C)CC3)C2)c1=O +C[C@@H]1CCC[C@](O)(c2ccc(Cl)s2)CC1 +CCOC(=O)C[C@H](C)CNC(=O)C(=O)N1CCc2ccc(C)cc21 +COc1ccccc1C[NH+]1CCC[C@H](N2CCCC2=O)C1 +CCNc1ccc2c(OC)ccc(F)c2n1 +Oc1ccc(C2[NH+](Cc3ccccc3)CC[NH+]2Cc2ccccc2)c(O)c1 +Fc1ccc([C@@H]2C[C@@H](c3ccc(Br)cc3)Nc3ncnn32)cc1Br +COc1ccc(O)c(CNC2CC[NH+](Cc3ccccc3Cl)CC2)c1 +Cc1ccc(NCc2cccc(C(=O)NCc3ccco3)c2)c(F)c1 +Cc1c(C(=O)N(C)[C@@H]2CCN(c3ccccc3Cl)C2=O)cnn1C +Cc1nc(C(C)(C)C)[nH]c(=O)c1C(=O)Nc1cccc(Cl)c1C +CCOC(=O)[C@H]1CCCN(C(=O)c2cn(CCc3ccccc3)nn2)C1 +C[C@@H]1[C@H](C(=O)[O-])CCN1S(=O)(=O)c1ccc(F)c(Cl)c1 +CC1CCC(OC(=O)C2=NC3=C(C(=O)C[C@@H](c4ccccc4)C3)[C@H]2C)CC1 +CN(Cc1ccno1)Cc1c(C(=O)N2CC[NH+](C3CCCCC3)CC2)nc2ccccn12 +Cc1cc(F)c([C@@H]([NH3+])[C@H]2Cc3ccccc3O2)cc1F +COc1ccc(OCC(=O)N/N=C2\CCCc3ccccc32)cc1 +CC(C)[NH+]1CCC(N2CC[NH+](Cc3c(F)ccc(F)c3F)C[C@@H]2CCO)CC1 +CC[NH2+][C@@H](C)c1cc(F)c(C)cc1N1C[C@H](C)S[C@H](C)C1 +CC(C)c1nsc(NC[C@H](C2CC2)[NH+](C)C)n1 +Cc1ccc(-c2nc(C[NH+]3CCCC[C@H]3c3cccnc3)c(C)o2)s1 +COc1cccc(C2=C[C@H](C(=O)N3CCCCC3)N=N2)c1 +O=C(Nc1ccc2[nH]c(=O)[nH]c2c1)c1cc(S(=O)(=O)NC2CC2)ccc1Br +c1ccc2c(c1)CC[C@H]([C@H]1CCCc3cccnc31)N2 +Cc1ccc(NC(=S)NC(C)C)cc1C +C[C@@H](Nc1ccc2c(c1)CCC2)C(=O)N1CCCC1 +CCOc1ccc(Nc2ccc(C#N)c([N+](=O)[O-])c2)cc1 +CC(C)OC(=O)CCNC(=O)c1cnn(-c2ccc(F)cc2)n1 +CCCCOc1ccccc1NC(=O)c1scnc1C1CC1 +O=C1N=C(N2CCCCC2)S/C1=C1/C(=O)Nc2ccccc21 +Cc1c([C@H](C)NC(=O)c2[nH]c3ccccc3c2Cl)cnn1C +CSc1cccc2sc(N3CCN(C(=O)c4ccn(C(C)C)n4)CC3)nc12 +Cc1ccc(C)c2nc3sc(C(=O)Nc4ccc5c(c4)OCO5)c(N)c3cc12 +COC(=O)CCN(Cc1cnc2ncccn12)C1CCOCC1 +O=C(c1c[nH]c2ccc(F)cc12)N(CC1CC1)CC(F)(F)F +O=C(Cc1ccc(Cl)cc1)/N=C1/S[C@@H]2CS(=O)(=O)C[C@H]2N1c1cc(Cl)ccc1Cl +Cc1cccc(C)c1NC(=O)C[NH+]1CCC(OCc2ccc(F)cc2)CC1 +C[C@H]1CCC[C@@H](C)N1C(=O)[C@@H]1COCCO1 +CC(C)[C@@H](C)CC(=O)NNC(=S)NC1CCCCC1 +Cc1cccc(NC(=O)CN2C(=O)/C(=C3\SC(=S)N(Cc4ccco4)C3=O)c3ccccc32)c1 +CC(=O)N[C@@H]1C(=O)C[C@@H]2[C@H]3CCC4=CC(=O)CC[C@@]4(C)[C@@H]3CC[C@]12C +CCCNC(=O)[C@H]1CS[C@H](c2ccccc2O)N1C(C)=O +Cc1ccc(N2CC[C@H](C(=O)NC[C@@H](CC(C)C)N3CCOCC3)C2=O)cc1 +COc1ccc(NC(=O)CCc2ccc3c(c2)OCCO3)cc1OC +Cc1ccc(C(F)(F)F)cc1/C=C/C(=O)[O-] +Cc1ccccc1C(=O)Nc1ccc(N2CC[NH+](Cc3ccccc3)CC2)cc1 +c1cc(C[NH2+]Cc2ccco2)cc(OC2CCCC2)c1 +O=C(c1cccs1)N(Cc1ccc(F)cc1)Cc1cc(-c2ccccc2)cn2nnnc12 +COc1ccccc1CC(=O)N[C@@H]1CS(=O)(=O)C[C@H]1Cl +c1cc(-c2nc3c4cn[nH]c4ncn3n2)ccc1COc1ccc2c(c1)CCC2 +CN(C(=O)c1ccc(Cl)cc1O)C1CCC(=O)CC1 +CNC(=O)CNS(=O)(=O)c1cccc(C(F)(F)F)c1 +Cc1ccc(S(=O)(=O)N2C[C@@H](CC(=O)[O-])c3ccccc32)c(C)c1 +Cc1cc(N)nc(SCC(=O)NC[C@@H](c2ccccc2)C(C)C)n1 +Cn1cc(C(N)=O)c(NC(=O)c2ccc3sccc3c2)n1 +COc1cc2c(cc1O)[C@H](c1cnc(-c3cccc(C)c3)nc1)CC(=O)N2 +CC(=O)Nc1ccc(O[C@H](C)c2nc(C(C)(C)C)no2)cc1 +O=C(c1c(-c2ccccc2)nc2sc3c(n12)CCCC3)C(F)(F)F +CC(C)CNC(=O)NC(=O)[C@@H](C)Nc1ccc(OC(C)C)cc1 +C=C[C@](C)(O)CC[C@H]1C(C)=CC(=O)[C@H]2C(C)(C)CCC[C@@]21C +CC[C@H](C)Sc1nncn1-c1ccccc1C +COCC(=O)N1CCCc2ccc(NC(=O)c3cccc(Br)c3)cc21 +C[C@H](Oc1cccc(Cl)c1)C(=O)Nc1ccc2ccccc2c1 +Cc1cc(Br)ccc1SCC(=O)N1CCC(C(=O)c2ccc3c(c2)OCCO3)CC1 +Nc1ccc2c(c1)CN(C(=O)c1ccc(Cl)cn1)CC2 +C[C@H](CN(C)C(=O)c1ccc(F)c(F)c1F)C(=O)[O-] +NC(=O)c1ccc(SCC(=O)Nc2ccc3c(c2)Cc2ccccc2-3)c([N+](=O)[O-])c1 +CN(C(=O)CSCC(F)(F)F)c1cccc([N+](=O)[O-])c1 +C=CCn1c(SCC(=O)Nc2cc(C)on2)nnc1[C@H]1COc2ccccc2O1 +CC(=O)N[C@@H](CC(=O)Nc1ccnn1Cc1ccc(C)o1)c1ccccc1 +CNC(=O)Cc1nc(C[NH+](C)C2CCC(c3ccccc3)CC2)cs1 +COCCOc1c(Cl)cccc1NC(=O)Cc1c[nH]c2ccccc12 +COCCN1[C@@H](C)CN(C(=O)C[NH+](C)C2CC2)C[C@H]1C +CSc1nncn1/N=C\c1cc(Cl)ccc1F +CC(C)(C)OC(=O)N1CC[C@H]2CC(=O)[C@H]2C1 +C[C@@H](CCO)SCc1ccccc1OC(F)F +C[C@@H]1C[C@@H]1C(=O)Nc1ccc(F)cc1C(=O)NC1CCC(O)CC1 +COc1ccccc1CNC(=O)COc1ncnc2oc(C)c(C)c12 +Fc1ccc(F)c(C[NH+]2CCC(n3cc(-c4cccnc4)nn3)CC2)c1F +O=C(C[NH+]1CCC(C(=O)c2ccc(Cl)cc2)CC1)NC[C@H]1COCCO1 +CCC[NH+](C)C[C@H]1CCN(C(=O)Nc2cc(NC(C)=O)ccc2C)C1 +Cc1nc2ncnn2c(NCCOC2CCCCCC2)c1C +C[C@H]1CN(Cc2cnn(-c3ccccc3)n2)C[C@H](C)S1 +COc1cc(OC)cc([C@@H](N[C@@H](C)c2ccc(F)cn2)c2[nH+]ccn2C)c1 +CC(C)CCc1noc(C[NH+](C)[C@H]2CCC[C@@H]2S(C)(=O)=O)n1 +CCc1nc2n(n1)CCC[C@H]2NC(=O)c1ccc(-n2cc(C)cn2)cc1 +C[C@@H](NC(=O)c1ccccc1CSc1nc2ccccc2[nH]1)C1CC1 +O=C([C@H]1CCCN1S(=O)(=O)N1CCCCC1)N1CCSCC1 +Cn1c(=O)c(=O)n(CC(=O)N2CCC3(CC2)OCCO3)c2cccnc21 +COc1ccc(C(=O)N(CC2=CC=C[C@@H]3N=CC=C23)C2CC2)cc1 +CCc1ccc(-c2nc(N)ccc2[N+](=O)[O-])cc1 +c1csc([C@@H]2CN(Cc3cnc(C4CCC4)s3)CCO2)c1 +O=S(=O)(/N=C(\[O-])c1ccsc1)N1CCCC1 +Cc1nccn1CC(=O)N1CCCC[C@@H]1CCNC(=O)c1ccccc1 +C[C@H](Sc1nnc(-c2ccc(Cl)cc2)n1C[C@H]1CCCO1)C(=O)Nc1ccc2c(c1)OCO2 +COc1ccc(-n2nnc(-c3nc(-c4ccc5c(c4)OCO5)cs3)c2C)cc1OC +CN(c1ccccc1)S(=O)(=O)c1ccc2c(c1)C(C)(C)C(=O)N2 +Cc1cc(S(=O)(=O)N2CCN(C(=O)[C@H]3C[C@H]3c3ccc(Cl)cc3)CC2)c(C)s1 +O=C(Nc1ccc(Oc2ccc(Cl)nn2)cc1)[C@@H](O)c1ccccc1 +CCCc1cc(=O)n2c(n1)SC[C@@H]2CC(=O)Nc1cccc(Cl)c1Cl +Cc1ccc(NC(=O)CSc2nnc([C@@H]3CCCN3C(=O)c3cccc(C)c3)n2C)cc1 +Cc1noc(C)c1CCCNC(=O)N[C@H]1CC(=O)N(C2CC2)C1 +CC(C)(C)[C@@H]1CCC(=O)[C@@H](CN2CCOCC2)C1 +COc1cc(/C=C2\SC([N-]c3cccc(C(=O)[O-])c3)=NC2=O)cc(OC)c1O +C=CCO[C@H](C)C(=O)Nc1ccc(F)cc1Br +O=C(CN1CCN(Cc2ccc(F)c(F)c2)CC1)c1cccs1 +CC[NH+]1C[C@H](c2ccccc2)CC2(CCN(C(=O)c3ccon3)CC2)C1 +C[C@H]1C[C@H]([NH+]2CC[C@H](S(=O)(=O)NC3CC3)C2)CC(C)(C)C1 +CC[C@H](Sc1cc(C)c2cccc(C)c2n1)C(=O)Nc1nc2ccc(S(N)(=O)=O)cc2s1 +COC(=O)[C@]1(NC2CC2)CC[C@H](Sc2ncc(C)cn2)C1 +COc1ccccc1Nc1nn(CN(C)OC)c(=S)s1 +Cc1cc(C)c(C)c(S(=O)(=O)/N=C(\[O-])c2cc(C3CC3)n(C(C)(C)C)n2)c1C +CC(C)(O)C#Cc1ccc(C[NH2+][C@H]2CCCN(c3nc4ccccc4s3)C2)s1 +COc1ccc(OC)c(S(=O)(=O)n2cc3c(=O)n(C)c(=O)n(C)c3n2)c1 +C/[NH+]=C(/NCc1noc(C(C)(C)C)n1)N[C@@H](C)c1ccc(F)cc1F +Cc1nnsc1C(=O)Nc1nnc(-c2ccc(Br)cc2)o1 +CCN(Cc1ccc(OC)c(OC)c1)C(=O)C[NH+]1CC[C@@H](C)[C@H](O)C1 +O=C([O-])[C@H]1CCCCN1C(=O)CSCc1ccccc1 +Cc1cccc([C@H](O)C[C@@H]2CCCCC[NH2+]2)c1 +O=C(C1CCCCC1)N1CCN(Cn2cc(Br)cn2)CC1 +Cc1ccccc1COc1ccc([C@@H]2C3=C(CCCC3=O)Nc3nnnn32)cc1 +CCO[C@@H]1C[C@@H]([NH+](C)C[C@@H]2CCCN(S(C)(=O)=O)C2)C12CCCCC2 +CCn1c(=O)c(=O)[nH]c2cc(C(=O)NN3C(=O)N[C@](C)(c4ccccc4)C3=O)ccc21 +COc1cc(OC)cc([C@@H](NC(=O)N(C)C2CCCCC2)c2nccn2C)c1 +O=C1[C@H]2[C@@H]3C=C[C@@H](C3)[C@H]2C(=O)N1CN(C(=O)C(F)(F)F)c1cccc(C(F)(F)F)c1 +CCNc1ncc(COc2cccc3ccccc23)s1 +Cc1noc(C)c1CCCNC(=O)c1c(C)nn(Cc2ccccc2)c1C +C[C@H]1CC([NH2+][C@@H](C)c2c[nH]c3cc(F)ccc23)C[C@H](C)O1 +C/C(=C1/SC(=O)N(c2ccc(Cl)cc2)C1=O)c1ccc(Br)cc1 +Cc1ccc(NC(=O)C[C@H]2SC([N-]c3ccc(N(C)C)cc3)=NC2=O)c(C)c1 +CO[C@H]1CCCC[C@H]1NC(=O)NC[C@H](c1cccc(F)c1)[NH+](C)C +Cc1nn(C)c(C)c1CN[C@H]1CCC[NH2+]C1 +Cc1cc(=O)[nH]c(SCC(=O)N2C[C@]3(C)C[C@H]2CC(C)(C)C3)n1 +COC1CC[NH+](Cn2nc(-c3ccc(C)cc3)n(C)c2=S)CC1 +C[NH2+][C@]1(C(=O)[O-])CCC[C@@H](OCC2CCCCC2)C1 +Cc1nnc(SCC(=O)c2cc(C)n(CC(F)(F)F)c2C)s1 +CC(C)CN(CCC#N)C(=O)NC[C@@H]1CC[C@H](C(=O)[O-])O1 +Cc1cc(CN2CCN3C(=O)NC[C@@H]3C2)cc(C)c1OC(F)F +CCNC(=O)c1cccc(NC(=O)NCCCSC)c1 +Cn1c(-c2cccc3ccccc23)nn(CN2CCOCC2)c1=S +CCCCCN1C(=O)/C(=C/c2ccc(O)c(OCC)c2)SC1=S +N#CCCN(Cc1ccccc1)C(=S)NC(=O)c1cccc(Cl)c1 +CCCS(=O)(=O)c1ccccc1C(=O)Nc1nnc(CC)s1 +CNC(=O)[C@H](C)CN(C)Cc1cc(=O)n2cccc(C)c2[nH+]1 +COCCNC(=O)/C(C#N)=C/c1cccc(O)c1 +CNC(=O)CN1c2ccccc2C(=O)N(C)[C@H]1c1ccccc1O +CC(=O)N[C@H](C)C(=O)Nc1ccc(Sc2nncs2)c(Cl)c1 +CC[n+]1c(N)n(CCOc2ccc(Cl)cc2Cl)c2ccccc21 +COC(=O)CCCc1nnc(NC(=O)N2CCC[C@@H]3CCC[C@@H]32)s1 +O=C(N[C@H]1CCCC[C@H]1OC1CCCC1)c1ccc([N+](=O)[O-])cc1 +O=C(CCOc1ccccc1)NNC(=O)CC1(O)CCCC1 +C=CCN(CC(=O)[O-])C(=O)[C@@H](C[NH3+])C(C)C +O=C(CCn1ccccc1=O)NCC1(c2ccccc2)CC1 +COc1cc(C)c([C@@H](C)NC(=O)CSC2CCCC2)cc1OC +C=CCNC(=O)Nc1ccc(F)c(NC(=O)OC)c1 +Oc1ccccc1/C=[NH+]/CCC/[NH+]=C/c1ccccc1O +Cn1cccc1Cc1nnc(SCC(=O)Nc2ccc3c(c2)OCCO3)n1C +C[C@@H](C#N)CNC(=O)c1cccc(Oc2cccc(C(F)(F)F)c2)c1 +Cc1oc(-c2ccccc2)nc1CCNC(=O)c1ccc([S@@](C)=O)cc1 +CCc1noc(C)c1C[NH+](C[C@@H]1CCCCO1)C(C)C +Cc1ccccc1Oc1cc(Br)ccc1C[NH3+] +CCc1cccc(CC)c1NC(=O)NC1CC1 +CC[NH+]1CCN(C2(CNC(=O)c3ccccc3Br)CCCCC2)CC1 +CCCCn1nc(C)c(C[NH2+]C[C@@H](C)O)c1Cl +C/C=C/C=C/C(=O)N1C[C@@H](C(=O)OC)[C@@H](C)C1 +COc1cccc(OC)c1OC1CC[NH+](Cc2ccc([C@H]3C[C@@H]3C)o2)CC1 +COC(=O)c1cc(CSc2nnc(-c3cccnc3)n2-c2ccccc2F)oc1C +CCOc1ccc2cc(C(=O)NCc3ccccc3)c(=[NH2+])oc2c1 +O[C@H]1CCN(c2ccnc(N3CCc4[nH]c5ccc(Cl)cc5c4C3)n2)C1 +O=S(=O)(NC[C@H](O)c1ccc(C(F)(F)F)cc1)c1cc(F)ccc1F +O=C(CSc1ccc2c(c1)OCCCO2)NC(=O)c1cccs1 +C[NH+](C)CCSc1ccc(NC(=O)C2CC2)nn1 +Cc1ccc(C)c(NC(=S)NCCc2cccs2)c1 +CNc1ncc(F)c(-c2cccc(Cl)c2)n1 +Cc1ccc(-c2nnc(SCC(=O)Nc3ccc(CC#N)cc3)n2N)cc1 +CCN(CC)C(=O)c1ccccc1OC(C)=O +Cc1ccc(-c2nc(-c3ccc(OCC(F)(F)F)nc3)no2)cc1 +COc1ccc([C@@H](CNc2nc3ccccc3o2)N2CCOCC2)cc1 +CNC(=O)[C@H]1CCCC[C@H]1[NH2+][C@H](C)c1cc(C)cc(C)c1 +CCN[C@H](c1cccnc1)C1([NH+]2CCCCC2)CCCC1 +Cn1ncc2c(NCc3ccco3)nc(CCc3ccccc3)nc21 +Cn1cc[nH+]c1C[C@H]1CCC[NH+](Cc2ncc(-c3ccccc3Cl)o2)C1 +Cc1cc(N2CC[C@H](C)[C@H](O)C2)nc(C)[nH+]1 +Cc1nnc(CCC[NH+]2CCC(CC[NH+]3CCCC[C@@H]3C)CC2)o1 +CCc1nn(C)cc1CNC(=O)C1(CC)CCC1 +COc1ccc(N2/C(=N/C(=O)CCCC(=O)[O-])S[C@@H]3CS(=O)(=O)C[C@H]32)cc1Cl +CC(=O)C1=C([O-])C(=O)N(CCC2=c3ccccc3=[NH+][C@H]2C)[C@H]1c1ccccc1F +COc1cc(C)c(C(=O)N[C@H]2C[C@H](C)N(c3ccccc3)C2)cc1OC +Cc1ccc([C@H]2C[C@@H]2NC(=O)N2CCC(C(N)=O)CC2)cc1C +Cc1nc2n(n1)CCC[C@@H]2N[C@@H]1CCc2c(Cl)cc(Cl)cc21 +CC(C)Oc1ccc(-c2nc(C(=O)O[C@@H](C)[C@@H]3CCOC3)cs2)cc1 +CN(C[C@@H]1CCCN(C(=O)NCCc2ccc(F)cc2)C1)C(=O)OC(C)(C)C +COc1cc(Cl)c(C)cc1NC(=O)[C@H](C)N1CCN(S(=O)(=O)c2c(C)noc2C)CC1 +Cc1nc(-c2ccc(Cl)s2)sc1C(=O)N[C@H]1C[C@H]1C +COc1cc(F)c([C@H]([NH3+])c2ccc(SC)cc2)cc1OC +Cc1ccc(-n2ccnc2SCC(=O)N(CC(N)=O)C(C)C)cc1C +Cc1cccc(/C=C2\SC(=S)N(c3c(C)cccc3C)C2=O)c1 +CC1=C(C(=O)OC(C)C)[C@H](C)N=C1C(=O)Nc1ncc(C)s1 +COC(=O)c1cccc(C(=O)N2C[C@@H](c3ccc(F)cc3)C[C@H]2C)c1 +O=C(NCCCn1ncccc1=O)[C@@H]1CC(=O)N(c2ccccc2)C1 +CCC(=O)N1CCCN(C(=O)N[C@@H]2CCc3ccccc32)CC1 +Cc1cccc([C@@H](C)[NH2+]CCS(=O)(=O)C(C)(C)C)c1C +Cc1ccc(-c2cnc(CCC(=O)NCC(C)(C)c3ccncc3)o2)cc1 +Cc1cc(NN)c2cccc(OC(F)(F)F)c2[nH+]1 +C[C@H]1CCC[C@@H](C(=O)Nc2cccc(OCCc3ccccc3)c2)[NH2+]1 +Cn1ncnc1CCNC(=O)[C@H]1C[C@@H]1c1cc(Cl)cc(Cl)c1 +CC[C@H](Sc1nnc2cc(C)c3cc(C)cc(C)c3n12)C(=O)Nc1nnc(COC)s1 +CCOc1ccc(S(=O)(=O)N2CCC(c3nnc(C4CC4)o3)CC2)cc1 +Cc1cc2nc(C)c(CCC(=O)NC[C@@H](c3ccccc3)N3CCOCC3)c(C)n2n1 +COc1ccc(N2C(=O)CS[C@H]2c2ccc(Cl)cc2)cc1Cl +C/C=C(\C)[C@@H]1C=C[C@@H]2C[C@H](C)C[C@H](C)[C@@H]2[C@@H]1C(=O)C1=C([O-])[C@H](C[C@](C)(O)C(=O)[O-])NC1=O +CS(=O)(=O)c1ccc(NC(=O)N2CCC[C@H]2CC2CCCCC2)cc1 +COc1ccc(C(=O)OCc2nc3ccccc3s2)cn1 +COc1ccc(S(=O)(=O)Nc2ccccc2-n2nc(C)cc2C)cc1NC(C)=O +Cn1c(=O)n(CC(=O)N[C@@H]2CCCc3ccc(F)cc32)c2ccccc21 +N#C/C(C(=O)NC1CCCC1)=C(/[O-])Cc1cnn(-c2ccccc2)c1 +NC(=O)[C@H](Nc1cccc(Oc2ccccc2)c1)c1ccc(F)cc1 +CCN(C[C@@H]1CCOC1)C(=O)Nc1cc2c(cc1Cl)OCCO2 +CCCCc1nnc(NC(=O)C2CCN(S(=O)(=O)c3ccc(C)cc3)CC2)s1 +CC[C@@H](NC(=O)c1ccc(Br)o1)C(=O)N1CCOCC1 +CC[C@]1(c2ccccc2)NC(=O)N(CCOc2ccc(Cl)cc2Cl)C1=O +COc1ccc(OC)c(NC(=O)c2ccc3c(c2)C(=O)N(c2cc(C)on2)C3=O)c1 +Cc1ccc(OCC(=O)NC(=S)NC[C@H]2CCCO2)cc1 +CN(C(=O)CCOc1cccc(C(N)=O)c1)[C@H]1CCC[NH+](C)C1 +COC(=O)c1ccc(NC(=O)c2c(C)sc3ncnc(N4CCC[C@H](C)C4)c23)cc1 +Cn1ncc2c1CC/C(=C\c1ccc(-n3cncn3)c(F)c1)C2=O +O=C(Cc1ccc([N+](=O)[O-])cc1)NCC1(O)CCOCC1 +C=CCOc1ccc(CNc2ccc(OC)cc2)cc1 +C[C@H]1N=C(CCNC(=O)CCC2=c3ccccc3=[NH+]C2)CS1 +COC(=O)c1cc(NC(=O)[C@@H]2CCO[C@H]2C)ccc1C +COc1ccccc1[C@@H](C)NC(=O)c1cnc2c(C)cccn2c1=O +CC[C@@H](C)NS(=O)(=O)c1cc(N2C(=O)[C@@H](C)CS2(=O)=O)ccc1OC +CC(C)(C)NS(=O)(=O)c1ccc(OCC(=O)N2CCOCC2)cc1 +COCCCNC(=O)[C@H]1CN(C(=O)c2cccs2)CC12CCCCC2 +Cc1nc(-n2cccc2)sc1C(=O)Nc1cccc(-c2cn3ccsc3n2)c1 +CC[S@](=O)[C@H]1CCCC[C@@H]1NC(=O)NC[C@H](O)c1ccco1 +Cc1cc2ncn(C[C@H]3CC3(Cl)Cl)c2cc1C +Cc1ccccc1OCC(=O)O[C@@H](C)c1nccs1 +C[C@H](C(=O)N1CCOCC1)[NH+]1CCN(Cc2ccc3c(c2)OCO3)CC1 +CCN(C(=O)[C@@H]1Cc2ccccc2S1)[C@H]1CCC[C@@H]1C[NH3+] +Cc1cc(Br)ccc1NC(=O)[C@@H](C)[NH+](C)Cc1cccs1 +COC(=O)[C@H]1CCC[C@H]1NC(=O)Nc1ccc(C)cc1C +CC(C)NS(=O)(=O)c1ccc(C(=O)N[C@H](C)c2ccccc2Br)cc1 +CC(C)OCCN1CCN(C(=O)Nc2ccccc2C(F)(F)F)CC1 +CCn1cc(-c2nc(-c3cccc(Cl)c3)no2)c(=O)c2ccccc21 +COC[C@@H](C)NC(=O)C(=O)Nc1cc(-c2ccccc2)nn1C(C)C +C#CCOc1ccccc1CN1CCN(C2=[NH+]C[C@@H](C)S2)CC1 +COCCNC(=O)[C@H]1CCCN1S(=O)(=O)c1ccc(Br)cc1 +Cc1ccc(C[C@H](O)c2c(F)cc(Br)cc2F)cc1 +COC(=O)[C@@]1(NC2CCCC2)CCCS[C@H]1C +CCC[NH2+][C@H](Cc1ccccc1)[C@@H]1CN(CC)CCO1 +C=CCn1c(=O)c2c(nc3n2[C@H](C)C(C)=NN3C)n(C)c1=O +C[C@H]1CCCC[C@H]1NC(=O)c1cc(S(=O)(=O)N2CCOCC2)ccc1Cl +COCCn1ccc2ccc(NC(=O)NCCc3ccccn3)cc21 +Cc1nc(CNC(=O)C(=O)Nc2cc(Cl)ccc2Cl)no1 +CS[C@@H]1CC[C@H](NC(=O)CCC(=O)c2ccc(C)s2)C1 +Cc1nc(/C=N/Nc2ccc(Cl)nn2)c[nH]1 +O=C(CSc1nnnn1C1CC1)N[C@H](CO)C(=O)[O-] +O=C(COc1ccc(Br)cc1)N[C@H]1CCS(=O)(=O)C1 +CCc1nnc(NC(=O)c2ccccc2N)s1 +O=C([O-])c1ccccc1-c1ccc(/C=C2\C(=O)N(c3cccc(Br)c3)C(=O)N=C2[O-])o1 +COc1ccc(-c2csc(NC(=O)c3ccc(S(C)(=O)=O)cc3)n2)cc1OC +COc1ccc(Cn2ccc3nc(N4CCN(c5ccccc5)CC4)ncc3c2=O)cc1 +Cc1nn(C)cc1/C=N/NC(=O)c1ccncc1 +N#Cc1ccc(OC2CCC(NC(=O)c3ccc[nH]3)CC2)nc1 +CC(C)c1ccc2oc(-c3ccc(C[NH3+])cc3)nc2c1 +COc1ccc([C@@H]2NC(=O)N[C@@](O)(C(F)(F)F)[C@H]2C(=O)c2ccc(F)cc2)cc1OC +Cc1cc(NC(=O)c2cccc([N+](=O)[O-])c2C)n(-c2ccccc2F)n1 +CCC(=O)N1CCCC[C@@H]1C(=O)NCCc1ccc(F)cc1C +Cc1ccc(S(=O)(=O)N2C(N)=C(C#N)[C@H](c3ccc(Cl)cc3)[C@H]2C(=O)c2ccccc2)cc1 +C[NH+](C)[C@@H]1CC[C@H](NC(=O)[C@@H]2CCCc3[nH]ncc32)C1 +CCOCCS(=O)(=O)[N-]c1cc(Br)ccc1O +NC1=NC(=O)[C@H](CC(=O)N2CC[C@H](c3ccccc3)C2)S1 +Cc1noc(-c2cccnc2N2CC[C@H](NC(=O)COc3ccc(F)cc3)C2)n1 +Cc1ccc(OC(=O)c2cccc(C(=O)Oc3ccc(C)cn3)n2)nc1 +CN(CC1CCCC1)C(=O)C(=O)Nc1cccc(SC(F)F)c1 +COc1cccnc1N(C)C(=O)C[C@H](C)Cc1ccc(Cl)cc1 +O=C(c1cc(=O)[nH]c2ccccc12)N1CCC([C@H](O)c2ccccc2)CC1 +CCO[C@@H]1C[C@@H]([NH3+])[C@@H]1Nc1ncc(Cl)cc1F +CC(=O)Nc1ccc(NC(=O)c2ccc3c(c2)Cc2ccccc2-3)cc1 +COCCOC[C@H]1CC[NH+](C2C[C@@H](C)O[C@H](C)C2)C1 +O=C1OC(c2ccccc2OC(F)F)=N/C1=C\c1cccc(F)c1 +COc1ccccc1CNC(=O)c1cc2sccc2n1Cc1cccc(F)c1 +CCOC(=O)C1(NCc2nnc(-c3cc(C)oc3C)o2)CCCC1 +N#Cc1cccc(NC(=O)N2CCC(NC(=O)CC3CCCC3)CC2)c1 +CCOC(=O)C1=C(c2ccccc2)Nc2ncnn2[C@H]1c1ccc(SC)cc1 +Cc1cc(NC(=O)N[C@@H](Cc2ccccc2)c2ccccc2F)n(C)n1 +C[C@H](Sc1cccc[n+]1[O-])C(=O)NC[C@H]1COc2ccccc2O1 +O=C(NCCOc1ccc2c(c1)OCO2)c1cc(C2CC2)on1 +N#CCC[NH2+]C1(C(=O)[O-])CC1 +CC[C@@H](Oc1ccccc1OC)C(=O)NCc1ccccn1 +Cc1nc(C[C@@H](N)[C@]2([NH+](C)C)CCC[C@H](C)C2)cs1 +CCOc1ccccc1/C=C1\Oc2c(ccc([O-])c2C[NH+]2CCN(C)CC2)C1=O +CC(C)Cn1cc[nH+]c1CN[C@@H](c1ccccc1)C(C)C +Cc1c(C[NH+]2CCC[C@H]2c2ccc3c(c2)OCO3)cc(C#N)n1C +O=C(CBr)c1cnc2ccc(Cl)cn12 +COc1ccc(F)cc1NC(=O)c1sccc1S(=O)(=O)N(C)C +COc1ccc(Br)cc1/C=C1/C(=O)NN(c2ccc(C)c(C)c2)C1=O +CC(C)(C)c1ccc(C(=O)N[C@H]2CCN3CCCc4cccc2c43)cc1 +CC[NH+]1CCC[C@@]2(CC1)C[NH+]=C(N)N2c1ccc(C)cc1 +CN(C)c1ccc(/C=C(/C#N)C(=O)c2cccc(C#N)c2)cc1 +CCCC[C@@H](NC(N)=O)C(=O)Nc1cc(OC)ccc1F +COC1=CC2=NC(SCc3cc(-c4ccccc4)on3)=NC2=CC1 +Cc1ccc([C@@](C)(O)CNC(=O)NC[C@@H](c2ccco2)[NH+]2CCCCC2)o1 +C[C@@H](NC(=O)CSc1ccc2c(c1)OCCCO2)c1ccc2ccccc2c1 +CCN1CCN(S(=O)(=O)c2cc(-c3csc(C)n3)ccc2C)CC1 +Cc1nccn1Cc1ccc(NC(=O)c2cccc(-n3cccc3)c2)cc1 +Cc1cccc([S@@](=O)Cc2ccc(N)c(F)c2)c1 +CC(C)C[C@@H](C[NH3+])c1nc(C2CCOCC2)no1 +CC1=NC(SCC(=O)Nc2ccccc2C(F)(F)F)=NC(=O)[C@H]1Cc1ccccc1 +Cc1ccc(F)cc1NC(=O)C(=O)NCCCn1cc[nH+]c1 +COc1cccc(C[NH2+]Cc2cccc(Br)c2OC)c1OC +Nc1cc(=O)[nH]c(SCC(=O)Nc2nc(-c3ccc(Br)cc3)cs2)n1 +CCc1noc(CN(CC)C(=O)C(C)(C)NC(=O)c2cccs2)n1 +CC(C)C[C@](C)(O)CNC(=O)CC[C@H](C)O +C=C[C@@H](C)NC(=O)c1c(C)cc(C)c([N+](=O)[O-])c1C +O=C(Nc1ccc(S(=O)(=O)NC[C@@H]2CCCO2)cc1)[C@H]1CC(=O)N(c2ccc(F)c(Cl)c2)C1 +CC1(C)CCC[C@H]1N1C(=O)c2cccc(N)c2C1=O +CCCNC(=O)NC(=O)CN1C(=O)N[C@](Cc2ccccc2)(c2ccccc2)C1=O +O=C(COc1cc(Cl)c(Cl)cc1Cl)N1CCN(C(=O)Nc2ccccc2)CC1 +CCc1nc2n(n1)C[C@H]([NH2+]CCNS(=O)(=O)c1ccccc1)CC2 +CCCCN(C(=O)c1oc2ccccc2c1C)c1c(N)n(CCC)c(=O)[nH]c1=O +CCCn1c(S[C@H](C(N)=O)c2ccccc2)nc2sc(CC)cc2c1=O +COc1cccc(CCNC(=O)CN2CCN(c3ccccc3O)CC2)c1 +c1ccc2c(NCc3nnc(C4CCC4)o3)cccc2c1 +CC1=C(C(=O)C2=C([O-])C(=O)N(CC[NH+](C)C)[C@H]2c2cccc(Cl)c2)[C@H](C)N=N1 +CCOc1ccc([C@@H]2Nc3ccc(C(=O)N(C)C)cc3[C@H]3C=CC[C@H]32)cc1 +COc1ccc(Cc2sc(NC(=O)[C@H]3COc4ccccc4O3)nc2C)cc1 +Cc1cnc2nc(C[C@H](O)C(F)(F)F)[nH]c2c1 +CCN(CC(C)(C)O)C(=O)COCc1ccccc1Cl +O=C(Nc1cccc2ccccc12)NC1CC[NH+](CC(F)F)CC1 +O=C1C(=O)N(CC[NH+]2CCOCC2)[C@@H](c2cccc([N+](=O)[O-])c2)/C1=C(\O)c1cccs1 +CC[C@@H](C)N(CC)C(=O)c1ccccc1N +Cc1cc(C)c(C(=O)Cn2nc(N)n(Nc3cccs3)c2=S)c(C)c1 +C[C@H]1CN(C(=O)CC[C@H](C)c2ccccc2)C(C)(C)CO1 +CN(C)c1nc2c(c(-c3ccc(S(C)(=O)=O)cc3)n1)CCCC2 +CC(C)c1ccc(CNC(N)=[NH2+])cc1 +CN(C)N1C(N)=C(C#N)[C@@H](c2cccs2)C2=C1CCCC2=O +O=C(N[C@H]1C=C[C@H](C(=O)[O-])C1)c1cc(F)c(Cl)cc1Cl +CC(C)n1nnnc1COc1cccc(C(=O)NC2CCCCCC2)c1 +O=S(=O)(NCc1ccc(Cl)cc1Cl)c1ccccc1Br +COC[C@H](NC(=O)c1cc(-c2ccccc2)c(C)[nH]c1=O)C(N)=O +O[C@H]1CCCCC[C@H]1n1cc(-c2ccccc2Cl)cn1 +CC1=C[C@@H](C)[C@H]2C(=O)N([C@H](Cc3ccccc3)C(=O)[O-])C(=O)[C@H]2C1 +CCCCNS(=O)(=O)Cc1ccc([N+](=O)[O-])cc1 +COc1ccc(CC[C@H]2C[C@@H](C(C)(C)C)CCC2=O)cc1 +C[C@H]1CN(CC(=O)Nc2nc(-c3ccccc3Cl)cs2)CCO1 +Cc1cc(Cl)cc(Cl)c1CNC(=O)c1cccs1 +CC(C)(C)OC(=O)NC1CCN(CC(=O)c2nccs2)CC1 +COc1ccc(C)cc1[C@@H](C)NC[C@@H]1CN(C2CC2)CCO1 +CCOC(=O)c1ccn(-c2cccc(NC(=O)C3CCCC3)c2)n1 +O=C1N(C[NH+]2CCN(c3ccccc3)CC2)c2ccccc2C12O[C@@H]1CCCC[C@H]1O2 +CC(C)c1nc(C(=O)[O-])nn1-c1ccccc1F +C1=C[C@H]2C[C@@H]1C[C@H]2CN1CC[NH+](C2CCCCCC2)CC1 +CC(C)C[C@H](C[NH+](C)C)Nc1ncncc1N +O=C1C=C(c2cccs2)C[C@H](c2cccs2)[C@@H]1n1cnc([N+](=O)[O-])n1 +C[NH2+]C1CCC([NH+](C)CC(=O)N[C@H](C)c2ccco2)CC1 +CC(C)c1nc2n(n1)CCC[C@H]2[NH2+]C[C@@H]1CCC[C@H](C)C1 +CCCOC(=O)c1ccc(NC(=O)c2ccc[n+]([O-])c2)cc1 +CCOc1cc(CO)cc(Br)c1OCc1ccccc1F +O=C(Cn1nnn(-c2cccs2)c1=O)NC[C@@H]1CN(Cc2ccccc2)CCO1 +CC(C)[C@H](O)CCNC(=O)C(=O)Nc1ccn(-c2ncccc2Cl)n1 +Cc1cccc(C(=O)NCCS(=O)(=O)NCC2CCC2)c1C +CCCC(=O)N1CCC(C(=O)NN=C(c2ccccc2)c2ccccc2)CC1 +O=C(N[C@H](NC(=S)Nc1ccccc1)C(Cl)(Cl)Cl)c1cccc(Br)c1 +CCN(Cc1ccccc1)C(=O)c1cc(NC(=O)Cc2ccccc2)n(C)n1 +CC(C)CC(=O)N1CCN(C(=O)c2cnc3c(c2)NC(=O)CO3)CC1 +O=C(Cc1csc(NC(=O)Nc2ccc(Cl)cc2)n1)NCCc1ccc(Cl)cc1 +Cc1ccc2c(c1)C[C@@H](C[C@@H](C[NH3+])c1ccc(F)cc1)O2 +CC1=C(C(=O)OC(C)C)[C@@H](c2ccc(C)s2)NC(=O)N1C +COc1cc(F)c([N+](=O)[O-])c(NC[C@@H](O)c2cnn(C)c2)c1 +COc1ccc(C(=O)N2CCC([C@@]3(C)NC(=O)N(C4Cc5ccccc5C4)C3=O)CC2)cc1 +CCc1cnc(NC(=O)c2cc(C)n(C(C)C)c2C)s1 +COc1ccccc1-c1nc(C[NH+](C)Cc2ccc(C#N)cc2)cs1 +COc1ccc(-c2noc(-c3cc(-c4ccc(Cl)cc4)n[nH]3)n2)cc1OC +CC[C@H](NC(=O)CN1C(=O)c2ccccc2N2C(=O)CC[C@]12C)c1ccc(C)cc1 +Cc1ccc(CNC(=O)NCc2ccnc(OC(C)(C)C)c2)cn1 +CCN(CCO)C(=O)Nc1cccc(C(=O)Nc2cccc(C#N)c2)c1 +Cc1cccc(-c2nn(C[NH+]3CCCCC3)c(=S)n2-c2ccccc2)c1 +C[NH2+][C@@H](C1CCCC1)[C@@H]1CCc2cccnc21 +Cc1ccc(-c2cccc(F)c2C(=O)[O-])c(C)c1 +CN(C)C(=O)[C@@H](Sc1nnc2n(C)c3ccccc3n12)c1ccccc1 +Cc1ccc(C(=O)NNC(=O)c2ccc(SC[C@H]3CCCO3)c([N+](=O)[O-])c2)cc1 +CCCn1/c(=N/C(=O)[C@@H](CCSC)NC(N)=O)[nH]c2ccccc21 +Cn1/c(=N/C(=O)c2sccc2S(=O)(=O)N2CCOCC2)sc2ccccc21 +COc1ccccc1N1C[C@@H](C(=O)NN2C(=O)NC3(CCCCC3)C2=O)CC1=O +COc1cc(OC)cc(C(=O)Nc2ccccc2C(=O)NC(C)(C)C)c1 +C[C@H](Nc1nc(-c2ccncc2)nc2ccccc12)c1ccccn1 +CCOc1ccc(C[NH+]2CCC[C@H]([C@H](O)c3nccn3C)C2)cc1OC +O=C([O-])c1ccc([S@@](=O)Cc2ccc(O)cc2)cc1 +CCc1nn(C)cc1CNC(=O)NCC(C)(C)Cc1ccccc1 +COC(=C(C#N)C#N)c1cccs1 +O=C(NCc1ccnc(OCC(F)F)c1)NCc1cscn1 +C[C@@H](C(=O)N1CCCC1)N1CCN(C(=O)NCc2ccco2)CC1 +CC(C)c1ccc(CN(C)C(=O)NCCCn2cccn2)cc1 +CNC(=O)[C@@H]1CCCN(C(=O)Nc2nn(-c3ccccc3Cl)cc2C)C1 +COc1ccccc1N1CC[NH+]([C@@H](C)C(=O)Nc2ccc(F)cc2)CC1 +CN1C[NH+](C)CC2=C1NCNS2(=O)=O +CNC(=O)c1ccc(O[C@@H]2CCC[C@H]([NH3+])C2)nn1 +COc1ccc(C(=O)O[C@@H](C)[C@@H]2CCCO2)cc1OC(F)F +Cc1ccc(/C=C2/SC(=S)N(CCC(=O)N3CCCc4ccccc43)C2=O)cc1 +CCc1cccc(S(=O)(=O)Nc2cccc(-c3nnnn3C)c2)c1 +Cc1ccc2c(c1)CCN2C(=O)c1ccc(C)nc1C +CCCNC(=O)CN1CCN(C(=O)Cc2c(C)nn(-c3ccccc3)c2C)CC1 +Cc1ccc(C(=O)NC2CC[NH+](Cc3nc(-c4ccccc4)cs3)CC2)s1 +CCOc1ccc(F)c(C(=O)OC[C@H]2CCCCO2)c1F +C[NH+]1CCC(NC(=O)c2ncoc2-c2ccccc2)CC1 +O=C(NC[C@@H](O)CN1CCCC1=O)Nc1cccc(F)c1 +CC[NH+]1CCC[C@H](NC(=O)c2ccc(OC)c(O)c2)C1 +O=C(CSCC(F)(F)F)N1CCN(c2ccc(Cl)cn2)CC1 +COc1ccc(Cl)cc1S(=O)(=O)N[C@H](C)C(=O)NCc1ccc2c(c1)OCO2 +Cn1cc(C(=O)Nc2ccc(-n3ccnn3)cc2)c(C(C)(C)C)n1 +Cc1cccc(NC(=O)C[C@@H]2CCCCO2)c1C(=O)[O-] +Clc1ccc([C@H](NCCc2nnc3ccccn23)C2CC2)cc1Cl +Clc1ccc(OCCCCSc2ncccn2)cc1Cl +Cn1nnc2cc(C(=O)N[C@@H](C#N)c3ccc(Cl)c(Cl)c3)ccc21 +Cc1cc(C)cc(-c2nnc(Sc3nc(C(C)C)ns3)o2)c1 +C/C(=C/c1ccc(F)cc1)C(=O)NCc1cccc(OCC(F)F)n1 +CC(C)Nc1cccc(CNC(=O)N[C@@H]2CC[NH+](CC3CC3)C2)c1 +O=C(COc1ccc(F)cc1F)NC[C@H](O)c1ccccc1Cl +Cc1ccc(-n2nc3c(c2NC(=O)C(C)C)C[S@@](=O)C3)cc1 +COc1nc(Oc2ccc3ccccc3c2)ccc1N +O=C(c1ccccc1)c1ccc2nc(Nc3ccccc3)c3nncn3c2c1 +O=C(C[C@@H](O)c1cccc(F)c1)Nc1cc(F)ccc1O +CC(C)c1nc(CSCc2ccnn2C)no1 +O=C(C1CCC1)N1CCC[C@H]1c1nc2cc(-c3ccccc3)ccc2o1 +O=C(CC[C@@H]1NC(=O)NC1=O)NC1CCN(c2ccccc2F)CC1 +Cc1ccc(-n2nc3c(c2NC(=O)c2ccc(Br)o2)CSC3)cc1C +C=CCn1c(SCc2nnc([S-])n2-c2ccccc2)nnc1-c1ccccc1 +c1nnn(C23C[C@H]4C[C@H](CC(c5nc6c7cn[nH]c7ncn6n5)(C4)C2)C3)n1 +CCCn1ncnc1COc1ccc(C)nc1C[C@H](C)[NH3+] +CCSc1ccc(C(=O)N2CC[C@H](C)[C@H](O)C2)cn1 +C[C@H]1[C@H](C(=O)[O-])CCN1S(=O)(=O)[C@@H](C)C#N +COc1ccccc1NC(=O)CSc1nnc(C)c(=O)n1N +Cc1ccc(S(=O)(=O)OCCc2coc3ccccc23)cc1 +COc1ccccc1[C@@H](C)NC(=O)[C@@H](C)Oc1cccc(F)c1 +CC[C@@H](Oc1ccccc1/C=C1\S/C(=N\c2cccc(O)c2)N(CC)C1=O)C(=O)[O-] +C[C@@H]([NH2+]C[C@H]1CC[C@H](C(N)=O)O1)c1ccc2c(c1)OCCCO2 +Cc1ccccc1Nc1nc(N)nc(COc2ccc(F)c(Cl)c2)n1 +CC(=O)Nc1ccc(OC(=O)/C=C/c2ccc(C(N)=O)cc2)cc1 +O=C([O-])c1ccc(-c2ccncc2)cn1 +O[C@H](c1c(F)c(F)c(F)c(F)c1F)C(Cl)(Cl)Cl +COc1ccc(/C=N/NC(=O)CNc2ccc3ccccc3c2)cc1OC +Cc1ccc(NC(=O)/C(C#N)=C/c2cc(C)n(-c3ccc(O)cc3)c2C)cc1Cl +Cc1cccn2c(=O)c(C(=O)N[C@H]3CCN(C(=O)C(C)C)C3)cnc12 +O=C(NCCCS(=O)(=O)c1ccccc1)c1n[nH]c2ccccc12 +C1=C(CC[NH2+]Cc2ccco2)CCCC1 +CCN(CC(=O)NCc1ccc(F)cc1)C(=O)c1cnc(-c2cccnc2)s1 +FC(F)(F)c1cccc2c1CCCC2 +CN(C)C(=O)CCCNC(=O)c1ccnc(OC(C)(C)C)c1 +Cc1ccc(NC(=O)[C@H](C)[NH+](C)Cc2nnc(C3CC3)n2C)c(C)c1 +CN(C)c1cccc(C(=O)OCC(=O)C(C)(C)C)c1 +Cc1nsnc1Cn1nnc(C(=O)NC(C)C)c1C +CC(=O)Cc1nsc(N[C@@H](C)c2ccccc2)n1 +COc1ccc(CNC(=O)[C@H]2Oc3ccccc3O[C@@H]2C)cn1 +[NH3+]CC1CCC(c2nc3ccc(Cl)cc3s2)CC1 +CCN(C(=O)NC1CC[NH+](C[C@@H](O)COC)CC1)C1CCCC1 +CCSc1nc2ccccc2c(=O)n1CCc1ccccc1 +Cc1cc(C(=O)COC(=O)c2cc(Cl)c3c(c2)OCCCO3)c(C)n1C1CC1 +COc1ccc([C@@H](CNC(=O)c2ccc([N+](=O)[O-])o2)[NH+](C)C)cc1 +CC[C@H](NC(=O)c1ccc(C#N)cn1)C(=O)N1CCOCC1 +CCOC(=O)NC(=O)c1c(NC(=O)Cc2ccc(F)cc2)sc2c1CC[C@H](C)C2 +CC(C)n1cnnc1SCC(=O)Nc1ccc2c(c1)nc(C1CC1)n2C +CCc1ccc([C@H](O)C2(C[NH3+])CCCC2)cc1 +COc1ccc(CNC(=O)N2CCc3c([nH]c4ccccc34)[C@H]2C)cc1OC +COc1ccc([C@@H]2C(C#N)=C(N)Oc3cc(C)n(CCN4CCOCC4)c(=O)c32)cc1OC +O=c1[nH]nc([O-])n1/N=C/c1ccco1 +C[C@H](Oc1cccc(Cl)c1)C(=O)N1CCC(Cc2ccccc2)CC1 +COc1cc([C@@H]2C(C(=O)Nc3ccc(F)cc3)=C(C)Nc3nc(C)nn32)cc(OC)c1OC +CC[C@@H](C)[C@@H](O)C[NH2+][C@@H](c1cccs1)C1CC1 +CSc1cc(-c2cccs2)oc(=O)c1C#N +CC(C)[C@@H](NC(=O)c1ccc(NS(C)(=O)=O)cc1)C(=O)[O-] +[NH3+][C@H](CO)c1ccc(N2CCOCC2)c(Cl)c1Cl +CCCn1cc(NC(=O)c2cc3nc(-c4ccccc4)cc(-c4ccccc4)n3n2)cn1 +Cc1ccc(S(=O)(=O)N2CCC(C(=O)N3CCCc4ccccc43)CC2)cc1C +C=CCN(CC=C)C(=O)C1CCN(C(=O)C(C)(C)C)CC1 +Cc1nc(CSc2nncc3ccccc23)nc2ccccc12 +C[C@H]1CCCC[NH+]1C[C@@H]1CCC(C)(C)[C@@H]1[NH3+] +COc1cc(C(=O)Nc2ccccc2Oc2ccccc2)on1 +COc1ccc(S(=O)(=O)N2CCOCC2)cc1NC(=O)/C=C/c1ccc(F)c(Cl)c1 +Cc1ccc(F)c(C[NH+]2CCC(C(=O)NC(C)C)CC2)c1 +CCn1nc(C)c(CNC(=O)[C@H]2[NH+]=c3ccccc3=C2NC(=O)c2cccc(C)c2)c1C +COc1ccc(C(=O)N2CCC[C@H](C(=O)Nc3cc(Cl)ccc3F)C2)c2ccccc12 +CCc1ccc(/C=C(\C#N)C(N)=O)s1 +COc1cccc(N2C(=O)Nc3ccccc3[C@]2(O)C(=O)NCc2ccccc2)c1 +COC(=O)c1sccc1NC(=O)[C@@H]1CC[NH2+][C@@H]1C +C/[NH+]=C(/NCc1ccc([N+]2=CCCC2)cc1)N[C@H]1CC[C@@H](SC)C1 +N#Cc1csc(C(=O)N2CC[C@H]3CCCC[C@@H]32)c1 +Cc1cccc(NC(=O)[C@H](C)[S@@](=O)Cc2ccc(F)c(F)c2)c1C +CNS(=O)(=O)c1cccc([C@H](C)NC(=O)c2ccc(Cn3cccn3)cc2)c1 +CC[NH2+][C@@]1(C(=O)OC)CCC[C@@H](Oc2ccccc2)C1 +COCCCn1c(C)c(C)c(C#N)c1NC(=O)C[NH+]1CC(C)(C)C1(C)C +C[C@H]1CCC[C@@H](C)N1C(=O)[C@H]1C[C@H]1c1ccccc1Cl +COCc1ccc(C[NH+](C)Cc2ccccc2O)o1 +Cc1c(F)cc(N)cc1S(=O)(=O)NCC(N)=O +CCNS(=O)(=O)[C@@H]1CC[NH+](C[C@@H]2CCCc3ccccc32)C1 +CC1(C)[C@@H]2CC[C@@]1(CS(=O)(=O)NCCCO)C(=O)C2 +COc1ccc(-n2ccc(CNC(=O)c3cc(Cl)ccc3[N+](=O)[O-])n2)cc1 +CC[C@@H](NC(=O)NC1CCC(C(=O)OC(C)(C)C)CC1)[C@H]1CCCO1 +CN(CC[NH+](C)C)C(=O)C[C@H]1COCCN1C(=O)c1ccc2[nH]nnc2c1 +CC[C@H](c1ccc(F)cc1)N(C)C(=O)Cn1nnc(-c2ccccc2)n1 +O=C(Cn1cccc1-c1nc(-c2ccc(OC(F)(F)F)cc2)no1)Nc1nccs1 +CCc1nsc(Nc2ccc(CC(=O)N3CC[NH+](CC)CC3)cc2)n1 +CS[C@@H]1CC[C@H](NC(=O)/C=C(/C)c2ccccc2)C1 +Cc1ccc([N+](=O)[O-])cc1NCC(=O)N[C@](C)(C#N)C1CC1 +CC1(C)[C@H]2OCC[C@@H]2[C@H]1NC(=O)CCNC(=O)C12CC3CC(CC(C3)C1)C2 +c1ccc(COC2CC[NH+](Cc3cccnc3)CC2)cc1 +O=C(C1CCCC1)N1CCC[C@@H]([NH+]2CCC(CO)CC2)C1 +COC(=O)CNC(=O)c1sc2ncn(CC(=O)N3CCCCC3)c(=O)c2c1C +CC(C)CCNC(=O)[C@@H](C)Oc1ccc(N)cc1C(=O)[O-] +CCN(CC)C(=O)[C@@H]1C[C@@H]([NH3+])CN1C(=O)Cc1cccc(O)c1 +COc1ccc(F)cc1NC(=O)N1CCO[C@H](c2ccc(C)o2)C1 +COc1ccc(-c2nnc(SCC(=O)c3ccc(Br)cc3)o2)cc1OC +CC(C)=CC(=O)NCCC1CCN(c2cc[nH+]cc2)CC1 +C[C@]1(O)[C@](C)(O)[C@@H](CO)O[C@](C)(Oc2c[nH]c3ccc(Br)c(Cl)c23)[C@]1(C)O +O=C(CCCc1nc(-c2cccnc2)no1)N1CCC[C@@H](Cc2ccccc2)C1 +Cn1nc(NC(=O)c2cccc(F)c2)c2c1NC(=O)C[C@@H]2c1ccccc1 +O=C(CSc1ccncc1)NCCN1Cc2ccccc2O[C@@H](c2ccccc2)C1 +N#Cc1ccc(OCCn2cc(Cl)cn2)cc1 +C[C@@H]1CCN(C(=O)Nc2ccc(O[C@@H]3CCOC3)cc2)[C@H](C)C1 +CC[C@H](C)[C@H](C)[NH2+]Cc1ncccc1F +C#CC(C)(C)NC(=O)c1ccc(OC)c(O)c1 +COc1cc([N+](=O)[O-])ccc1OCc1nc(-c2cccs2)no1 +CC1CCN(C(=O)C[NH+]2CCC[C@@H](c3nc4ccccc4o3)C2)CC1 +C[C@H]1CCC[C@H](NC(=O)Cc2c[nH]c3ccccc23)[C@@H]1C +Cn1nc(CNC(=O)Nc2ccccc2C(F)(F)F)cc1-c1ccncc1 +CC[NH2+]C[C@H](Cc1cscn1)c1cccc(F)c1 +CCCCS(=O)(=O)N1CCN(c2ccc(-n3ccnc3C)nn2)CC1 +O=C(c1cc2ccccc2o1)N(C[C@H]1CCCO1)c1nc2c(F)cccc2s1 +C[C@H](CC#N)Sc1ccccc1NC(=O)c1ccc(Cl)nc1Cl +O=C(CCCc1nc2ccccc2s1)N[C@H]1CCOC1=O +COC(=O)c1cc(S(=O)(=O)N[C@H](C)c2ccccc2C)cn1C +Cc1cnn(CC(=O)[C@@H](C#N)c2nc([O-])c3ccc(Cl)cc3n2)c1 +COc1ccc(C)cc1NC(=O)[C@H]1CCCN1c1cc(C)ccc1[N+](=O)[O-] +CC[NH+]1CCC2(CC1)OC[C@H](C(=O)[O-])N2C(=O)c1ccc(F)cc1 +Cc1ccc2c(c1)N(C(=O)C[C@H](O)c1ccc(Cl)cc1)CC2 +O=c1[nH]cnc2c1[nH]c(=S)n2[C@@H]1O[C@H](CO)[C@@H](O)[C@H]1O +Cc1cn2c([nH+]1)CC[C@H](NC(=O)C[C@@H]1CCCc3ccccc31)C2 +CC(=O)c1cc(CN2CCC3=NN=C(c4ccccc4F)[C@@H]3C2)cs1 +C[C@H](Nc1ccc(S(=O)(=O)N2CCCCC2)cn1)[C@@H](C)CO +COCCn1nc(C)c(NC(=O)N2CCC[C@H]2c2cccc(C)c2)c1C +CCCCS(=O)(=O)[N-]c1ccc(NC(=O)[C@H]2CCC[NH+](C)C2)cc1 +Fc1ccc(Oc2ccnc(Sc3nnc(-c4cccs4)o3)n2)cc1 +N#Cc1ccc(NC(=O)[C@@H]2CSCN2C(=O)c2cn(Cc3ccccc3)c3ccccc23)cc1 +CC[C@@H](C)n1nccc1NC(=O)C(=O)N1CCc2cc(F)ccc2C1 +Cc1ccc(-c2cnc(CCC(=O)N(C)C3CCOCC3)o2)cc1 +C[C@H](CNC(=O)c1ccc(-c2ccccc2)[nH]c1=O)Oc1ccc(F)cc1 +CCOC(=O)c1c(NC(=O)[C@H]2CCCN2S(C)(=O)=O)sc2ccccc12 +Cc1ccc(C(=O)Cc2cccc(O)c2)cc1 +Cc1ccccc1-c1nn(CN2CCCc3ccc(S(C)(=O)=O)cc32)c(=S)o1 +CCOc1cccc([C@H](C)NC[C@@](C)(O)c2ccc(F)cc2F)c1 +CC[C@@H]1CCCCCN1C(=O)c1cnc2sc(C)cn2c1=O +N#Cc1ccc(OCC(=O)NCc2cccc(CO)c2)cc1 +O=C(Nc1ccccc1)NC1CCN(C(=O)[C@H]2CCCC[C@H]2C(F)(F)F)CC1 +CC1(C)CCC(O)(C[NH2+][C@@H]2CCOC3(CCC3)C2)CC1 +[NH3+][C@H]1CCC[C@H]1CCN1C(=O)c2cccc3cccc1c23 +C#CCN(Cc1cc(Br)cc(OC)c1O)[C@@H]1CCS(=O)(=O)C1 +Cc1ccc(-c2nc3nc(CN4CC[NH+](C)CC4)cc([O-])n3n2)cc1 +C[C@@](O)(CNC(=O)C1CCCC1)c1cccs1 +O[C@H](CSc1nnc(-c2c[nH]c3ccccc23)n1C1CC1)CN1CCOCC1 +C[C@@H]1CCC/C(=N/[NH+]=C(/[S-])NCc2ccccc2)C1 +O=C(c1cc2ccc(Cl)cc2[nH]1)N1CCC[C@@H]1Cn1nnc(-c2cccs2)n1 +COc1ccc(S(=O)(=O)Oc2ccc(C(C)=O)cc2OC)cc1 +CC(C)c1ccccc1NC(=O)C[NH+](C(C)C)[C@@H]1CCCC[C@@H]1O +O=Cc1ccn(-c2ccc(Br)cc2)c1 +O=C(C1CC1)N1CCC[C@H](Cn2cc[nH+]c2-c2cc3n(n2)CC[NH2+]C3)C1 +O=C(CCNc1ccccc1[N+](=O)[O-])N1CCC[C@@H]([NH+]2CCCC2)C1 +Cc1ccc2c(c1)-c1onc(C(=O)N3C[C@@H](C)C[C@H](C)C3)c1CO2 +O=C(COC(=O)c1ccc(Cl)nc1)NC(=O)Nc1ccc2c(c1)OCCO2 +CCC[C@@H]1C[C@H]1NC(=O)C1(c2ccc(F)cc2F)CCOCC1 +CCOC(=O)C1CCC(NC(=O)[C@@](C)([NH3+])CC)CC1 +CC[C@@H](O)C(=O)NCc1cccnc1Oc1ccccc1OC +C[C@@H](Sc1nnc(-c2cccs2)n1-c1ccccc1)C(=O)N1CC(=O)Nc2ccccc21 +Cc1cc(F)ccc1NC(=O)COc1ccc2c(c1)CCC2 +C[C@H](NC(=O)NCCC[S@](C)=O)c1ccc(Cl)s1 +O=C(Cn1c(=O)c(=O)n(Cc2ccncc2)c2ncccc21)NCCc1ccccc1 +CS(=O)(=O)c1ccc(C(=O)Nc2ccc(F)c(F)c2F)cc1 +CCCCOc1ccccc1/C=C1\SC(N2CCC(C)CC2)=NC1=O +CC(C)[C@@H](CNC(=O)N1CCc2ccc(Cl)cc2C1)c1cccnc1 +CCCOc1ncnc(Nc2cc(Cl)cc(Cl)c2)c1N +CC(=O)Nc1ccc(NC(=O)c2nnn(-c3ccc(C)c(C)c3)c2C)cc1 +CCOC1CC[NH+](CC[C@@H](O)c2ccc(C)c(F)c2)CC1 +IC[C@@H]1Cn2c(nnc2-c2ccncc2)S1 +CCOc1ccc([C@H]2CCCN2C(=O)c2[nH]c(C)c(C(C)=O)c2C)cc1 +CC(C)CONc1ncnc2sc3c(c12)CCC3 +CC(C)[C@@H](ON1C(=O)c2ccccc2C1=O)C(=O)[O-] +COC[C@H](O)C[NH+]1CCC(C)(C)C1 +COC(=O)[C@@H](c1ccccc1Cl)N1CCCSCC1 +O=C(Nc1nc2ccc(F)cc2s1)c1cc(-c2ccccc2O)[nH]n1 +C/C(=N\Nc1ncnc2sc(C)c(C)c12)c1cccc(OC(F)F)c1 +CC(C)CN1CCO[C@@H](CNC(=O)/C=C/c2ccnc(Cl)c2)C1 +Cc1cc(F)ccc1CCNC(=O)Cc1c[nH]c2c(C)cccc12 +NC(=O)[C@@H]1CCCN(C(=O)Cn2nc(-c3cccs3)oc2=O)C1 +CCCCCn1c(SCC(=O)[O-])nc2ccccc2c1=O +Cc1ccc2nc(NC(=O)c3ccc(OCc4nc(-c5ccco5)cs4)cc3)sc2c1 +COc1ccccc1N1CCN(c2ccc(=O)n(CC(=O)NC3CC3)n2)CC1 +C/C(=C/C(=O)N[C@@H](C)c1c(C)noc1C)c1ccccc1OC(F)F +COCCn1nc(C)c(NC(=O)N2CC[C@H](Cc3ccccc3)C2)c1C +Cc1ccsc1C[NH+](Cc1nc2ccccc2n1C(C)C)C[C@H](C)O +Cc1cccc(CNC(=O)C[C@H]2Oc3ccc(C)cc3NC2=O)c1 +C[C@@H](c1ccc([S@](C)=O)cc1)N(C)C(=O)c1cc2cccc(F)c2o1 +COc1cc(OC)c(C(C)=O)cc1CSc1nnnn1-c1ccccc1 +Cc1sc(=O)n(CCC(=O)NC2CC(C)(C)[NH2+]C(C)(C)C2)c1C +CC[C@@H]1CCCCN1C(=S)NC(=O)c1ccc(C)cc1 +CCC[C@H](C)C(=O)N[C@H](C)c1cccc(Br)c1 +COc1cccc([C@H]2CCCN2C(=O)c2ccccc2I)c1 +C[C@H](NC(=O)N1CCCC[C@@H]1C1OCCO1)c1cccc(-n2ccnc2)c1 +CS(=O)(=O)N1CCC[C@@H](C[NH+]2CCC[C@H](CO)C2)C1 +CCN(C(=O)Cn1nc2n(c1=O)CCCCC2)[C@H]1CCS(=O)(=O)C1 +COc1ccccc1NC[C@H]1CCCN(S(C)(=O)=O)C1 +COc1cc([C@@H]2CC(=O)Nc3c2cnn3Cc2cccnc2)cc2c1OCO2 +O=C(CNc1ccc(Cl)cc1NC(=O)c1ccco1)Nc1ccc(F)c(Cl)c1 +Cc1nn(C)cc1[C@@H](C)NC(=O)C(=O)Nc1ccc(OCC2CCCCC2)cc1 +CC(C)(C)[S@](=O)CCNC(=O)c1cccc(F)c1Cl +C[C@@H](O)c1ccc(F)cc1OCc1nc(C(C)(C)C)cs1 +COC[C@H](NC(=O)Nc1cn[nH]c1)c1ccc(F)c(F)c1 +C[C@@H]1CS(=O)(=O)N(c2ccc(S(=O)(=O)Nc3ccccc3C(F)(F)F)cc2)C1=O +CC(C)c1ccc2c(c1)[C@]1(CC(O)=Nc3c1cnn3Cc1ccccc1Cl)C(=O)N2C +CC[C@@H](C)C(=O)NCC(=O)N(C)[C@@H](C)c1cc(F)ccc1F +CO/N=C\C(C#N)=C/c1cccnc1 +CO[C@H](c1ccc(Cl)cc1)[C@@H](C)NC(=O)C(=O)Nc1ccccc1C +CC(C)CNC(=O)[C@](C)(N)C(F)(F)F +C[C@@H](C(=O)C1=c2ccccc2=[NH+]C1)[NH+]1CCC[C@@H]1[C@@H]1CC=CS1 +Cc1nc(Br)ccc1NC(=O)NCc1cnn(C)c1 +COc1c(C)cnc(CNC(=O)Nc2ccc(N(C)C)cc2)c1C +COc1ccc2cc(COC(=O)COc3ccccc3C#N)ccc2c1 +CCS(=O)(=O)CCN(C)Cc1c[nH]nc1-c1ccc(C)cc1 +COc1ccc([C@@](C)([NH3+])Cc2[nH+]ccn2C)cc1 +C[C@H]1CCN(C(=O)NCCc2nnc3n2CCCCC3)[C@@H](C)C1 +O=C(NC[C@@H]1CCC[NH+](Cc2ccccc2F)C1)c1nc[nH]n1 +O=C1CC[C@@H](NC(=O)COc2ccc(Cl)c(Cl)c2)CN1 +Cc1noc(C)c1COc1ccc(C[NH2+]C[C@H]2CCCO2)cc1 +N#Cc1cnn2c1N[C@@H](c1ccccc1)C[C@@H]2C(F)F +C[C@@H]1Cc2ccccc2N1C(=O)[C@H]1CCCN(C(=O)NC2CC2)C1 +COCc1nc(C(=O)OCC2=CC[C@H]3C[C@@H]2C3(C)C)cs1 +CCN(Cc1ccc(Br)s1)C(=O)C[NH+](C)CC(=O)[O-] +O=C([O-])[C@H]1CCCN(c2ccc([O-])nn2)C1 +COc1ccc(CCCC(=O)Nc2cccc(S(N)(=O)=O)c2)cc1F +Fc1ccccc1[C@@H](c1nnnn1C1CCCCC1)[NH+]1CCN(c2ccccc2)CC1 +O=C(C/C(=N\Nc1nc(-c2ccccc2)cs1)c1ccccc1)C(F)(F)F +C[C@H]([NH2+]CC(=O)N(C)C)c1ccc(Cl)s1 +CCOC(=O)COc1ccccc1/C=C1/C(=O)NC(=O)N(c2ccc3c(c2)OCO3)C1=O +Cc1cc(C(=O)NNC(=O)c2cccc3ccccc23)c(C)o1 +COc1ccc(-c2csc(NC(=O)Nc3ccc(F)cc3)n2)cc1OC +Cc1ccc(C)c(S(=O)(=O)N2CCN([C@H](C)c3nc(N)nc(Nc4ccccc4)n3)CC2)c1 +Cc1cccc(C)c1-n1nnnc1CSCc1nnc(C)n1C +COc1cccc(C(=O)N[C@@](C)(C(N)=O)c2cccc(Cl)c2)c1 +C[C@H](CCO)[NH2+][C@H]1CCc2c(Br)cccc21 +CCc1ccsc1-c1cnc(C[NH3+])o1 +NC(=O)C1(N2CCCC2)CC[NH2+]CC1 +CC(C)C[C@@H]([NH3+])C(=O)N1CC[C@H](C(=O)[O-])[C@@H]1C +CC[C@H](C)Cn1c(CCCl)nc2c(C)nn(C)c21 +Cc1csc([C@H](C)NC(=O)CCC[NH+]2CCCCC2)n1 +CCC(=O)NN/C(C)=C/C(=O)NCC(C)(C)C +CCC(=O)N1CCC([NH+](C)Cc2ccc(SC)c(OC)c2)CC1 +Clc1ccc(CNc2ncccc2Cl)cn1 +C[C@H](c1nc(C(C)(C)C)no1)[S@](=O)Cc1ncn(-c2ccccc2)n1 +Cn1cnn(C[NH+](Cc2ccc(F)cc2)C2CC2)c1=S +C#CC(C)(C)NC[C@H]1CN(C)CCO1 +C[C@H](NC(=O)[C@H]1CCCN1S(C)(=O)=O)c1ccc2c(c1)OCO2 +O=C(Nc1ccc2ncccc2c1)C(=O)NC1CCC(O)CC1 +CC(C)[C@H]([NH2+]CC1CCN(C(=O)OC(C)(C)C)CC1)c1cccnc1 +CCC[C@H](C)NC(=O)C[NH2+]Cc1cscc1C +CS(=O)(=O)N1CCC(C(=O)Nc2sc3c(c2C#N)CCCC3)CC1 +O=C(CN(C(=O)Cn1nnc2ccccc21)c1ccccc1)NC[C@H]1CCCO1 +O=C(Nc1cccc(N2CCCNC2=O)c1)C(=O)N1CCc2cc(F)ccc2C1 +C[C@@H]1CC[NH+](CCCN2C(=O)CNC2=O)C[C@@H]1O +O=C(COc1ccc(Br)cc1)NOCc1ccccc1 +CC(=O)C[C@]1(O)C(=O)N(Cc2ccc(C)cc2)c2c(C)cccc21 +C[C@H](OC(=O)c1ccc2ccccc2n1)C(=O)NCC1CCCCC1 +C[C@@H](Sc1cc(Cl)ccc1Cl)C(=O)N1CCC[C@H](CCC(N)=O)C1 +COc1cc(-c2ccno2)ccc1S(=O)(=O)NCc1ccco1 +CCC[C@@H](CC)Nc1c(F)c(F)nc(F)c1F +Oc1cccc([C@@H]2CN(c3nccc(Oc4ccc(F)cc4)n3)CCO2)c1 +COc1ccc([C@H](CNC(=O)c2cccc3ccccc23)[NH+]2CCCC2)cc1 +CC1(C)[C@@H]2CC[C@@]1(C)[C@H](NC(=O)COc1ccc(C3SCCCS3)cc1)C2 +Cc1c(C)n(-c2ccccc2)c2nc(C(=O)Nc3ccc(F)cc3)nc(N3CCCCC3)c12 +CC(C)C(=O)Nc1cccc(NC(=O)C(=O)NCC[C@H]2C[C@H]3CC[C@@H]2C3)c1 +COc1ccc(CNC(=O)c2cc(N3C(=O)C(C)(C)CS3(=O)=O)ccc2Cl)cc1OC +CC(=O)Nc1cccc(NC(=O)CCc2c(C)[nH]c(=S)[nH]c2=O)c1 +COc1ccccc1[C@H]1CCCN1C(=O)[C@@H](C)CCOc1ccccc1 +CCNc1ncc(COCC2CCCCC2)s1 +CC[C@H](C)C(=O)Nc1ccccn1 +O=C([O-])CC1=C(C(=O)[O-])CCCC1 +Fc1ccc(C[NH2+]C[C@@H]([C@H]2CCOC2)N2CCOCC2)c(F)c1 +CCc1cc(Cn2cc(N)nn2)n(C)n1 +N#Cc1cccnc1Oc1ccccc1NCc1cccc2ccccc12 +O=C(NCc1ccc([N+](=O)[O-])cc1)N[C@@H]1CCCC[C@H]1CO +C[C@H](NC(=O)NC[C@H](C)C[C@@H](C)O)c1ccc(S(C)(=O)=O)cc1 +CC(C)Oc1ccc(NC(=O)NC[C@H](C)N2CCOCC2)c(F)c1 +COc1ccccc1COC1CCN(C(=O)[C@H]2CCC[C@@H](C)C2)CC1 +CC(=O)O[C@H]1CC[C@H]2[C@H]3C[C@H](OC(C)=O)[C@]45C[C@H]4CC[C@]5(C)[C@@H]3CC[C@]12C +CCCn1nnnc1CN1CC[C@]2(C1)NC(=O)N(C(C)C)C2=O +CC1=C(C(=O)OCC(C)C)[C@H](c2cccc(F)c2)c2c(n(C)c(=O)n(C)c2=O)N1 +Cc1cc(C)c2c(-c3ccccc3)nc(SCC(=O)NC3CC3)n2n1 +COc1cc2c(cc1OC)[C@H](C(=O)[O-])[C@H](c1cccc(Cl)c1)N(C)C2=O +COCCN1C(=O)CC[C@@H]2C[NH+](Cc3cc(C)ccc3C)CC[C@@H]21 +Cc1cnc([C@H](C)NC(=O)NNC(=O)Nc2ccccc2)s1 +CNC(=O)c1ccc(NC(=O)c2csc(-c3ccccc3)n2)cc1 +CNc1nc(C2CCN(C(=O)Cc3ccccn3)CC2)[nH+]c2c1CN(C(C)=O)CC2 +Cc1cc(C)n(C[C@@H](C)CNC(=O)NCc2cc3ccccc3o2)n1 +Cc1nc2n(n1)C[C@H]([NH2+]C[C@@H](O)CN(C)Cc1ccccc1)CC2 +c1ccc(Cn2c(SCc3ncon3)nnc2-c2cccs2)cc1 +COc1ccc(-n2nc(C)c3c2C[C@H](c2cc(OC)c(OC)c(OC)c2)CC3=O)cc1 +COc1ccc(OC)c(/C=C/C(=O)OCC(C)C)c1 +C[C@H](NC(=O)[C@H]1CC[C@H](C[NH3+])O1)C(=O)N(C)C +Cc1cccc(C(C)C)c1NC(=O)[C@@H](C)Sc1nnc(-c2cccs2)n1N +Cn1cc(C(=O)Nc2ccccc2C(=O)NCCc2ccccc2)c(=O)c2cccn21 +Cc1ccccc1[C@@H]1C[C@H](C)N(C(=O)[C@@H](C)Sc2ccccn2)C1 +Cc1occc1C(=O)/C(C#N)=C/c1ccc([C@@H]2C[C@H]2C)o1 +CC[C@](C)(C[NH3+])[C@H](O)c1ccc2c(c1)OCO2 +CCc1nn(CC)c(C[C@@]2(C3CC3)CCC[NH2+]2)c1Br +CC(=O)Nc1ccc(CN2CC[NH+](C3CCCC3)[C@H](CCO)C2)cc1 +COc1ccc([C@H](O)[C@@H](C)NC(=O)[C@@H](C)SC)cc1 +CCC(=O)N1CCC[C@@H]1c1cc(C(F)(F)F)c2c(=O)n(C)c(=O)n(C)c2n1 +C[C@H]([NH3+])[C@@H](CC(=O)[O-])c1ccccc1 +Cc1ccc([C@H]2C3=C(NC(=O)N2C)c2ccccc2C3=O)cc1 +C[C@@]1(Cc2ccc3c(c2)OCO3)CCC(=O)N(CCc2ccc(O)cc2)C1 +Cc1cc(C)cc(O[C@H]2CCCC(C)(C)[C@@H]2O)c1 +Cc1c(-c2nc(-c3cccs3)no2)sc2nc[nH]c(=O)c12 +COc1cc([C@H]2C(C(=O)Nc3ccccn3)=C(C)NC3=C2C(=O)CCC3)ccc1OCc1ccccc1 +COc1cc(C)c([C@@H](C)NC2CC[NH+]([C@H]3CCCC[C@@H]3O)CC2)cc1OC +C[C@@H]1CCC[C@@H](NS(=O)(=O)Cc2cccc(N)c2)C1 +CCCN1C(N)=[NH+]C[C@@H]1c1cc(Cl)c2c(c1)OCO2 +CC(C)C[C@@H](NC(=O)[C@@H]1C[C@@H]1c1cccc(Cl)c1Cl)C(=O)Nc1cc[nH]n1 +Cc1nc(CCC[NH+]2CCC[C@H]2C(N)=O)cs1 +NC(=O)COc1cccc(CNC(=O)c2cc3cc(Cl)ccc3[nH]2)c1 +COc1ccc(CNC(=O)c2cc(=O)c3ccc(Br)cc3o2)cc1 +CC[C@@H](C)c1ccccc1N1C[C@H](C(=O)N2CCN(C)CC2)CC1=O +CC[NH2+][C@H](Cc1ccccc1Cl)[C@H]1C[NH+](C)CCN1C +C[NH+]1CCC(N[C@@H]2CC(=O)N(CCc3cccc(Cl)c3)C2)CC1 +CCN1CC(=O)Nc2cc(C(=O)NC3CC[NH+](C4CCCC4)CC3)ccc21 +COc1ccc(Br)cc1/C=C/C(=O)N1CCN(C(=O)c2ccccc2)CC1 +NC(=O)COc1ccc(C(=O)N[C@H]2CCCc3ccccc32)cc1 +O=C(COc1ncnc2ccc(Br)cc12)Nc1ccccc1Cl +C[C@@H](c1ccc(Cl)cc1Cl)N(C)C(=O)c1ccc(NC(N)=O)cc1 +Cc1ccc(N2C(=O)[C@@H](Cc3cccc(C)c3)S/C2=C(/C#N)C(N)=O)cc1 +CC(C)CN(C(=O)NCc1ccc(C(F)(F)F)cc1)C1CC1 +ClCCc1nc2cccnc2n1CCn1cccn1 +CC[C@@](C)([C@@H]([NH3+])c1cc(Br)ccc1F)N1CCOCC1 +Cc1ccc(NC(=O)c2cc3ccccc3oc2=O)c([N+](=O)[O-])c1 +CC1(C)OC[C@H]([C@H]2O[C@@H]3OC(C)(C)O[C@@H]3[C@H]2OS(C)(=O)=O)O1 +Cc1cccc([C@H](CCl)CCC[C@@H]2CCCO2)c1 diff --git a/MoleculeSTM/models/GA/__init__.py b/MoleculeSTM/models/GA/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/MoleculeSTM/models/GA/crossover.py b/MoleculeSTM/models/GA/crossover.py new file mode 100644 index 0000000..ec6d6b1 --- /dev/null +++ b/MoleculeSTM/models/GA/crossover.py @@ -0,0 +1,194 @@ +''' +Written by Jan H. Jensen 2018 +''' +from rdkit import Chem +from rdkit.Chem import AllChem + +import random +import numpy as np + +from rdkit import rdBase +rdBase.DisableLog('rdApp.error') + +average_size = 39.15 +size_stdev = 3.50 + + +def cut(mol): + if not mol.HasSubstructMatch(Chem.MolFromSmarts('[*]-;!@[*]')): + return None + bis = random.choice(mol.GetSubstructMatches(Chem.MolFromSmarts('[*]-;!@[*]'))) #single bond not in ring + #print bis,bis[0],bis[1] + bs = [mol.GetBondBetweenAtoms(bis[0],bis[1]).GetIdx()] + + fragments_mol = Chem.FragmentOnBonds(mol,bs,addDummies=True,dummyLabels=[(1, 1)]) + + try: + fragments = Chem.GetMolFrags(fragments_mol,asMols=True) + return fragments + except: + return None + + +def cut_ring(mol): + for i in range(10): + if random.random() < 0.5: + if not mol.HasSubstructMatch(Chem.MolFromSmarts('[R]@[R]@[R]@[R]')): + return None + bis = random.choice(mol.GetSubstructMatches(Chem.MolFromSmarts('[R]@[R]@[R]@[R]'))) + bis = ((bis[0],bis[1]),(bis[2],bis[3]),) + else: + if not mol.HasSubstructMatch(Chem.MolFromSmarts('[R]@[R;!D2]@[R]')): + return None + bis = random.choice(mol.GetSubstructMatches(Chem.MolFromSmarts('[R]@[R;!D2]@[R]'))) + bis = ((bis[0],bis[1]),(bis[1],bis[2]),) + + #print bis + bs = [mol.GetBondBetweenAtoms(x,y).GetIdx() for x,y in bis] + + fragments_mol = Chem.FragmentOnBonds(mol,bs,addDummies=True,dummyLabels=[(1, 1),(1,1)]) + + try: + fragments = Chem.GetMolFrags(fragments_mol,asMols=True) + except: + return None + + if len(fragments) == 2: + return fragments + + return None + +def ring_OK(mol): + if not mol.HasSubstructMatch(Chem.MolFromSmarts('[R]')): + return True + + ring_allene = mol.HasSubstructMatch(Chem.MolFromSmarts('[R]=[R]=[R]')) + + cycle_list = mol.GetRingInfo().AtomRings() + max_cycle_length = max([ len(j) for j in cycle_list ]) + macro_cycle = max_cycle_length > 6 + + double_bond_in_small_ring = mol.HasSubstructMatch(Chem.MolFromSmarts('[r3,r4]=[r3,r4]')) + + return not ring_allene and not macro_cycle and not double_bond_in_small_ring + +def mol_OK(mol): + try: + Chem.SanitizeMol(mol) + test_mol = Chem.MolFromSmiles(Chem.MolToSmiles(mol)) + if test_mol == None: + return None + target_size = size_stdev*np.random.randn() + average_size #parameters set in GA_mol + if mol.GetNumAtoms() > 5 and mol.GetNumAtoms() < target_size: + return True + else: + return False + except: + return False + + +def crossover_ring(parent_A,parent_B): + ring_smarts = Chem.MolFromSmarts('[R]') + if not parent_A.HasSubstructMatch(ring_smarts) and not parent_B.HasSubstructMatch(ring_smarts): + return None + + rxn_smarts1 = ['[*:1]~[1*].[1*]~[*:2]>>[*:1]-[*:2]','[*:1]~[1*].[1*]~[*:2]>>[*:1]=[*:2]'] + rxn_smarts2 = ['([*:1]~[1*].[1*]~[*:2])>>[*:1]-[*:2]','([*:1]~[1*].[1*]~[*:2])>>[*:1]=[*:2]'] + for i in range(10): + fragments_A = cut_ring(parent_A) + fragments_B = cut_ring(parent_B) + #print [Chem.MolToSmiles(x) for x in list(fragments_A)+list(fragments_B)] + if fragments_A == None or fragments_B == None: + return None + + new_mol_trial = [] + for rs in rxn_smarts1: + rxn1 = AllChem.ReactionFromSmarts(rs) + new_mol_trial = [] + for fa in fragments_A: + for fb in fragments_B: + new_mol_trial.append(rxn1.RunReactants((fa,fb))[0]) + + new_mols = [] + for rs in rxn_smarts2: + rxn2 = AllChem.ReactionFromSmarts(rs) + for m in new_mol_trial: + m = m[0] + if mol_OK(m): + new_mols += list(rxn2.RunReactants((m,))) + + new_mols2 = [] + for m in new_mols: + m = m[0] + if mol_OK(m) and ring_OK(m): + new_mols2.append(m) + + if len(new_mols2) > 0: + return random.choice(new_mols2) + + return None + +def crossover_non_ring(parent_A,parent_B): + for i in range(10): + fragments_A = cut(parent_A) + fragments_B = cut(parent_B) + if fragments_A == None or fragments_B == None: + return None + rxn = AllChem.ReactionFromSmarts('[*:1]-[1*].[1*]-[*:2]>>[*:1]-[*:2]') + new_mol_trial = [] + for fa in fragments_A: + for fb in fragments_B: + new_mol_trial.append(rxn.RunReactants((fa,fb))[0]) + + new_mols = [] + for mol in new_mol_trial: + mol = mol[0] + if mol_OK(mol): + new_mols.append(mol) + + if len(new_mols) > 0: + return random.choice(new_mols) + + return None + +def crossover(parent_A,parent_B): + parent_smiles = [Chem.MolToSmiles(parent_A),Chem.MolToSmiles(parent_B)] + try: + Chem.Kekulize(parent_A,clearAromaticFlags=True) + Chem.Kekulize(parent_B,clearAromaticFlags=True) + except: + pass + for i in range(10): + if random.random() <= 0.5: + #print 'non-ring crossover' + new_mol = crossover_non_ring(parent_A,parent_B) + if new_mol != None: + new_smiles = Chem.MolToSmiles(new_mol) + if new_mol != None and new_smiles not in parent_smiles: + return new_mol + else: + #print 'ring crossover' + new_mol = crossover_ring(parent_A,parent_B) + if new_mol != None: + new_smiles = Chem.MolToSmiles(new_mol) + if new_mol != None and new_smiles not in parent_smiles: + return new_mol + + return None + +if __name__ == "__main__": + smiles1 = 'CC(C)(C)c1ccc2occ(CC(=O)Nc3ccccc3F)c2c1' + smiles2 = 'C[C@@H]1CC(Nc2cncc(-c3nncn3C)c2)C[C@@H](C)C1' + + smiles1 = 'Cc1ccc(S(=O)(=O)N2C(N)=C(C#N)C(c3ccc(Cl)cc3)C2C(=O)c2ccccc2)cc1' + smiles2 = 'CC(C#N)CNC(=O)c1cccc(Oc2cccc(C(F)(F)F)c2)c1' + + mol1 = Chem.MolFromSmiles(smiles1) + mol2 = Chem.MolFromSmiles(smiles2) + + child = crossover(mol1,mol2) + mutation_rate = 1.0 + #mutated_child = mutate(child,mutation_rate) + + for i in range(100): + child = crossover(mol1,mol2) diff --git a/MoleculeSTM/models/GA/mutate.py b/MoleculeSTM/models/GA/mutate.py new file mode 100644 index 0000000..f52905d --- /dev/null +++ b/MoleculeSTM/models/GA/mutate.py @@ -0,0 +1,132 @@ +''' +Written by Jan H. Jensen 2018 +''' +from rdkit import Chem +from rdkit.Chem import AllChem + +import random +import numpy as np +import MoleculeSTM.models.GA.crossover as co + +from rdkit import rdBase +rdBase.DisableLog('rdApp.error') + +def delete_atom(): + choices = ['[*:1]~[D1]>>[*:1]', '[*:1]~[D2]~[*:2]>>[*:1]-[*:2]', + '[*:1]~[D3](~[*;!H0:2])~[*:3]>>[*:1]-[*:2]-[*:3]', + '[*:1]~[D4](~[*;!H0:2])(~[*;!H0:3])~[*:4]>>[*:1]-[*:2]-[*:3]-[*:4]', + '[*:1]~[D4](~[*;!H0;!H1:2])(~[*:3])~[*:4]>>[*:1]-[*:2](-[*:3])-[*:4]'] + p = [0.25,0.25,0.25,0.1875,0.0625] + + return np.random.choice(choices, p=p) + +def append_atom(): + choices = [['single',['C','N','O','F','S','Cl','Br'],7*[1.0/7.0]], + ['double',['C','N','O'],3*[1.0/3.0]], + ['triple',['C','N'],2*[1.0/2.0]] ] + p_BO = [0.60,0.35,0.05] + + index = np.random.choice(list(range(3)), p=p_BO) + + BO, atom_list, p = choices[index] + new_atom = np.random.choice(atom_list, p=p) + + if BO == 'single': + rxn_smarts = '[*;!H0:1]>>[*:1]X'.replace('X','-'+new_atom) + if BO == 'double': + rxn_smarts = '[*;!H0;!H1:1]>>[*:1]X'.replace('X','='+new_atom) + if BO == 'triple': + rxn_smarts = '[*;H3:1]>>[*:1]X'.replace('X','#'+new_atom) + + return rxn_smarts + +def insert_atom(): + choices = [['single',['C','N','O','S'],4*[1.0/4.0]], + ['double',['C','N'],2*[1.0/2.0]], + ['triple',['C'],[1.0]] ] + p_BO = [0.60,0.35,0.05] + + index = np.random.choice(list(range(3)), p=p_BO) + + BO, atom_list, p = choices[index] + new_atom = np.random.choice(atom_list, p=p) + + if BO == 'single': + rxn_smarts = '[*:1]~[*:2]>>[*:1]X[*:2]'.replace('X',new_atom) + if BO == 'double': + rxn_smarts = '[*;!H0:1]~[*:2]>>[*:1]=X-[*:2]'.replace('X',new_atom) + if BO == 'triple': + rxn_smarts = '[*;!R;!H1;!H0:1]~[*:2]>>[*:1]#X-[*:2]'.replace('X',new_atom) + + return rxn_smarts + +def change_bond_order(): + choices = ['[*:1]!-[*:2]>>[*:1]-[*:2]','[*;!H0:1]-[*;!H0:2]>>[*:1]=[*:2]', + '[*:1]#[*:2]>>[*:1]=[*:2]','[*;!R;!H1;!H0:1]~[*:2]>>[*:1]#[*:2]'] + p = [0.45,0.45,0.05,0.05] + + return np.random.choice(choices, p=p) + +def delete_cyclic_bond(): + return '[*:1]@[*:2]>>([*:1].[*:2])' + +def add_ring(): + choices = ['[*;!r;!H0:1]~[*;!r:2]~[*;!r;!H0:3]>>[*:1]1~[*:2]~[*:3]1', + '[*;!r;!H0:1]~[*!r:2]~[*!r:3]~[*;!r;!H0:4]>>[*:1]1~[*:2]~[*:3]~[*:4]1', + '[*;!r;!H0:1]~[*!r:2]~[*:3]~[*:4]~[*;!r;!H0:5]>>[*:1]1~[*:2]~[*:3]~[*:4]~[*:5]1', + '[*;!r;!H0:1]~[*!r:2]~[*:3]~[*:4]~[*!r:5]~[*;!r;!H0:6]>>[*:1]1~[*:2]~[*:3]~[*:4]~[*:5]~[*:6]1'] + p = [0.05,0.05,0.45,0.45] + + return np.random.choice(choices, p=p) + +def change_atom(mol): + choices = ['#6','#7','#8','#9','#16','#17','#35'] + p = [0.15,0.15,0.14,0.14,0.14,0.14,0.14] + + X = np.random.choice(choices, p=p) + while not mol.HasSubstructMatch(Chem.MolFromSmarts('['+X+']')): + X = np.random.choice(choices, p=p) + Y = np.random.choice(choices, p=p) + while Y == X: + Y = np.random.choice(choices, p=p) + + return '[X:1]>>[Y:1]'.replace('X',X).replace('Y',Y) + +def mutate(mol,mutation_rate): + + if random.random() > mutation_rate: + return mol + + Chem.Kekulize(mol,clearAromaticFlags=True) + p = [0.15,0.14,0.14,0.14,0.14,0.14,0.15] + for i in range(10): + rxn_smarts_list = 7*[''] + rxn_smarts_list[0] = insert_atom() + rxn_smarts_list[1] = change_bond_order() + rxn_smarts_list[2] = delete_cyclic_bond() + rxn_smarts_list[3] = add_ring() + rxn_smarts_list[4] = delete_atom() + rxn_smarts_list[5] = change_atom(mol) + rxn_smarts_list[6] = append_atom() + rxn_smarts = np.random.choice(rxn_smarts_list, p=p) + + #print('mutation',rxn_smarts) + + rxn = AllChem.ReactionFromSmarts(rxn_smarts) + + new_mol_trial = rxn.RunReactants((mol,)) + + new_mols = [] + for m in new_mol_trial: + m = m[0] + #print Chem.MolToSmiles(mol),mol_OK(mol) + if co.mol_OK(m) and co.ring_OK(m): + new_mols.append(m) + + if len(new_mols) > 0: + return random.choice(new_mols) + + return None + +if __name__ == "__main__": + pass diff --git a/MoleculeSTM/models/MLP.py b/MoleculeSTM/models/MLP.py new file mode 100644 index 0000000..b5175c2 --- /dev/null +++ b/MoleculeSTM/models/MLP.py @@ -0,0 +1,49 @@ +from torch import nn +from torch.nn import functional as F +from collections.abc import Sequence + + +class MLP(nn.Module): + def __init__(self, input_dim, hidden_dims, batch_norm=False, activation="relu", dropout=0): + super(MLP, self).__init__() + + if not isinstance(hidden_dims, Sequence): + hidden_dims = [hidden_dims] + self.dims = [input_dim] + hidden_dims + + if isinstance(activation, str): + self.activation = getattr(F, activation) + else: + self.activation = activation + if dropout: + self.dropout = nn.Dropout(dropout) + else: + self.dropout = None + + self.layers = nn.ModuleList() + for i in range(len(self.dims) - 1): + self.layers.append(nn.Linear(self.dims[i], self.dims[i + 1])) + if batch_norm: + self.batch_norms = nn.ModuleList() + for i in range(len(self.dims) - 2): + self.batch_norms.append(nn.BatchNorm1d(self.dims[i + 1])) + else: + self.batch_norms = None + + def forward(self, input): + layer_input = input + + for i, layer in enumerate(self.layers): + hidden = layer(layer_input) + if i < len(self.layers) - 1: + if self.batch_norms: + x = hidden.flatten(0, -2) + hidden = self.batch_norms[i](x).view_as(hidden) + hidden = self.activation(hidden) + if self.dropout: + hidden = self.dropout(hidden) + if hidden.shape == layer_input.shape: + hidden = hidden + layer_input + layer_input = hidden + + return hidden \ No newline at end of file diff --git a/MoleculeSTM/models/__init__.py b/MoleculeSTM/models/__init__.py new file mode 100644 index 0000000..719b380 --- /dev/null +++ b/MoleculeSTM/models/__init__.py @@ -0,0 +1,2 @@ +from MoleculeSTM.models.molecule_gnn_model import GNN, GNN_graphpred +from MoleculeSTM.models.MLP import MLP \ No newline at end of file diff --git a/MoleculeSTM/models/mega_molbart/__init__.py b/MoleculeSTM/models/mega_molbart/__init__.py new file mode 100644 index 0000000..0ae0ec7 --- /dev/null +++ b/MoleculeSTM/models/mega_molbart/__init__.py @@ -0,0 +1 @@ +from MoleculeSTM.models.mega_molbart.megatron_bart import MegatronBART \ No newline at end of file diff --git a/MoleculeSTM/models/mega_molbart/decoder.py b/MoleculeSTM/models/mega_molbart/decoder.py new file mode 100644 index 0000000..b7aaad1 --- /dev/null +++ b/MoleculeSTM/models/mega_molbart/decoder.py @@ -0,0 +1,426 @@ +# coding=utf-8 + +import torch +from rdkit import Chem, RDLogger +from .util import DEFAULT_MAX_SEQ_LEN + +class DecodeSampler: + def __init__( + self, + tokenizer, + max_seq_len=DEFAULT_MAX_SEQ_LEN + ): + self.tokenizer = tokenizer + self.max_seq_len = max_seq_len + + assert max_seq_len > 1, f"Max sequence must be at least 2, got {max_seq_len}" + + self.begin_token_id = self.tokenizer.vocab[self.tokenizer.begin_token] + self.pad_token_id = self.tokenizer.vocab[self.tokenizer.pad_token] + self.end_token_id = self.tokenizer.vocab[self.tokenizer.end_token] + + self.bad_token_ll = -1e5 + + RDLogger.DisableLog("rdApp.*") + + + def decode(self, decode_fn, batch_size, sampling_alg="greedy", device="cpu", **kwargs): + """ Sample a molecule from a model by calling the decode function argument + + Args: + decode_fn: A function mapping a batched sequence of token identifiers and their associated pad masks + to a log probability distribution over possible next tokens + batch_size: The number of elements to pass into the decode function in one batch + sampling_alg: Algorithm to use for sampling from the model + + Returns: + (SMILES of sampled molecules (List[str]), log likelihoods (List[float])) + """ + + if sampling_alg == "greedy": + output = self.greedy_decode(decode_fn, batch_size, device) + + elif sampling_alg == "beam": + output = self.beam_decode(decode_fn, batch_size, device, kwargs) + + else: + raise ValueError(f"Unknown sampling algorithm {sampling_alg}") + + return output + + + def greedy_decode(self, decode_fn, batch_size, device="cpu"): + """ Sample molecules from the model using greedy search + + Args: + decode_fn (fn): Function used to apply tokens to model and produce log probability distribution + batch_size (int): Number of molecules to sample + device: Torch device to create tensors on + + Returns: + (List[str], List[float]): Tuple of (molecules, their log likelihoods) + """ + + # Create tensors which will be reused + token_ids = [self.begin_token_id] + ([self.pad_token_id] * (self.max_seq_len - 1)) + token_ids = [token_ids] * batch_size + token_ids = torch.tensor(token_ids, device=device).transpose(0, 1) + pad_mask = torch.zeros((self.max_seq_len, batch_size), device=device, dtype=torch.bool) + log_lhs = torch.zeros((batch_size)) + + # Iteratively apply the tokens to the model and build up the sequence + for i in range(1, self.max_seq_len): + token_ids_seq = token_ids[:i, :] + pad_mask_seq = pad_mask[:i, :] + + # Sample next id for each element in the batch + output_dist = decode_fn(token_ids_seq, pad_mask_seq) + probs, output_ids = output_dist.max(dim=2) + new_ids = output_ids[-1, :] + new_probs = probs[-1, :] + + # Generate next elements in the pad mask. An element is padded if: + # 1. The previous token is an end token + # 2. The previous token is a pad token + is_end_token = token_ids[i-1, :] == self.end_token_id + is_pad_token = token_ids[i-1, :] == self.pad_token_id + new_pad_mask = torch.logical_or(is_end_token, is_pad_token) + + # Break if sampling is complete + if new_pad_mask.sum().item() == new_pad_mask.numel(): + break + + # Ensure all sequences contain an end token + if i == self.max_seq_len - 1: + new_ids[~new_pad_mask] = self.end_token_id + + # Set the token to pad where required, update the token ids and update lls + new_ids[new_pad_mask] = self.pad_token_id + token_ids[i, :] = new_ids + pad_mask[i, :] = new_pad_mask + log_lhs += new_probs.cpu() + + tokens = token_ids.transpose(0, 1).tolist() + tokens = self.tokenizer.convert_ids_to_tokens(tokens) + mol_strs = self.tokenizer.detokenize(tokens) + log_lhs = log_lhs.tolist() + + return mol_strs, log_lhs + + + def beam_decode(self, decode_fn, batch_size, device="cpu", k=5): + """ Sample molecules from the model using beam search + + Samples molecules by iteratively building up the sequence of SMILES characters using beam search. + Molecules are returned in a 2D list where batch_size is the outer dimension and k is the inner dimension. + + Args: + decode_fn (fn): Function used to apply tokens to model and produce log probability distribution + batch_size (int): Number of molecules to sample + device: Torch device to create tensors on + k (int): Number of beams + + Returns: + (List[List[str]], List[List[float]]): Tuple of (molecules, their log likelihoods) + """ + + # Create tensors which will be reused + token_ids = [self.begin_token_id] + ([self.pad_token_id] * (self.max_seq_len - 1)) + token_ids = [token_ids] * batch_size + token_ids = torch.tensor(token_ids, device=device).transpose(0, 1) + pad_mask = torch.zeros((self.max_seq_len, batch_size), device=device, dtype=torch.bool) + + ts = token_ids[:1, :] + ms = pad_mask[:1, :] + ll = torch.zeros((batch_size)) + + # Apply starting token to model to get a distribution over next tokens + first_lls = self._beam_step(decode_fn, ts, ms, ll) + top_lls, top_idxs = torch.topk(first_lls, k, dim=1) + top_ids = list(top_idxs.T) + + # Setup tensors for each beam which will be reused + token_ids_list = [token_ids.clone() for _ in range(k)] + pad_mask_list = [pad_mask.clone() for _ in range(k)] + lls_list = list(top_lls.cpu().T) + + for beam_idx, ids in enumerate(top_ids): + token_ids_list[beam_idx][1, :] = ids + pad_mask_list[beam_idx][1, :] = 0 + + for i in range(2, self.max_seq_len): + complete = self._update_beams_(i, decode_fn, token_ids_list, pad_mask_list, lls_list) + if complete: + break + + tokens_list = [token_ids.transpose(0, 1).tolist() for token_ids in token_ids_list] + tokens_list = [self.tokenizer.convert_ids_to_tokens(tokens) for tokens in tokens_list] + mol_strs_list = [self.tokenizer.detokenize(tokens) for tokens in tokens_list] + log_lhs_list = [log_lhs.tolist() for log_lhs in lls_list] + + # Transpose and sort list of molecules based on ll + new_mol_strs = self._transpose_list(mol_strs_list) + new_log_lhs = self._transpose_list(log_lhs_list) + sorted_mols, sorted_lls = self._sort_beams(new_mol_strs, new_log_lhs) + + return sorted_mols, sorted_lls + + + def _update_beams_(self, i, decode_fn, token_ids_list, pad_mask_list, lls_list): + """ Update beam tokens and pad mask in-place using a single decode step + + Updates token ids and pad mask in-place by producing the probability distribution over next tokens + and choosing the top k (number of beams) log likelihoods to choose the next tokens. + Sampling is complete if every batch element in every beam has produced an end token. + + Args: + i (int): The current iteration counter + decode_fn (fn): Function used to apply tokens to model and produce log probability distribution + token_ids_list (List[torch.Tensor]): List of token_ids, each of shape [seq_len, batch_size] + pad_mask_list (List[torch.Tensor]): List of pad_masks, each of shape [seq_len, batch_size] + lls_list (List[torch.Tensor]): List of log likelihoods, each of shape [batch_size] + + Returns: + (bool): Specifies whether all of the beams are complete + """ + + assert len(token_ids_list) == len(pad_mask_list) == len(lls_list) + + num_beams = len(token_ids_list) + + ts = [token_ids[:i, :] for token_ids in token_ids_list] + ms = [pad_mask[:i, :] for pad_mask in pad_mask_list] + + # Apply current seqs to model to get a distribution over next tokens + # new_lls is a tensor of shape [batch_size, vocab_size * num_beams] + new_lls = [self._beam_step(decode_fn, t, m, lls) for t, m, lls in zip(ts, ms, lls_list)] + _, vocab_size = new_lls[0].shape + new_lls = torch.cat(new_lls, dim=1) + + # Keep lists (of length num_beams) of tensors of shape [batch_size] + top_lls, top_idxs = torch.topk(new_lls, num_beams, dim=1) + new_ids_list = list((top_idxs % vocab_size).T) + beam_idxs_list = list((top_idxs // vocab_size).T) + top_lls = list(top_lls.T) + + beam_complete = [] + new_ts_list = [] + new_pm_list = [] + new_lls_list = [] + + # Set the sampled tokens, pad masks and log likelihoods for each of the new beams + for new_beam_idx, (new_ids, beam_idxs, lls) in enumerate(zip(new_ids_list, beam_idxs_list, top_lls)): + # Get the previous sequences corresponding to the new beams + token_ids = [token_ids_list[beam_idx][:, b_idx] for b_idx, beam_idx in enumerate(beam_idxs)] + token_ids = torch.stack(token_ids).transpose(0, 1) + + # Generate next elements in the pad mask. An element is padded if: + # 1. The previous token is an end token + # 2. The previous token is a pad token + is_end_token = token_ids[i-1, :] == self.end_token_id + is_pad_token = token_ids[i-1, :] == self.pad_token_id + new_pad_mask = torch.logical_or(is_end_token, is_pad_token) + beam_complete.append(new_pad_mask.sum().item() == new_pad_mask.numel()) + + # Ensure all sequences contain an end token + if i == self.max_seq_len - 1: + new_ids[~new_pad_mask] = self.end_token_id + + # Set the tokens to pad if an end token as already been produced + new_ids[new_pad_mask] = self.pad_token_id + token_ids[i, :] = new_ids + + # Generate full pad mask sequence for new token sequence + pad_mask = [pad_mask_list[beam_idx][:, b_idx] for b_idx, beam_idx in enumerate(beam_idxs)] + pad_mask = torch.stack(pad_mask).transpose(0, 1) + pad_mask[i, :] = new_pad_mask + + # Add tokens, pad mask and lls to list to be updated after all beams have been processed + new_ts_list.append(token_ids) + new_pm_list.append(pad_mask) + new_lls_list.append(lls) + + complete = sum(beam_complete) == len(beam_complete) + + # Update all tokens, pad masks and lls + if not complete: + for beam_idx, (ts, pm, lls) in enumerate(zip(new_ts_list, new_pm_list, new_lls_list)): + token_ids_list[beam_idx] = ts + pad_mask_list[beam_idx] = pm + lls_list[beam_idx] = lls + + return complete + + def _beam_step(self, decode_fn, tokens, mask, lls): + """ Apply tokens to model to produce the log likelihoods for the full sequence + + A single iteration of decode is applied to the model to produce the next tokens in the sequences + and the log likelihoods for the entire sequences (including the next token) + The lls are returned as a distribution over all possible next tokens + + Args: + decode_fn (fn): Function used to apply tokens to model and produce log probability distribution + tokens (torch.Tensor): Tensor of shape [seq_len, batch_size] containing the current token ids + mask (torch.Tensor): BoolTensor of shape [seq_len, batch_size] containing the padding mask + lls (torch.Tensor): Tensor of shape [batch_size] containing log likelihoods for seqs so far + + Returns: + seq_lls (torch.Tensor): Tensor of shape [batch_size, vocab_size] + """ + + output_dist = decode_fn(tokens, mask) + next_token_lls = output_dist[-1, :, :].cpu() + + # Create a vector from which only a pad token can be sampled + # And use this vector in the output for sequences which are complete + _, vocab_size = tuple(next_token_lls.shape) + complete_seq_ll = torch.ones((1, vocab_size)) * self.bad_token_ll + complete_seq_ll[:, self.pad_token_id] = 0.0 + + is_end_token = tokens[-1, :] == self.end_token_id + is_pad_token = tokens[-1, :] == self.pad_token_id + ll_mask = torch.logical_or(is_end_token, is_pad_token).cpu().unsqueeze(1) + masked_lls = (ll_mask * complete_seq_ll) + (~ll_mask * next_token_lls) + + seq_lls = (lls + masked_lls.T).T + return seq_lls + + @staticmethod + def _transpose_list(l): + """ Transpose 2D list so that inner dimension is first + + Args: + l (List[Any]): List to be transposed + + Returns: + (List[Any]): Transposed list + """ + + outer_dim = len(l) + inner_dim = len(l[0]) + + transposed = [[[]] * outer_dim for _ in range(inner_dim)] + for outer_idx, inner in enumerate(l): + for inner_idx, item in enumerate(inner): + transposed[inner_idx][outer_idx] = item + + return transposed + + @staticmethod + def _sort_beams(mol_strs, log_lhs): + """ Return mols sorted by their log likelihood + + Args: + mol_strs (List[List[str]]): SMILES encoding of molecules + log_lhs (List[List[float]]): Log likelihood for each molecule + + Returns: + (List[str], List[float]): Tuple of sorted molecules and sorted log lhs + """ + + assert len(mol_strs) == len(log_lhs) + + sorted_mols = [] + sorted_lls = [] + + for mols, lls in zip(mol_strs, log_lhs): + mol_lls = sorted(zip(mols, lls), reverse=True, key=lambda mol_ll: mol_ll[1]) + mols, lls = tuple(zip(*mol_lls)) + sorted_mols.append(list(mols)) + sorted_lls.append(list(lls)) + + return sorted_mols, sorted_lls + + @staticmethod + def calc_sampling_metrics(sampled_smiles, target_smiles): + """ Calculate sampling metrics for the model + + If sampled_smiles is a List[List[str]] then the following metrics for beam search are calculated (up to the + maximum given by the number of elements in the inner lists): + - "top_1_accuracy" + - "top_5_accuracy" + - "top_10_accuracy" + - "top_20_accuracy" + - "top_50_accuracy" + The SMILES strings must be sorted in decreasing order of their predicted likelihood + + If the sampled_smiles is a List[str] then "accuracy" is calculated + + The the number of invalid SMILES "invalid" is also returned (for beam search this is just from the top_1) + + Args: + sampled_smiles: SMILES strings produced by decode function, + target_smiles: target molecules as canonicalised SMILES strings + + Returns: + dict containing results + """ + + num_sampled = len(sampled_smiles) + num_target = len(target_smiles) + err_msg = f"The number of sampled and target molecules must be the same, got {num_sampled} and {num_target}" + assert num_sampled == num_target, err_msg + + data_type = type(sampled_smiles[0]) + if data_type == str: + results = DecodeSampler._calc_greedy_metrics(sampled_smiles, target_smiles) + elif data_type == list: + results = DecodeSampler._calc_beam_metrics(sampled_smiles, target_smiles) + else: + raise TypeError(f"Elements of sampled_smiles must be either a str or a list, got {data_type}") + + return results + + @staticmethod + def _calc_greedy_metrics(sampled_smiles, target_smiles): + sampled_mols = [Chem.MolFromSmiles(smi) for smi in sampled_smiles] + invalid = [mol is None for mol in sampled_mols] + + canon_smiles = ["Unknown" if mol is None else Chem.MolToSmiles(mol) for mol in sampled_mols] + target_mols = [Chem.MolFromSmiles(smi) for smi in target_smiles] + canon_target_smiles = [Chem.MolToSmiles(mol) for mol in target_mols] + correct_smiles = [canon_target_smiles[idx] == smi for idx, smi in enumerate(canon_smiles)] + + num_correct = sum(correct_smiles) + total = len(correct_smiles) + num_invalid = sum(invalid) + perc_invalid = num_invalid / total + accuracy = num_correct / total + + # Todo: need to move accuracy and perc_invalid to cuda for reducing later + metrics = { + "accuracy": accuracy, + "invalid": perc_invalid + } + + return metrics + + @staticmethod + def _calc_beam_metrics(sampled_smiles, target_smiles): + top_1_samples = [mols[0] for mols in sampled_smiles] + top_1_results = DecodeSampler._calc_greedy_metrics(top_1_samples, target_smiles) + + metrics = { + "top_1_accuracy": top_1_results["accuracy"], + "invalid": top_1_results["invalid"] + } + + ks = [2, 3, 5, 10, 20, 50] + num_samples_list = [k for k in ks if k <= len(sampled_smiles[0])] + + for num_samples in num_samples_list: + top_k_correct = [] + num_mols = len(sampled_smiles) + + for batch_idx, mols in enumerate(sampled_smiles): + samples = mols[:num_samples] + samples_mols = [Chem.MolFromSmiles(smi) for smi in samples] + samples_smiles = ["Unknown" if mol is None else Chem.MolToSmiles(mol) for mol in samples_mols] + correct_smiles = [smi == target_smiles[batch_idx] for smi in samples_smiles] + is_correct = sum(correct_smiles) >= 1 + top_k_correct.append(is_correct) + + accuracy = sum(top_k_correct) / num_mols + metrics[f"top_{str(num_samples)}_accuracy"] = accuracy + + return metrics \ No newline at end of file diff --git a/MoleculeSTM/models/mega_molbart/mega_mol_bart.py b/MoleculeSTM/models/mega_molbart/mega_mol_bart.py new file mode 100644 index 0000000..088fb8d --- /dev/null +++ b/MoleculeSTM/models/mega_molbart/mega_mol_bart.py @@ -0,0 +1,454 @@ +''' +Credit to https://github.com/NVIDIA/cheminformatics/blob/master/megamolbart/megamolbart/inference.py +''' +import logging +from functools import partial +from pathlib import Path +from typing import List +from rdkit import Chem +import random +import numpy as np + +import torch +from torch.nn.parallel import DistributedDataParallel as torchDDP +import pandas as pd +from megatron.checkpointing import load_checkpoint +import megatron.checkpointing as megatron_checkpointing +from megatron.global_vars import set_global_variables +from MoleculeSTM.cuchemcommon.workflow import BaseGenerativeWorkflow, add_jitter +from .decoder import DecodeSampler +from megatron import get_args, mpu +from megatron.initialize import initialize_megatron +from .megatron_bart import MegatronBART +from .tokenizer import MolEncTokenizer +from .util import (REGEX, DEFAULT_CHEM_TOKEN_START, DEFAULT_MAX_SEQ_LEN, + DEFAULT_VOCAB_PATH, + DEFAULT_NUM_LAYERS, DEFAULT_D_MODEL, DEFAULT_NUM_HEADS) + + +logger = logging.getLogger(__name__) + + +@add_jitter.register(torch.Tensor) +def _(embedding, radius, cnt, shape): + if shape is not None: + embedding = torch.reshape(embedding, (1, shape[0], shape[1])).to(embedding.device) + permuted_emb = embedding.permute(1, 0, 2) + + distorteds = [] + for i in range(cnt): + noise = torch.normal(0, radius, permuted_emb.shape).to(embedding.device) + distorted = (noise + permuted_emb).permute(1, 0, 2) + distorteds.append(distorted) + + return distorteds + + +def use_model_module(model): + ''' Credit to https://github.com/MolecularAI/MolBART/blob/megatron-molbart-with-zinc/megatron_molbart/checkpointing.py#L20 ''' + use_model = isinstance(model, torchDDP) + try: + from deepspeed.runtime.engine import DeepSpeedEngine + except: + pass + else: + use_model = use_model | isinstance(model, DeepSpeedEngine) + return use_model + + +class MegaMolBART(BaseGenerativeWorkflow): + + def __init__(self, + input_dir=None, + output_dir=None, + max_seq_len=DEFAULT_MAX_SEQ_LEN, + vocab_path=DEFAULT_VOCAB_PATH, + regex=REGEX, + default_chem_token_start=DEFAULT_CHEM_TOKEN_START, + num_layers=DEFAULT_NUM_LAYERS, + hidden_size=DEFAULT_D_MODEL, + num_attention_heads=DEFAULT_NUM_HEADS, + decoder_max_seq_len=None, + grad_enabled=True) -> None: + super().__init__() + + torch.set_grad_enabled(grad_enabled) # Testing this instead of `with torch.no_grad():` context since it doesn't exit + + self.device = 'cuda' # Megatron arg loading seems to only work with GPU + self.min_jitter_radius = 1.0 + self.max_model_position_embeddings = max_seq_len + + args = { + 'num_layers': num_layers, + 'hidden_size': hidden_size, + 'num_attention_heads': num_attention_heads, + 'max_position_embeddings': self.max_model_position_embeddings, + 'tokenizer_type': 'GPT2BPETokenizer', + 'vocab_file': vocab_path, + } + if input_dir is not None: + args["load"] = input_dir + if output_dir is not None: + args["save"] = output_dir + args["save_interval"] = 1 + + initialize_megatron(args_defaults=args, ignore_unknown_args=True) + args = get_args() + self.tokenizer = self.load_tokenizer(args.vocab_file, regex, default_chem_token_start) + self.model = self.load_model(args, self.tokenizer, decoder_max_seq_len) + + def _compute_radius(self, scaled_radius): # TODO REMOVE + if scaled_radius: + return float(scaled_radius * self.min_jitter_radius) + else: + return self.min_jitter_radius + + def load_tokenizer(self, tokenizer_vocab_path, regex, default_chem_token_start): + """Load tokenizer from vocab file + + Params: + tokenizer_vocab_path: str, path to tokenizer vocab + + Returns: + MolEncTokenizer tokenizer object + """ + print("Loading vocab from {}.".format(tokenizer_vocab_path)) + tokenizer_vocab_path = Path(tokenizer_vocab_path) + tokenizer = MolEncTokenizer.from_vocab_file( + tokenizer_vocab_path, + regex, + default_chem_token_start) + + return tokenizer + + def load_model(self, args, tokenizer, decoder_max_seq_len=None): + """Load saved model checkpoint + + Params: + tokenizer: MolEncTokenizer tokenizer object + decoder_max_seq_len: int, maximum sequence length + args: Megatron initialized arguments + + Returns: + MegaMolBART trained model + """ + + vocab_size = len(tokenizer) + pad_token_idx = tokenizer.vocab[tokenizer.pad_token] + + # TODO how to handle length overrun for batch processing + if not decoder_max_seq_len: + decoder_max_seq_len = args.max_position_embeddings + + sampler = DecodeSampler(tokenizer, decoder_max_seq_len) + model = MegatronBART( + sampler, + pad_token_idx, + vocab_size, + args.hidden_size, + args.num_layers, + args.num_attention_heads, + args.hidden_size * 4, + args.max_position_embeddings, + dropout=0.1, + ) + if args.load is not None: + print("Loading from {}".format(args.load)) + self.iteration = load_checkpoint(model, None, None) + model = model.cuda() + return model + + def save_model(self, iteration, model, optimizer=None, lr_scheduler=None): + ''' Credit to https://github.com/MolecularAI/MolBART/blob/megatron-molbart-with-zinc/megatron_molbart/checkpointing.py#L46 ''' + + """Save a model checkpoint.""" + args = get_args() + + # Only rank zero of the data parallel writes to the disk. + if use_model_module(model): + model = model.module + + if mpu.get_data_parallel_rank() == 0: + + # Arguments, iteration, and model. + state_dict = {} + state_dict['args'] = args + state_dict['checkpoint_version'] = 2.0 + state_dict['iteration'] = iteration + state_dict['model'] = model.state_dict_for_save_checkpoint() + + # Optimizer stuff. + if not args.no_save_optim: + if optimizer is not None: + state_dict['optimizer'] = optimizer.state_dict() + if lr_scheduler is not None: + state_dict['lr_scheduler'] = lr_scheduler.state_dict() + + # RNG states. + if not args.no_save_rng: + state_dict['random_rng_state'] = random.getstate() + state_dict['np_rng_state'] = np.random.get_state() + state_dict['torch_rng_state'] = torch.get_rng_state() + state_dict['cuda_rng_state'] = torch.cuda.get_rng_state() + state_dict['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states() + + # Save. + checkpoint_name = megatron_checkpointing.get_checkpoint_name(args.save, iteration) + print('global rank {} is saving checkpoint at iteration {:7d} to {}'. + format(torch.distributed.get_rank(), iteration, + checkpoint_name)) + megatron_checkpointing.ensure_directory_exists(checkpoint_name) + torch.save(state_dict, checkpoint_name) + print(' successfully saved {}'.format(checkpoint_name)) + + # Wait so everyone is done (necessary) + torch.distributed.barrier() + # And update the latest iteration + if torch.distributed.get_rank() == 0: + tracker_filename = megatron_checkpointing.get_checkpoint_tracker_filename(args.save) + with open(tracker_filename, 'w') as f: + f.write(str(iteration)) + # Wait so everyone is done (not necessary) + torch.distributed.barrier() + return + + def smiles2embedding(self, smiles, pad_length=None): + """Calculate embedding and padding mask for smiles with optional extra padding + + Params + smiles: string, input SMILES molecule + pad_length: optional extra + + Returns + embedding array and boolean mask + """ + + assert isinstance(smiles, str) + if pad_length: + assert pad_length >= len(smiles) + 2 + + tokens = self.tokenizer.tokenize([smiles], pad=True) + + # Append to tokens and mask if appropriate + if pad_length: + for i in range(len(tokens['original_tokens'])): + n_pad = pad_length - len(tokens['original_tokens'][i]) + tokens['original_tokens'][i] += [self.tokenizer.pad_token] * n_pad + tokens['masked_pad_masks'][i] += [1] * n_pad + + token_ids = torch.tensor(self.tokenizer.convert_tokens_to_ids(tokens['original_tokens'])).cuda().T + pad_mask = torch.tensor(tokens['masked_pad_masks']).bool().cuda().T + token_ids = token_ids[:self.max_model_position_embeddings] + pad_mask = pad_mask[:self.max_model_position_embeddings] + encode_input = {"encoder_input": token_ids, "encoder_pad_mask": pad_mask} + + embedding = self.model.encode(encode_input) + torch.cuda.empty_cache() + return embedding, pad_mask + + def smileslist2embedding(self, smiles_list): + tokens = self.tokenizer.tokenize(smiles_list, pad=True) + + token_ids = torch.tensor(self.tokenizer.convert_tokens_to_ids(tokens['original_tokens'])).cuda().T + pad_mask = torch.tensor(tokens['masked_pad_masks']).bool().cuda().T + token_ids = token_ids[:self.max_model_position_embeddings] + pad_mask = pad_mask[:self.max_model_position_embeddings] + encode_input = {"encoder_input": token_ids, "encoder_pad_mask": pad_mask} + + embedding = self.model.encode(encode_input) + torch.cuda.empty_cache() + return embedding, pad_mask + + def smileslist2embedding_model_given(self, model, smiles_list): + tokens = self.tokenizer.tokenize(smiles_list, pad=True) + + token_ids = torch.tensor(self.tokenizer.convert_tokens_to_ids(tokens['original_tokens'])).cuda().T + pad_mask = torch.tensor(tokens['masked_pad_masks']).bool().cuda().T + token_ids = token_ids[:self.max_model_position_embeddings] + pad_mask = pad_mask[:self.max_model_position_embeddings] + encode_input = {"encoder_input": token_ids, "encoder_pad_mask": pad_mask} + + embedding = model.encode(encode_input) + torch.cuda.empty_cache() + return embedding, pad_mask + + def inverse_transform(self, embeddings, mem_pad_mask, k=1, sanitize=True): + mem_pad_mask = mem_pad_mask.clone() + smiles_interp_list = [] + + batch_size = 1 # TODO: parallelize this loop as a batch + with torch.no_grad(): + for memory in embeddings: + + if isinstance(memory, list): + memory = torch.FloatTensor(memory).cuda() + + decode_fn = partial(self.model._decode_fn, + mem_pad_mask=mem_pad_mask.type(torch.LongTensor).cuda(), + memory=memory) + + mol_strs, _ = self.model.sampler.beam_decode(decode_fn, + batch_size=batch_size, + device='cuda', + k=k) + mol_strs = sum(mol_strs, []) # flatten list + + # TODO: add back sanitization and validity checking once model is trained + logger.warn('WARNING: MOLECULE VALIDATION AND SANITIZATION CURRENTLY DISABLED') + for smiles in mol_strs: + if sanitize: + mol = Chem.MolFromSmiles(smiles, sanitize=sanitize) + if mol: + sanitized_smiles = Chem.MolToSmiles(mol) + smiles_interp_list.append(sanitized_smiles) + logger.debug(f'Sanitized SMILES {sanitized_smiles} added...') + break + smiles_interp_list.append(smiles) + + return smiles_interp_list + + def interpolate_molecules(self, smiles1, smiles2, num_interp, tokenizer, k=1): + """Interpolate between two molecules in embedding space. + + Params + smiles1: str, input SMILES molecule + smiles2: str, input SMILES molecule + num_interp: int, number of molecules to interpolate + tokenizer: MolEncTokenizer tokenizer object + k: number of molecules for beam search, default 1. Can increase if there are issues with validity + + Returns + list of interpolated smiles molecules + """ + + pad_length = max(len(smiles1), len(smiles2)) + 2 # add 2 for start / stop + embedding1, pad_mask1 = self.smiles2embedding(smiles1, + pad_length=pad_length) + + embedding2, pad_mask2 = self.smiles2embedding(smiles2, + pad_length=pad_length) + + scale = torch.linspace(0.0, 1.0, num_interp + 2)[ + 1:-1] # skip first and last because they're the selected molecules + scale = scale.unsqueeze(0).unsqueeze(-1).cuda() + + interpolated_emb = torch.lerp(embedding1, embedding2, scale).cuda() # dims: batch, tokens, embedding + combined_mask = (pad_mask1 & pad_mask2).bool().cuda() + + embeddings = [] + dims = [] + for emb in interpolated_emb.permute(1, 0, 2): + dims.append(emb.shape) + embeddings.append(emb) + + generated_mols = self.inverse_transform(embeddings, + combined_mask, + k=k, + sanitize=True) + generated_mols = [smiles1] + generated_mols + [smiles2] + embeddings = [embedding1] + embeddings + [embedding2] + dims = [embedding1.shape] + dims + [embedding2.shape] + return generated_mols, embeddings, combined_mask, dims + + def find_similars_smiles_list(self, + smiles: str, + num_requested: int = 10, + scaled_radius=None, + force_unique=False): + distance = self._compute_radius(scaled_radius) + logger.info(f'Computing with distance {distance}...') + + embedding, pad_mask = self.smiles2embedding(smiles) + + neighboring_embeddings = self.addjitter(embedding, distance, cnt=num_requested) + + generated_mols = self.inverse_transform(neighboring_embeddings, + pad_mask.bool().cuda(), + k=1, sanitize=True) + if force_unique: + generated_mols = list(set(generated_mols)) + + generated_mols = [smiles] + generated_mols + neighboring_embeddings = [embedding] + neighboring_embeddings + return generated_mols, neighboring_embeddings, pad_mask + + def find_similars_smiles(self, + smiles: str, + num_requested: int = 10, + scaled_radius=None, + force_unique=False): + generated_mols, neighboring_embeddings, pad_mask = \ + self.find_similars_smiles_list(smiles, + num_requested=num_requested, + scaled_radius=scaled_radius, + force_unique=force_unique) + + # Rest of the applications and libraries use RAPIDS and cuPY libraries. + # For interoperability, we need to convert the embeddings to cupy. + embeddings = [] + dims = [] + for neighboring_embedding in neighboring_embeddings: + dims.append(neighboring_embedding.shape) + embeddings.append(neighboring_embedding.flatten().tolist()) + + generated_df = pd.DataFrame({'SMILES': generated_mols, + 'embeddings': embeddings, + 'embeddings_dim': dims, + 'Generated': [True for i in range(len(generated_mols))]}) + generated_df.iat[0, 3] = False + + if force_unique: + inv_transform_funct = partial(self.inverse_transform, + mem_pad_mask=pad_mask) + generated_df = self.compute_unique_smiles(generated_df, + inv_transform_funct, + scaled_radius=scaled_radius) + return generated_df + + def interpolate_smiles(self, + smiles: List, + num_points: int = 10, + scaled_radius=None, + force_unique=False): + num_points = int(num_points) + if len(smiles) < 2: + raise Exception('At-least two or more smiles are expected') + + k = 1 + result_df = [] + for idx in range(len(smiles) - 1): + interpolated_mol, interpolated_embeddings, combined_mask, dims = \ + self.interpolate_molecules(smiles[idx], + smiles[idx + 1], + num_points, + self.tokenizer, + k=k) + + # Rest of the applications and libraries use RAPIDS and cuPY libraries. + # For interoperability, we need to convert the embeddings to cupy. + embeddings = [] + for interpolated_embedding in interpolated_embeddings: + embeddings.append(interpolated_embedding.cpu()) + + interp_df = pd.DataFrame({'SMILES': interpolated_mol, + 'embeddings': embeddings, + 'embeddings_dim': dims, + 'Generated': [True for i in range(len(interpolated_mol))]}) + + inv_transform_funct = partial(self.inverse_transform, mem_pad_mask=combined_mask) + + # Mark the source and desinations as not generated + interp_df.iat[0, 3] = False + interp_df.iat[-1, 3] = False + + if force_unique: + interp_df = self.compute_unique_smiles(interp_df, + inv_transform_funct, + scaled_radius=scaled_radius) + + result_df.append(interp_df) + + result_df = pd.concat(result_df) + smile_list = list(result_df['SMILES']) + + return result_df, smile_list \ No newline at end of file diff --git a/MoleculeSTM/models/mega_molbart/megatron_bart.py b/MoleculeSTM/models/mega_molbart/megatron_bart.py new file mode 100644 index 0000000..e780307 --- /dev/null +++ b/MoleculeSTM/models/mega_molbart/megatron_bart.py @@ -0,0 +1,800 @@ +from megatron.module import MegatronModule +from apex.normalization import FusedLayerNorm +from megatron import mpu +from torch.nn import init +import torch.nn as nn +import torch.nn.functional as F +import torch +import math +from functools import partial +from .tokenizer import load_tokenizer +from .util import DEFAULT_CHEM_TOKEN_START, DEFAULT_VOCAB_PATH, REGEX + + +class MultiheadAttention(MegatronModule): + + def __init__( + self, + embed_dim, + num_heads, + dropout=0.0, + bias=True, + cross_attention=False, + init_method=init.xavier_uniform_, + ): + + super(MultiheadAttention, self).__init__() + + self.embed_dim = embed_dim + self.num_heads = num_heads + self.attn_dropout = nn.Dropout(p=dropout) + self.bias = bias + self.cross_attention = cross_attention + self.head_dim = self.embed_dim // self.num_heads + self.scaling = self.head_dim ** -0.5 + self.init_method = init_method + self.skip_bias = not bias + + # Self-Attention is Column Parallelized + self.query_key_value = mpu.ColumnParallelLinear(self.embed_dim, + 3 * self.embed_dim, gather_output=True, + init_method=self.init_method, + skip_bias_add=self.skip_bias) + + # Cross-Attention is Row and Column Parallelized + self.q_proj = mpu.RowParallelLinear(self.embed_dim, + self.embed_dim, input_is_parallel=False, + init_method=self.init_method, bias=bias, + skip_bias_add=self.skip_bias) + self.key_value = mpu.ColumnParallelLinear(self.embed_dim, 2 + * self.embed_dim, gather_output=True, + init_method=self.init_method, + skip_bias_add=self.skip_bias) + + # Final projection is Row Parallelized + self.out_proj = mpu.RowParallelLinear(self.embed_dim, + self.embed_dim, input_is_parallel=False, + init_method=self.init_method, bias=bias) + + def forward( + self, + query, + key=None, + value=None, + key_padding_mask=None, + attn_mask=None, + ): + """Input shape: Time x Batch x Channel + + Args: + query - tokens/states of shape [Time x Batch x Channel] + key - tokens/states of shape [Time x Batch x Channel] + value - tokens/states of shape [Time x Batch x Channel] + key_padding_mask - keys that are pads where padding + elements are indicated by 1s. Shape: [batch, src_len]. + attn_mask - typically used to implement causal attention, where + the mask prevents the attention from looking forward in time. + Shape: [tgt_len, src_len]. + Returns: + outputs - attention probability scores of shape (Time x Batch x Channel) + """ + + (tgt_len, bsz, embed_dim) = query.size() + + # Compute attention projections + if not self.cross_attention: + (q_k_v, bias) = self.query_key_value(query) + (q, k, v) = mpu.split_tensor_along_last_dim(q_k_v, 3) + else: + q, _ = self.q_proj(query) + if key is None: + assert value is None, \ + 'Cross attention mode: since key is None, value must also be None.' + k = v = None + else: + (k_v, bias) = self.key_value(key) + (k, v) = mpu.split_tensor_along_last_dim(k_v, 2) + + # Scale query and reshape + q = q.contiguous() + q *= self.scaling + q = q.view(tgt_len, bsz * self.num_heads, + self.head_dim).transpose(0, 1) + if k is not None: + k = k.contiguous().view(-1, bsz * self.num_heads, + self.head_dim).transpose(0, 1) + if v is not None: + v = v.contiguous().view(-1, bsz * self.num_heads, + self.head_dim).transpose(0, 1) + + # Compute attention scores + src_len = k.size(1) + attn_weights = torch.bmm(q, k.transpose(1, 2)) + assert list(attn_weights.size()) == [bsz * self.num_heads, + tgt_len, src_len] + + # Apply causal attention mask + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + attn_weights += attn_mask + + # Apply padding mask + if key_padding_mask is not None: + attn_weights = attn_weights.view(bsz, self.num_heads, + tgt_len, src_len) + attn_weights = \ + attn_weights.masked_fill(key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), + float('-inf')) + attn_weights = attn_weights.view(bsz * self.num_heads, + tgt_len, src_len) + + # Compute attention probabilities + attn_weights = F.softmax(attn_weights, dim=-1) + attn_probs = self.attn_dropout(attn_weights) + + # Compute context and output projection + attn = torch.bmm(attn_probs, v) + assert list(attn.size()) == [bsz * self.num_heads, tgt_len, + self.head_dim] + if attn.size(1) == 1: # a single decoder step (sequence length == 1) + attn = attn.contiguous().view(tgt_len, bsz, embed_dim) + else: + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, + embed_dim) + (attn, bias) = self.out_proj(attn) + attn_output_weights = attn_probs.view(bsz, self.num_heads, + tgt_len, src_len) + attn_output_weights = attn_output_weights.sum(dim=1) \ + / self.num_heads + return (attn, attn_output_weights) + + +class EncoderLayer(MegatronModule): + + def __init__( + self, + embed_dim, + num_heads, + dropout=0.0, + bias=True, + init_method=init.xavier_uniform_, + ): + + super(EncoderLayer, self).__init__() + self.self_attn = MultiheadAttention( + embed_dim, + num_heads, + dropout=dropout, + bias=bias, + cross_attention=False, + init_method=init_method, + ) + self.self_attn_layer_norm = FusedLayerNorm(embed_dim) + self.attn_dropout = nn.Dropout(p=dropout) + self.activation_fn = F.gelu + self.activation_dropout = nn.Dropout(p=dropout) + self.fc1 = mpu.ColumnParallelLinear(embed_dim, 4 + * embed_dim, gather_output=False, + init_method=init_method, skip_bias_add=False) + self.fc2 = mpu.RowParallelLinear(4 * embed_dim, + embed_dim, input_is_parallel=True, + init_method=init_method, skip_bias_add=False) + self.final_layer_norm = FusedLayerNorm(embed_dim) + + def forward( + self, + x, + encoder_padding_mask=None, + attn_mask=None, + ): + """ + Args: + x: input to the layer of shape (seq_len, batch, embed_dim) + encoder_padding_mask: binary ByteTensor of shape + (batch, seq_len) where padding elements are indicated by 1. + attn_mask: binary tensor of shape (tgt_len, src_len), + where tgt_len is the length of output and src_len is the + length of input, though here both are equal to seq_len. + Returns: + encoded output of shape (seq_len, batch, embed_dim) + """ + + if attn_mask is not None: + attn_mask = attn_mask.masked_fill(attn_mask.to(torch.bool), + -1e8) + residual = x + x = self.self_attn_layer_norm(x) + (x, weights) = self.self_attn(query=x, key=x, value=x, + key_padding_mask=encoder_padding_mask, + attn_mask=attn_mask) + x = self.attn_dropout(x) + x = x + residual + residual = x + x = self.final_layer_norm(x) + x, _ = self.fc1(x) + x = self.activation_fn(x) + x = self.activation_dropout(x) + x, _ = self.fc2(x) + x = self.attn_dropout(x) + x = x + residual + return x + + +class DecoderLayer(MegatronModule): + + def __init__( + self, + embed_dim, + num_heads, + dropout=0.0, + bias=True, + init_method=init.xavier_uniform_, + ): + + super(DecoderLayer, self).__init__() + self.self_attn = MultiheadAttention( + embed_dim, + num_heads, + dropout=dropout, + bias=bias, + cross_attention=False, + init_method=init_method, + ) + self.self_attn_layer_norm = FusedLayerNorm(embed_dim) + self.encoder_attn = MultiheadAttention( + embed_dim, + num_heads, + dropout=dropout, + bias=bias, + cross_attention=True, + init_method=init_method, + ) + self.encoder_attn_layer_norm = FusedLayerNorm(embed_dim) + self.dropout = nn.Dropout(p=dropout) + self.activation_fn = F.gelu + self.activation_dropout = nn.Dropout(p=dropout) + self.fc1 = mpu.ColumnParallelLinear(embed_dim, 4 + * embed_dim, gather_output=False, + init_method=init_method, skip_bias_add=False) + self.fc2 = mpu.RowParallelLinear(4 * embed_dim, + embed_dim, input_is_parallel=True, + init_method=init_method, skip_bias_add=False) + self.final_layer_norm = FusedLayerNorm(embed_dim) + + def forward( + self, + x, + encoder_out=None, + encoder_padding_mask=None, + self_attn_mask=None, + self_attn_padding_mask=None, + ): + """ + Args: + x: input to decoder layer of shape (seq_len, batch, embed_dim) + encoder_out: output from the encoder + encoder_padding_mask: binary ByteTensor of shape + (batch, seq_len) where padding elements are indicated by 1 + self_attn_mask: binary tensor of shape (tgt_len, src_len), + where tgt_lent is the length of output and src_len is the + length of input, though here both are equal to seq_len. + self_attn_padding_mask: binary ByteTensor of shape + (batch, seq_len) where padding elements are indicated by 1. + Returns: + encoded output of shape (seq_len, batch, embed_dim) + """ + + residual = x + x = self.self_attn_layer_norm(x) + + # Self-Attention block + + (x, weights) = self.self_attn(query=x, key=x, value=x, + key_padding_mask=self_attn_padding_mask, + attn_mask=self_attn_mask) + x = self.dropout(x) + x = x + residual + + # Cross-Attention block + if encoder_out is not None: + residual = x + x = self.encoder_attn_layer_norm(x) + (x, attn) = self.encoder_attn(query=x, key=encoder_out, + value=encoder_out, + key_padding_mask=encoder_padding_mask) + x = self.dropout(x) + x = x + residual + residual = x + x = self.final_layer_norm(x) + + # Fully-connected block + x, _ = self.fc1(x) + x = self.activation_fn(x) + x = self.activation_dropout(x) + x, _ = self.fc2(x) + x = self.dropout(x) + x = x + residual + return x + + +class ParallelTransformerEncoder(MegatronModule): + + def __init__( + self, + num_layers, + embed_dim, + num_heads, + dropout=0.0, + bias=True, + init_method=init.xavier_uniform_, + ): + + super(ParallelTransformerEncoder, self).__init__() + self.layers = nn.ModuleList([]) + self.num_layers = num_layers + self.embed_dim = embed_dim + self.num_heads = num_heads + self.attn_dropout = dropout + self.bias = bias + self.init_method = init_method + self.layers.extend([self.build_encoder_layer() for i in + range(self.num_layers)]) + self.norm = FusedLayerNorm(self.embed_dim) + + def build_encoder_layer(self): + layer = EncoderLayer(self.embed_dim, self.num_heads, + dropout=self.attn_dropout, bias=self.bias, + init_method=self.init_method) + return layer + + def forward( + self, + src, + mask=None, + src_key_padding_mask=None, + ): + """Pass the input through the encoder layers in turn. + Args: + src: the sequence to the encoder (required). + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + Returns: + encoded output of shape (src_len, batch, embed_dim) + """ + + output = src + for mod in self.layers: + output = mod(output, attn_mask=mask, + encoder_padding_mask=src_key_padding_mask) + output = self.norm(output) + return output + + +class ParallelTransformerDecoder(MegatronModule): + + def __init__( + self, + num_layers, + embed_dim, + num_heads, + dropout=0.0, + bias=True, + init_method=init.xavier_uniform_, + ): + + super(ParallelTransformerDecoder, self).__init__() + self.layers = nn.ModuleList([]) + self.num_layers = num_layers + self.embed_dim = embed_dim + self.num_heads = num_heads + self.attn_dropout = dropout + self.bias = bias + self.init_method = init_method + self.layers.extend([self.build_decoder_layer() for i in + range(self.num_layers)]) + self.norm = FusedLayerNorm(self.embed_dim) + + def build_decoder_layer(self): + layer = DecoderLayer(self.embed_dim, self.num_heads, + dropout=self.attn_dropout, bias=self.bias, + init_method=self.init_method) + return layer + + def forward( + self, + tgt, + memory, + tgt_mask=None, + tgt_key_padding_mask=None, + memory_key_padding_mask=None, + ): + """Pass the inputs (and mask) through the decoder layer in turn. + Args: + tgt: the sequence to the decoder (required). + memory: the sequence from the last layer of the encoder (required). + tgt_mask: the mask for the tgt sequence (optional). + tgt_key_padding_mask: the mask for the tgt keys per batch (optional). + memory_key_padding_mask: the mask for the memory keys per batch (optional). + Returns: + decoded output of shape (tgt_len, batch, embed_dim) + """ + + output = tgt + for mod in self.layers: + output = mod(output, encoder_out=memory, + encoder_padding_mask=memory_key_padding_mask, + self_attn_mask=tgt_mask, + self_attn_padding_mask=tgt_key_padding_mask) + output = self.norm(output) + return output + + +class MegatronBART(MegatronModule): + + def __init__( + self, + decode_sampler, + pad_token_idx, + vocab_size, + d_model, + num_layers, + num_heads, + d_feedforward, + max_seq_len, + dropout=0.0, + ): + + super().__init__() + + self.sampler = decode_sampler + self.pad_token_idx = pad_token_idx + self.val_sampling_alg = 'greedy' + self.num_beams = 5 + self.vocab_size = vocab_size + self.d_model = d_model + self.num_layers = num_layers + self.num_heads = num_heads + self.d_feedforward = d_feedforward + self.max_seq_len = max_seq_len + self.dropout = dropout + self.emb_dropout = nn.Dropout(p=dropout) + init_method = init.xavier_uniform_ + + self.emb = nn.Embedding(vocab_size, d_model) + self.dropout = dropout + self.encoder = ParallelTransformerEncoder( + self.num_layers, + self.d_model, + self.num_heads, + self.dropout, + bias=True, + init_method=init_method, + ) + self.decoder = ParallelTransformerDecoder( + self.num_layers, + self.d_model, + self.num_heads, + self.dropout, + bias=True, + init_method=init_method, + ) + self.token_fc = mpu.RowParallelLinear(d_model, vocab_size, + input_is_parallel=False, init_method=init_method, + skip_bias_add=False) + self.loss_fn = nn.CrossEntropyLoss(reduction='none', + ignore_index=pad_token_idx) + self.log_softmax = nn.LogSoftmax(dim=2) + self._init_params(init_method) + self.register_buffer('pos_emb', self._positional_embs()) + + def forward(self, x): + """ Apply SMILES strings to model + + The dictionary returned will be passed to other functions, so its contents are fairly flexible, + except that it must contain the key "token_output" which is the output of the model + (possibly after any fully connected layers) for each token. + + Arg: + x (dict { + "encoder_input": tensor of token_ids of shape (src_len, batch_size), + "encoder_pad_mask": bool tensor of padded elems of shape (src_len, batch_size), + "decoder_input": tensor of decoder token_ids of shape (tgt_len, batch_size) + "decoder_pad_mask": bool tensor of decoder padding mask of shape (tgt_len, batch_size) + }): + + Returns: + Output from model (dict containing key "token_output" and "model_output") + """ + + encoder_input = x['encoder_input'] + decoder_input = x['decoder_input'] + encoder_pad_mask = x['encoder_pad_mask'].transpose(0, 1) + decoder_pad_mask = x['decoder_pad_mask'].transpose(0, 1) + + encoder_embs = self._construct_input(encoder_input) + decoder_embs = self._construct_input(decoder_input) + + (seq_len, _, _) = tuple(decoder_embs.size()) + tgt_mask = \ + self._generate_square_subsequent_mask(seq_len).to(decoder_embs.device) + + memory = self.encoder(encoder_embs, + src_key_padding_mask=encoder_pad_mask) + model_output = self.decoder(decoder_embs, memory, + tgt_mask=tgt_mask, + tgt_key_padding_mask=decoder_pad_mask, + memory_key_padding_mask=encoder_pad_mask.clone()) + + token_output, _ = self.token_fc(model_output) + output = {'model_output': model_output, + 'token_output': token_output} + + return output + + def encode(self, batch): + """ Construct the memory embedding for an encoder input + + Args: + batch (dict { + "encoder_input": tensor of token_ids of shape (src_len, batch_size), + "encoder_pad_mask": bool tensor of padded elems of shape (src_len, batch_size), + }) + + Returns: + encoder memory (Tensor of shape (seq_len, batch_size, d_model)) + """ + + encoder_input = batch['encoder_input'] + encoder_pad_mask = batch['encoder_pad_mask'].transpose(0, 1) + encoder_embs = self._construct_input(encoder_input) + model_output = self.encoder(encoder_embs, + src_key_padding_mask=encoder_pad_mask) + return model_output + + def decode(self, batch): + """ Construct an output from a given decoder input + + Args: + batch (dict { + "decoder_input": tensor of decoder token_ids of shape (tgt_len, batch_size) + "decoder_pad_mask": bool tensor of decoder padding mask of shape (tgt_len, batch_size) + "memory_input": tensor from encoded input of shape (src_len, batch_size, d_model) + "memory_pad_mask": bool tensor of memory padding mask of shape (src_len, batch_size) + }) + """ + + decoder_input = batch['decoder_input'] + decoder_pad_mask = batch['decoder_pad_mask'].transpose(0, 1) + memory_input = batch['memory_input'] + memory_pad_mask = batch['memory_pad_mask'].transpose(0, 1) + + decoder_embs = self._construct_input(decoder_input) + + (seq_len, _, _) = tuple(decoder_embs.size()) + tgt_mask = \ + self._generate_square_subsequent_mask(seq_len).to(decoder_embs.device) + + model_output = self.decoder(decoder_embs, memory_input, + tgt_key_padding_mask=decoder_pad_mask, + memory_key_padding_mask=memory_pad_mask, + tgt_mask=tgt_mask) + token_output, _ = self.token_fc(model_output) + token_probs = self.log_softmax(token_output) + return token_probs + + def validation_step(self, batch, batch_idx=None): + self.eval() + # TODO: This can be further optimized + tokenizer = load_tokenizer(vocab_path=DEFAULT_VOCAB_PATH, chem_token_start=DEFAULT_CHEM_TOKEN_START, regex=REGEX) + + with torch.no_grad(): + model_output = self.forward(batch) + #target_smiles = batch['target_smiles'] + token_ids = batch['target'] + tokens = token_ids.transpose(0, 1).tolist() + tokens = tokenizer.convert_ids_to_tokens(tokens) + target_smiles = tokenizer.detokenize(tokens) + + loss = self._calc_loss(batch, model_output) + token_acc = self._calc_char_acc(batch, model_output) + perplexity = self._calc_perplexity(batch, model_output) + (mol_strs, log_lhs) = self.sample_molecules(batch, + sampling_alg=self.val_sampling_alg) + metrics = self.sampler.calc_sampling_metrics(mol_strs, + target_smiles) + + self.train() + + val_outputs = { + 'val_loss': loss.item(), + 'val_token_acc': token_acc, + 'val_perplexity': perplexity, + 'val_molecular_accuracy': metrics['accuracy'], + 'val_invalid_smiles': metrics['invalid'], + } + return val_outputs + + def _calc_loss(self, batch_input, model_output): + """ Calculate the loss for the model + + Args: + batch_input (dict): Input given to model, + model_output (dict): Output from model + + Returns: + loss (singleton tensor), + """ + + tokens = batch_input['target'] + pad_mask = batch_input['target_pad_mask'] + token_output = model_output['token_output'] + token_mask_loss = self._calc_mask_loss(token_output, tokens, + pad_mask) + return token_mask_loss + + def _calc_mask_loss( + self, + token_output, + target, + target_mask, + ): + """ Calculate the loss for the token prediction task + + Args: + token_output (Tensor of shape (seq_len, batch_size, vocab_size)): token output from transformer + target (Tensor of shape (seq_len, batch_size)): Original (unmasked) SMILES token ids from the tokenizer + target_mask (Tensor of shape (seq_len, batch_size)): Pad mask for target tokens + + Output: + loss (singleton Tensor): Loss computed using cross-entropy, + """ + + (seq_len, batch_size) = tuple(target.size()) + token_pred = token_output.reshape((seq_len * batch_size, + -1)).float() + loss = self.loss_fn(token_pred, + target.reshape(-1)).reshape((seq_len, + batch_size)) + inv_target_mask = ~(target_mask > 0) + num_tokens = inv_target_mask.sum() + loss = loss.sum() / num_tokens + return loss + + def _calc_perplexity(self, batch_input, model_output): + target_ids = batch_input['target'] + target_mask = batch_input['target_pad_mask'] + vocab_dist_output = model_output['token_output'] + inv_target_mask = ~(target_mask > 0) + log_probs = vocab_dist_output.gather(2, + target_ids.unsqueeze(2)).squeeze(2) + log_probs = log_probs * inv_target_mask + log_probs = log_probs.sum(dim=0) + seq_lengths = inv_target_mask.sum(dim=0) + exp = -(1 / seq_lengths) + perp = torch.pow(log_probs.exp(), exp) + return perp.mean().item() + + def _calc_char_acc(self, batch_input, model_output): + token_ids = batch_input['target'] + target_mask = batch_input['target_pad_mask'] + token_output = model_output['token_output'] + target_mask = ~(target_mask > 0) + (_, pred_ids) = torch.max(token_output.float(), dim=2) + correct_ids = torch.eq(token_ids, pred_ids) + correct_ids = correct_ids * target_mask + num_correct = correct_ids.sum() + total = target_mask.sum() + accuracy = num_correct / total + return accuracy + + def sample_molecules(self, batch_input, sampling_alg='greedy'): + """ Sample molecules from the model + + Args: + batch_input (dict): Input given to model + sampling_alg (str): Algorithm to use to sample SMILES strings from model + + Returns: + ([[str]], [[float]]): Tuple of molecule SMILES strings and log lhs (outer dimension is batch) + """ + + enc_input = batch_input['encoder_input'] + enc_mask = batch_input['encoder_pad_mask'] + + # Freezing the weights reduces the amount of memory leakage in the transformer + #model.eval() + + with torch.no_grad(): + + encode_input = {'encoder_input': enc_input, + 'encoder_pad_mask': enc_mask} + memory = self.encode(encode_input) + mem_mask = enc_mask.clone() + (_, batch_size, _) = tuple(memory.size()) + decode_fn = partial(self._decode_fn, memory=memory, + mem_pad_mask=mem_mask) + #self.sampler.device = self.device + if sampling_alg == 'greedy': + (mol_strs, log_lhs) = \ + self.sampler.greedy_decode(decode_fn, batch_size,device=memory.device) + elif sampling_alg == 'beam': + (mol_strs, log_lhs) = \ + self.sampler.beam_decode(decode_fn, batch_size, + self.num_beams,device=memory.device) + + # Must remember to unfreeze! + #model.train() + + return (mol_strs, log_lhs) + + def _decode_fn( + self, + token_ids, + pad_mask, + memory, + mem_pad_mask, + ): + decode_input = { + 'decoder_input': token_ids, + 'decoder_pad_mask': pad_mask, + 'memory_input': memory, + 'memory_pad_mask': mem_pad_mask, + } + model_output = self.decode(decode_input) + return model_output + + def _construct_input(self, token_ids, sentence_masks=None): + (seq_len, _) = tuple(token_ids.size()) + token_embs = self.emb(token_ids) + + # Scaling the embeddings like this is done in other transformer libraries + token_embs = token_embs * math.sqrt(self.d_model) + positional_embs = self.pos_emb[:seq_len, : + ].unsqueeze(0).transpose(0, 1) + embs = token_embs + positional_embs + embs = self.emb_dropout(embs) + return embs + + def _positional_embs(self): + """ Produces a tensor of positional embeddings for the model + + Returns a tensor of shape (self.max_seq_len, self.d_model) filled with positional embeddings, + which are created from sine and cosine waves of varying wavelength + """ + + encs = torch.tensor([dim / self.d_model for dim in range(0, + self.d_model, 2)]) + encs = 10000 ** encs + encs = [(torch.sin(pos / encs), torch.cos(pos / encs)) + for pos in range(self.max_seq_len)] + encs = [torch.stack(enc, dim=1).flatten()[:self.d_model] + for enc in encs] + encs = torch.stack(encs) + return encs + + def _generate_square_subsequent_mask(self, sz): + """ + Method copied from Pytorch nn.Transformer. + Generate a square mask for the sequence. The masked positions are filled with float('-inf'). + Unmasked positions are filled with float(0.0). + + Args: + sz (int): Size of mask to generate + + Returns: + torch.Tensor: Square autoregressive mask for decode + """ + + mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) + mask = mask.float().masked_fill(mask == 0, float('-inf' + )).masked_fill(mask == 1, float(0.0)) + return mask + + def _init_params(self, method): + """ + Apply initialisation of learnable weights + """ + + for p in self.parameters(): + if p.dim() > 1: + method(p) diff --git a/MoleculeSTM/models/mega_molbart/tokenizer.py b/MoleculeSTM/models/mega_molbart/tokenizer.py new file mode 100644 index 0000000..14c8db3 --- /dev/null +++ b/MoleculeSTM/models/mega_molbart/tokenizer.py @@ -0,0 +1,412 @@ +# coding=utf-8 + +import re +import torch +import random +from pathlib import Path +from .util import (DEFAULT_BEGIN_TOKEN, DEFAULT_END_TOKEN, DEFAULT_PAD_TOKEN, \ + DEFAULT_UNK_TOKEN, DEFAULT_MASK_TOKEN, DEFAULT_SEP_TOKEN, \ + DEFAULT_MASK_PROB, DEFAULT_SHOW_MASK_TOKEN_PROB, DEFAULT_MASK_SCHEME, \ + DEFAULT_SPAN_LAMBDA, DEFAULT_VOCAB_PATH, DEFAULT_CHEM_TOKEN_START, REGEX) + + +class MolEncTokenizer: + def __init__( + self, + vocab, + chem_token_idxs, + prog, + begin_token=DEFAULT_BEGIN_TOKEN, + end_token=DEFAULT_END_TOKEN, + pad_token=DEFAULT_PAD_TOKEN, + unk_token=DEFAULT_UNK_TOKEN, + mask_token=DEFAULT_MASK_TOKEN, + sep_token=DEFAULT_SEP_TOKEN, + mask_prob=DEFAULT_MASK_PROB, + show_mask_token_prob=DEFAULT_SHOW_MASK_TOKEN_PROB, + mask_scheme=DEFAULT_MASK_SCHEME, + span_lambda=DEFAULT_SPAN_LAMBDA + ): + """ Initialise the tokenizer + + Args: + vocab (List[str]): Vocabulary for tokenizer + chem_token_idxs (List[int]): List of idxs of chemical tokens + prog (re.Pattern): Regex object for tokenizing + begin_token (str): Token to use at start of each sequence + end_token (str): Token to use at end of each sequence + pad_token (str): Token to use when padding batches of sequences + unk_token (str): Token to use for tokens which are not in the vocabulary + mask_token (str): Token to use when masking pieces of the sequence + sep_token (str): Token to use when sepatating two sentences + mask_prob (float): Probability of token being masked when masking is enabled + show_mask_token_prob (float): Probability of a masked token being replaced with mask token + mask_scheme (str): Masking scheme used by the tokenizer when masking + span_lambda (float): Mean for poisson distribution when sampling a span of tokens + """ + + self.vocab = {t: i for i, t in enumerate(vocab)} + self.decode_vocab = {i: t for t, i in self.vocab.items()} + self.chem_token_idxs = chem_token_idxs + self.prog = prog + + self.begin_token = begin_token + self.end_token = end_token + self.pad_token = pad_token + self.unk_token = unk_token + self.mask_token = mask_token + self.sep_token = sep_token + + self.mask_prob = mask_prob + self.show_mask_token_prob = show_mask_token_prob + self.mask_scheme = mask_scheme + self.span_lambda = span_lambda + + self.unk_id = self.vocab[unk_token] + self.unk_token_cnt = {} + + @staticmethod + def from_vocab_file( + vocab_path, + regex, + chem_tokens_start_idx, + pad_token_idx=0, + unk_token_idx=1, + begin_token_idx=2, + end_token_idx=3, + mask_token_idx=4, + sep_token_idx=5, + mask_prob=DEFAULT_MASK_PROB, + show_mask_token_prob=DEFAULT_SHOW_MASK_TOKEN_PROB, + mask_scheme=DEFAULT_MASK_SCHEME, + span_lambda=DEFAULT_SPAN_LAMBDA + ): + """ Load the tokenizer object from a vocab file and regex + + Reads a newline separated list of tokens from a file to use as the vocabulary + Note: Assumes that the chemical tokens run from chem_tokens_start_idx to the end of the tokens list + Anything after the defined tokens and before chem_tokens_start_idx is assumed to be an extra token + and is added to the regex for tokenizing + + Args: + vocab_path (str): Path to vocab file + regex (str): Regex to use for tokenizing + chem_tokens_start_idx (int): Index of the start of the chemical tokens in the tokens list + + Returns: + MolEncTokenizer object + """ + + text = Path(vocab_path).read_text() + tokens = text.split("\n") + tokens = [t for t in tokens if t is not None and t != ""] + + token_idxs = [pad_token_idx, unk_token_idx, begin_token_idx, end_token_idx, mask_token_idx, sep_token_idx] + extra_tokens_idxs = range(max(token_idxs) + 1, chem_tokens_start_idx) + extra_tokens = [tokens[idx] for idx in extra_tokens_idxs] + prog = MolEncTokenizer._get_compiled_regex(regex, extra_tokens) + + pad_token = tokens[pad_token_idx] + unk_token = tokens[unk_token_idx] + begin_token = tokens[begin_token_idx] + end_token = tokens[end_token_idx] + mask_token = tokens[mask_token_idx] + sep_token = tokens[sep_token_idx] + + chem_tokens_idxs = list(range(chem_tokens_start_idx, len(tokens))) + tokenizer = MolEncTokenizer( + tokens, + chem_tokens_idxs, + prog, + begin_token=begin_token, + end_token=end_token, + pad_token=pad_token, + unk_token=unk_token, + mask_token=mask_token, + sep_token=sep_token, + mask_prob=mask_prob, + show_mask_token_prob=show_mask_token_prob, + mask_scheme=mask_scheme, + span_lambda=span_lambda + ) + return tokenizer + + @staticmethod + def from_smiles( + smiles, + regex, + extra_tokens=None, + begin_token=DEFAULT_BEGIN_TOKEN, + end_token=DEFAULT_END_TOKEN, + pad_token=DEFAULT_PAD_TOKEN, + unk_token=DEFAULT_UNK_TOKEN, + mask_token=DEFAULT_MASK_TOKEN, + sep_token=DEFAULT_SEP_TOKEN, + mask_prob=DEFAULT_MASK_PROB, + show_mask_token_prob=DEFAULT_SHOW_MASK_TOKEN_PROB, + mask_scheme=DEFAULT_MASK_SCHEME, + span_lambda=DEFAULT_SPAN_LAMBDA + ): + """ Build the tokenizer from smiles strings and a regex + + Args: + smiles (List[str]): SMILES strings to use to build vocabulary + regex (str): Regex to use for tokenizing + extra_tokens (Optional[List[str]]): Additional tokens to add to the vocabulary that + may not appear in the SMILES strings + """ + + vocab = { + pad_token: 0, + unk_token: 1, + begin_token: 2, + end_token: 3, + mask_token: 4, + sep_token: 5 + } + + extra_tokens = [] if extra_tokens is None else extra_tokens + [vocab.setdefault(token, len(vocab)) for token in extra_tokens] + + chem_start_idx = len(vocab) + prog = MolEncTokenizer._get_compiled_regex(regex, extra_tokens) + print(f"Chemistry tokens start at index {chem_start_idx}") + + for smi in smiles: + for token in prog.findall(smi): + vocab.setdefault(token, len(vocab)) + + chem_token_idxs = list(range(chem_start_idx, len(vocab))) + + vocab = sorted(vocab.items(), key=lambda k_v: k_v[1]) + vocab = [key for key, val in vocab] + + tokenizer = MolEncTokenizer( + vocab, + chem_token_idxs, + prog, + begin_token=begin_token, + end_token=end_token, + pad_token=pad_token, + unk_token=unk_token, + mask_token=mask_token, + sep_token=sep_token, + mask_prob=mask_prob, + show_mask_token_prob=show_mask_token_prob, + mask_scheme=mask_scheme, + span_lambda=span_lambda + ) + return tokenizer + + def save_vocab(self, vocab_path): + tokens = sorted(self.vocab.items(), key=lambda k_v: k_v[1]) + tokens = [key for key, val in tokens] + + tokens_str = "" + for token in tokens: + tokens_str += f"{token}\n" + + p = Path(vocab_path) + p.write_text(tokens_str) + + def __len__(self): + return len(self.vocab) + + def tokenize(self, sents1, sents2=None, mask=False, pad=False): + if sents2 is not None and len(sents1) != len(sents2): + raise ValueError("Sentence 1 batch and sentence 2 batch must have the same number of elements") + + tokens = self._regex_match(sents1) + m_tokens, token_masks = self._mask_tokens(tokens, empty_mask=not mask) + + sent_masks = None + if sents2 is not None: + sents2_tokens = self._regex_match(sents2) + sents2_m_tokens, sents2_masks = self._mask_tokens(sents2_tokens, empty_mask=not mask) + tokens, sent_masks = self._concat_sentences(tokens, sents2_tokens, self.sep_token) + m_tokens, _ = self._concat_sentences(m_tokens, sents2_m_tokens, self.sep_token) + token_masks, _ = self._concat_sentences(token_masks, sents2_masks, False) + + tokens = [[self.begin_token] + ts + [self.end_token] for ts in tokens] + m_tokens = [[self.begin_token] + ts + [self.end_token] for ts in m_tokens] + token_masks = [[False] + ts + [False] for ts in token_masks] + sent_masks = [[0] + mask + [1] for mask in sent_masks] if sent_masks is not None else None + + output = {} + + if pad: + tokens, orig_pad_masks = self._pad_seqs(tokens, self.pad_token) + m_tokens, masked_pad_masks = self._pad_seqs(m_tokens, self.pad_token) + token_masks, _ = self._pad_seqs(token_masks, False) + sent_masks, _ = self._pad_seqs(sent_masks, False) if sent_masks is not None else (None, None) + output["original_pad_masks"] = orig_pad_masks + output["masked_pad_masks"] = masked_pad_masks + + output["original_tokens"] = tokens + + if mask: + output["masked_tokens"] = m_tokens + output["token_masks"] = token_masks + + if sent_masks is not None: + output["sentence_masks"] = sent_masks + + return output + + def _regex_match(self, smiles): + tokenized = [] + for smi in smiles: + tokens = self.prog.findall(smi) + tokenized.append(tokens) + + return tokenized + + @staticmethod + def _get_compiled_regex(regex, extra_tokens): + regex_string = r"(" + for token in extra_tokens: + processed_token = token + for special_character in "()[].|": + processed_token = processed_token.replace(special_character, f"\\{special_character}") + regex_string += processed_token + r"|" + + regex_string += regex + r"|" + regex_string += r".)" + return re.compile(regex_string) + + def _concat_sentences(self, tokens1, tokens2, sep): + tokens = [ts1 + [sep] + ts2 for ts1, ts2 in zip(tokens1, tokens2)] + sent_masks = [([0] * len(ts1)) + [0] + ([1] * len(ts2)) for ts1, ts2 in zip(tokens1, tokens2)] + return tokens, sent_masks + + def detokenize(self, tokens_list): + new_tokens_list = [] + for tokens in tokens_list: + if tokens[0] == self.begin_token: + tokens = tokens[1:] + + # Remove any tokens after the end token (and end token) if it's there + if self.end_token in tokens: + end_token_idx = tokens.index(self.end_token) + tokens = tokens[:end_token_idx] + + new_tokens_list.append(tokens) + + strs = ["".join(tokens) for tokens in new_tokens_list] + return strs + + def convert_tokens_to_ids(self, token_data): + ids_list = [] + for tokens in token_data: + for token in tokens: + token_id = self.vocab.get(token) + if token_id is None: + self._inc_in_dict(self.unk_token_cnt, token) + + ids = [self.vocab.get(token, self.unk_id) for token in tokens] + ids_list.append(ids) + + return ids_list + + def convert_ids_to_tokens(self, token_ids): + tokens_list = [] + for ids in token_ids: + for token_id in ids: + token = self.decode_vocab.get(token_id) + if token is None: + raise ValueError(f"Token id {token_id} is not recognised") + + tokens = [self.decode_vocab.get(token_id) for token_id in ids] + tokens_list.append(tokens) + + return tokens_list + + def print_unknown_tokens(self): + print(f"{'Token':<10}Count") + for token, cnt in self.unk_token_cnt.items(): + print(f"{token:<10}{cnt}") + + print() + + @staticmethod + def _inc_in_dict(coll, item): + cnt = coll.get(item, 0) + cnt += 1 + coll[item] = cnt + + def _mask_tokens(self, tokens, empty_mask=False): + if empty_mask: + mask = [[False] * len(ts) for ts in tokens] + return tokens, mask + + masked_tokens = [] + token_masks = [] + + for ts in tokens: + if self.mask_scheme == "replace": + masked, token_mask = self._mask_replace(ts) + elif self.mask_scheme == "span": + masked, token_mask = self._mask_span(ts) + else: + raise ValueError(f"Unrecognised mask scheme: {self.mask_scheme}") + + masked_tokens.append(masked) + token_masks.append(token_mask) + + return masked_tokens, token_masks + + def _mask_replace(self, ts): + mask_bools = [True, False] + weights = [self.mask_prob, 1 - self.mask_prob] + token_mask = random.choices(mask_bools, weights=weights, k=len(ts)) + masked = [self._mask_token(ts[i]) if m else ts[i] for i, m in enumerate(token_mask)] + return masked, token_mask + + def _mask_span(self, ts): + curr_token = 0 + masked = [] + token_mask = [] + + mask_bools = [True, False] + weights = [self.mask_prob, 1 - self.mask_prob] + sampled_mask = random.choices(mask_bools, weights=weights, k=len(ts)) + + while curr_token < len(ts): + # If mask, sample from a poisson dist to get length of mask + if sampled_mask[curr_token]: + mask_len = torch.poisson(torch.tensor(self.span_lambda)).long().item() + masked.append(self.mask_token) + token_mask.append(True) + curr_token += mask_len + + # Otherwise don't mask + else: + masked.append(ts[curr_token]) + token_mask.append(False) + curr_token += 1 + + return masked, token_mask + + def _mask_token(self, token): + rand = random.random() + if rand < self.show_mask_token_prob: + return self.mask_token + + elif rand < self.show_mask_token_prob + ((1 - self.show_mask_token_prob) / 2): + token_idx = random.choice(self.chem_token_idxs) + return self.decode_vocab[token_idx] + + else: + return token + + @staticmethod + def _pad_seqs(seqs, pad_token): + pad_length = max([len(seq) for seq in seqs]) + padded = [seq + ([pad_token] * (pad_length - len(seq))) for seq in seqs] + masks = [([0] * len(seq)) + ([1] * (pad_length - len(seq))) for seq in seqs] + return padded, masks + + +def load_tokenizer(vocab_path=DEFAULT_VOCAB_PATH, chem_token_start=DEFAULT_CHEM_TOKEN_START, regex=REGEX): + tokenizer = MolEncTokenizer.from_vocab_file(vocab_path, regex, chem_token_start) + return tokenizer \ No newline at end of file diff --git a/MoleculeSTM/models/mega_molbart/util.py b/MoleculeSTM/models/mega_molbart/util.py new file mode 100644 index 0000000..37807ad --- /dev/null +++ b/MoleculeSTM/models/mega_molbart/util.py @@ -0,0 +1,21 @@ +DEFAULT_VOCAB_PATH = "bart_vocab.txt" + +# Tokenization and vocabulary +DEFAULT_MAX_SEQ_LEN = 512 +DEFAULT_CHEM_TOKEN_START = 272 +DEFAULT_BEGIN_TOKEN = "^" +DEFAULT_END_TOKEN = "&" +DEFAULT_PAD_TOKEN = "" +DEFAULT_UNK_TOKEN = "?" +DEFAULT_MASK_TOKEN = "" +DEFAULT_SEP_TOKEN = "" +DEFAULT_MASK_PROB = 0.15 +DEFAULT_SHOW_MASK_TOKEN_PROB = 1.0 +DEFAULT_MASK_SCHEME = "span" +DEFAULT_SPAN_LAMBDA = 3.0 +REGEX = "\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9]" + +# Model parameters +DEFAULT_D_MODEL = 256 +DEFAULT_NUM_LAYERS = 4 +DEFAULT_NUM_HEADS = 8 \ No newline at end of file diff --git a/MoleculeSTM/models/molecule_gnn_model.py b/MoleculeSTM/models/molecule_gnn_model.py new file mode 100644 index 0000000..eb2fdd9 --- /dev/null +++ b/MoleculeSTM/models/molecule_gnn_model.py @@ -0,0 +1,197 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_geometric.nn import (MessagePassing, global_add_pool, + global_max_pool, global_mean_pool) +from torch_geometric.nn.inits import glorot, zeros +from torch_geometric.utils import add_self_loops, softmax, degree +from torch_scatter import scatter_add +from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder +from collections import OrderedDict + + +class GINConv(MessagePassing): + def __init__(self, emb_dim, aggr="add"): + ''' + emb_dim (int): node embedding dimensionality + ''' + super(GINConv, self).__init__(aggr=aggr) + + self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2*emb_dim), torch.nn.BatchNorm1d(2*emb_dim), torch.nn.ReLU(), torch.nn.Linear(2*emb_dim, emb_dim)) + self.eps = torch.nn.Parameter(torch.Tensor([0])) + + self.bond_encoder = BondEncoder(emb_dim = emb_dim) + + def forward(self, x, edge_index, edge_attr): + edge_embedding = self.bond_encoder(edge_attr) + out = self.mlp((1 + self.eps) *x + self.propagate(edge_index, x=x, edge_attr=edge_embedding)) + return out + + def message(self, x_j, edge_attr): + return F.relu(x_j + edge_attr) + + def update(self, aggr_out): + return aggr_out + + +class GCNConv(MessagePassing): + def __init__(self, emb_dim, aggr="add"): + super(GCNConv, self).__init__(aggr=aggr) + + self.linear = torch.nn.Linear(emb_dim, emb_dim) + self.root_emb = torch.nn.Embedding(1, emb_dim) + self.bond_encoder = BondEncoder(emb_dim = emb_dim) + + def forward(self, x, edge_index, edge_attr): + x = self.linear(x) + edge_embedding = self.bond_encoder(edge_attr) + + row, col = edge_index + + #edge_weight = torch.ones((edge_index.size(1), ), device=edge_index.device) + deg = degree(row, x.size(0), dtype = x.dtype) + 1 + deg_inv_sqrt = deg.pow(-0.5) + deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 + + norm = deg_inv_sqrt[row] * deg_inv_sqrt[col] + + return self.propagate(edge_index, x=x, edge_attr = edge_embedding, norm=norm) + F.relu(x + self.root_emb.weight) * 1./deg.view(-1,1) + + def message(self, x_j, edge_attr, norm): + return norm.view(-1, 1) * F.relu(x_j + edge_attr) + + def update(self, aggr_out): + return aggr_out + + +class GNN(nn.Module): + def __init__(self, num_layer, emb_dim, JK="last", drop_ratio=0., gnn_type="gin"): + + if num_layer < 2: + raise ValueError("Number of GNN layers must be greater than 1.") + + super(GNN, self).__init__() + self.drop_ratio = drop_ratio + self.num_layer = num_layer + self.JK = JK + + self.atom_encoder = AtomEncoder(emb_dim) + + ###List of MLPs + self.gnns = nn.ModuleList() + for layer in range(num_layer): + if gnn_type == "gin": + self.gnns.append(GINConv(emb_dim, aggr="add")) + elif gnn_type == "gcn": + self.gnns.append(GCNConv(emb_dim)) + + ###List of batchnorms + self.batch_norms = nn.ModuleList() + for layer in range(num_layer): + self.batch_norms.append(nn.BatchNorm1d(emb_dim)) + + # def forward(self, x, edge_index, edge_attr): + def forward(self, *argv): + if len(argv) == 3: + x, edge_index, edge_attr = argv[0], argv[1], argv[2] + elif len(argv) == 1: + data = argv[0] + x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr + else: + raise ValueError("unmatched number of arguments.") + + x = self.atom_encoder(x) + + h_list = [x] + for layer in range(self.num_layer): + h = self.gnns[layer](h_list[layer], edge_index, edge_attr) + h = self.batch_norms[layer](h) + # h = F.dropout(F.relu(h), self.drop_ratio, training = self.training) + if layer == self.num_layer - 1: + # remove relu for the last layer + h = F.dropout(h, self.drop_ratio, training=self.training) + else: + h = F.dropout(F.relu(h), self.drop_ratio, training=self.training) + h_list.append(h) + + ### Different implementations of Jk-concat + if self.JK == "concat": + node_representation = torch.cat(h_list, dim=1) + elif self.JK == "last": + node_representation = h_list[-1] + elif self.JK == "max": + h_list = [h.unsqueeze_(0) for h in h_list] + node_representation = torch.max(torch.cat(h_list, dim=0), dim=0)[0] + elif self.JK == "sum": + h_list = [h.unsqueeze_(0) for h in h_list] + node_representation = torch.sum(torch.cat(h_list, dim=0), dim=0)[0] + else: + raise ValueError("not implemented.") + return node_representation + + +class GNN_graphpred(nn.Module): + """ + Extension of GIN to incorporate edge information by concatenation. + + Args: + num_layer (int): the number of GNN layers + arg.emb_dim (int): dimensionality of embeddings + num_tasks (int): number of tasks in multi-task learning scenario + JK (str): last, concat, max or sum. + graph_pooling (str): sum, mean, max, attention, set2set + + See https://arxiv.org/abs/1810.00826 + JK-net: https://arxiv.org/abs/1806.03536 """ + + def __init__(self, num_layer, emb_dim, num_tasks, JK, graph_pooling, molecule_node_model=None): + super(GNN_graphpred, self).__init__() + + if num_layer < 2: + raise ValueError("# layers must > 1.") + + self.molecule_node_model = molecule_node_model + self.num_layer = num_layer + self.emb_dim = emb_dim + self.num_tasks = num_tasks + self.JK = JK + + # Different kind of graph pooling + if graph_pooling == "sum": + self.pool = global_add_pool + elif graph_pooling == "mean": + self.pool = global_mean_pool + elif graph_pooling == "max": + self.pool = global_max_pool + else: + raise ValueError("Invalid graph pooling type.") + + # For graph-level binary classification + self.mult = 1 + + if self.JK == "concat": + self.graph_pred_linear = nn.Linear(self.mult * (self.num_layer + 1) * self.emb_dim, + self.num_tasks) + else: + self.graph_pred_linear = nn.Linear(self.mult * self.emb_dim, self.num_tasks) + return + + def from_pretrained(self, model_file): + print("Loading from {} ...".format(model_file)) + state_dict = torch.load(model_file) + self.molecule_node_model.load_state_dict(state_dict) + return + + def forward(self, *argv): + if len(argv) == 4: + x, edge_index, edge_attr, batch = argv[0], argv[1], argv[2], argv[3] + elif len(argv) == 1: + data = argv[0] + x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch + else: + raise ValueError("unmatched number of arguments.") + + node_representation = self.molecule_node_model(x, edge_index, edge_attr) + graph_representation = self.pool(node_representation, batch) + output = self.graph_pred_linear(graph_representation) + return graph_representation, output \ No newline at end of file diff --git a/MoleculeSTM/splitters.py b/MoleculeSTM/splitters.py new file mode 100644 index 0000000..0f9ac8d --- /dev/null +++ b/MoleculeSTM/splitters.py @@ -0,0 +1,93 @@ +import random +from collections import defaultdict +from itertools import compress + +import numpy as np +import torch +from rdkit.Chem.Scaffolds import MurckoScaffold +from sklearn.model_selection import StratifiedKFold + +from torch.utils.data import Subset + + +def generate_scaffold(smiles, include_chirality=False): + """ Obtain Bemis-Murcko scaffold from smiles + :return: smiles of scaffold """ + scaffold = MurckoScaffold.MurckoScaffoldSmiles( + smiles=smiles, includeChirality=include_chirality) + return scaffold + + +def scaffold_split(dataset, smiles_list, task_idx=None, null_value=0, + frac_train=0.8, frac_valid=0.1, frac_test=0.1, + pyg_dataset=True): + """ + Adapted from https://github.com/deepchem/deepchem/blob/master/deepchem/splits/splitters.py + Split dataset by Bemis-Murcko scaffolds + This function can also ignore examples containing null values for a + selected task when splitting. Deterministic split + :param dataset: pytorch geometric dataset obj + :param smiles_list: list of smiles corresponding to the dataset obj + :param task_idx: column idx of the data.y tensor. Will filter out + examples with null value in specified task column of the data.y tensor + prior to splitting. If None, then no filtering + :param null_value: float that specifies null value in data.y to filter if + task_idx is provided + :param frac_train, frac_valid, frac_test: fractions + :param pyg_dataset: if this is pytorch or pytorch-gemetric dataset + :return: train, valid, test slices of the input dataset obj. """ + np.testing.assert_almost_equal(frac_train + frac_valid + frac_test, 1.0) + + if task_idx is not None: + # filter based on null values in task_idx + # get task array + y_task = np.array([data.y[task_idx].item() for data in dataset]) + # boolean array that correspond to non null values + non_null = y_task != null_value + smiles_list = list(compress(enumerate(smiles_list), non_null)) + else: + non_null = np.ones(len(dataset)) == 1 + smiles_list = list(compress(enumerate(smiles_list), non_null)) + + # create dict of the form {scaffold_i: [idx1, idx....]} + all_scaffolds = {} + for i, smiles in smiles_list: + scaffold = generate_scaffold(smiles, include_chirality=True) + if scaffold not in all_scaffolds: + all_scaffolds[scaffold] = [i] + else: + all_scaffolds[scaffold].append(i) + + # sort from largest to smallest sets + all_scaffolds = {key: sorted(value) for key, value in all_scaffolds.items()} + all_scaffold_sets = [ + scaffold_set for (scaffold, scaffold_set) in sorted( + all_scaffolds.items(), key=lambda x: (len(x[1]), x[1][0]), reverse=True) + ] + + # get train, valid test indices + train_cutoff = frac_train * len(smiles_list) + valid_cutoff = (frac_train + frac_valid) * len(smiles_list) + train_idx, valid_idx, test_idx = [], [], [] + for scaffold_set in all_scaffold_sets: + if len(train_idx) + len(scaffold_set) > train_cutoff: + if len(train_idx) + len(valid_idx) + len(scaffold_set) > valid_cutoff: + test_idx.extend(scaffold_set) + else: + valid_idx.extend(scaffold_set) + else: + train_idx.extend(scaffold_set) + + assert len(set(train_idx).intersection(set(valid_idx))) == 0 + assert len(set(test_idx).intersection(set(valid_idx))) == 0 + + if pyg_dataset: + train_dataset = dataset[torch.tensor(train_idx)] + valid_dataset = dataset[torch.tensor(valid_idx)] + test_dataset = dataset[torch.tensor(test_idx)] + return train_dataset, valid_dataset, test_dataset + else: + train_dataset = Subset(dataset, train_idx) + valid_dataset = Subset(dataset, valid_idx) + test_dataset = Subset(dataset, test_idx) + return train_dataset, valid_dataset, test_dataset diff --git a/MoleculeSTM/utils.py b/MoleculeSTM/utils.py new file mode 100644 index 0000000..b0086a5 --- /dev/null +++ b/MoleculeSTM/utils.py @@ -0,0 +1,71 @@ +import numpy as np +import torch + + +# This is for BERT +def padarray(A, size, value=0): + t = size - len(A) + return np.pad(A, pad_width=(0, t), mode='constant', constant_values = value) + + +# This is for BERT +def preprocess_each_sentence(sentence, tokenizer, max_seq_len): + text_input = tokenizer( + sentence, truncation=True, max_length=max_seq_len, + padding='max_length', return_tensors='np') + + input_ids = text_input['input_ids'].squeeze() + attention_mask = text_input['attention_mask'].squeeze() + + sentence_tokens_ids = padarray(input_ids, max_seq_len) + sentence_masks = padarray(attention_mask, max_seq_len) + return [sentence_tokens_ids, sentence_masks] + + +# This is for BERT +def prepare_text_tokens(device, description, tokenizer, max_seq_len): + B = len(description) + tokens_outputs = [preprocess_each_sentence(description[idx], tokenizer, max_seq_len) for idx in range(B)] + tokens_ids = [o[0] for o in tokens_outputs] + masks = [o[1] for o in tokens_outputs] + tokens_ids = torch.Tensor(tokens_ids).long().to(device) + masks = torch.Tensor(masks).bool().to(device) + return tokens_ids, masks + + +def get_molecule_repr_MoleculeSTM(molecule_data, mol2latent=None, molecule_type="SMILES", MegaMolBART_wrapper=None, molecule_model=None): + if molecule_type == "SMILES": + embedding, pad_mask = MegaMolBART_wrapper.smileslist2embedding(molecule_data) # [pad, B, d], [pad, B] + molecule_repr = embedding[0, :, :] # [B, d] + else: + molecule_repr, _ = molecule_model(molecule_data) + + if mol2latent is not None: + molecule_repr = mol2latent(molecule_repr) + return molecule_repr + + +def freeze_network(model): + for param in model.parameters(): + param.requires_grad = False + return + + +def get_num_task_and_type(dataset): + if dataset in ["esol", "freesolv", "lipophilicity"]: + return 1, "regression" + elif dataset in ["hiv", "bace", "bbbp"]: + return 1, "classification" + elif dataset == "tox21": + return 12, "classification" + elif dataset == "pcba": + return 92, "classification" + elif dataset == "muv": + return 17, "classification" + elif dataset == "toxcast": + return 617, "classification" + elif dataset == "sider": + return 27, "classification" + elif dataset == "clintox": + return 2, "classification" + raise ValueError("Invalid dataset name.") diff --git a/README.md b/README.md index 7e7d64b..d040464 100644 --- a/README.md +++ b/README.md @@ -1 +1,288 @@ -# Working on the approval now. Please stay tuned! +# MoleculeSTM: Multi-modal Molecule Structure-text Model for Text-based Editing and Retrieval + +Authors: Shengchao Liu, Weili Nie, Chengpeng Wang, Jiarui Lu, Zhuoran Qiao, Ling Liu, Jian Tang\*, Chaowei Xiao\*, Anima Anandkumar\* + +\* Equal advising + +[[Project Page](https://chao1224.github.io/MoleculeSTM)] [[ArXiv](https://arxiv.org/abs/2212.10789)] +[[Datasets on Hugging Face](https://huggingface.co/datasets/chao1224/MoleculeSTM/tree/main)] [[Checkpoints on Hugging Face](https://huggingface.co/chao1224/MoleculeSTM/tree/main)] + + +## 1 Environment + +First install conda: +``` +wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh +bash Miniconda3-latest-Linux-x86_64.sh +``` + +Then create virtual environment and install packages: +``` +conda create -n MoleculeSTM python=3.7 +conda activate MoleculeSTM + +conda install -y -c rdkit rdkit=2020.09.1.0 +conda install -y -c conda-forge -c pytorch pytorch=1.9.1 +conda install -y -c pyg -c conda-forge pyg==2.0.3 + +pip install requests +pip install tqdm +pip install matplotlib +pip install spacy +pip install Levenshtein + +# for SciBert +conda install -y boto3 +pip install transformers + +# for MoleculeNet +pip install ogb==1.2.0 + +# install pysmilesutils +python -m pip install git+https://github.com/MolecularAI/pysmilesutils.git + +pip install deepspeed + +# install metagron +# pip install megatron-lm==1.1.5 +git clone https://github.com/MolecularAI/MolBART.git --branch megatron-molbart-with-zinc +cd MolBART/megatron_molbart/Megatron-LM-v1.1.5-3D_parallelism +pip install . +cd ../../.. + +# install apex +# wget https://github.com/NVIDIA/apex/archive/refs/tags/22.03.zip +# unzip 22.03.zip +git clone https://github.com/chao1224/apex.git +cd apex +pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ +cd .. +``` + +We also provide the docker in `Dockerfile`. + +## 2 Datasets and Preprocessing + +We provide the raw dataset (after preprocessing) at [this Hugging Face link](https://huggingface.co/datasets/chao1224/MoleculeSTM). Or you can use the following python script (see `data/download.py`): +``` +from huggingface_hub import HfApi, snapshot_download +api = HfApi() +snapshot_download(repo_id="chao1224/MoleculeSTM", repo_type="dataset", local_dir='.') +``` + +Then you can move all the downloaded datasets under `./data` folder. + +### 2.1 Pretraining Dataset: PubChemSTM + +Useful resources: +- For molecular structure information (SMILES, 2D molecular graph etc), we can download it from PubChem in SDF format [here](https://ftp.ncbi.nlm.nih.gov/pubchem/Compound/CURRENT-Full/SDF/). +- For textual data, we may first refer to this [PubChem RDF tutorial](https://ftp.ncbi.nlm.nih.gov/pubchem/presentations/pubchem_rdf_tutorial.pdf). +- `The RDF data on the PubChem FTP site is arranged in such a way that you only need to download the type of information in which you are interested, thus allowing you to avoid downloading parts of PubChem data you will not use. For example, if you are just interested in computed chemical properties, you only need to download PubChemRDF data in the compound descriptor directory.` The link is [here](https://ftp.ncbi.nlm.nih.gov/pubchem/RDF/descriptor/compound/). +- Guidance on using `RDF` and `REST` API can be found [here](https://ftp.ncbi.nlm.nih.gov/pubchem/presentations/pubchem_rdf_details.pdf). + +As confirmed with PubChem group, performing research on these data is not violating their license; however, PubChem does not possess the license for the textual data, which necessitates an extensive evaluation of the license for each pair of structure-text pair data in PubChemSTM. This task poses a substantial workload and has hindered the release of PubChemSTM. However, we have tried our best to upload the structure part of the PubChemSTM data on Hugging Face, and we also provide all the details to generate PubChemSTM as follows: +1. Go to `preprocessing/PubChemSTM` folder. +2. `python step_01_description_extraction.py`. This step extracts and merge all the textual descriptions into a single json file. We run this on May 30th, 2022. The APIs will keep updating, so you may have slightly different versions if you run this script yourself. +3. `bash step_02.sh`. This will download all the SDF files, with SMILES, 2D graph, and computed molecular properties. This may take hours. +4. `python step_03_filter_out_SDF.py`. This will filter all the molecules with textual descriptions and save them int the SDF file. This may take <2 hours. +5. `python step_04_merge_SDF.py`. This will gather all the molecules into a single SDF file. + +### 2.2 Downstream Datasets + +We have included them in [the Hugging Face link](https://huggingface.co/datasets/chao1224/MoleculeSTM). We briefly list the details below: + +- `DrugBank_data` for zero-shot structure-text retrieval +- `ZINC250K_data` for space alignment (step 1 in editing) +- `Editing_data` for zero-shot text-guided (step 2 in editing) + - `single_multi_property_SMILES.txt` for single-objective, multi-objective, binding-affinity-based, and drug relevance editing + - `neighbor2drug` for neighborhood searching for patent drug molecules +- `MoleculeNet_data` for molecular property prediction + +## 3 Pre-trained Checkpoints from Previous Works + +### 3.1 SciBERT +This can be done by simplying calling the following for SciBERT: +``` +SciBERT_tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased', cache_dir=pretrained_SciBERT_folder) +SciBERT_model = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased', cache_dir=pretrained_SciBERT_folder).to(device) +``` + +### 3.2 MegaMolBART +Run `downsload_MegaBolBART.sh`, and the output structure is like: +``` +. +├── bart_vocab.txt +├── checkpoints +│   ├── iter_0134000 +│   │   ├── mp_rank_00 +│   │   │   └── model_optim_rng.pt +│   │   ├── mp_rank_00_model_states.pt +│   │   ├── zero_pp_rank_0_mp_rank_00optim_states.pt +│   │   ├── zero_pp_rank_1_mp_rank_00optim_states.pt +│   │   ├── zero_pp_rank_2_mp_rank_00optim_states.pt +│   │   ├── zero_pp_rank_3_mp_rank_00optim_states.pt +│   │   ├── zero_pp_rank_4_mp_rank_00optim_states.pt +│   │   ├── zero_pp_rank_5_mp_rank_00optim_states.pt +│   │   ├── zero_pp_rank_6_mp_rank_00optim_states.pt +│   │   └── zero_pp_rank_7_mp_rank_00optim_states.pt +│   └── latest_checkpointed_iteration.txt +└── megamolbart_0.1.zip +``` + +### 3.3 GNN and GraphMVP +For GraphMVP, check this [repo](https://github.com/chao1224/GraphMVP), and the checkpoints on [Google Drive link](https://drive.google.com/drive/u/1/folders/1uPsBiQF3bfeCAXSDd4JfyXiTh-qxYfu6). +``` +pretrained_GraphMVP/ +├── GraphMVP_C +│   └── model.pth +└── GraphMVP_G + └── model.pth +``` + +### 3.4 Baseline KV-PLM +For KV-PLM, check this [repo](https://github.com/thunlp/KV-PLM) and checkpoints on [Google Drive link](https://drive.google.com/drive/folders/1xig3-3JG63kR-Xqj1b9wkPEdxtfD_4IX). + +### 3.5 Toy Checkpoints for MoleculeSTM +We provide two sets of demo checkpoints at [this huggingface link](https://huggingface.co/chao1224/MoleculeSTM). Or you can use the following python script: +``` +from huggingface_hub import HfApi +api = HfApi() +snapshot_download(repo_id="chao1224/MoleculeSTM", repo_type="model", cache_dir='.') +``` + +## 4 Scripts and Demos + +All the running scripts and demos can be found under the `scripts` folder and `demos` folder, respectively. + +### 4.1 Pretraining + +MoleculeSTM-SMILES +``` +python pretrain.py \ + --verbose --batch_size=8 \ + --molecule_type=SMILES +``` + +MoleculeSTM-Graph +``` +python pretrain.py \ + --verbose --batch_size=8 \ + --molecule_type=Graph +``` + +### 4.2 Downstream: Zero-shot Structure-text Retrieval + +**For DrugBank-Description** + +MoleculeSTM-SMILES +``` +python downstream_01_retrieval_Description_Pharmacodynamics.py \ + --task=molecule_description_removed_PubChem \ + --molecule_type=SMILES \ + --input_model_dir=../data/demo/demo_checkpoints_SMILES +``` + +MoleculeSTM-Graph +``` +python downstream_01_retrieval_Description_Pharmacodynamics.py \ + --task=molecule_description_removed_PubChem \ + --molecule_type=Graph \ + --input_model_dir=../data/demo/demo_checkpoints_Graph +``` + +**For DrugBank-Pharmacodynamics** + +MoleculeSTM-SMILES +``` +python downstream_01_retrieval_Description_Pharmacodynamics.py \ + --task=molecule_pharmacodynamics_removed_PubChem \ + --molecule_type=SMILES \ + --input_model_dir=../data/demo/demo_checkpoints_SMILES +``` + +MoleculeSTM-Graph +``` +python downstream_01_retrieval_Description_Pharmacodynamics.py \ + --task=molecule_pharmacodynamics_removed_PubChem \ + --molecule_type=Graph \ + --input_model_dir=../data/demo/demo_checkpoints_Graph +``` + +**For DrugBank-ATC** + + +MoleculeSTM-SMILES +``` +python downstream_01_retrieval_ATC.py \ + --molecule_type=SMILES \ + --input_model_dir=../data/demo/demo_checkpoints_SMILES +``` + +MoleculeSTM-Graph +``` +python downstream_01_retrieval_ATC.py \ + --molecule_type=Graph \ + --input_model_dir=../data/demo/demo_checkpoints_Graph +``` + +### 4.3 Downstream: Zero-shot Text-based Molecule Editing + +For description id list, you can find them in `MoleculeSTM/downstream_molecule_edit_utils.py`. + +MoleculeSTM-SMILES +``` +python downstream_02_molecule_edit_step_01_MoleculeSTM_Space_Alignment.py \ + --MoleculeSTM_molecule_type=SMILES \ + --MoleculeSTM_model_dir=../data/demo/demo_checkpoints_SMILES + + +python downstream_02_molecule_edit_step_02_MoleculeSTM_Latent_Optimization.py \ + --MoleculeSTM_molecule_type=SMILES \ + --MoleculeSTM_model_dir=../data/demo/demo_checkpoints_SMILES \ + --language_edit_model_dir=../data/demo/demo_checkpoints_SMILES \ + --input_description_id=101 +``` + +MoleculeSTM-Graph +``` +python downstream_02_molecule_edit_step_01_MoleculeSTM_Space_Alignment.py \ + --MoleculeSTM_molecule_type=Graph \ + --MoleculeSTM_model_dir=../data/demo/demo_checkpoints_Graph + + +python downstream_02_molecule_edit_step_02_MoleculeSTM_Latent_Optimization.py \ + --MoleculeSTM_molecule_type=Graph \ + --MoleculeSTM_model_dir=../data/demo/demo_checkpoints_Graph \ + --language_edit_model_dir=../data/demo/demo_checkpoints_Graph \ + --input_description_id=101 +``` + +### 4.4 Downstream: Molecular Property Prediction + +MoleculeSTM-SMILES +``` +python downstream_03_property_prediction.py \ + --dataset=bace --molecule_type=SMILES \ +``` + +MoleculeSTM-Graph +``` +python downstream_03_property_prediction.py \ + --dataset=bace --molecule_type=Graph +``` + +### 4.5 Demo +Please check the `demos` folder. This may require you download the dataset and checkpoints first: +- raw dataset (after preprocessing) at [this huggingface link](https://huggingface.co/datasets/chao1224/MoleculeSTM). +- checkpoints at [this huggingface link](https://huggingface.co/chao1224/MoleculeSTM). + +## Cite Us +Feel free to cite this work if you find it useful to you! +``` +@article{liu2022moleculestm, + title={Multi-modal molecule structure-text model for text-based retrieval and editing}, + author={Liu, Shengchao and Nie, Weili and Wang, Chengpeng and Lu, Jiarui and Qiao, Zhuoran and Liu, Ling and Tang, Jian and Xiao, Chaowei and Anandkumar, Anima}, + journal={arXiv preprint arXiv:2212.10789}, + year={2022} +} +``` \ No newline at end of file diff --git a/data/README.md b/data/README.md new file mode 100644 index 0000000..8da8042 --- /dev/null +++ b/data/README.md @@ -0,0 +1,17 @@ +# Dataset Specifications for MoleculeSTM + +We provide the raw dataset (after preprocessing) at [this Hugging Face link](https://huggingface.co/datasets/chao1224/MoleculeSTM). Or you can download them by running `python download.py`. + +## 1. Pretraining Dataset: PubChemSTM + +For PubChemSTM, please note that we can only release the chemical structure information. If you need the textual data, please follow our preprocessing scripts. + +## 2. Downstream Datasets + +Please refer to the following for three downstream tasks: +- `DrugBank_data` for zero-shot structure-text retrieval +- `ZINC250K_data` for space alignment (step 1 in editing) +- `Editing_data` for zero-shot text-guided (step 2 in editing) + - `single_multi_property_SMILES.txt` for single-objective, multi-objective, binding-affinity-based, and drug relevance editing + - `neighbor2drug` for neighborhood searching for patent drug molecules +- `MoleculeNet_data` for molecular property prediction diff --git a/data/download.py b/data/download.py new file mode 100644 index 0000000..34ca755 --- /dev/null +++ b/data/download.py @@ -0,0 +1,5 @@ +from huggingface_hub import HfApi, snapshot_download + +api = HfApi() + +snapshot_download(repo_id="chao1224/MoleculeSTM", repo_type="dataset", local_dir='.') diff --git a/demos/README.md b/demos/README.md new file mode 100644 index 0000000..5a7e7f9 --- /dev/null +++ b/demos/README.md @@ -0,0 +1,60 @@ +Here we show demos on how to run MoleculeSTM pretraining and downstream tasks. + +## Checkpoints for Demo + +First, please check [this Hugging Face link](https://huggingface.co/chao1224/MoleculeSTM/tree/main/demo) for toy checkpoints. + +Or you can run the following (also in `download.py`): +``` +from huggingface_hub import HfApi, hf_hub_download +api = HfApi() +snapshot_download(repo_id="chao1224/MoleculeSTM", repo_type="model", local_dir='.', allow_patterns="*demo*") +``` + +Then move the folders under `demos` to this folder. The folder structure is the following: +``` +. +├── demo_checkpoints_Graph +│   ├── foundation2generation_model.pth +│   ├── generation2foundation_model.pth +│   ├── mol2latent_model_final.pth +│   ├── mol2latent_model.pth +│   ├── molecule_model_final.pth +│   ├── molecule_model.pth +│   ├── text2latent_model_final.pth +│   ├── text2latent_model.pth +│   ├── text_model_final.pth +│   └── text_model.pth +├── demo_checkpoints_MegaMolBART +│   ├── foundation2generation_model.pth +│   ├── generation2foundation_model.pth +│   ├── mol2latent_model_final.pth +│   ├── mol2latent_model.pth +│   ├── molecule_model_final.pth +│   ├── molecule_model.pth +│   ├── text2latent_model_final.pth +│   ├── text2latent_model.pth +│   ├── text_model_final.pth +│   └── text_model.pth +├── demo_downstream_property_prediction_Graph.ipynb +├── demo_downstream_property_prediction_SMILES.ipynb +├── demo_downstream_retrieval_Graph.ipynb +├── demo_downstream_retrieval_SMILES.ipynb +├── demo_downstream_zero_shot_molecule_edit.ipynb +├── demo_pretrain_Graph.ipynb +├── demo_pretrain_SMILES.ipynb +├── download.py +└── README.md +``` + +## Pretraining + +Please check `demo_pretrain_Graph.ipynb` and `demo_pretrain_SMILES.ipynb`. + +## Downstream + +Then we provide notebooks for three types of downstream tasks: +- For zero-shot structure-text retrieval: `demo_downstream_retrieval_SMILES.ipynb` and `demo_downstream_retrieval_Graph.ipynb`. +- For zero-shot text-based molecule editing: `demo_downstream_zero_shot_molecule_edit.ipynb` + - Notice that at this step, we are only using the textual branch (SciBERT) and a pretrained molecule generative model (MegaMolBART). The MoleculeSTM chemical branch (MegaMolBART or GraphMVP) is only used at the module alignment phase, and we can change it in the `MoleculeSTM_model_dir` argument. +- For molecular property prediction: `demo_downstream_property_prediction_SMILES.ipynb` and `demo_downstream_property_prediction_Graph.ipynb`. diff --git a/demos/demo_downstream_property_prediction_Graph.ipynb b/demos/demo_downstream_property_prediction_Graph.ipynb new file mode 100644 index 0000000..40828a6 --- /dev/null +++ b/demos/demo_downstream_property_prediction_Graph.ipynb @@ -0,0 +1,605 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "5692e778", + "metadata": {}, + "source": [ + "# Demo for MoleculeSTM Downstream: Property Prediction\n", + "\n", + "## Load Packages" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "7d13d71d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2023-08-30 12:22:37,424] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n" + ] + } + ], + "source": [ + "import warnings\n", + "warnings.filterwarnings('ignore')\n", + "\n", + "import os\n", + "import argparse\n", + "import numpy as np\n", + "import pandas as pd\n", + "from tqdm import tqdm\n", + "from sklearn.metrics import accuracy_score, average_precision_score, roc_auc_score, mean_absolute_error, mean_squared_error\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "from torch.utils.data import DataLoader as torch_DataLoader\n", + "from torch_geometric.loader import DataLoader as pyg_DataLoader\n", + "\n", + "from MoleculeSTM.datasets import MoleculeNetSMILESDataset, MoleculeNetGraphDataset\n", + "from MoleculeSTM.splitters import scaffold_split\n", + "from MoleculeSTM.utils import get_num_task_and_type, get_molecule_repr_MoleculeSTM\n", + "from MoleculeSTM.models.mega_molbart.mega_mol_bart import MegaMolBART\n", + "from MoleculeSTM.models import GNN, GNN_graphpred" + ] + }, + { + "cell_type": "markdown", + "id": "c5ae9b29", + "metadata": {}, + "source": [ + "## Setup Arguments" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "e579793f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "arguments\t Namespace(JK='last', batch_size=32, dataset='bace', dataspace_path='../data', device=0, dropout_ratio=0.5, epochs=5, eval_train=0, gnn_emb_dim=300, gnn_type='gin', graph_pooling='mean', input_model_path='demo_checkpoints_Graph/molecule_model.pth', lr=0.0001, lr_scale=1, molecule_type='Graph', num_layer=5, num_workers=1, output_model_dir=None, schedule='cycle', seed=42, split='scaffold', training_mode='fine_tuning', verbose=1, warm_up_steps=10, weight_decay=0)\n" + ] + } + ], + "source": [ + "parser = argparse.ArgumentParser()\n", + "parser.add_argument(\"--seed\", type=int, default=42)\n", + "parser.add_argument(\"--device\", type=int, default=0)\n", + "parser.add_argument(\"--training_mode\", type=str, default=\"fine_tuning\", choices=[\"fine_tuning\", \"linear_probing\"])\n", + "parser.add_argument(\"--molecule_type\", type=str, default=\"Graph\", choices=[\"SMILES\", \"Graph\"])\n", + "\n", + "########## for dataset and split ##########\n", + "parser.add_argument(\"--dataspace_path\", type=str, default=\"../data\")\n", + "parser.add_argument(\"--dataset\", type=str, default=\"bace\")\n", + "parser.add_argument(\"--split\", type=str, default=\"scaffold\")\n", + "\n", + "########## for optimization ##########\n", + "parser.add_argument(\"--batch_size\", type=int, default=32)\n", + "parser.add_argument(\"--lr\", type=float, default=1e-4)\n", + "parser.add_argument(\"--lr_scale\", type=float, default=1)\n", + "parser.add_argument(\"--num_workers\", type=int, default=1)\n", + "parser.add_argument(\"--epochs\", type=int, default=5)\n", + "parser.add_argument(\"--weight_decay\", type=float, default=0)\n", + "parser.add_argument(\"--schedule\", type=str, default=\"cycle\")\n", + "parser.add_argument(\"--warm_up_steps\", type=int, default=10)\n", + "\n", + "########## for 2D GNN ##########\n", + "parser.add_argument(\"--gnn_emb_dim\", type=int, default=300)\n", + "parser.add_argument(\"--num_layer\", type=int, default=5)\n", + "parser.add_argument('--JK', type=str, default='last')\n", + "parser.add_argument(\"--dropout_ratio\", type=float, default=0.5)\n", + "parser.add_argument(\"--gnn_type\", type=str, default=\"gin\")\n", + "parser.add_argument('--graph_pooling', type=str, default='mean')\n", + "\n", + "########## for saver ##########\n", + "parser.add_argument(\"--eval_train\", type=int, default=0)\n", + "parser.add_argument(\"--verbose\", type=int, default=1)\n", + "\n", + "parser.add_argument(\"--input_model_path\", type=str, default=\"demo_checkpoints_Graph/molecule_model.pth\")\n", + "parser.add_argument(\"--output_model_dir\", type=str, default=None)\n", + "\n", + "args = parser.parse_args(\"\")\n", + "print(\"arguments\\t\", args)" + ] + }, + { + "cell_type": "markdown", + "id": "fda8e87f", + "metadata": {}, + "source": [ + "## Setup Seed" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "e440e4b6", + "metadata": {}, + "outputs": [], + "source": [ + "torch.manual_seed(args.seed)\n", + "np.random.seed(args.seed)\n", + "if torch.cuda.is_available():\n", + " torch.cuda.manual_seed_all(args.seed)\n", + "device = torch.device(\"cuda:\" + str(args.device)) \\\n", + " if torch.cuda.is_available() else torch.device(\"cpu\")" + ] + }, + { + "cell_type": "markdown", + "id": "1bcb6df8", + "metadata": {}, + "source": [ + "## Setup Dataset and Dataloader" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "465ef86d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Dataset: bace\n", + "Data: Data(x=[51577, 9], edge_index=[2, 111536], edge_attr=[111536, 3], id=[1513], y=[1513])\n" + ] + } + ], + "source": [ + "num_tasks, task_mode = get_num_task_and_type(args.dataset)\n", + "dataset_folder = os.path.join(args.dataspace_path, \"MoleculeNet_data\", args.dataset)\n", + "\n", + "dataset = MoleculeNetGraphDataset(dataset_folder, args.dataset)\n", + "dataloader_class = pyg_DataLoader\n", + "use_pyg_dataset = True\n", + "\n", + "smiles_list = pd.read_csv(\n", + " dataset_folder + \"/processed/smiles.csv\", header=None)[0].tolist()\n", + "train_dataset, valid_dataset, test_dataset = scaffold_split(\n", + " dataset, smiles_list, null_value=0, frac_train=0.8,\n", + " frac_valid=0.1, frac_test=0.1, pyg_dataset=use_pyg_dataset)\n", + "\n", + "\n", + "train_loader = dataloader_class(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)\n", + "val_loader = dataloader_class(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)\n", + "test_loader = dataloader_class(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)" + ] + }, + { + "cell_type": "markdown", + "id": "4e191498", + "metadata": {}, + "source": [ + "## Initialize and Load Model" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "4b2363c6", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Start from pretrained model (MoleculeSTM) in demo_checkpoints_Graph/molecule_model.pth.\n" + ] + } + ], + "source": [ + "molecule_node_model = GNN(\n", + " num_layer=args.num_layer, emb_dim=args.gnn_emb_dim,\n", + " JK=args.JK, drop_ratio=args.dropout_ratio,\n", + " gnn_type=args.gnn_type)\n", + "model = GNN_graphpred(\n", + " num_layer=args.num_layer, emb_dim=args.gnn_emb_dim, JK=args.JK, graph_pooling=args.graph_pooling,\n", + " num_tasks=1, molecule_node_model=molecule_node_model) \n", + "molecule_dim = args.gnn_emb_dim\n", + "\n", + "if \"GraphMVP\" in args.input_model_path:\n", + " print(\"Start from pretrained model (GraphMVP) in {}.\".format(args.input_model_path))\n", + " model.from_pretrained(args.input_model_path)\n", + "else:\n", + " print(\"Start from pretrained model (MoleculeSTM) in {}.\".format(args.input_model_path))\n", + " state_dict = torch.load(args.input_model_path, map_location='cpu')\n", + " model.load_state_dict(state_dict)\n", + "\n", + "\n", + "model = model.to(device)\n", + "linear_model = nn.Linear(molecule_dim, num_tasks).to(device)\n", + "\n", + "# Rewrite the seed by MegaMolBART\n", + "torch.manual_seed(args.seed)\n", + "np.random.seed(args.seed)\n", + "if torch.cuda.is_available():\n", + " torch.cuda.manual_seed_all(args.seed)" + ] + }, + { + "cell_type": "markdown", + "id": "24ae1f5c", + "metadata": {}, + "source": [ + "## Setup Optimizer" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "05837d7b", + "metadata": {}, + "outputs": [], + "source": [ + "if args.training_mode == \"fine_tuning\":\n", + " model_param_group = [\n", + " {\"params\": model.parameters()},\n", + " {\"params\": linear_model.parameters(), 'lr': args.lr * args.lr_scale}\n", + " ]\n", + "else:\n", + " model_param_group = [\n", + " {\"params\": linear_model.parameters(), 'lr': args.lr * args.lr_scale}\n", + " ]\n", + "optimizer = optim.Adam(model_param_group, lr=args.lr, weight_decay=args.weight_decay)" + ] + }, + { + "cell_type": "markdown", + "id": "b26e511d", + "metadata": {}, + "source": [ + "## Define Support Functions" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "6d471e8d", + "metadata": {}, + "outputs": [], + "source": [ + "def train_classification(model, device, loader, optimizer):\n", + " if args.training_mode == \"fine_tuning\":\n", + " model.train()\n", + " else:\n", + " model.eval()\n", + " linear_model.train()\n", + " total_loss = 0\n", + "\n", + " if args.verbose:\n", + " L = tqdm(loader)\n", + " else:\n", + " L = loader\n", + " for step, batch in enumerate(L):\n", + " if args.molecule_type == \"MegaMolBART\":\n", + " SMILES_list, y = batch\n", + " SMILES_list = list(SMILES_list)\n", + " molecule_repr = get_molecule_repr_MoleculeSTM(\n", + " SMILES_list, mol2latent=None,\n", + " molecule_type=\"MegaMolBART\", MegaMolBART_wrapper=MegaMolBART_wrapper)\n", + " pred = linear_model(molecule_repr)\n", + " pred = pred.float()\n", + " y = y.to(device).float()\n", + " else:\n", + " batch = batch.to(device)\n", + " molecule_repr = get_molecule_repr_MoleculeSTM(\n", + " batch, mol2latent=None,\n", + " molecule_type=\"Graph\", molecule_model=model)\n", + " pred = linear_model(molecule_repr)\n", + " pred = pred.float()\n", + " y = batch.y.view(pred.shape).to(device).float()\n", + "\n", + " is_valid = y ** 2 > 0\n", + " loss_mat = criterion(pred, (y + 1) / 2)\n", + " loss_mat = torch.where(\n", + " is_valid, loss_mat,\n", + " torch.zeros(loss_mat.shape).to(device).to(loss_mat.dtype))\n", + "\n", + " optimizer.zero_grad()\n", + " loss = torch.sum(loss_mat) / torch.sum(is_valid)\n", + " loss.backward()\n", + " optimizer.step()\n", + " total_loss += loss.detach().item()\n", + "\n", + " return total_loss / len(loader)\n", + "\n", + "\n", + "@torch.no_grad()\n", + "def eval_classification(model, device, loader):\n", + " model.eval()\n", + " linear_model.eval()\n", + " y_true, y_scores = [], []\n", + "\n", + " if args.verbose:\n", + " L = tqdm(loader)\n", + " else:\n", + " L = loader\n", + " for step, batch in enumerate(L):\n", + " if args.molecule_type == \"MegaMolBART\":\n", + " SMILES_list, y = batch\n", + " SMILES_list = list(SMILES_list)\n", + " molecule_repr = get_molecule_repr_MoleculeSTM(\n", + " SMILES_list, mol2latent=None,\n", + " molecule_type=\"MegaMolBART\", MegaMolBART_wrapper=MegaMolBART_wrapper)\n", + " pred = linear_model(molecule_repr)\n", + " pred = pred.float()\n", + " y = y.to(device).float()\n", + " else:\n", + " batch = batch.to(device)\n", + " molecule_repr = get_molecule_repr_MoleculeSTM(\n", + " batch, mol2latent=None,\n", + " molecule_type=\"Graph\", molecule_model=model)\n", + " pred = linear_model(molecule_repr)\n", + " pred = pred.float()\n", + " y = batch.y.view(pred.shape).to(device).float()\n", + "\n", + " y_true.append(y)\n", + " y_scores.append(pred)\n", + "\n", + " y_true = torch.cat(y_true, dim=0).cpu().numpy()\n", + " y_scores = torch.cat(y_scores, dim=0).cpu().numpy()\n", + "\n", + " roc_list = []\n", + " for i in range(y_true.shape[1]):\n", + " # AUC is only defined when there is at least one positive data.\n", + " if np.sum(y_true[:, i] == 1) > 0 and np.sum(y_true[:, i] == -1) > 0:\n", + " is_valid = y_true[:, i] ** 2 > 0\n", + " roc_list.append(roc_auc_score((y_true[is_valid, i] + 1) / 2, y_scores[is_valid, i]))\n", + " else:\n", + " print(\"{} is invalid\".format(i))\n", + "\n", + " if len(roc_list) < y_true.shape[1]:\n", + " print(len(roc_list))\n", + " print(\"Some target is missing!\")\n", + " print(\"Missing ratio: %f\" %(1 - float(len(roc_list)) / y_true.shape[1]))\n", + "\n", + " return sum(roc_list) / len(roc_list), 0, y_true, y_scores" + ] + }, + { + "cell_type": "markdown", + "id": "fa37ee43", + "metadata": {}, + "source": [ + "## Start Training" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "7408a546", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 38/38 [00:01<00:00, 19.51it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch: 1\n", + "Loss: 0.6760892538647902\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 5.19it/s]\n", + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 5.39it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train: 0.000000\tval: 0.642125\ttest: 0.663189\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 38/38 [00:01<00:00, 23.04it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch: 2\n", + "Loss: 0.6383239313175804\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 5.31it/s]\n", + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 5.58it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train: 0.000000\tval: 0.676190\ttest: 0.720049\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 38/38 [00:01<00:00, 22.44it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch: 3\n", + "Loss: 0.6019486816305863\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 5.11it/s]\n", + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 5.19it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train: 0.000000\tval: 0.683516\ttest: 0.752043\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 38/38 [00:01<00:00, 23.23it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch: 4\n", + "Loss: 0.5672228501031273\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 5.20it/s]\n", + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:01<00:00, 5.00it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train: 0.000000\tval: 0.686447\ttest: 0.774474\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 38/38 [00:01<00:00, 23.79it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch: 5\n", + "Loss: 0.5250759069856844\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 7.85it/s]\n", + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 5.60it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train: 0.000000\tval: 0.689377\ttest: 0.788211\n", + "\n", + "best train: 0.000000\tval: 0.642125\ttest: 0.663189\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "train_func = train_classification\n", + "eval_func = eval_classification\n", + "\n", + "train_roc_list, val_roc_list, test_roc_list = [], [], []\n", + "train_acc_list, val_acc_list, test_acc_list = [], [], []\n", + "best_val_roc, best_val_idx = -1, 0\n", + "criterion = nn.BCEWithLogitsLoss(reduction=\"none\")\n", + "\n", + "for epoch in range(1, args.epochs + 1):\n", + " loss_acc = train_func(model, device, train_loader, optimizer)\n", + " print(\"Epoch: {}\\nLoss: {}\".format(epoch, loss_acc))\n", + "\n", + " if args.eval_train:\n", + " train_roc, train_acc, train_target, train_pred = eval_func(model, device, train_loader)\n", + " else:\n", + " train_roc = train_acc = 0\n", + " val_roc, val_acc, val_target, val_pred = eval_func(model, device, val_loader)\n", + " test_roc, test_acc, test_target, test_pred = eval_func(model, device, test_loader)\n", + "\n", + " train_roc_list.append(train_roc)\n", + " train_acc_list.append(train_acc)\n", + " val_roc_list.append(val_roc)\n", + " val_acc_list.append(val_acc)\n", + " test_roc_list.append(test_roc)\n", + " test_acc_list.append(test_acc)\n", + " print(\"train: {:.6f}\\tval: {:.6f}\\ttest: {:.6f}\".format(train_roc, val_roc, test_roc))\n", + " print()\n", + "\n", + "print(\"best train: {:.6f}\\tval: {:.6f}\\ttest: {:.6f}\".format(train_roc_list[best_val_idx], val_roc_list[best_val_idx], test_roc_list[best_val_idx]))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/demos/demo_downstream_property_prediction_SMILES.ipynb b/demos/demo_downstream_property_prediction_SMILES.ipynb new file mode 100644 index 0000000..1c1597b --- /dev/null +++ b/demos/demo_downstream_property_prediction_SMILES.ipynb @@ -0,0 +1,726 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "5692e778", + "metadata": {}, + "source": [ + "# Demo for MoleculeSTM Downstream: Property Prediction\n", + "\n", + "## Load Packages" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "7d13d71d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2023-08-30 12:23:17,997] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n" + ] + } + ], + "source": [ + "import warnings\n", + "warnings.filterwarnings('ignore')\n", + "\n", + "import os\n", + "import argparse\n", + "import numpy as np\n", + "import pandas as pd\n", + "from tqdm import tqdm\n", + "from sklearn.metrics import accuracy_score, average_precision_score, roc_auc_score, mean_absolute_error, mean_squared_error\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "from torch.utils.data import DataLoader as torch_DataLoader\n", + "from torch_geometric.loader import DataLoader as pyg_DataLoader\n", + "\n", + "from MoleculeSTM.datasets import MoleculeNetSMILESDataset, MoleculeNetGraphDataset\n", + "from MoleculeSTM.splitters import scaffold_split\n", + "from MoleculeSTM.utils import get_num_task_and_type, get_molecule_repr_MoleculeSTM\n", + "from MoleculeSTM.models.mega_molbart.mega_mol_bart import MegaMolBART\n", + "from MoleculeSTM.models import GNN, GNN_graphpred" + ] + }, + { + "cell_type": "markdown", + "id": "c5ae9b29", + "metadata": {}, + "source": [ + "## Setup Arguments" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "e579793f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "arguments\t Namespace(batch_size=32, dataset='bace', dataspace_path='../data', device=0, epochs=5, eval_train=0, input_model_path='demo_checkpoints_SMILES/molecule_model.pth', lr=0.0001, lr_scale=1, megamolbart_input_dir='../data/pretrained_MegaMolBART/checkpoints', molecule_type='SMILES', num_workers=1, output_model_dir=None, schedule='cycle', seed=42, split='scaffold', training_mode='fine_tuning', verbose=1, vocab_path='../MoleculeSTM/bart_vocab.txt', warm_up_steps=10, weight_decay=0)\n" + ] + } + ], + "source": [ + "parser = argparse.ArgumentParser()\n", + "parser.add_argument(\"--seed\", type=int, default=42)\n", + "parser.add_argument(\"--device\", type=int, default=0)\n", + "parser.add_argument(\"--training_mode\", type=str, default=\"fine_tuning\", choices=[\"fine_tuning\", \"linear_probing\"])\n", + "parser.add_argument(\"--molecule_type\", type=str, default=\"SMILES\", choices=[\"SMILES\", \"Graph\"])\n", + "\n", + "########## for dataset and split ##########\n", + "parser.add_argument(\"--dataspace_path\", type=str, default=\"../data\")\n", + "parser.add_argument(\"--dataset\", type=str, default=\"bace\")\n", + "parser.add_argument(\"--split\", type=str, default=\"scaffold\")\n", + "\n", + "########## for optimization ##########\n", + "parser.add_argument(\"--batch_size\", type=int, default=32)\n", + "parser.add_argument(\"--lr\", type=float, default=1e-4)\n", + "parser.add_argument(\"--lr_scale\", type=float, default=1)\n", + "parser.add_argument(\"--num_workers\", type=int, default=1)\n", + "parser.add_argument(\"--epochs\", type=int, default=5)\n", + "parser.add_argument(\"--weight_decay\", type=float, default=0)\n", + "parser.add_argument(\"--schedule\", type=str, default=\"cycle\")\n", + "parser.add_argument(\"--warm_up_steps\", type=int, default=10)\n", + "\n", + "########## for MegaMolBART ##########\n", + "parser.add_argument(\"--megamolbart_input_dir\", type=str, default=\"../data/pretrained_MegaMolBART/checkpoints\", help=\"This is only for MegaMolBART.\")\n", + "parser.add_argument(\"--vocab_path\", type=str, default=\"../MoleculeSTM/bart_vocab.txt\")\n", + "\n", + "########## for saver ##########\n", + "parser.add_argument(\"--eval_train\", type=int, default=0)\n", + "parser.add_argument(\"--verbose\", type=int, default=1)\n", + "\n", + "parser.add_argument(\"--input_model_path\", type=str, default=\"demo_checkpoints_SMILES/molecule_model.pth\")\n", + "parser.add_argument(\"--output_model_dir\", type=str, default=None)\n", + "\n", + "args = parser.parse_args(\"\")\n", + "print(\"arguments\\t\", args)" + ] + }, + { + "cell_type": "markdown", + "id": "fda8e87f", + "metadata": {}, + "source": [ + "## Setup Seed" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "e440e4b6", + "metadata": {}, + "outputs": [], + "source": [ + "torch.manual_seed(args.seed)\n", + "np.random.seed(args.seed)\n", + "if torch.cuda.is_available():\n", + " torch.cuda.manual_seed_all(args.seed)\n", + "device = torch.device(\"cuda:\" + str(args.device)) \\\n", + " if torch.cuda.is_available() else torch.device(\"cpu\")" + ] + }, + { + "cell_type": "markdown", + "id": "1bcb6df8", + "metadata": {}, + "source": [ + "## Setup Dataset and Dataloader" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "465ef86d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1513 \t (1513, 1)\n" + ] + } + ], + "source": [ + "num_tasks, task_mode = get_num_task_and_type(args.dataset)\n", + "dataset_folder = os.path.join(args.dataspace_path, \"MoleculeNet_data\", args.dataset)\n", + "\n", + "\n", + "dataset = MoleculeNetSMILESDataset(dataset_folder)\n", + "dataloader_class = torch_DataLoader\n", + "use_pyg_dataset = False\n", + "\n", + "smiles_list = pd.read_csv(\n", + " dataset_folder + \"/processed/smiles.csv\", header=None)[0].tolist()\n", + "train_dataset, valid_dataset, test_dataset = scaffold_split(\n", + " dataset, smiles_list, null_value=0, frac_train=0.8,\n", + " frac_valid=0.1, frac_test=0.1, pyg_dataset=use_pyg_dataset)\n", + "\n", + "\n", + "train_loader = dataloader_class(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)\n", + "val_loader = dataloader_class(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)\n", + "test_loader = dataloader_class(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)" + ] + }, + { + "cell_type": "markdown", + "id": "4e191498", + "metadata": {}, + "source": [ + "## Initialize and Load Model" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "4b2363c6", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "using world size: 1 and model-parallel size: 1 \n", + "using torch.float32 for parameters ...\n", + "-------------------- arguments --------------------\n", + " adam_beta1 ...................... 0.9\n", + " adam_beta2 ...................... 0.999\n", + " adam_eps ........................ 1e-08\n", + " adlr_autoresume ................. False\n", + " adlr_autoresume_interval ........ 1000\n", + " apply_query_key_layer_scaling ... False\n", + " apply_residual_connection_post_layernorm False\n", + " attention_dropout ............... 0.1\n", + " attention_softmax_in_fp32 ....... False\n", + " batch_size ...................... None\n", + " bert_load ....................... None\n", + " bias_dropout_fusion ............. False\n", + " bias_gelu_fusion ................ False\n", + " block_data_path ................. None\n", + " checkpoint_activations .......... False\n", + " checkpoint_in_cpu ............... False\n", + " checkpoint_num_layers ........... 1\n", + " clip_grad ....................... 1.0\n", + " contigious_checkpointing ........ False\n", + " cpu_optimizer ................... False\n", + " cpu_torch_adam .................. False\n", + " data_impl ....................... infer\n", + " data_path ....................... None\n", + " dataset_path .................... None\n", + " DDP_impl ........................ local\n", + " deepscale ....................... False\n", + " deepscale_config ................ None\n", + " deepspeed ....................... False\n", + " deepspeed_activation_checkpointing False\n", + " deepspeed_config ................ None\n", + " deepspeed_mpi ................... False\n", + " distribute_checkpointed_activations False\n", + " distributed_backend ............. nccl\n", + " dynamic_loss_scale .............. True\n", + " eod_mask_loss ................... False\n", + " eval_interval ................... 1000\n", + " eval_iters ...................... 100\n", + " exit_interval ................... None\n", + " faiss_use_gpu ................... False\n", + " finetune ........................ False\n", + " fp16 ............................ False\n", + " fp16_lm_cross_entropy ........... False\n", + " fp32_allreduce .................. False\n", + " gas ............................. 1\n", + " hidden_dropout .................. 0.1\n", + " hidden_size ..................... 256\n", + " hysteresis ...................... 2\n", + " ict_head_size ................... None\n", + " ict_load ........................ None\n", + " indexer_batch_size .............. 128\n", + " indexer_log_interval ............ 1000\n", + " init_method_std ................. 0.02\n", + " layernorm_epsilon ............... 1e-05\n", + " lazy_mpu_init ................... None\n", + " load ............................ ../data/pretrained_MegaMolBART/checkpoints\n", + " local_rank ...................... None\n", + " log_interval .................... 100\n", + " loss_scale ...................... None\n", + " loss_scale_window ............... 1000\n", + " lr .............................. None\n", + " lr_decay_iters .................. None\n", + " lr_decay_style .................. linear\n", + " make_vocab_size_divisible_by .... 128\n", + " mask_prob ....................... 0.15\n", + " max_position_embeddings ......... 512\n", + " merge_file ...................... None\n", + " min_lr .......................... 0.0\n", + " min_scale ....................... 1\n", + " mmap_warmup ..................... False\n", + " model_parallel_size ............. 1\n", + " no_load_optim ................... False\n", + " no_load_rng ..................... False\n", + " no_save_optim ................... False\n", + " no_save_rng ..................... False\n", + " num_attention_heads ............. 8\n", + " num_layers ...................... 4\n", + " num_unique_layers ............... None\n", + " num_workers ..................... 2\n", + " onnx_safe ....................... None\n", + " openai_gelu ..................... False\n", + " override_lr_scheduler ........... False\n", + " param_sharing_style ............. grouped\n", + " params_dtype .................... torch.float32\n", + " partition_activations ........... False\n", + " pipe_parallel_size .............. 0\n", + " profile_backward ................ False\n", + " query_in_block_prob ............. 0.1\n", + " rank ............................ 0\n", + " report_topk_accuracies .......... []\n", + " reset_attention_mask ............ False\n", + " reset_position_ids .............. False\n", + " save ............................ None\n", + " save_interval ................... None\n", + " scaled_masked_softmax_fusion .... False\n", + " scaled_upper_triang_masked_softmax_fusion False\n", + " seed ............................ 1234\n", + " seq_length ...................... None\n", + " short_seq_prob .................. 0.1\n", + " split ........................... 969, 30, 1\n", + " synchronize_each_layer .......... False\n", + " tensorboard_dir ................. None\n", + " titles_data_path ................ None\n", + " tokenizer_type .................. GPT2BPETokenizer\n", + " train_iters ..................... None\n", + " use_checkpoint_lr_scheduler ..... False\n", + " use_cpu_initialization .......... False\n", + " use_one_sent_docs ............... False\n", + " vocab_file ...................... ../MoleculeSTM/bart_vocab.txt\n", + " warmup .......................... 0.01\n", + " weight_decay .................... 0.01\n", + " world_size ...................... 1\n", + " zero_allgather_bucket_size ...... 0.0\n", + " zero_contigious_gradients ....... False\n", + " zero_reduce_bucket_size ......... 0.0\n", + " zero_reduce_scatter ............. False\n", + " zero_stage ...................... 1.0\n", + "---------------- end of arguments ----------------\n", + "> initializing torch distributed ...\n", + "> initializing model parallel with size 1\n", + "> setting random seeds to 1234 ...\n", + "> initializing model parallel cuda seeds on global rank 0, model parallel rank 0, and data parallel rank 0 with model parallel seed: 3952 and data parallel seed: 1234\n", + "Loading vocab from ../MoleculeSTM/bart_vocab.txt.\n", + "Loading from ../data/pretrained_MegaMolBART/checkpoints\n", + "global rank 0 is loading checkpoint ../data/pretrained_MegaMolBART/checkpoints/iter_0134000/mp_rank_00/model_optim_rng.pt\n", + "could not find arguments in the checkpoint ...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[W ProcessGroupNCCL.cpp:1569] Rank 0 using best-guess GPU 0 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " successfully loaded ../data/pretrained_MegaMolBART/checkpoints/iter_0134000/mp_rank_00/model_optim_rng.pt\n", + "Start from pretrained MegaMolBART using MLM.\n", + "Update MegaMolBART with pretrained MoleculeSTM. Loading from demo_checkpoints_SMILES/molecule_model.pth...\n" + ] + } + ], + "source": [ + "if args.megamolbart_input_dir is not None:\n", + " # This is loading from the pretarined_MegaMolBART\n", + " # --megamolbart_input_dir=../../Datasets/pretrained_MegaMolBART/checkpoints\n", + " # TODO: or maybe --input_model_path=../../Datasets/pretrained_MegaMolBART/checkpoints/iter_0134000/mp_rank_00/model_optim_rng.pt\n", + " MegaMolBART_wrapper = MegaMolBART(vocab_path=args.vocab_path, input_dir=args.megamolbart_input_dir, output_dir=None)\n", + " print(\"Start from pretrained MegaMolBART using MLM.\")\n", + "else:\n", + " # This is starting from scratch\n", + " MegaMolBART_wrapper = MegaMolBART(input_dir=None, output_dir=None)\n", + " print(\"Start from randomly initialized MegaMolBART.\")\n", + "\n", + "model = MegaMolBART_wrapper.model\n", + "print(\"Update MegaMolBART with pretrained MoleculeSTM. Loading from {}...\".format(args.input_model_path))\n", + "state_dict = torch.load(args.input_model_path, map_location='cpu')\n", + "model.load_state_dict(state_dict)\n", + "molecule_dim = 256\n", + "\n", + "\n", + "model = model.to(device)\n", + "linear_model = nn.Linear(molecule_dim, num_tasks).to(device)\n", + "\n", + "# Rewrite the seed by MegaMolBART\n", + "torch.manual_seed(args.seed)\n", + "np.random.seed(args.seed)\n", + "if torch.cuda.is_available():\n", + " torch.cuda.manual_seed_all(args.seed)" + ] + }, + { + "cell_type": "markdown", + "id": "24ae1f5c", + "metadata": {}, + "source": [ + "## Setup Optimizer" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "05837d7b", + "metadata": {}, + "outputs": [], + "source": [ + "if args.training_mode == \"fine_tuning\":\n", + " model_param_group = [\n", + " {\"params\": model.parameters()},\n", + " {\"params\": linear_model.parameters(), 'lr': args.lr * args.lr_scale}\n", + " ]\n", + "else:\n", + " model_param_group = [\n", + " {\"params\": linear_model.parameters(), 'lr': args.lr * args.lr_scale}\n", + " ]\n", + "optimizer = optim.Adam(model_param_group, lr=args.lr, weight_decay=args.weight_decay)" + ] + }, + { + "cell_type": "markdown", + "id": "b26e511d", + "metadata": {}, + "source": [ + "## Define Support Functions" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "6d471e8d", + "metadata": {}, + "outputs": [], + "source": [ + "def train_classification(model, device, loader, optimizer):\n", + " if args.training_mode == \"fine_tuning\":\n", + " model.train()\n", + " else:\n", + " model.eval()\n", + " linear_model.train()\n", + " total_loss = 0\n", + "\n", + " if args.verbose:\n", + " L = tqdm(loader)\n", + " else:\n", + " L = loader\n", + " for step, batch in enumerate(L):\n", + " SMILES_list, y = batch\n", + " SMILES_list = list(SMILES_list)\n", + " molecule_repr = get_molecule_repr_MoleculeSTM(\n", + " SMILES_list, mol2latent=None,\n", + " molecule_type=\"SMILES\", MegaMolBART_wrapper=MegaMolBART_wrapper)\n", + " pred = linear_model(molecule_repr)\n", + " pred = pred.float()\n", + " y = y.to(device).float()\n", + " \n", + " is_valid = y ** 2 > 0\n", + " loss_mat = criterion(pred, (y + 1) / 2)\n", + " loss_mat = torch.where(\n", + " is_valid, loss_mat,\n", + " torch.zeros(loss_mat.shape).to(device).to(loss_mat.dtype))\n", + "\n", + " optimizer.zero_grad()\n", + " loss = torch.sum(loss_mat) / torch.sum(is_valid)\n", + " loss.backward()\n", + " optimizer.step()\n", + " total_loss += loss.detach().item()\n", + "\n", + " return total_loss / len(loader)\n", + "\n", + "\n", + "@torch.no_grad()\n", + "def eval_classification(model, device, loader):\n", + " model.eval()\n", + " linear_model.eval()\n", + " y_true, y_scores = [], []\n", + "\n", + " if args.verbose:\n", + " L = tqdm(loader)\n", + " else:\n", + " L = loader\n", + " for step, batch in enumerate(L):\n", + " SMILES_list, y = batch\n", + " SMILES_list = list(SMILES_list)\n", + " molecule_repr = get_molecule_repr_MoleculeSTM(\n", + " SMILES_list, mol2latent=None,\n", + " molecule_type=\"SMILES\", MegaMolBART_wrapper=MegaMolBART_wrapper)\n", + " pred = linear_model(molecule_repr)\n", + " pred = pred.float()\n", + " y = y.to(device).float()\n", + " \n", + " y_true.append(y)\n", + " y_scores.append(pred)\n", + "\n", + " y_true = torch.cat(y_true, dim=0).cpu().numpy()\n", + " y_scores = torch.cat(y_scores, dim=0).cpu().numpy()\n", + "\n", + " roc_list = []\n", + " for i in range(y_true.shape[1]):\n", + " # AUC is only defined when there is at least one positive data.\n", + " if np.sum(y_true[:, i] == 1) > 0 and np.sum(y_true[:, i] == -1) > 0:\n", + " is_valid = y_true[:, i] ** 2 > 0\n", + " roc_list.append(roc_auc_score((y_true[is_valid, i] + 1) / 2, y_scores[is_valid, i]))\n", + " else:\n", + " print(\"{} is invalid\".format(i))\n", + "\n", + " if len(roc_list) < y_true.shape[1]:\n", + " print(len(roc_list))\n", + " print(\"Some target is missing!\")\n", + " print(\"Missing ratio: %f\" %(1 - float(len(roc_list)) / y_true.shape[1]))\n", + "\n", + " return sum(roc_list) / len(roc_list), 0, y_true, y_scores" + ] + }, + { + "cell_type": "markdown", + "id": "fa37ee43", + "metadata": {}, + "source": [ + "## Start Training" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "7408a546", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 38/38 [00:01<00:00, 20.98it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch: 1\n", + "Loss: 0.6168293129456671\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 5.69it/s]\n", + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 6.62it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train: 0.000000\tval: 0.716484\ttest: 0.721788\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 38/38 [00:01<00:00, 25.27it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch: 2\n", + "Loss: 0.4680136606881493\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 5.49it/s]\n", + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 5.08it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train: 0.000000\tval: 0.759707\ttest: 0.791167\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 38/38 [00:01<00:00, 23.34it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch: 3\n", + "Loss: 0.4001527561953193\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 5.11it/s]\n", + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 5.06it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train: 0.000000\tval: 0.763736\ttest: 0.785950\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 38/38 [00:01<00:00, 23.44it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch: 4\n", + "Loss: 0.35615202117907374\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 6.30it/s]\n", + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 5.01it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train: 0.000000\tval: 0.759341\ttest: 0.796035\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 38/38 [00:01<00:00, 23.84it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch: 5\n", + "Loss: 0.31917470811229004\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 5.10it/s]\n", + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 5.08it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train: 0.000000\tval: 0.766300\ttest: 0.786646\n", + "\n", + "best train: 0.000000\tval: 0.716484\ttest: 0.721788\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "train_func = train_classification\n", + "eval_func = eval_classification\n", + "\n", + "train_roc_list, val_roc_list, test_roc_list = [], [], []\n", + "train_acc_list, val_acc_list, test_acc_list = [], [], []\n", + "best_val_roc, best_val_idx = -1, 0\n", + "criterion = nn.BCEWithLogitsLoss(reduction=\"none\")\n", + "\n", + "for epoch in range(1, args.epochs + 1):\n", + " loss_acc = train_func(model, device, train_loader, optimizer)\n", + " print(\"Epoch: {}\\nLoss: {}\".format(epoch, loss_acc))\n", + "\n", + " if args.eval_train:\n", + " train_roc, train_acc, train_target, train_pred = eval_func(model, device, train_loader)\n", + " else:\n", + " train_roc = train_acc = 0\n", + " val_roc, val_acc, val_target, val_pred = eval_func(model, device, val_loader)\n", + " test_roc, test_acc, test_target, test_pred = eval_func(model, device, test_loader)\n", + "\n", + " train_roc_list.append(train_roc)\n", + " train_acc_list.append(train_acc)\n", + " val_roc_list.append(val_roc)\n", + " val_acc_list.append(val_acc)\n", + " test_roc_list.append(test_roc)\n", + " test_acc_list.append(test_acc)\n", + " print(\"train: {:.6f}\\tval: {:.6f}\\ttest: {:.6f}\".format(train_roc, val_roc, test_roc))\n", + " print()\n", + "\n", + "print(\"best train: {:.6f}\\tval: {:.6f}\\ttest: {:.6f}\".format(train_roc_list[best_val_idx], val_roc_list[best_val_idx], test_roc_list[best_val_idx]))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/demos/demo_downstream_retrieval_Graph.ipynb b/demos/demo_downstream_retrieval_Graph.ipynb new file mode 100644 index 0000000..8a613e3 --- /dev/null +++ b/demos/demo_downstream_retrieval_Graph.ipynb @@ -0,0 +1,508 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "8a41a864", + "metadata": {}, + "source": [ + "# Demo for MoleculeSTM Downstream: Structure-Text Retrieval\n", + "\n", + "## Load Packages" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "b9bc8496", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2023-08-30 12:24:25,704] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n" + ] + } + ], + "source": [ + "import warnings\n", + "warnings.filterwarnings('ignore')\n", + "\n", + "import os\n", + "import time\n", + "import argparse\n", + "import numpy as np\n", + "import pandas as pd\n", + "from tqdm import tqdm\n", + "from collections import defaultdict\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "import torch.nn.functional as F\n", + "from torch.utils.data import DataLoader as torch_DataLoader\n", + "from torch_geometric.loader import DataLoader as pyg_DataLoader\n", + "\n", + "from transformers import AutoModel, AutoTokenizer\n", + "from MoleculeSTM.datasets import DrugBank_Datasets_SMILES_retrieval, DrugBank_Datasets_Graph_retrieval\n", + "from MoleculeSTM.models.mega_molbart.mega_mol_bart import MegaMolBART\n", + "from MoleculeSTM.models import GNN, GNN_graphpred\n", + "from MoleculeSTM.utils import prepare_text_tokens, get_molecule_repr_MoleculeSTM, freeze_network\n", + "\n", + "# Set-up the environment variable to ignore warnings\n", + "os.environ['TOKENIZERS_PARALLELISM'] = 'False'" + ] + }, + { + "cell_type": "markdown", + "id": "7c3a7eee", + "metadata": {}, + "source": [ + "## Setup Arguments" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "a2d76596", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "arguments\t Namespace(CL_neg_samples=1, JK='last', SSL_emb_dim=256, SSL_loss='EBM_NCE', T=0.1, T_list=[4, 10, 20], batch_size=32, dataset='PubChem', dataspace_path='../data', decay=0, device=0, dropout_ratio=0.5, epochs=1, eval_train=0, gnn_emb_dim=300, gnn_type='gin', graph_pooling='mean', input_model_dir='demo_checkpoints_Graph', input_model_path='demo_checkpoints_Graph/molecule_model.pth', load_latent_projector=1, max_seq_len=512, mol_lr=1e-05, mol_lr_scale=0.1, molecule_type='Graph', normalize=True, num_layer=5, num_workers=8, seed=42, task='molecule_description', test_mode='given_text', text_lr=1e-05, text_lr_scale=0.1, text_type='SciBERT', training_mode='zero_shot', verbose=0)\n" + ] + } + ], + "source": [ + "parser = argparse.ArgumentParser()\n", + "parser.add_argument(\"--seed\", type=int, default=42)\n", + "parser.add_argument(\"--device\", type=int, default=0)\n", + "parser.add_argument(\"--SSL_emb_dim\", type=int, default=256)\n", + "parser.add_argument(\"--text_type\", type=str, default=\"SciBERT\", choices=[\"SciBERT\", \"BioBERT\"])\n", + "parser.add_argument(\"--load_latent_projector\", type=int, default=1)\n", + "parser.add_argument(\"--training_mode\", type=str, default=\"zero_shot\", choices=[\"zero_shot\"])\n", + "\n", + "########## for dataset and split ##########\n", + "parser.add_argument(\"--dataspace_path\", type=str, default=\"../data\")\n", + "parser.add_argument(\"--dataset\", type=str, default=\"PubChem\")\n", + "parser.add_argument(\"--task\", type=str, default=\"molecule_description\",\n", + " choices=[\n", + " \"molecule_description\", \"molecule_description_Raw\",\n", + " \"molecule_description_removed_PubChem\", \"molecule_description_removed_PubChem_Raw\",\n", + " \"molecule_pharmacodynamics\", \"molecule_pharmacodynamics_Raw\",\n", + " \"molecule_pharmacodynamics_removed_PubChem\", \"molecule_pharmacodynamics_removed_PubChem_Raw\"])\n", + "parser.add_argument(\"--test_mode\", type=str, default=\"given_text\", choices=[\"given_text\", \"given_molecule\"])\n", + "\n", + "########## for optimization ##########\n", + "parser.add_argument(\"--T_list\", type=int, nargs=\"+\", default=[4, 10, 20])\n", + "parser.add_argument(\"--batch_size\", type=int, default=32)\n", + "parser.add_argument(\"--num_workers\", type=int, default=8)\n", + "parser.add_argument(\"--epochs\", type=int, default=1)\n", + "parser.add_argument(\"--text_lr\", type=float, default=1e-5)\n", + "parser.add_argument(\"--mol_lr\", type=float, default=1e-5)\n", + "parser.add_argument(\"--text_lr_scale\", type=float, default=0.1)\n", + "parser.add_argument(\"--mol_lr_scale\", type=float, default=0.1)\n", + "parser.add_argument(\"--decay\", type=float, default=0)\n", + "\n", + "########## for contrastive objective ##########\n", + "parser.add_argument(\"--SSL_loss\", type=str, default=\"EBM_NCE\", choices=[\"EBM_NCE\", \"InfoNCE\"])\n", + "parser.add_argument(\"--CL_neg_samples\", type=int, default=1)\n", + "parser.add_argument(\"--T\", type=float, default=0.1)\n", + "parser.add_argument('--normalize', dest='normalize', action='store_true')\n", + "parser.add_argument('--no_normalize', dest='normalize', action='store_false')\n", + "parser.set_defaults(normalize=True)\n", + "\n", + "########## for BERT model ##########\n", + "parser.add_argument(\"--max_seq_len\", type=int, default=512)\n", + "\n", + "########## for molecule model ##########\n", + "parser.add_argument(\"--molecule_type\", type=str, default=\"Graph\", choices=[\"SMILES\", \"Graph\"])\n", + "\n", + "########## for 2D GNN ##########\n", + "parser.add_argument(\"--gnn_emb_dim\", type=int, default=300)\n", + "parser.add_argument(\"--num_layer\", type=int, default=5)\n", + "parser.add_argument('--JK', type=str, default='last')\n", + "parser.add_argument(\"--dropout_ratio\", type=float, default=0.5)\n", + "parser.add_argument(\"--gnn_type\", type=str, default=\"gin\")\n", + "parser.add_argument('--graph_pooling', type=str, default='mean')\n", + "\n", + "########## for saver ##########\n", + "parser.add_argument(\"--eval_train\", type=int, default=0)\n", + "parser.add_argument(\"--verbose\", type=int, default=0)\n", + "\n", + "parser.add_argument(\"--input_model_dir\", type=str, default=\"demo_checkpoints_Graph\")\n", + "parser.add_argument(\"--input_model_path\", type=str, default=\"demo_checkpoints_Graph/molecule_model.pth\")\n", + "\n", + "\n", + "args = parser.parse_args(\"\")\n", + "print(\"arguments\\t\", args)" + ] + }, + { + "cell_type": "markdown", + "id": "e3f80fc0", + "metadata": {}, + "source": [ + "## Setup Seed" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "b65ca274", + "metadata": {}, + "outputs": [], + "source": [ + "np.random.seed(args.seed)\n", + "torch.random.manual_seed(args.seed)\n", + "if torch.cuda.is_available():\n", + " torch.cuda.manual_seed_all(args.seed)\n", + "device = torch.device(\"cuda:\" + str(args.device)) \\\n", + " if torch.cuda.is_available() else torch.device(\"cpu\")" + ] + }, + { + "cell_type": "markdown", + "id": "54d7cb65", + "metadata": {}, + "source": [ + "## Load SciBERT" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "a8e70ba5", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at allenai/scibert_scivocab_uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight']\n", + "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading from demo_checkpoints_Graph/text_model.pth...\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pretrained_SciBERT_folder = os.path.join(args.dataspace_path, 'pretrained_SciBERT')\n", + "text_tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased', cache_dir=pretrained_SciBERT_folder)\n", + "text_model = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased', cache_dir=pretrained_SciBERT_folder).to(device)\n", + "text_dim = 768\n", + "\n", + "input_model_path = os.path.join(args.input_model_dir, \"text_model.pth\")\n", + "print(\"Loading from {}...\".format(input_model_path))\n", + "state_dict = torch.load(input_model_path, map_location='cpu')\n", + "text_model.load_state_dict(state_dict)" + ] + }, + { + "cell_type": "markdown", + "id": "c99247bc", + "metadata": {}, + "source": [ + "## Load MoleculeSTM-Graph" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "4964eb40", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading from demo_checkpoints_Graph/molecule_model.pth...\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "molecule_node_model = GNN(\n", + " num_layer=args.num_layer, emb_dim=args.gnn_emb_dim,\n", + " JK=args.JK, drop_ratio=args.dropout_ratio,\n", + " gnn_type=args.gnn_type)\n", + "molecule_model = GNN_graphpred(\n", + " num_layer=args.num_layer, emb_dim=args.gnn_emb_dim, JK=args.JK, graph_pooling=args.graph_pooling,\n", + " num_tasks=1, molecule_node_model=molecule_node_model) \n", + "molecule_dim = args.gnn_emb_dim\n", + "\n", + "input_model_path = os.path.join(args.input_model_dir, \"molecule_model.pth\")\n", + "print(\"Loading from {}...\".format(input_model_path))\n", + "state_dict = torch.load(input_model_path, map_location='cpu')\n", + "molecule_model.load_state_dict(state_dict)" + ] + }, + { + "cell_type": "markdown", + "id": "15a4a0cf", + "metadata": {}, + "source": [ + "## Load Projection Layers" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "1d28fd67", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading from demo_checkpoints_Graph/text2latent_model.pth...\n", + "Loading from demo_checkpoints_Graph/mol2latent_model.pth...\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "text2latent = nn.Linear(text_dim, args.SSL_emb_dim)\n", + "input_model_path = os.path.join(args.input_model_dir, \"text2latent_model.pth\")\n", + "print(\"Loading from {}...\".format(input_model_path))\n", + "state_dict = torch.load(input_model_path, map_location='cpu')\n", + "text2latent.load_state_dict(state_dict)\n", + "\n", + "mol2latent = nn.Linear(molecule_dim, args.SSL_emb_dim)\n", + "input_model_path = os.path.join(args.input_model_dir, \"mol2latent_model.pth\")\n", + "print(\"Loading from {}...\".format(input_model_path))\n", + "state_dict = torch.load(input_model_path, map_location='cpu')\n", + "mol2latent.load_state_dict(state_dict)" + ] + }, + { + "cell_type": "markdown", + "id": "3f5cd050", + "metadata": {}, + "source": [ + "## Define Support Functions" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "146e5c08", + "metadata": {}, + "outputs": [], + "source": [ + "def cycle_index(num, shift):\n", + " arr = torch.arange(num) + shift\n", + " arr[-shift:] = torch.arange(shift)\n", + " return arr\n", + "\n", + "\n", + "def do_CL_eval(X, Y, neg_Y, args):\n", + " X = F.normalize(X, dim=-1)\n", + " X = X.unsqueeze(1) # B, 1, d\n", + "\n", + " Y = Y.unsqueeze(0)\n", + " Y = torch.cat([Y, neg_Y], dim=0) # T, B, d\n", + " Y = Y.transpose(0, 1) # B, T, d\n", + " Y = F.normalize(Y, dim=-1)\n", + "\n", + " logits = torch.bmm(X, Y.transpose(1, 2)).squeeze() # B*T\n", + " B = X.size()[0]\n", + " labels = torch.zeros(B).long().to(logits.device) # B*1\n", + "\n", + " criterion = nn.CrossEntropyLoss()\n", + "\n", + " CL_loss = criterion(logits, labels)\n", + " pred = logits.argmax(dim=1, keepdim=False)\n", + " confidence = logits\n", + " CL_conf = confidence.max(dim=1)[0]\n", + " CL_conf = CL_conf.cpu().numpy()\n", + "\n", + " CL_acc = pred.eq(labels).sum().detach().cpu().item() * 1. / B\n", + " return CL_loss, CL_conf, CL_acc\n", + "\n", + "\n", + "def get_text_repr(text):\n", + " text_tokens_ids, text_masks = prepare_text_tokens(\n", + " device=device, description=text, tokenizer=text_tokenizer, max_seq_len=args.max_seq_len)\n", + " text_output = text_model(input_ids=text_tokens_ids, attention_mask=text_masks)\n", + " text_repr = text_output[\"pooler_output\"]\n", + " text_repr = text2latent(text_repr)\n", + " return text_repr\n", + "\n", + "\n", + "@torch.no_grad()\n", + "def eval_epoch(dataloader):\n", + " text_model.eval()\n", + " molecule_model.eval()\n", + " text2latent.eval()\n", + " mol2latent.eval()\n", + "\n", + " accum_acc_list = [0 for _ in args.T_list]\n", + " if args.verbose:\n", + " L = tqdm(dataloader)\n", + " else:\n", + " L = dataloader\n", + " for batch in L:\n", + " text = batch[0]\n", + " molecule_data = batch[1]\n", + " neg_text = batch[2]\n", + " neg_molecule_data = batch[3]\n", + "\n", + " text_repr = get_text_repr(text)\n", + "\n", + " molecule_data = molecule_data.to(device)\n", + " molecule_repr = get_molecule_repr_MoleculeSTM(\n", + " molecule_data, mol2latent=mol2latent,\n", + " molecule_type=\"Graph\", molecule_model=molecule_model)\n", + "\n", + " if test_mode == \"given_text\":\n", + " neg_molecule_repr = [\n", + " get_molecule_repr_MoleculeSTM(\n", + " neg_molecule_data[idx].to(device), mol2latent=mol2latent,\n", + " molecule_type=\"Graph\", molecule_model=molecule_model) for idx in range(T_max)\n", + " ]\n", + " neg_molecule_repr = torch.stack(neg_molecule_repr)\n", + " for T_idx, T in enumerate(args.T_list):\n", + " _, _, acc = do_CL_eval(text_repr, molecule_repr, neg_molecule_repr[:T-1], args)\n", + " accum_acc_list[T_idx] += acc\n", + " elif test_mode == \"given_molecule\":\n", + " neg_text_repr = [get_text_repr(neg_text[idx]) for idx in range(T_max)]\n", + " neg_text_repr = torch.stack(neg_text_repr)\n", + " for T_idx, T in enumerate(args.T_list):\n", + " _, _, acc = do_CL_eval(molecule_repr, text_repr, neg_text_repr[:T-1], args)\n", + " accum_acc_list[T_idx] += acc\n", + " else:\n", + " raise Exception\n", + " \n", + " accum_acc_list = np.array(accum_acc_list)\n", + " accum_acc_list /= len(dataloader)\n", + " return accum_acc_list" + ] + }, + { + "cell_type": "markdown", + "id": "96b41532", + "metadata": {}, + "source": [ + "## Start Retrieval" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "ebfb842a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Data: Data(x=[40309, 2], edge_index=[2, 85886], edge_attr=[85886, 2], id=[1168])\n", + "Index(['text', 'smiles'], dtype='object')\n", + "Loading negative samples from ../data/DrugBank_data/index/SMILES_description_full.txt\n", + "Results [0.96030405 0.9214527 0.87584459]\n" + ] + } + ], + "source": [ + "text_model = text_model.to(device)\n", + "molecule_model = molecule_model.to(device)\n", + "text2latent = text2latent.to(device)\n", + "mol2latent = mol2latent.to(device)\n", + "\n", + "T_max = max(args.T_list) - 1\n", + "\n", + "initial_test_acc_list = []\n", + "test_mode = args.test_mode\n", + "dataset_folder = os.path.join(args.dataspace_path, \"DrugBank_data\")\n", + "\n", + "\n", + "dataset_class = DrugBank_Datasets_Graph_retrieval\n", + "dataloader_class = pyg_DataLoader\n", + "processed_dir_prefix = args.task\n", + "\n", + "if args.task == \"molecule_description\":\n", + " template = \"SMILES_description_{}.txt\"\n", + "elif args.task == \"molecule_description_removed_PubChem\":\n", + " template = \"SMILES_description_removed_from_PubChem_{}.txt\"\n", + "elif args.task == \"molecule_description_Raw\":\n", + " template = \"SMILES_description_{}_Raw.txt\"\n", + "elif args.task == \"molecule_description_removed_PubChem_Raw\":\n", + " template = \"SMILES_description_removed_from_PubChem_{}_Raw.txt\"\n", + "elif args.task == \"molecule_pharmacodynamics\":\n", + " template = \"SMILES_pharmacodynamics_{}.txt\"\n", + "elif args.task == \"molecule_pharmacodynamics_removed_PubChem\":\n", + " template = \"SMILES_pharmacodynamics_removed_from_PubChem_{}.txt\"\n", + "elif args.task == \"molecule_pharmacodynamics_Raw\":\n", + " template = \"SMILES_pharmacodynamics_{}_Raw.txt\"\n", + "elif args.task == \"molecule_pharmacodynamics_removed_PubChem_Raw\":\n", + " template = \"SMILES_pharmacodynamics_removed_from_PubChem_{}_Raw.txt\"\n", + "\n", + "full_dataset = dataset_class(dataset_folder, 'full', neg_sample_size=T_max, processed_dir_prefix=processed_dir_prefix, template=template)\n", + "full_dataloader = dataloader_class(full_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) # The program will get blcoked with none-zero num_workers\n", + "\n", + "initial_test_acc_list = eval_epoch(full_dataloader)\n", + "print('Results', initial_test_acc_list)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/demos/demo_downstream_retrieval_SMILES.ipynb b/demos/demo_downstream_retrieval_SMILES.ipynb new file mode 100644 index 0000000..0fa0a13 --- /dev/null +++ b/demos/demo_downstream_retrieval_SMILES.ipynb @@ -0,0 +1,613 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "8a41a864", + "metadata": {}, + "source": [ + "# Demo for MoleculeSTM Downstream: Structure-Text Retrieval\n", + "\n", + "## Load Packages" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "b9bc8496", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2023-08-30 12:26:27,252] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n" + ] + } + ], + "source": [ + "import warnings\n", + "warnings.filterwarnings('ignore')\n", + "\n", + "import os\n", + "import time\n", + "import argparse\n", + "import numpy as np\n", + "import pandas as pd\n", + "from tqdm import tqdm\n", + "from collections import defaultdict\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "import torch.nn.functional as F\n", + "from torch.utils.data import DataLoader as torch_DataLoader\n", + "from torch_geometric.loader import DataLoader as pyg_DataLoader\n", + "\n", + "from transformers import AutoModel, AutoTokenizer\n", + "from MoleculeSTM.datasets import DrugBank_Datasets_SMILES_retrieval, DrugBank_Datasets_Graph_retrieval\n", + "from MoleculeSTM.models.mega_molbart.mega_mol_bart import MegaMolBART\n", + "from MoleculeSTM.models import GNN, GNN_graphpred\n", + "from MoleculeSTM.utils import prepare_text_tokens, get_molecule_repr_MoleculeSTM, freeze_network\n", + "\n", + "# Set-up the environment variable to ignore warnings\n", + "os.environ['TOKENIZERS_PARALLELISM'] = 'False'" + ] + }, + { + "cell_type": "markdown", + "id": "7c3a7eee", + "metadata": {}, + "source": [ + "## Setup Arguments" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "a2d76596", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "arguments\t Namespace(CL_neg_samples=1, SSL_emb_dim=256, SSL_loss='EBM_NCE', T=0.1, T_list=[4, 10, 20], batch_size=32, dataspace_path='../data', decay=0, device=0, epochs=1, eval_train=0, input_model_dir='demo_checkpoints_SMILES', input_model_path='demo_checkpoints_SMILES/molecule_model.pth', load_latent_projector=1, max_seq_len=512, mol_lr=1e-05, mol_lr_scale=0.1, molecule_type='SMILES', normalize=True, num_workers=8, seed=42, task='molecule_description', test_mode='given_text', text_lr=1e-05, text_lr_scale=0.1, text_type='SciBERT', training_mode='zero_shot', verbose=0, vocab_path='../MoleculeSTM/bart_vocab.txt')\n" + ] + } + ], + "source": [ + "parser = argparse.ArgumentParser()\n", + "parser.add_argument(\"--seed\", type=int, default=42)\n", + "parser.add_argument(\"--device\", type=int, default=0)\n", + "parser.add_argument(\"--SSL_emb_dim\", type=int, default=256)\n", + "parser.add_argument(\"--text_type\", type=str, default=\"SciBERT\", choices=[\"SciBERT\", \"BioBERT\"])\n", + "parser.add_argument(\"--load_latent_projector\", type=int, default=1)\n", + "parser.add_argument(\"--training_mode\", type=str, default=\"zero_shot\", choices=[\"zero_shot\"])\n", + "\n", + "########## for dataset and split ##########\n", + "parser.add_argument(\"--dataspace_path\", type=str, default=\"../data\")\n", + "parser.add_argument(\"--task\", type=str, default=\"molecule_description\",\n", + " choices=[\n", + " \"molecule_description\", \"molecule_description_Raw\",\n", + " \"molecule_description_removed_PubChem\", \"molecule_description_removed_PubChem_Raw\",\n", + " \"molecule_pharmacodynamics\", \"molecule_pharmacodynamics_Raw\",\n", + " \"molecule_pharmacodynamics_removed_PubChem\", \"molecule_pharmacodynamics_removed_PubChem_Raw\"])\n", + "parser.add_argument(\"--test_mode\", type=str, default=\"given_text\", choices=[\"given_text\", \"given_molecule\"])\n", + "\n", + "########## for optimization ##########\n", + "parser.add_argument(\"--T_list\", type=int, nargs=\"+\", default=[4, 10, 20])\n", + "parser.add_argument(\"--batch_size\", type=int, default=32)\n", + "parser.add_argument(\"--num_workers\", type=int, default=8)\n", + "parser.add_argument(\"--epochs\", type=int, default=1)\n", + "parser.add_argument(\"--text_lr\", type=float, default=1e-5)\n", + "parser.add_argument(\"--mol_lr\", type=float, default=1e-5)\n", + "parser.add_argument(\"--text_lr_scale\", type=float, default=0.1)\n", + "parser.add_argument(\"--mol_lr_scale\", type=float, default=0.1)\n", + "parser.add_argument(\"--decay\", type=float, default=0)\n", + "\n", + "########## for contrastive objective ##########\n", + "parser.add_argument(\"--SSL_loss\", type=str, default=\"EBM_NCE\", choices=[\"EBM_NCE\", \"InfoNCE\"])\n", + "parser.add_argument(\"--CL_neg_samples\", type=int, default=1)\n", + "parser.add_argument(\"--T\", type=float, default=0.1)\n", + "parser.add_argument('--normalize', dest='normalize', action='store_true')\n", + "parser.add_argument('--no_normalize', dest='normalize', action='store_false')\n", + "parser.set_defaults(normalize=True)\n", + "\n", + "########## for BERT model ##########\n", + "parser.add_argument(\"--max_seq_len\", type=int, default=512)\n", + "\n", + "########## for molecule model ##########\n", + "parser.add_argument(\"--molecule_type\", type=str, default=\"SMILES\", choices=[\"SMILES\", \"Graph\"])\n", + "\n", + "########## for MegaMolBART ##########\n", + "parser.add_argument(\"--vocab_path\", type=str, default=\"../MoleculeSTM/bart_vocab.txt\")\n", + "\n", + "########## for saver ##########\n", + "parser.add_argument(\"--eval_train\", type=int, default=0)\n", + "parser.add_argument(\"--verbose\", type=int, default=0)\n", + "\n", + "parser.add_argument(\"--input_model_dir\", type=str, default=\"demo_checkpoints_SMILES\")\n", + "parser.add_argument(\"--input_model_path\", type=str, default=\"demo_checkpoints_SMILES/molecule_model.pth\")\n", + "\n", + "\n", + "args = parser.parse_args(\"\")\n", + "print(\"arguments\\t\", args)" + ] + }, + { + "cell_type": "markdown", + "id": "e3f80fc0", + "metadata": {}, + "source": [ + "## Setup Seed" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "b65ca274", + "metadata": {}, + "outputs": [], + "source": [ + "np.random.seed(args.seed)\n", + "torch.random.manual_seed(args.seed)\n", + "if torch.cuda.is_available():\n", + " torch.cuda.manual_seed_all(args.seed)\n", + "device = torch.device(\"cuda:\" + str(args.device)) \\\n", + " if torch.cuda.is_available() else torch.device(\"cpu\")" + ] + }, + { + "cell_type": "markdown", + "id": "54d7cb65", + "metadata": {}, + "source": [ + "## Load SciBERT" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "a8e70ba5", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at allenai/scibert_scivocab_uncased were not used when initializing BertModel: ['cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias']\n", + "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading from demo_checkpoints_SMILES/text_model.pth...\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pretrained_SciBERT_folder = os.path.join(args.dataspace_path, 'pretrained_SciBERT')\n", + "text_tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased', cache_dir=pretrained_SciBERT_folder)\n", + "text_model = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased', cache_dir=pretrained_SciBERT_folder).to(device)\n", + "text_dim = 768\n", + "\n", + "input_model_path = os.path.join(args.input_model_dir, \"text_model.pth\")\n", + "print(\"Loading from {}...\".format(input_model_path))\n", + "state_dict = torch.load(input_model_path, map_location='cpu')\n", + "text_model.load_state_dict(state_dict)" + ] + }, + { + "cell_type": "markdown", + "id": "c99247bc", + "metadata": {}, + "source": [ + "## Load MoleculeSTM-SMILES" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "4964eb40", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading from demo_checkpoints_SMILES/molecule_model.pth...\n", + "using world size: 1 and model-parallel size: 1 \n", + "using torch.float32 for parameters ...\n", + "-------------------- arguments --------------------\n", + " adam_beta1 ...................... 0.9\n", + " adam_beta2 ...................... 0.999\n", + " adam_eps ........................ 1e-08\n", + " adlr_autoresume ................. False\n", + " adlr_autoresume_interval ........ 1000\n", + " apply_query_key_layer_scaling ... False\n", + " apply_residual_connection_post_layernorm False\n", + " attention_dropout ............... 0.1\n", + " attention_softmax_in_fp32 ....... False\n", + " batch_size ...................... None\n", + " bert_load ....................... None\n", + " bias_dropout_fusion ............. False\n", + " bias_gelu_fusion ................ False\n", + " block_data_path ................. None\n", + " checkpoint_activations .......... False\n", + " checkpoint_in_cpu ............... False\n", + " checkpoint_num_layers ........... 1\n", + " clip_grad ....................... 1.0\n", + " contigious_checkpointing ........ False\n", + " cpu_optimizer ................... False\n", + " cpu_torch_adam .................. False\n", + " data_impl ....................... infer\n", + " data_path ....................... None\n", + " dataset_path .................... None\n", + " DDP_impl ........................ local\n", + " deepscale ....................... False\n", + " deepscale_config ................ None\n", + " deepspeed ....................... False\n", + " deepspeed_activation_checkpointing False\n", + " deepspeed_config ................ None\n", + " deepspeed_mpi ................... False\n", + " distribute_checkpointed_activations False\n", + " distributed_backend ............. nccl\n", + " dynamic_loss_scale .............. True\n", + " eod_mask_loss ................... False\n", + " eval_interval ................... 1000\n", + " eval_iters ...................... 100\n", + " exit_interval ................... None\n", + " faiss_use_gpu ................... False\n", + " finetune ........................ False\n", + " fp16 ............................ False\n", + " fp16_lm_cross_entropy ........... False\n", + " fp32_allreduce .................. False\n", + " gas ............................. 1\n", + " hidden_dropout .................. 0.1\n", + " hidden_size ..................... 256\n", + " hysteresis ...................... 2\n", + " ict_head_size ................... None\n", + " ict_load ........................ None\n", + " indexer_batch_size .............. 128\n", + " indexer_log_interval ............ 1000\n", + " init_method_std ................. 0.02\n", + " layernorm_epsilon ............... 1e-05\n", + " lazy_mpu_init ................... None\n", + " load ............................ None\n", + " local_rank ...................... None\n", + " log_interval .................... 100\n", + " loss_scale ...................... None\n", + " loss_scale_window ............... 1000\n", + " lr .............................. None\n", + " lr_decay_iters .................. None\n", + " lr_decay_style .................. linear\n", + " make_vocab_size_divisible_by .... 128\n", + " mask_prob ....................... 0.15\n", + " max_position_embeddings ......... 512\n", + " merge_file ...................... None\n", + " min_lr .......................... 0.0\n", + " min_scale ....................... 1\n", + " mmap_warmup ..................... False\n", + " model_parallel_size ............. 1\n", + " no_load_optim ................... False\n", + " no_load_rng ..................... False\n", + " no_save_optim ................... False\n", + " no_save_rng ..................... False\n", + " num_attention_heads ............. 8\n", + " num_layers ...................... 4\n", + " num_unique_layers ............... None\n", + " num_workers ..................... 2\n", + " onnx_safe ....................... None\n", + " openai_gelu ..................... False\n", + " override_lr_scheduler ........... False\n", + " param_sharing_style ............. grouped\n", + " params_dtype .................... torch.float32\n", + " partition_activations ........... False\n", + " pipe_parallel_size .............. 0\n", + " profile_backward ................ False\n", + " query_in_block_prob ............. 0.1\n", + " rank ............................ 0\n", + " report_topk_accuracies .......... []\n", + " reset_attention_mask ............ False\n", + " reset_position_ids .............. False\n", + " save ............................ None\n", + " save_interval ................... None\n", + " scaled_masked_softmax_fusion .... False\n", + " scaled_upper_triang_masked_softmax_fusion False\n", + " seed ............................ 1234\n", + " seq_length ...................... None\n", + " short_seq_prob .................. 0.1\n", + " split ........................... 969, 30, 1\n", + " synchronize_each_layer .......... False\n", + " tensorboard_dir ................. None\n", + " titles_data_path ................ None\n", + " tokenizer_type .................. GPT2BPETokenizer\n", + " train_iters ..................... None\n", + " use_checkpoint_lr_scheduler ..... False\n", + " use_cpu_initialization .......... False\n", + " use_one_sent_docs ............... False\n", + " vocab_file ...................... ../MoleculeSTM/bart_vocab.txt\n", + " warmup .......................... 0.01\n", + " weight_decay .................... 0.01\n", + " world_size ...................... 1\n", + " zero_allgather_bucket_size ...... 0.0\n", + " zero_contigious_gradients ....... False\n", + " zero_reduce_bucket_size ......... 0.0\n", + " zero_reduce_scatter ............. False\n", + " zero_stage ...................... 1.0\n", + "---------------- end of arguments ----------------\n", + "> initializing torch distributed ...\n", + "> initializing model parallel with size 1\n", + "> setting random seeds to 1234 ...\n", + "> initializing model parallel cuda seeds on global rank 0, model parallel rank 0, and data parallel rank 0 with model parallel seed: 3952 and data parallel seed: 1234\n", + "Loading vocab from ../MoleculeSTM/bart_vocab.txt.\n" + ] + } + ], + "source": [ + "input_model_path = os.path.join(args.input_model_dir, \"molecule_model.pth\")\n", + "print(\"Loading from {}...\".format(input_model_path))\n", + "MegaMolBART_wrapper = MegaMolBART(vocab_path=args.vocab_path, input_dir=None, output_dir=None)\n", + "molecule_model = MegaMolBART_wrapper.model\n", + "state_dict = torch.load(input_model_path, map_location='cpu')\n", + "molecule_model.load_state_dict(state_dict)\n", + "molecule_dim = 256\n", + "\n", + "# Rewrite the seed by MegaMolBART\n", + "np.random.seed(args.seed)\n", + "torch.random.manual_seed(args.seed)\n", + "if torch.cuda.is_available():\n", + " torch.cuda.manual_seed_all(args.seed)" + ] + }, + { + "cell_type": "markdown", + "id": "15a4a0cf", + "metadata": {}, + "source": [ + "## Load Projection Layers" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "1d28fd67", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading from demo_checkpoints_SMILES/text2latent_model.pth...\n", + "Loading from demo_checkpoints_SMILES/mol2latent_model.pth...\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "text2latent = nn.Linear(text_dim, args.SSL_emb_dim)\n", + "input_model_path = os.path.join(args.input_model_dir, \"text2latent_model.pth\")\n", + "print(\"Loading from {}...\".format(input_model_path))\n", + "state_dict = torch.load(input_model_path, map_location='cpu')\n", + "text2latent.load_state_dict(state_dict)\n", + "\n", + "mol2latent = nn.Linear(molecule_dim, args.SSL_emb_dim)\n", + "input_model_path = os.path.join(args.input_model_dir, \"mol2latent_model.pth\")\n", + "print(\"Loading from {}...\".format(input_model_path))\n", + "state_dict = torch.load(input_model_path, map_location='cpu')\n", + "mol2latent.load_state_dict(state_dict)" + ] + }, + { + "cell_type": "markdown", + "id": "3f5cd050", + "metadata": {}, + "source": [ + "## Define Support Functions" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "146e5c08", + "metadata": {}, + "outputs": [], + "source": [ + "def cycle_index(num, shift):\n", + " arr = torch.arange(num) + shift\n", + " arr[-shift:] = torch.arange(shift)\n", + " return arr\n", + "\n", + "\n", + "def do_CL_eval(X, Y, neg_Y, args):\n", + " X = F.normalize(X, dim=-1)\n", + " X = X.unsqueeze(1) # B, 1, d\n", + "\n", + " Y = Y.unsqueeze(0)\n", + " Y = torch.cat([Y, neg_Y], dim=0) # T, B, d\n", + " Y = Y.transpose(0, 1) # B, T, d\n", + " Y = F.normalize(Y, dim=-1)\n", + "\n", + " logits = torch.bmm(X, Y.transpose(1, 2)).squeeze() # B*T\n", + " B = X.size()[0]\n", + " labels = torch.zeros(B).long().to(logits.device) # B*1\n", + "\n", + " criterion = nn.CrossEntropyLoss()\n", + "\n", + " CL_loss = criterion(logits, labels)\n", + " pred = logits.argmax(dim=1, keepdim=False)\n", + " confidence = logits\n", + " CL_conf = confidence.max(dim=1)[0]\n", + " CL_conf = CL_conf.cpu().numpy()\n", + "\n", + " CL_acc = pred.eq(labels).sum().detach().cpu().item() * 1. / B\n", + " return CL_loss, CL_conf, CL_acc\n", + "\n", + "\n", + "def get_text_repr(text):\n", + " text_tokens_ids, text_masks = prepare_text_tokens(\n", + " device=device, description=text, tokenizer=text_tokenizer, max_seq_len=args.max_seq_len)\n", + " text_output = text_model(input_ids=text_tokens_ids, attention_mask=text_masks)\n", + " text_repr = text_output[\"pooler_output\"]\n", + " text_repr = text2latent(text_repr)\n", + " return text_repr\n", + "\n", + "\n", + "@torch.no_grad()\n", + "def eval_epoch(dataloader):\n", + " text_model.eval()\n", + " molecule_model.eval()\n", + " text2latent.eval()\n", + " mol2latent.eval()\n", + "\n", + " accum_acc_list = [0 for _ in args.T_list]\n", + " if args.verbose:\n", + " L = tqdm(dataloader)\n", + " else:\n", + " L = dataloader\n", + " for batch in L:\n", + " text = batch[0]\n", + " molecule_data = batch[1]\n", + " neg_text = batch[2]\n", + " neg_molecule_data = batch[3]\n", + "\n", + " text_repr = get_text_repr(text)\n", + " SMILES_list = list(molecule_data)\n", + " molecule_repr = get_molecule_repr_MoleculeSTM(\n", + " SMILES_list, mol2latent=mol2latent,\n", + " molecule_type=\"SMILES\", MegaMolBART_wrapper=MegaMolBART_wrapper)\n", + "\n", + " if test_mode == \"given_text\":\n", + " neg_molecule_repr = [\n", + " get_molecule_repr_MoleculeSTM(\n", + " list(neg_molecule_data[idx]), mol2latent=mol2latent,\n", + " molecule_type=\"SMILES\", MegaMolBART_wrapper=MegaMolBART_wrapper) for idx in range(T_max)\n", + " ]\n", + " neg_molecule_repr = torch.stack(neg_molecule_repr)\n", + "\n", + " for T_idx, T in enumerate(args.T_list):\n", + " _, _, acc = do_CL_eval(text_repr, molecule_repr, neg_molecule_repr[:T-1], args)\n", + " accum_acc_list[T_idx] += acc\n", + " elif test_mode == \"given_molecule\":\n", + " neg_text_repr = [get_text_repr(neg_text[idx]) for idx in range(T_max)]\n", + " neg_text_repr = torch.stack(neg_text_repr)\n", + " for T_idx, T in enumerate(args.T_list):\n", + " _, _, acc = do_CL_eval(molecule_repr, text_repr, neg_text_repr[:T-1], args)\n", + " accum_acc_list[T_idx] += acc\n", + " else:\n", + " raise Exception\n", + " \n", + " accum_acc_list = np.array(accum_acc_list)\n", + " accum_acc_list /= len(dataloader)\n", + " return accum_acc_list" + ] + }, + { + "cell_type": "markdown", + "id": "96b41532", + "metadata": {}, + "source": [ + "## Start Retrieval" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "ebfb842a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading negative samples from ../data/DrugBank_data/index/SMILES_description_full.txt\n", + "Results [0.94256757 0.89864865 0.84797297]\n" + ] + } + ], + "source": [ + "text_model = text_model.to(device)\n", + "molecule_model = molecule_model.to(device)\n", + "text2latent = text2latent.to(device)\n", + "mol2latent = mol2latent.to(device)\n", + "\n", + "T_max = max(args.T_list) - 1\n", + "\n", + "initial_test_acc_list = []\n", + "test_mode = args.test_mode\n", + "dataset_folder = os.path.join(args.dataspace_path, \"DrugBank_data\")\n", + "\n", + "dataset_class = DrugBank_Datasets_SMILES_retrieval\n", + "dataloader_class = torch_DataLoader\n", + "\n", + "if args.task == \"molecule_description\":\n", + " template = \"SMILES_description_{}.txt\"\n", + "elif args.task == \"molecule_description_removed_PubChem\":\n", + " template = \"SMILES_description_removed_from_PubChem_{}.txt\"\n", + "elif args.task == \"molecule_description_Raw\":\n", + " template = \"SMILES_description_{}_Raw.txt\"\n", + "elif args.task == \"molecule_description_removed_PubChem_Raw\":\n", + " template = \"SMILES_description_removed_from_PubChem_{}_Raw.txt\"\n", + "elif args.task == \"molecule_pharmacodynamics\":\n", + " template = \"SMILES_pharmacodynamics_{}.txt\"\n", + "elif args.task == \"molecule_pharmacodynamics_removed_PubChem\":\n", + " template = \"SMILES_pharmacodynamics_removed_from_PubChem_{}.txt\"\n", + "elif args.task == \"molecule_pharmacodynamics_Raw\":\n", + " template = \"SMILES_pharmacodynamics_{}_Raw.txt\"\n", + "elif args.task == \"molecule_pharmacodynamics_removed_PubChem_Raw\":\n", + " template = \"SMILES_pharmacodynamics_removed_from_PubChem_{}_Raw.txt\"\n", + "\n", + "full_dataset = dataset_class(dataset_folder, 'full', neg_sample_size=T_max, template=template)\n", + "full_dataloader = dataloader_class(full_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) # The program will get blcoked with none-zero num_workers\n", + "\n", + "initial_test_acc_list = eval_epoch(full_dataloader)\n", + "print('Results', initial_test_acc_list)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/demos/demo_downstream_zero_shot_molecule_edit.ipynb b/demos/demo_downstream_zero_shot_molecule_edit.ipynb new file mode 100644 index 0000000..e7b96e3 --- /dev/null +++ b/demos/demo_downstream_zero_shot_molecule_edit.ipynb @@ -0,0 +1,589 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "ef1d4052", + "metadata": {}, + "source": [ + "# Demo for MoleculeSTM Downstream: Molecule Editing\n", + "\n", + "## Load Packages" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "d3e81e82-8fd8-4c68-be10-7e0e3760d6c1", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2023-08-30 12:31:55,780] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n" + ] + } + ], + "source": [ + "import warnings\n", + "warnings.filterwarnings('ignore')\n", + "\n", + "import argparse\n", + "import math\n", + "import numpy as np\n", + "import os\n", + "\n", + "import torch\n", + "from torch import optim\n", + "import torch.nn.functional as F\n", + "from tqdm import tqdm\n", + "\n", + "from rdkit import Chem, RDLogger\n", + "from rdkit.Chem import AllChem, Descriptors\n", + "from rdkit import DataStructs\n", + "lg = RDLogger.logger()\n", + "lg.setLevel(RDLogger.CRITICAL)\n", + "\n", + "from MoleculeSTM.utils import prepare_text_tokens\n", + "from MoleculeSTM.downstream_molecule_edit_utils import get_SMILES_list, get_description_list, load_language_molecule_and_edit_models, clip_loss_for_edit\n", + "\n", + "import sys\n", + "sys.path.insert(0, \"../scripts\")\n", + "from downstream_02_molecule_edit_step_02_MoleculeSTM_Latent_Optimization import get_lr, mean_pooling" + ] + }, + { + "cell_type": "markdown", + "id": "db7bb3d4", + "metadata": {}, + "source": [ + "## Setup Arguments\n", + "\n", + "Notice that at this step, we are only using the textual branch (SciBERT) and a pretrained molecule generative model (MegaMolBART). The MoleculeSTM chemical branch (MegaMolBART or GraphMVP) is only used at the module alignment phase, and we can change it in the `MoleculeSTM_model_dir` argument." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "a9c84026-7791-450d-a86c-59f86281fea8", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Namespace(MegaMolBART_generation_model_dir='../data/pretrained_MegaMolBART/checkpoints', MoleculeSTM_model_dir='demo_checkpoints_SMILES', MoleculeSTM_molecule_type='SMILES', SSL_emb_dim=256, dataspace_path='../data', device=0, epochs=100, input_SMILES='OC1C2C1CC2', input_SMILES_file=None, input_description=None, input_description_id=None, language_edit_model_dir='demo_checkpoints_SMILES', lr=0.1, lr_rampup=0.05, max_seq_len=512, mode='edit', normalize=True, output_model_dir=None, seed=42, use_noise_for_init=True, verbose=1, vocab_path='../MoleculeSTM/bart_vocab.txt')\n" + ] + } + ], + "source": [ + "parser = argparse.ArgumentParser()\n", + "parser.add_argument(\"--seed\", type=int, default=42)\n", + "parser.add_argument(\"--device\", type=int, default=0)\n", + "parser.add_argument(\"--verbose\", type=int, default=1)\n", + "\n", + "########## for editing ##########\n", + "parser.add_argument(\"--input_description\", type=str, default=None)\n", + "parser.add_argument(\"--input_description_id\", type=int, default=None)\n", + "parser.add_argument(\"--input_SMILES\", type=str, default=\"OC1C2C1CC2\")\n", + "parser.add_argument(\"--input_SMILES_file\", type=str, default=None)\n", + "parser.add_argument(\"--output_model_dir\", type=str, default=None)\n", + "parser.add_argument(\"--mode\", type=str, default=\"edit\", choices=[\"edit\", \"free_generation\"])\n", + "parser.add_argument(\"--use_noise_for_init\", dest=\"use_noise_for_init\", action=\"store_true\")\n", + "parser.add_argument(\"--no_noise_for_init\", dest=\"use_noise_for_init\", action=\"store_false\")\n", + "parser.set_defaults(use_noise_for_init=True)\n", + "parser.add_argument('--normalize', dest='normalize', action='store_true')\n", + "parser.add_argument('--no_normalize', dest='normalize', action='store_false')\n", + "parser.set_defaults(normalize=True)\n", + "\n", + "parser.add_argument(\"--dataspace_path\", type=str, default=\"../data\")\n", + "parser.add_argument(\"--SSL_emb_dim\", type=int, default=256)\n", + "parser.add_argument(\"--max_seq_len\", type=int, default=512)\n", + "\n", + "########## for foundation ##########\n", + "parser.add_argument(\"--MoleculeSTM_model_dir\", type=str, default=\"demo_checkpoints_SMILES\")\n", + "parser.add_argument(\"--MoleculeSTM_molecule_type\", type=str, default=\"SMILES\", choices=[\"SMILES\", \"Graph\"])\n", + "parser.add_argument(\"--vocab_path\", type=str, default=\"../MoleculeSTM/bart_vocab.txt\")\n", + "\n", + "########## for generation ##########\n", + "parser.add_argument(\"--MegaMolBART_generation_model_dir\", type=str, default=\"../data/pretrained_MegaMolBART/checkpoints\")\n", + "\n", + "########## for foundation and generation projection ##########\n", + "parser.add_argument(\"--language_edit_model_dir\", type=str, default=\"demo_checkpoints_SMILES\") \n", + "\n", + "########## for editing ##########\n", + "parser.add_argument(\"--lr_rampup\", type=float, default=0.05)\n", + "parser.add_argument(\"--lr\", type=float, default=0.1)\n", + "parser.add_argument(\"--epochs\", type=int, default=100)\n", + "args, unknown = parser.parse_known_args()\n", + "\n", + "print(args)" + ] + }, + { + "cell_type": "markdown", + "id": "0083fa5c", + "metadata": {}, + "source": [ + "## Load Models" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "c5089f92-fec1-45c2-9035-32579ad8725a", + "metadata": { + "scrolled": false, + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at allenai/scibert_scivocab_uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight']\n", + "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading from demo_checkpoints_SMILES/text_model.pth...\n", + "using world size: 1 and model-parallel size: 1 \n", + "using torch.float32 for parameters ...\n", + "-------------------- arguments --------------------\n", + " adam_beta1 ...................... 0.9\n", + " adam_beta2 ...................... 0.999\n", + " adam_eps ........................ 1e-08\n", + " adlr_autoresume ................. False\n", + " adlr_autoresume_interval ........ 1000\n", + " apply_query_key_layer_scaling ... False\n", + " apply_residual_connection_post_layernorm False\n", + " attention_dropout ............... 0.1\n", + " attention_softmax_in_fp32 ....... False\n", + " batch_size ...................... None\n", + " bert_load ....................... None\n", + " bias_dropout_fusion ............. False\n", + " bias_gelu_fusion ................ False\n", + " block_data_path ................. None\n", + " checkpoint_activations .......... False\n", + " checkpoint_in_cpu ............... False\n", + " checkpoint_num_layers ........... 1\n", + " clip_grad ....................... 1.0\n", + " contigious_checkpointing ........ False\n", + " cpu_optimizer ................... False\n", + " cpu_torch_adam .................. False\n", + " data_impl ....................... infer\n", + " data_path ....................... None\n", + " dataset_path .................... None\n", + " DDP_impl ........................ local\n", + " deepscale ....................... False\n", + " deepscale_config ................ None\n", + " deepspeed ....................... False\n", + " deepspeed_activation_checkpointing False\n", + " deepspeed_config ................ None\n", + " deepspeed_mpi ................... False\n", + " distribute_checkpointed_activations False\n", + " distributed_backend ............. nccl\n", + " dynamic_loss_scale .............. True\n", + " eod_mask_loss ................... False\n", + " eval_interval ................... 1000\n", + " eval_iters ...................... 100\n", + " exit_interval ................... None\n", + " faiss_use_gpu ................... False\n", + " finetune ........................ False\n", + " fp16 ............................ False\n", + " fp16_lm_cross_entropy ........... False\n", + " fp32_allreduce .................. False\n", + " gas ............................. 1\n", + " hidden_dropout .................. 0.1\n", + " hidden_size ..................... 256\n", + " hysteresis ...................... 2\n", + " ict_head_size ................... None\n", + " ict_load ........................ None\n", + " indexer_batch_size .............. 128\n", + " indexer_log_interval ............ 1000\n", + " init_method_std ................. 0.02\n", + " layernorm_epsilon ............... 1e-05\n", + " lazy_mpu_init ................... None\n", + " load ............................ ../data/pretrained_MegaMolBART/checkpoints\n", + " local_rank ...................... None\n", + " log_interval .................... 100\n", + " loss_scale ...................... None\n", + " loss_scale_window ............... 1000\n", + " lr .............................. None\n", + " lr_decay_iters .................. None\n", + " lr_decay_style .................. linear\n", + " make_vocab_size_divisible_by .... 128\n", + " mask_prob ....................... 0.15\n", + " max_position_embeddings ......... 512\n", + " merge_file ...................... None\n", + " min_lr .......................... 0.0\n", + " min_scale ....................... 1\n", + " mmap_warmup ..................... False\n", + " model_parallel_size ............. 1\n", + " no_load_optim ................... False\n", + " no_load_rng ..................... False\n", + " no_save_optim ................... False\n", + " no_save_rng ..................... False\n", + " num_attention_heads ............. 8\n", + " num_layers ...................... 4\n", + " num_unique_layers ............... None\n", + " num_workers ..................... 2\n", + " onnx_safe ....................... None\n", + " openai_gelu ..................... False\n", + " override_lr_scheduler ........... False\n", + " param_sharing_style ............. grouped\n", + " params_dtype .................... torch.float32\n", + " partition_activations ........... False\n", + " pipe_parallel_size .............. 0\n", + " profile_backward ................ False\n", + " query_in_block_prob ............. 0.1\n", + " rank ............................ 0\n", + " report_topk_accuracies .......... []\n", + " reset_attention_mask ............ False\n", + " reset_position_ids .............. False\n", + " save ............................ None\n", + " save_interval ................... None\n", + " scaled_masked_softmax_fusion .... False\n", + " scaled_upper_triang_masked_softmax_fusion False\n", + " seed ............................ 1234\n", + " seq_length ...................... None\n", + " short_seq_prob .................. 0.1\n", + " split ........................... 969, 30, 1\n", + " synchronize_each_layer .......... False\n", + " tensorboard_dir ................. None\n", + " titles_data_path ................ None\n", + " tokenizer_type .................. GPT2BPETokenizer\n", + " train_iters ..................... None\n", + " use_checkpoint_lr_scheduler ..... False\n", + " use_cpu_initialization .......... False\n", + " use_one_sent_docs ............... False\n", + " vocab_file ...................... ../MoleculeSTM/bart_vocab.txt\n", + " warmup .......................... 0.01\n", + " weight_decay .................... 0.01\n", + " world_size ...................... 1\n", + " zero_allgather_bucket_size ...... 0.0\n", + " zero_contigious_gradients ....... False\n", + " zero_reduce_bucket_size ......... 0.0\n", + " zero_reduce_scatter ............. False\n", + " zero_stage ...................... 1.0\n", + "---------------- end of arguments ----------------\n", + "> initializing torch distributed ...\n", + "> initializing model parallel with size 1\n", + "> setting random seeds to 1234 ...\n", + "> initializing model parallel cuda seeds on global rank 0, model parallel rank 0, and data parallel rank 0 with model parallel seed: 3952 and data parallel seed: 1234\n", + "Loading vocab from ../MoleculeSTM/bart_vocab.txt.\n", + "Loading from ../data/pretrained_MegaMolBART/checkpoints\n", + "global rank 0 is loading checkpoint ../data/pretrained_MegaMolBART/checkpoints/iter_0134000/mp_rank_00/model_optim_rng.pt\n", + "could not find arguments in the checkpoint ...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[W ProcessGroupNCCL.cpp:1569] Rank 0 using best-guess GPU 0 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " successfully loaded ../data/pretrained_MegaMolBART/checkpoints/iter_0134000/mp_rank_00/model_optim_rng.pt\n", + "Loading from pretrained MegaMolBART (../data/pretrained_MegaMolBART/checkpoints).\n", + "Loading from demo_checkpoints_SMILES/text2latent_model.pth...\n", + "Loading from demo_checkpoints_SMILES/mol2latent_model.pth...\n", + "Loading from demo_checkpoints_SMILES/generation2foundation_model.pth...\n", + "Loading from demo_checkpoints_SMILES/foundation2generation_model.pth...\n" + ] + }, + { + "data": { + "text/plain": [ + "MLP(\n", + " (layers): ModuleList(\n", + " (0): Linear(in_features=256, out_features=256, bias=True)\n", + " (1): Linear(in_features=256, out_features=256, bias=True)\n", + " )\n", + ")" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "text_model, text_tokenizer, text_dim, molecule_model, MegaMolBART_wrapper, molecule_dim, \\\n", + " text2latent, mol2latent, generation2foundation, foundation2generation = load_language_molecule_and_edit_models(args)\n", + "device = torch.device(\"cuda:\" + str(args.device)) \\\n", + " if torch.cuda.is_available() else torch.device(\"cpu\")\n", + "text_model = text_model.to(device)\n", + "molecule_model = molecule_model.to(device)\n", + "text2latent = text2latent.to(device)\n", + "mol2latent = mol2latent.to(device)\n", + "generation2foundation.to(device)\n", + "foundation2generation.to(device)\n", + "text_model.eval()\n", + "molecule_model.eval()\n", + "text2latent.eval()\n", + "mol2latent.eval()\n", + "generation2foundation.eval()\n", + "foundation2generation.eval()" + ] + }, + { + "cell_type": "markdown", + "id": "c4c07337", + "metadata": {}, + "source": [ + "# Reset seed" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "dd7b38a3", + "metadata": {}, + "outputs": [], + "source": [ + "np.random.seed(args.seed)\n", + "torch.random.manual_seed(args.seed)\n", + "if torch.cuda.is_available():\n", + " torch.cuda.manual_seed_all(args.seed)\n", + "device = torch.device(\"cuda:\" + str(args.device)) \\\n", + " if torch.cuda.is_available() else torch.device(\"cpu\")" + ] + }, + { + "cell_type": "markdown", + "id": "c390a992", + "metadata": {}, + "source": [ + "## Define Support Functions" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "ddc9a3a3-ebb6-4806-9e49-314213ff4aef", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def evaluate_SMILES_list(SMILES_list, description):\n", + " print(\"SMILES_list:\", SMILES_list)\n", + " mol_list = []\n", + " for SMILES in SMILES_list:\n", + " mol = Chem.MolFromSmiles(SMILES)\n", + " if mol is None:\n", + " continue\n", + " mol_list.append(mol)\n", + "\n", + " if len(mol_list) < 3:\n", + " return [False]\n", + "\n", + " if \"soluble\" in description and \"insoluble\" not in description:\n", + " props = [\"MolLogP\"]\n", + " prop_pred = [(n, func) for n, func in Descriptors.descList if n.split(\"_\")[-1] in props]\n", + " value_list = []\n", + " for name, func in prop_pred:\n", + " for idx, (SMILES, mol) in enumerate(zip(SMILES_list, mol_list)):\n", + " if idx == 1:\n", + " continue\n", + " value = func(mol)\n", + " value_list.append(value)\n", + " print(\"SMILES: {}\\t\\t\\tlogP: {:.5f}\".format(SMILES, value))\n", + " if value_list[0] > value_list[-1]:\n", + " answer = [True]\n", + " else:\n", + " answer = [False]\n", + "\n", + " return answer\n", + "\n", + "\n", + "def check_edit(SMILES, text, device):\n", + " text_list = [text]\n", + " text_tokens_ids, text_masks = prepare_text_tokens(\n", + " device=device, description=text_list, tokenizer=text_tokenizer, max_seq_len=args.max_seq_len)\n", + " text_output = text_model(input_ids=text_tokens_ids, attention_mask=text_masks)\n", + " text_repr = text_output[\"pooler_output\"]\n", + " text_repr = text2latent(text_repr)\n", + "\n", + " first_and_second_SMILES_list = []\n", + "\n", + " latent_code_init, pad_mask_init = MegaMolBART_wrapper.smileslist2embedding([SMILES]) # [pad, B, d], [pad, B]\n", + " first_and_second_SMILES_list.append(SMILES)\n", + "\n", + " regenerated_mols = MegaMolBART_wrapper.inverse_transform([latent_code_init], pad_mask_init.bool().cuda(), k=1, sanitize=True)\n", + " first_and_second_SMILES_list.append(regenerated_mols[0])\n", + "\n", + " l2_lambda_list = [1e0]\n", + " result_SMILES_list_one_pair, result_eval_list_one_pair = [], []\n", + " \n", + " if args.use_noise_for_init:\n", + " print(\"Use random noise for init\")\n", + " random_noise = torch.randn(latent_code_init.size()).to(device)\n", + " \n", + " for l2_lambda in l2_lambda_list:\n", + " print(\"l2 lambda: {}\".format(l2_lambda))\n", + " current_SMILES_list = [first_and_second_SMILES_list[0]] + [first_and_second_SMILES_list[1]]\n", + " if args.use_noise_for_init:\n", + " print(\"Use random noise for init\")\n", + " latent = latent_code_init.detach().clone() + random_noise\n", + " else:\n", + " print(\"No random noise for init\")\n", + " latent = latent_code_init.detach().clone()\n", + " pad_mask = pad_mask_init.detach().clone()\n", + " latent.requires_grad = True\n", + " optimizer = optim.Adam([latent], lr=args.lr)\n", + " \n", + " if args.verbose:\n", + " L = tqdm(range(args.epochs))\n", + " else:\n", + " L = range(args.epochs)\n", + "\n", + " for i in L:\n", + " t = i / args.epochs\n", + " lr = get_lr(t, args.lr)\n", + " optimizer.param_groups[0][\"lr\"] = lr\n", + "\n", + " molecule_repr_generation = mean_pooling(latent, pad_mask) # [B, d]\n", + " if args.normalize:\n", + " molecule_repr_generation = F.normalize(molecule_repr_generation, dim=-1)\n", + " molecule_repr_foundation = generation2foundation(molecule_repr_generation)\n", + "\n", + " clip_loss_ = clip_loss_for_edit(molecule_repr_foundation, text_repr)\n", + " l2_loss_ = l2_lambda * ((latent_code_init - latent) ** 2).mean()\n", + "\n", + " loss = clip_loss_ + l2_loss_\n", + "\n", + " optimizer.zero_grad()\n", + " loss.backward(retain_graph=True)\n", + " optimizer.step()\n", + " print(\"clip loss: {:.5f}\\tL2 loss: {:.5f}\".format(clip_loss_.item(), l2_loss_.item()))\n", + "\n", + " generated_mols = MegaMolBART_wrapper.inverse_transform([latent], pad_mask.bool().cuda(), k=1, sanitize=True)\n", + " current_SMILES_list.append(generated_mols[0])\n", + " result_SMILES_list_one_pair.append([text] + current_SMILES_list + ['{}'.format(l2_lambda)])\n", + "\n", + " current_result_list = evaluate_SMILES_list(current_SMILES_list, text)\n", + " result_eval_list_one_pair.append(current_result_list)\n", + " print()\n", + " \n", + " result_eval_list_one_pair = np.array(result_eval_list_one_pair)\n", + " result_eval_list_one_pair = np.any(result_eval_list_one_pair, axis=0, keepdims=True)\n", + " return result_SMILES_list_one_pair, result_eval_list_one_pair\n" + ] + }, + { + "cell_type": "markdown", + "id": "e4d7b4ac", + "metadata": {}, + "source": [ + "## Start Molecule Editing" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "2491585e-cae4-4d36-b5df-764cf72e9115", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "start editing\n", + "\n", + "\n", + "\n", + "===== for text prompt: This molecule is soluble in water. =====\n", + "===== for SMILES OC1C2C1CC2 =====\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING: MOLECULE VALIDATION AND SANITIZATION CURRENTLY DISABLED\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Use random noise for init\n", + "l2 lambda: 1.0\n", + "Use random noise for init\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:02<00:00, 38.36it/s]\n", + "WARNING: MOLECULE VALIDATION AND SANITIZATION CURRENTLY DISABLED\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "clip loss: -0.92124\tL2 loss: 0.33059\n", + "SMILES_list: ['OC1C2C1CC2', 'OC12CC1C2', 'OC1CC2CC(O)(C1)C2']\n", + "SMILES: OC1C2C1CC2\t\t\tlogP: 0.38710\n", + "SMILES: OC1CC2CC(O)(C1)C2\t\t\tlogP: 0.28220\n", + "\n" + ] + } + ], + "source": [ + "print(\"start editing\\n\\n\\n\")\n", + "\n", + "source_SMILES_list = get_SMILES_list(args)\n", + "\n", + "description = \"This molecule is soluble in water.\"\n", + "\n", + "\n", + "print(\"===== for text prompt: {} =====\".format(description))\n", + "result_SMILES_list, result_acc_list = [], []\n", + "\n", + "for SMILES in source_SMILES_list:\n", + " print(\"===== for SMILES {} =====\".format(SMILES))\n", + " result_SMILES_list_, result_acc_list_ = check_edit(SMILES, description, device)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/demos/demo_pretrain_Graph.ipynb b/demos/demo_pretrain_Graph.ipynb new file mode 100644 index 0000000..92dbe21 --- /dev/null +++ b/demos/demo_pretrain_Graph.ipynb @@ -0,0 +1,521 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Demo for MoleculeSTM pretraining\n", + "\n", + "All the scripts can be found in `MoleculeSTM/pretrain.py`." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1 Load and Customize Arguments" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "arguments\t Namespace(CL_neg_samples=1, JK='last', SSL_emb_dim=256, SSL_loss='EBM_NCE', T=0.1, batch_size=4, dataset='PubChemSTM1K', dataspace_path='../data', decay=0, device=0, dropout_ratio=0.5, epochs=100, gnn_emb_dim=300, gnn_type='gin', graph_pooling='mean', max_seq_len=512, megamolbart_input_dir='../data/pretrained_MegaMolBART/checkpoints', mol_lr=0.0001, mol_lr_scale=0.1, molecule_type='Graph', normalize=True, num_layer=5, num_workers=8, output_model_dir=None, pretrain_gnn_mode='GraphMVP_G', seed=42, text_lr=0.0001, text_lr_scale=0.1, text_type='SciBERT', verbose=1)\n" + ] + } + ], + "source": [ + "# Set-up the environment variable to ignore warnings\n", + "import warnings\n", + "warnings.filterwarnings('ignore')\n", + "import os\n", + "os.environ['TOKENIZERS_PARALLELISM'] = 'False'\n", + "os.environ['CUDA_LAUNCH_BLOCKING'] = '1'\n", + "\n", + "import argparse\n", + "\n", + "parser = argparse.ArgumentParser()\n", + "\n", + "parser.add_argument(\"--seed\", type=int, default=42)\n", + "parser.add_argument(\"--device\", type=int, default=0)\n", + "\n", + "parser.add_argument(\"--dataspace_path\", type=str, default=\"../data\")\n", + "parser.add_argument(\"--dataset\", type=str, default=\"PubChemSTM1K\")\n", + "parser.add_argument(\"--text_type\", type=str, default=\"SciBERT\", choices=[\"SciBERT\"])\n", + "parser.add_argument(\"--molecule_type\", type=str, default=\"Graph\", choices=[\"SMILES\", \"Graph\"])\n", + "\n", + "parser.add_argument(\"--batch_size\", type=int, default=4)\n", + "parser.add_argument(\"--text_lr\", type=float, default=1e-4)\n", + "parser.add_argument(\"--mol_lr\", type=float, default=1e-4)\n", + "parser.add_argument(\"--text_lr_scale\", type=float, default=0.1)\n", + "parser.add_argument(\"--mol_lr_scale\", type=float, default=0.1)\n", + "parser.add_argument(\"--num_workers\", type=int, default=8)\n", + "parser.add_argument(\"--epochs\", type=int, default=100)\n", + "parser.add_argument(\"--decay\", type=float, default=0)\n", + "parser.add_argument(\"--verbose\", type=int, default=1)\n", + "parser.add_argument(\"--output_model_dir\", type=str, default=None)\n", + "\n", + "########## for SciBERT ##########\n", + "parser.add_argument(\"--max_seq_len\", type=int, default=512)\n", + "\n", + "########## for MegaMolBART ##########\n", + "parser.add_argument(\"--megamolbart_input_dir\", type=str, default=\"../data/pretrained_MegaMolBART/checkpoints\")\n", + "\n", + "########## for 2D GNN ##########\n", + "parser.add_argument(\"--pretrain_gnn_mode\", type=str, default=\"GraphMVP_G\", choices=[\"GraphMVP_G\"])\n", + "parser.add_argument(\"--gnn_emb_dim\", type=int, default=300)\n", + "parser.add_argument(\"--num_layer\", type=int, default=5)\n", + "parser.add_argument('--JK', type=str, default='last')\n", + "parser.add_argument(\"--dropout_ratio\", type=float, default=0.5)\n", + "parser.add_argument(\"--gnn_type\", type=str, default=\"gin\")\n", + "parser.add_argument('--graph_pooling', type=str, default='mean')\n", + "\n", + "########## for contrastive SSL ##########\n", + "parser.add_argument(\"--SSL_loss\", type=str, default=\"EBM_NCE\", choices=[\"EBM_NCE\", \"InfoNCE\"])\n", + "parser.add_argument(\"--SSL_emb_dim\", type=int, default=256)\n", + "parser.add_argument(\"--CL_neg_samples\", type=int, default=1)\n", + "parser.add_argument(\"--T\", type=float, default=0.1)\n", + "parser.add_argument('--normalize', dest='normalize', action='store_true')\n", + "parser.add_argument('--no_normalize', dest='normalize', action='store_false')\n", + "parser.set_defaults(normalize=True)\n", + "\n", + "args = parser.parse_args(\"\")\n", + "print(\"arguments\\t\", args)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2 Load Packages" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import time\n", + "import numpy as np\n", + "from tqdm import tqdm\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "import torch.nn.functional as F\n", + "from torch.utils.data import DataLoader as torch_DataLoader\n", + "\n", + "from torch_geometric.loader import DataLoader as pyg_DataLoader\n", + "from transformers import AutoModel, AutoTokenizer\n", + "\n", + "from MoleculeSTM.datasets import (\n", + " PubChemSTM_Datasets_SMILES, PubChemSTM_SubDatasets_SMILES,\n", + " PubChemSTM_Datasets_Graph, PubChemSTM_SubDatasets_Graph,\n", + " PubChemSTM_Datasets_Raw_SMILES, PubChemSTM_SubDatasets_Raw_SMILES,\n", + " PubChemSTM_Datasets_Raw_Graph, PubChemSTM_SubDatasets_Raw_Graph\n", + ")\n", + "from MoleculeSTM.models import GNN, GNN_graphpred\n", + "from MoleculeSTM.utils import prepare_text_tokens, get_molecule_repr_MoleculeSTM, freeze_network" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3 Supporting Functions" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "def cycle_index(num, shift):\n", + " arr = torch.arange(num) + shift\n", + " arr[-shift:] = torch.arange(shift)\n", + " return arr\n", + "\n", + "\n", + "def do_CL(X, Y, args):\n", + " if args.normalize:\n", + " X = F.normalize(X, dim=-1)\n", + " Y = F.normalize(Y, dim=-1)\n", + "\n", + " if args.SSL_loss == 'EBM_NCE':\n", + " criterion = nn.BCEWithLogitsLoss()\n", + " neg_Y = torch.cat([Y[cycle_index(len(Y), i + 1)] for i in range(args.CL_neg_samples)], dim=0)\n", + " neg_X = X.repeat((args.CL_neg_samples, 1))\n", + "\n", + " pred_pos = torch.sum(X * Y, dim=1) / args.T\n", + " pred_neg = torch.sum(neg_X * neg_Y, dim=1) / args.T\n", + "\n", + " loss_pos = criterion(pred_pos, torch.ones(len(pred_pos)).to(pred_pos.device))\n", + " loss_neg = criterion(pred_neg, torch.zeros(len(pred_neg)).to(pred_neg.device))\n", + " CL_loss = (loss_pos + args.CL_neg_samples * loss_neg) / (1 + args.CL_neg_samples)\n", + "\n", + " CL_acc = (torch.sum(pred_pos > 0).float() + torch.sum(pred_neg < 0).float()) / \\\n", + " (len(pred_pos) + len(pred_neg))\n", + " CL_acc = CL_acc.detach().cpu().item()\n", + "\n", + " elif args.SSL_loss == 'InfoNCE':\n", + " criterion = nn.CrossEntropyLoss()\n", + " B = X.size()[0]\n", + " logits = torch.mm(X, Y.transpose(1, 0)) # B*B\n", + " logits = torch.div(logits, args.T)\n", + " labels = torch.arange(B).long().to(logits.device) # B*1\n", + "\n", + " CL_loss = criterion(logits, labels)\n", + " pred = logits.argmax(dim=1, keepdim=False)\n", + " CL_acc = pred.eq(labels).sum().detach().cpu().item() * 1. / B\n", + "\n", + " else:\n", + " raise Exception\n", + "\n", + " return CL_loss, CL_acc" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4 Training Function" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "def train(\n", + " epoch,\n", + " dataloader,\n", + " text_model, text_tokenizer,\n", + " molecule_model, MegaMolBART_wrapper=None):\n", + "\n", + " text_model.train()\n", + " molecule_model.train()\n", + " text2latent.train()\n", + " mol2latent.train()\n", + "\n", + " if args.verbose:\n", + " L = tqdm(dataloader)\n", + " else:\n", + " L = dataloader\n", + " \n", + " start_time = time.time()\n", + " accum_loss, accum_acc = 0, 0\n", + " for step, batch in enumerate(L):\n", + " description = batch[0]\n", + " molecule_data = batch[1]\n", + "\n", + " description_tokens_ids, description_masks = prepare_text_tokens(\n", + " device=device, description=description, tokenizer=text_tokenizer, max_seq_len=args.max_seq_len)\n", + " description_output = text_model(input_ids=description_tokens_ids, attention_mask=description_masks)\n", + " description_repr = description_output[\"pooler_output\"]\n", + " description_repr = text2latent(description_repr)\n", + "\n", + " molecule_data = molecule_data.to(device)\n", + " molecule_repr = get_molecule_repr_MoleculeSTM(\n", + " molecule_data, mol2latent=mol2latent,\n", + " molecule_type=molecule_type, molecule_model=molecule_model)\n", + "\n", + " loss_01, acc_01 = do_CL(description_repr, molecule_repr, args)\n", + " loss_02, acc_02 = do_CL(molecule_repr, description_repr, args)\n", + " loss = (loss_01 + loss_02) / 2\n", + " acc = (acc_01 + acc_02) / 2\n", + " optimizer.zero_grad()\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " accum_loss += loss.item()\n", + " accum_acc += acc\n", + " \n", + " accum_loss /= len(L)\n", + " accum_acc /= len(L)\n", + " \n", + " global optimal_loss\n", + " temp_loss = accum_loss\n", + " if temp_loss < optimal_loss:\n", + " optimal_loss = temp_loss\n", + " print(\"CL Loss: {:.5f}\\tCL Acc: {:.5f}\\tTime: {:.5f}\".format(accum_loss, accum_acc, time.time() - start_time))\n", + " return" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5 Start Pretraining" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 5.1 Set seed" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "torch.manual_seed(args.seed)\n", + "np.random.seed(args.seed)\n", + "device = torch.device(\"cuda:\" + str(args.device)) \\\n", + " if torch.cuda.is_available() else torch.device(\"cpu\")\n", + "if torch.cuda.is_available():\n", + " torch.cuda.manual_seed_all(args.seed)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 5.2 Prepare Text Model " + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Download SciBert to ../data/pretrained_SciBERT\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at allenai/scibert_scivocab_uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.bias']\n", + "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" + ] + } + ], + "source": [ + "kwargs = {}\n", + "\n", + "if args.text_type == \"SciBERT\":\n", + " pretrained_SciBERT_folder = os.path.join(args.dataspace_path, 'pretrained_SciBERT')\n", + " print(\"Download SciBert to {}\".format(pretrained_SciBERT_folder))\n", + " text_tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased', cache_dir=pretrained_SciBERT_folder)\n", + " text_model = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased', cache_dir=pretrained_SciBERT_folder).to(device)\n", + " kwargs[\"text_tokenizer\"] = text_tokenizer\n", + " kwargs[\"text_model\"] = text_model\n", + " text_dim = 768\n", + "else:\n", + " raise Exception" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 5.3 Start training MoleculeSTM-Graph\n", + "\n", + "#### 5.3.1 Prepare GraphMVP (Graph Model) and Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading from ../data/pretrained_GraphMVP/GraphMVP_G/model.pth ...\n" + ] + } + ], + "source": [ + "dataset_root = os.path.join(args.dataspace_path, \"PubChemSTM_data\")\n", + " \n", + "molecule_type = \"Graph\"\n", + "\n", + "# You need to first run the following for data preprocessing if you haven't done so.\n", + "# PubChemSTM_Datasets_Graph(dataset_root)\n", + "dataset = PubChemSTM_SubDatasets_Graph(dataset_root, size=1000)\n", + "\n", + "dataloader_class = pyg_DataLoader\n", + "\n", + "molecule_node_model = GNN(\n", + " num_layer=args.num_layer, emb_dim=args.gnn_emb_dim,\n", + " JK=args.JK, drop_ratio=args.dropout_ratio,\n", + " gnn_type=args.gnn_type)\n", + "molecule_model = GNN_graphpred(\n", + " num_layer=args.num_layer, emb_dim=args.gnn_emb_dim, JK=args.JK, graph_pooling=args.graph_pooling,\n", + " num_tasks=1, molecule_node_model=molecule_node_model)\n", + "pretrained_model_path = os.path.join(args.dataspace_path, \"pretrained_GraphMVP\", args.pretrain_gnn_mode, \"model.pth\")\n", + "molecule_model.from_pretrained(pretrained_model_path)\n", + "\n", + "molecule_model = molecule_model.to(device)\n", + "\n", + "kwargs[\"molecule_model\"] = molecule_model\n", + "molecule_dim = args.gnn_emb_dim\n", + "\n", + "dataloader = dataloader_class(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 5.3.2 Prepare Two Projection Layers" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "text2latent = nn.Linear(text_dim, args.SSL_emb_dim).to(device)\n", + "mol2latent = nn.Linear(molecule_dim, args.SSL_emb_dim).to(device)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 5.3.3 Prepare Optimizers" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "model_param_group = [\n", + " {\"params\": text_model.parameters(), \"lr\": args.text_lr},\n", + " {\"params\": molecule_model.parameters(), \"lr\": args.mol_lr},\n", + " {\"params\": text2latent.parameters(), \"lr\": args.text_lr * args.text_lr_scale},\n", + " {\"params\": mol2latent.parameters(), \"lr\": args.mol_lr * args.mol_lr_scale},\n", + "]\n", + "optimizer = optim.Adam(model_param_group, weight_decay=args.decay)\n", + "optimal_loss = 1e10" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 5.3.4 Start Training" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:57<00:00, 4.35it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CL Loss: 0.71635\tCL Acc: 0.50225\tTime: 57.53959\n", + "Epoch 1\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:56<00:00, 4.44it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CL Loss: 0.70258\tCL Acc: 0.49950\tTime: 56.35668\n", + "Epoch 2\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:56<00:00, 4.39it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CL Loss: 0.69960\tCL Acc: 0.49900\tTime: 56.90493\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "for e in range(3):\n", + " print(\"Epoch {}\".format(e))\n", + " train(e, dataloader, **kwargs)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/demos/demo_pretrain_SMILES.ipynb b/demos/demo_pretrain_SMILES.ipynb new file mode 100644 index 0000000..c4247c4 --- /dev/null +++ b/demos/demo_pretrain_SMILES.ipynb @@ -0,0 +1,657 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Demo for MoleculeSTM pretraining\n", + "\n", + "All the scripts can be found in `MoleculeSTM/pretrain.py`." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1 Load and Customize Arguments" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "arguments\t Namespace(CL_neg_samples=1, SSL_emb_dim=256, SSL_loss='EBM_NCE', T=0.1, batch_size=4, dataset='PubChemSTM1K', dataspace_path='../data', decay=0, device=0, epochs=100, max_seq_len=512, megamolbart_input_dir='../data/pretrained_MegaMolBART/checkpoints', mol_lr=0.0001, mol_lr_scale=0.1, molecule_type='SMILES', normalize=True, num_workers=8, output_model_dir=None, seed=42, text_lr=0.0001, text_lr_scale=0.1, text_type='SciBERT', verbose=1)\n" + ] + } + ], + "source": [ + "# Set-up the environment variable to ignore warnings\n", + "import warnings\n", + "warnings.filterwarnings('ignore')\n", + "import os\n", + "os.environ['TOKENIZERS_PARALLELISM'] = 'False'\n", + "\n", + "import argparse\n", + "\n", + "parser = argparse.ArgumentParser()\n", + "\n", + "parser.add_argument(\"--seed\", type=int, default=42)\n", + "parser.add_argument(\"--device\", type=int, default=0)\n", + "\n", + "parser.add_argument(\"--dataspace_path\", type=str, default=\"../data\")\n", + "parser.add_argument(\"--dataset\", type=str, default=\"PubChemSTM1K\")\n", + "parser.add_argument(\"--text_type\", type=str, default=\"SciBERT\", choices=[\"SciBERT\"])\n", + "parser.add_argument(\"--molecule_type\", type=str, default=\"SMILES\", choices=[\"SMILES\", \"Graph\"])\n", + "\n", + "parser.add_argument(\"--batch_size\", type=int, default=4)\n", + "parser.add_argument(\"--text_lr\", type=float, default=1e-4)\n", + "parser.add_argument(\"--mol_lr\", type=float, default=1e-4)\n", + "parser.add_argument(\"--text_lr_scale\", type=float, default=0.1)\n", + "parser.add_argument(\"--mol_lr_scale\", type=float, default=0.1)\n", + "parser.add_argument(\"--num_workers\", type=int, default=8)\n", + "parser.add_argument(\"--epochs\", type=int, default=100)\n", + "parser.add_argument(\"--decay\", type=float, default=0)\n", + "parser.add_argument(\"--verbose\", type=int, default=1)\n", + "parser.add_argument(\"--output_model_dir\", type=str, default=None)\n", + "\n", + "########## for SciBERT ##########\n", + "parser.add_argument(\"--max_seq_len\", type=int, default=512)\n", + "\n", + "########## for MegaMolBART ##########\n", + "parser.add_argument(\"--megamolbart_input_dir\", type=str, default=\"../data/pretrained_MegaMolBART/checkpoints\")\n", + "\n", + "########## for contrastive SSL ##########\n", + "parser.add_argument(\"--SSL_loss\", type=str, default=\"EBM_NCE\", choices=[\"EBM_NCE\", \"InfoNCE\"])\n", + "parser.add_argument(\"--SSL_emb_dim\", type=int, default=256)\n", + "parser.add_argument(\"--CL_neg_samples\", type=int, default=1)\n", + "parser.add_argument(\"--T\", type=float, default=0.1)\n", + "parser.add_argument('--normalize', dest='normalize', action='store_true')\n", + "parser.add_argument('--no_normalize', dest='normalize', action='store_false')\n", + "parser.set_defaults(normalize=True)\n", + "\n", + "args = parser.parse_args(\"\")\n", + "print(\"arguments\\t\", args)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2 Load Packages" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2023-08-30 12:36:54,712] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n" + ] + } + ], + "source": [ + "import os\n", + "import time\n", + "import numpy as np\n", + "from tqdm import tqdm\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "import torch.nn.functional as F\n", + "from torch.utils.data import DataLoader as torch_DataLoader\n", + "\n", + "from torch_geometric.loader import DataLoader as pyg_DataLoader\n", + "from transformers import AutoModel, AutoTokenizer\n", + "\n", + "from MoleculeSTM.datasets import (\n", + " PubChemSTM_Datasets_SMILES, PubChemSTM_SubDatasets_SMILES,\n", + " PubChemSTM_Datasets_Graph, PubChemSTM_SubDatasets_Graph,\n", + " PubChemSTM_Datasets_Raw_SMILES, PubChemSTM_SubDatasets_Raw_SMILES,\n", + " PubChemSTM_Datasets_Raw_Graph, PubChemSTM_SubDatasets_Raw_Graph\n", + ")\n", + "from MoleculeSTM.models import GNN, GNN_graphpred\n", + "from MoleculeSTM.utils import prepare_text_tokens, get_molecule_repr_MoleculeSTM, freeze_network\n", + "from MoleculeSTM.models.mega_molbart.mega_mol_bart import MegaMolBART" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3 Supporting Functions" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "def cycle_index(num, shift):\n", + " arr = torch.arange(num) + shift\n", + " arr[-shift:] = torch.arange(shift)\n", + " return arr\n", + "\n", + "\n", + "def do_CL(X, Y, args):\n", + " if args.normalize:\n", + " X = F.normalize(X, dim=-1)\n", + " Y = F.normalize(Y, dim=-1)\n", + "\n", + " if args.SSL_loss == 'EBM_NCE':\n", + " criterion = nn.BCEWithLogitsLoss()\n", + " neg_Y = torch.cat([Y[cycle_index(len(Y), i + 1)] for i in range(args.CL_neg_samples)], dim=0)\n", + " neg_X = X.repeat((args.CL_neg_samples, 1))\n", + "\n", + " pred_pos = torch.sum(X * Y, dim=1) / args.T\n", + " pred_neg = torch.sum(neg_X * neg_Y, dim=1) / args.T\n", + "\n", + " loss_pos = criterion(pred_pos, torch.ones(len(pred_pos)).to(pred_pos.device))\n", + " loss_neg = criterion(pred_neg, torch.zeros(len(pred_neg)).to(pred_neg.device))\n", + " CL_loss = (loss_pos + args.CL_neg_samples * loss_neg) / (1 + args.CL_neg_samples)\n", + "\n", + " CL_acc = (torch.sum(pred_pos > 0).float() + torch.sum(pred_neg < 0).float()) / \\\n", + " (len(pred_pos) + len(pred_neg))\n", + " CL_acc = CL_acc.detach().cpu().item()\n", + "\n", + " elif args.SSL_loss == 'InfoNCE':\n", + " criterion = nn.CrossEntropyLoss()\n", + " B = X.size()[0]\n", + " logits = torch.mm(X, Y.transpose(1, 0)) # B*B\n", + " logits = torch.div(logits, args.T)\n", + " labels = torch.arange(B).long().to(logits.device) # B*1\n", + "\n", + " CL_loss = criterion(logits, labels)\n", + " pred = logits.argmax(dim=1, keepdim=False)\n", + " CL_acc = pred.eq(labels).sum().detach().cpu().item() * 1. / B\n", + "\n", + " else:\n", + " raise Exception\n", + "\n", + " return CL_loss, CL_acc" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4 Training Function" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "def train(\n", + " epoch,\n", + " dataloader,\n", + " text_model, text_tokenizer,\n", + " molecule_model, MegaMolBART_wrapper=None):\n", + "\n", + " text_model.train()\n", + " molecule_model.train()\n", + " text2latent.train()\n", + " mol2latent.train()\n", + "\n", + " if args.verbose:\n", + " L = tqdm(dataloader)\n", + " else:\n", + " L = dataloader\n", + " \n", + " start_time = time.time()\n", + " accum_loss, accum_acc = 0, 0\n", + " for step, batch in enumerate(L):\n", + " description = batch[0]\n", + " molecule_data = batch[1]\n", + "\n", + " description_tokens_ids, description_masks = prepare_text_tokens(\n", + " device=device, description=description, tokenizer=text_tokenizer, max_seq_len=args.max_seq_len)\n", + " description_output = text_model(input_ids=description_tokens_ids, attention_mask=description_masks)\n", + " description_repr = description_output[\"pooler_output\"]\n", + " description_repr = text2latent(description_repr)\n", + "\n", + " molecule_data = list(molecule_data) # for SMILES_list\n", + " molecule_repr = get_molecule_repr_MoleculeSTM(\n", + " molecule_data, mol2latent=mol2latent,\n", + " molecule_type=molecule_type, MegaMolBART_wrapper=MegaMolBART_wrapper)\n", + "\n", + " loss_01, acc_01 = do_CL(description_repr, molecule_repr, args)\n", + " loss_02, acc_02 = do_CL(molecule_repr, description_repr, args)\n", + " loss = (loss_01 + loss_02) / 2\n", + " acc = (acc_01 + acc_02) / 2\n", + " optimizer.zero_grad()\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " accum_loss += loss.item()\n", + " accum_acc += acc\n", + " \n", + " accum_loss /= len(L)\n", + " accum_acc /= len(L)\n", + " \n", + " global optimal_loss\n", + " temp_loss = accum_loss\n", + " if temp_loss < optimal_loss:\n", + " optimal_loss = temp_loss \n", + " print(\"CL Loss: {:.5f}\\tCL Acc: {:.5f}\\tTime: {:.5f}\".format(accum_loss, accum_acc, time.time() - start_time))\n", + " return" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5 Start Pretraining" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 5.1 Set seed" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "torch.manual_seed(args.seed)\n", + "np.random.seed(args.seed)\n", + "device = torch.device(\"cuda:\" + str(args.device)) \\\n", + " if torch.cuda.is_available() else torch.device(\"cpu\")\n", + "if torch.cuda.is_available():\n", + " torch.cuda.manual_seed_all(args.seed)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 5.2 Prepare Text Model " + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Download SciBert to ../data/pretrained_SciBERT\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at allenai/scibert_scivocab_uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias']\n", + "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" + ] + } + ], + "source": [ + "kwargs = {}\n", + "\n", + "if args.text_type == \"SciBERT\":\n", + " pretrained_SciBERT_folder = os.path.join(args.dataspace_path, 'pretrained_SciBERT')\n", + " print(\"Download SciBert to {}\".format(pretrained_SciBERT_folder))\n", + " text_tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased', cache_dir=pretrained_SciBERT_folder)\n", + " text_model = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased', cache_dir=pretrained_SciBERT_folder).to(device)\n", + " kwargs[\"text_tokenizer\"] = text_tokenizer\n", + " kwargs[\"text_model\"] = text_model\n", + " text_dim = 768\n", + "else:\n", + " raise Exception" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 5.3 Start training MoleculeSTM-SMILES\n", + "\n", + "#### 5.3.1 Prepare MegaMolBART (SMILES Model) and Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "len of CID2text: 250962\n", + "len of CID2SMILES: 250950\n", + "len of text_list: 1000\n", + "using world size: 1 and model-parallel size: 1 \n", + "using torch.float32 for parameters ...\n", + "-------------------- arguments --------------------\n", + " adam_beta1 ...................... 0.9\n", + " adam_beta2 ...................... 0.999\n", + " adam_eps ........................ 1e-08\n", + " adlr_autoresume ................. False\n", + " adlr_autoresume_interval ........ 1000\n", + " apply_query_key_layer_scaling ... False\n", + " apply_residual_connection_post_layernorm False\n", + " attention_dropout ............... 0.1\n", + " attention_softmax_in_fp32 ....... False\n", + " batch_size ...................... None\n", + " bert_load ....................... None\n", + " bias_dropout_fusion ............. False\n", + " bias_gelu_fusion ................ False\n", + " block_data_path ................. None\n", + " checkpoint_activations .......... False\n", + " checkpoint_in_cpu ............... False\n", + " checkpoint_num_layers ........... 1\n", + " clip_grad ....................... 1.0\n", + " contigious_checkpointing ........ False\n", + " cpu_optimizer ................... False\n", + " cpu_torch_adam .................. False\n", + " data_impl ....................... infer\n", + " data_path ....................... None\n", + " dataset_path .................... None\n", + " DDP_impl ........................ local\n", + " deepscale ....................... False\n", + " deepscale_config ................ None\n", + " deepspeed ....................... False\n", + " deepspeed_activation_checkpointing False\n", + " deepspeed_config ................ None\n", + " deepspeed_mpi ................... False\n", + " distribute_checkpointed_activations False\n", + " distributed_backend ............. nccl\n", + " dynamic_loss_scale .............. True\n", + " eod_mask_loss ................... False\n", + " eval_interval ................... 1000\n", + " eval_iters ...................... 100\n", + " exit_interval ................... None\n", + " faiss_use_gpu ................... False\n", + " finetune ........................ False\n", + " fp16 ............................ False\n", + " fp16_lm_cross_entropy ........... False\n", + " fp32_allreduce .................. False\n", + " gas ............................. 1\n", + " hidden_dropout .................. 0.1\n", + " hidden_size ..................... 256\n", + " hysteresis ...................... 2\n", + " ict_head_size ................... None\n", + " ict_load ........................ None\n", + " indexer_batch_size .............. 128\n", + " indexer_log_interval ............ 1000\n", + " init_method_std ................. 0.02\n", + " layernorm_epsilon ............... 1e-05\n", + " lazy_mpu_init ................... None\n", + " load ............................ ../data/pretrained_MegaMolBART/checkpoints\n", + " local_rank ...................... None\n", + " log_interval .................... 100\n", + " loss_scale ...................... None\n", + " loss_scale_window ............... 1000\n", + " lr .............................. None\n", + " lr_decay_iters .................. None\n", + " lr_decay_style .................. linear\n", + " make_vocab_size_divisible_by .... 128\n", + " mask_prob ....................... 0.15\n", + " max_position_embeddings ......... 512\n", + " merge_file ...................... None\n", + " min_lr .......................... 0.0\n", + " min_scale ....................... 1\n", + " mmap_warmup ..................... False\n", + " model_parallel_size ............. 1\n", + " no_load_optim ................... False\n", + " no_load_rng ..................... False\n", + " no_save_optim ................... False\n", + " no_save_rng ..................... False\n", + " num_attention_heads ............. 8\n", + " num_layers ...................... 4\n", + " num_unique_layers ............... None\n", + " num_workers ..................... 2\n", + " onnx_safe ....................... None\n", + " openai_gelu ..................... False\n", + " override_lr_scheduler ........... False\n", + " param_sharing_style ............. grouped\n", + " params_dtype .................... torch.float32\n", + " partition_activations ........... False\n", + " pipe_parallel_size .............. 0\n", + " profile_backward ................ False\n", + " query_in_block_prob ............. 0.1\n", + " rank ............................ 0\n", + " report_topk_accuracies .......... []\n", + " reset_attention_mask ............ False\n", + " reset_position_ids .............. False\n", + " save ............................ None\n", + " save_interval ................... None\n", + " scaled_masked_softmax_fusion .... False\n", + " scaled_upper_triang_masked_softmax_fusion False\n", + " seed ............................ 1234\n", + " seq_length ...................... None\n", + " short_seq_prob .................. 0.1\n", + " split ........................... 969, 30, 1\n", + " synchronize_each_layer .......... False\n", + " tensorboard_dir ................. None\n", + " titles_data_path ................ None\n", + " tokenizer_type .................. GPT2BPETokenizer\n", + " train_iters ..................... None\n", + " use_checkpoint_lr_scheduler ..... False\n", + " use_cpu_initialization .......... False\n", + " use_one_sent_docs ............... False\n", + " vocab_file ...................... ../MoleculeSTM/bart_vocab.txt\n", + " warmup .......................... 0.01\n", + " weight_decay .................... 0.01\n", + " world_size ...................... 1\n", + " zero_allgather_bucket_size ...... 0.0\n", + " zero_contigious_gradients ....... False\n", + " zero_reduce_bucket_size ......... 0.0\n", + " zero_reduce_scatter ............. False\n", + " zero_stage ...................... 1.0\n", + "---------------- end of arguments ----------------\n", + "> initializing torch distributed ...\n", + "> initializing model parallel with size 1\n", + "> setting random seeds to 1234 ...\n", + "> initializing model parallel cuda seeds on global rank 0, model parallel rank 0, and data parallel rank 0 with model parallel seed: 3952 and data parallel seed: 1234\n", + "Loading vocab from ../MoleculeSTM/bart_vocab.txt.\n", + "Loading from ../data/pretrained_MegaMolBART/checkpoints\n", + "global rank 0 is loading checkpoint ../data/pretrained_MegaMolBART/checkpoints/iter_0134000/mp_rank_00/model_optim_rng.pt\n", + "could not find arguments in the checkpoint ...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[W ProcessGroupNCCL.cpp:1569] Rank 0 using best-guess GPU 0 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect.Specify device_ids in barrier() to force use of a particular device.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " successfully loaded ../data/pretrained_MegaMolBART/checkpoints/iter_0134000/mp_rank_00/model_optim_rng.pt\n" + ] + } + ], + "source": [ + "dataset_root = os.path.join(args.dataspace_path, \"PubChemSTM_data\")\n", + " \n", + "molecule_type = \"SMILES\"\n", + "\n", + "dataset = PubChemSTM_SubDatasets_SMILES(dataset_root, size=1000)\n", + "dataloader_class = torch_DataLoader\n", + "\n", + "if args.output_model_dir is not None:\n", + " MegaMolBART_dir = os.path.join(args.output_model_dir, \"SMILES\")\n", + "else:\n", + " MegaMolBART_dir = None\n", + "MegaMolBART_wrapper = MegaMolBART(\n", + " vocab_path=\"../MoleculeSTM/bart_vocab.txt\",\n", + " input_dir=args.megamolbart_input_dir,\n", + " output_dir=MegaMolBART_dir)\n", + "molecule_model = MegaMolBART_wrapper.model\n", + "kwargs[\"MegaMolBART_wrapper\"] = MegaMolBART_wrapper\n", + "kwargs[\"molecule_model\"] = molecule_model\n", + "molecule_dim = 256\n", + "\n", + "dataloader = dataloader_class(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 5.3.2 Prepare Two Projection Layers" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "text2latent = nn.Linear(text_dim, args.SSL_emb_dim).to(device)\n", + "mol2latent = nn.Linear(molecule_dim, args.SSL_emb_dim).to(device)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 5.3.3 Prepare Optimizers" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "model_param_group = [\n", + " {\"params\": text_model.parameters(), \"lr\": args.text_lr},\n", + " {\"params\": molecule_model.parameters(), \"lr\": args.mol_lr},\n", + " {\"params\": text2latent.parameters(), \"lr\": args.text_lr * args.text_lr_scale},\n", + " {\"params\": mol2latent.parameters(), \"lr\": args.mol_lr * args.mol_lr_scale},\n", + "]\n", + "optimizer = optim.Adam(model_param_group, weight_decay=args.decay)\n", + "optimal_loss = 1e10" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 5.3.4 Start Training" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:49<00:00, 5.02it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CL Loss: 0.69800\tCL Acc: 0.50400\tTime: 49.79203\n", + "Epoch 1\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:48<00:00, 5.14it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CL Loss: 0.69504\tCL Acc: 0.50450\tTime: 48.61034\n", + "Epoch 2\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:48<00:00, 5.13it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CL Loss: 0.69426\tCL Acc: 0.50175\tTime: 48.72926\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "for e in range(3):\n", + " print(\"Epoch {}\".format(e))\n", + " train(e, dataloader, **kwargs)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/demos/download.py b/demos/download.py new file mode 100644 index 0000000..b8c7193 --- /dev/null +++ b/demos/download.py @@ -0,0 +1,4 @@ +from huggingface_hub import HfApi, snapshot_download +api = HfApi() + +snapshot_download(repo_id="chao1224/MoleculeSTM", repo_type="model", local_dir='.', allow_patterns="*demo*") diff --git a/preprocessing/PubChemSTM/PubChem_utils.py b/preprocessing/PubChemSTM/PubChem_utils.py new file mode 100644 index 0000000..f7b9974 --- /dev/null +++ b/preprocessing/PubChemSTM/PubChem_utils.py @@ -0,0 +1,13 @@ +import os +from six.moves.urllib.request import urlretrieve + + +def download_and_extract_compound_file(PubChem_datasets_home_folder, compound_file_name): + compound_url = "https://ftp.ncbi.nlm.nih.gov/pubchem/Compound/CURRENT-Full/SDF/{}".format(compound_file_name) + + zipped_compound_file_path = "{}/{}".format(PubChem_datasets_home_folder, compound_file_name) + + if not os.path.exists(zipped_compound_file_path): + print("Downloading {} to {} ...".format(compound_url, zipped_compound_file_path)) + urlretrieve(compound_url, zipped_compound_file_path) + return diff --git a/preprocessing/PubChemSTM/step_01_description_extraction.py b/preprocessing/PubChemSTM/step_01_description_extraction.py new file mode 100644 index 0000000..7f1d4d7 --- /dev/null +++ b/preprocessing/PubChemSTM/step_01_description_extraction.py @@ -0,0 +1,226 @@ +import requests +from tqdm import tqdm +from collections import defaultdict +import json + + +def clean_up_description(description): + description = description + " " + + ##### extra adj Pure ##### + if description.startswith("Pure "): + description = description.replace("Pure ", "") + ##### fix typo ##### + if description.startswith("Mercurycombines"): + description = description.replace("Mercurycombines", "Mercury combines") + + name_special_case_list = [ + '17-Hydroxy-6-methylpregna-3,6-diene-3,20-dione. ', + '5-Thymidylic acid. ', + "5'-S-(3-Amino-3-carboxypropyl)-5'-thioadenosine. ", + "Guanosine 5'-(trihydrogen diphosphate), monoanhydride with phosphorothioic acid. ", + "5'-Uridylic acid. ", + "5'-Adenylic acid, ", + "Uridine 5'-(tetrahydrogen triphosphate). ", + "Inosine 5'-Monophosphate. ", + "Pivaloyloxymethyl butyrate (AN-9), ", + "4-Amino-5-cyano-7-(D-ribofuranosyl)-7H- pyrrolo(2,3-d)pyrimidine. ", + "Cardamonin (also known as Dihydroxymethoxychalcone), ", + ] + + ##### a special case ##### + description = description.replace("17-Hydroxy-6-methylpregna-3,6-diene-3,20-dione. ", "17-Hydroxy-6-methylpregna-3,6-diene-3,20-dione is ") + + ##### a special case ##### + description = description.replace("5-Thymidylic acid. ", "5-Thymidylic acid. is ") + + ##### a special case ##### + description = description.replace("5'-S-(3-Amino-3-carboxypropyl)-5'-thioadenosine. ", "5'-S-(3-Amino-3-carboxypropyl)-5'-thioadenosine. is ") + + ##### a special case ##### + description = description.replace("Guanosine 5'-(trihydrogen diphosphate), monoanhydride with phosphorothioic acid. ", "Guanosine 5'-(trihydrogen diphosphate), monoanhydride with phosphorothioic acid is ") + + ##### a special case ##### + description = description.replace("5'-Uridylic acid. ", "5'-Uridylic acid is ") + + ##### a special case ##### + description = description.replace("5'-Adenylic acid, ", "5'-Adenylic acid is ") + + ##### a special case ##### + description = description.replace("Uridine 5'-(tetrahydrogen triphosphate). ", "Uridine 5'-(tetrahydrogen triphosphate). is ") + + ##### a special case ##### + description = description.replace("Inosine 5'-Monophosphate. ", "Inosine 5'-Monophosphate. is ") + + ##### a special case ##### + description = description.replace("Pivaloyloxymethyl butyrate (AN-9), ", "Pivaloyloxymethyl butyrate (AN-9) is ") + + ##### a special case ##### + description = description.replace("4-Amino-5-cyano-7-(D-ribofuranosyl)-7H- pyrrolo(2,3-d)pyrimidine. ", "4-Amino-5-cyano-7-(D-ribofuranosyl)-7H- pyrrolo(2,3-d)pyrimidine is ") + + ##### a special case ##### + description = description.replace("Cardamonin (also known as Dihydroxymethoxychalcone), ", "Cardamonin (also known as Dihydroxymethoxychalcone) is ") + + ##### a special case ##### + description = description.replace("Lithium has been used to treat ", "Lithium is ") + + ##### a special case ##### + description = description.replace("4,4'-Methylenebis ", "4,4'-Methylenebis is ") + + ##### a special case ##### + description = description.replace("2,3,7,8-Tetrachlorodibenzo-p-dioxin", "2,3,7,8-Tetrachlorodibenzo-p-dioxin is ") + + ##### a special case ##### + description = description.replace("Exposure to 2,4,5-trichlorophenol ", "2,4,5-Trichlorophenol exposure ") + + index = 0 + L = len(description) + if description.startswith('C.I. '): + start_index = len('C.I. ') + elif description.startswith('Nectriapyrone. D '): + start_index = len('Nectriapyrone. D ') + elif description.startswith('Salmonella enterica sv. Minnesota LPS core oligosaccharide'): + start_index = len('Salmonella enterica sv. Minnesota LPS core oligosaccharide') + else: + start_index = 0 + for index in range(start_index, L - 1): + if index < L-2: + if description[index] == '.' and description[index+1] == ' ' and 'A' <= description[index+2] <= 'Z': + break + elif index == L - 2: + break + + first_sentence = description[:index+1] + return first_sentence + + +def extract_name(name_raw, description): + first_sentence = clean_up_description(description) + + splitter = ' -- -- ' + if ' are ' in first_sentence or ' were ' in first_sentence: + replaced_words = 'These molecules' + else: + replaced_words = 'This molecule' + + first_sentence = first_sentence.replace(' is ', splitter) + first_sentence = first_sentence.replace(' are ', splitter) + first_sentence = first_sentence.replace(' was ', splitter) + first_sentence = first_sentence.replace(' were ', splitter) + first_sentence = first_sentence.replace(' appears ', splitter) + first_sentence = first_sentence.replace(' occurs ', splitter) + first_sentence = first_sentence.replace(' stands for ', splitter) + first_sentence = first_sentence.replace(' belongs to ', splitter) + first_sentence = first_sentence.replace(' exists ', splitter) # only for CID=11443 + first_sentence = first_sentence.replace(' has been used in trials ', splitter) + first_sentence = first_sentence.replace(' has been investigated ', splitter) + first_sentence = first_sentence.replace(' has many uses ', splitter) + + if splitter in first_sentence: + extracted_name = first_sentence.split(splitter, 1)[0] + elif first_sentence.startswith(name_raw): + extracted_name = name_raw + elif name_raw in first_sentence: + extracted_name = name_raw + extracted_name = None + print("=====", name_raw) + print("first sentence: ", first_sentence) + # print() + else: + extracted_name = None + + if extracted_name is not None: + extracted_description = description.replace(extracted_name, replaced_words) + else: + extracted_description = description + + return extracted_name, extracted_description, first_sentence + + +if __name__ == "__main__": + total_page_num = 290 + # Please put your own dataset path here + datasets_home_folder = "../../../Datasets" + + PubChemSTM_datasets_description_home_folder = "{}/step_01_PubChemSTM_description".format(datasets_home_folder) + valid_CID_list = set() + CID2name_raw, CID2name_extracted = defaultdict(list), defaultdict(list) + CID2text_raw, CID2text_extracted = defaultdict(list), defaultdict(list) + + for page_index in tqdm(range(total_page_num)): + page_num = page_index + 1 + compound_description_file_name = "Compound_description_{}.txt".format(page_num) + f_out = open("{}/{}".format(PubChemSTM_datasets_description_home_folder, compound_description_file_name), "w") + + description_url = "https://pubchem.ncbi.nlm.nih.gov/rest/pug_view/annotations/heading/json?heading_type=Compound&heading=Record+Description&page={}".format(page_num) + description_data = requests.get(description_url).json() + + description_data = description_data["Annotations"] + assert description_data["Page"] == page_num + assert description_data["TotalPages"] == total_page_num + + record_list = description_data["Annotation"] + + for record in record_list: + try: + CID = record["LinkedRecords"]["CID"][0] + if "Name" in record: + name_raw = record["Name"] + CID2name_raw[CID].append(name_raw) + else: + name_raw = None + + data_list = record["Data"] + for data in data_list: + description = data["Value"]["StringWithMarkup"][0]["String"].strip() + + extracted_name, extracted_description, first_sentence = extract_name(name_raw, description) + if extracted_name is not None: + CID2name_extracted[CID].append(extracted_name) + + CID_special_case_list = [45266824, 11683, 3759, 9700, 439155, 135398675, 135563708, 6030, 10238, 6133, 135398640, 77918, 60748, 11824, 641785, 11125, 7543, 15625, 7271] + + ##### only for debugging ##### + if CID in CID_special_case_list: + print("page: {}\tCID: {}".format(page_index, CID)) + if "Name" in record: + print('yes-name') + name = record["Name"] + print('name:', name) + else: + print('no-name') + print('extracted name:', extracted_name) + print("first_sentence:", first_sentence) + print("extracted_description:", extracted_description) + print("description:", description) + print() + + CID2text_raw[CID].append(description) + CID2text_extracted[CID].append(extracted_description) + + valid_CID_list.add(CID) + f_out.write("{}\n".format(CID)) + f_out.write("{}\n\n".format(extracted_description)) + except: + # print("===\n", record) + # print("missing page: {}\tSourceName: {}\tSourceID: {}".format(page_index, record['SourceName'], record['SourceID'])) + continue + + valid_CID_list = list(set(valid_CID_list)) + valid_CID_list = sorted(valid_CID_list) + # print("valid CID list: {}".format(valid_CID_list)) + print("Total CID (with raw name) {}".format(len(CID2name_raw))) + print("Total CID (with extracted name) {}".format(len(CID2name_extracted))) + print("Total CID {}".format(len(valid_CID_list))) + + with open("{}/PubChemSTM_data/raw/CID2name_raw.json".format(datasets_home_folder), "w") as f: + json.dump(CID2name_raw, f) + + with open("{}/PubChemSTM_data/raw/CID2name.json".format(datasets_home_folder), "w") as f: + json.dump(CID2name_extracted, f) + + with open("{}/PubChemSTM_data/raw/CID2text_raw.json".format(datasets_home_folder), "w") as f: + json.dump(CID2text_raw, f) + + with open("{}/PubChemSTM_data/raw/CID2text.json".format(datasets_home_folder), "w") as f: + json.dump(CID2text_extracted, f) \ No newline at end of file diff --git a/preprocessing/PubChemSTM/step_02.sh b/preprocessing/PubChemSTM/step_02.sh new file mode 100644 index 0000000..52960eb --- /dev/null +++ b/preprocessing/PubChemSTM/step_02.sh @@ -0,0 +1,5 @@ + +for block_id in {0..325}; do + echo "$block_id" + python step_02_download_SDF.py --block_id="$block_id" +done diff --git a/preprocessing/PubChemSTM/step_02_download_SDF.py b/preprocessing/PubChemSTM/step_02_download_SDF.py new file mode 100644 index 0000000..77a787f --- /dev/null +++ b/preprocessing/PubChemSTM/step_02_download_SDF.py @@ -0,0 +1,20 @@ +import argparse +from PubChem_utils import download_and_extract_compound_file + + +parser = argparse.ArgumentParser() +parser.add_argument("--block_id", type=int, default=0) +args = parser.parse_args() + + +if __name__ == "__main__": + datasets_home_folder = "../../../Datasets" + + PubChemSTM_datasets_home_folder = "{}/step_02_PubChemSTM_SDF".format(datasets_home_folder) + block_id = args.block_id + block_size = 500000 + start_id = block_id * block_size + 1 + end_id = (block_id + 1) * block_size + + compound_file_name = "Compound_{:09d}_{:09d}.sdf.gz".format(start_id, end_id) + download_and_extract_compound_file(PubChemSTM_datasets_home_folder, compound_file_name) diff --git a/preprocessing/PubChemSTM/step_03_filter_out_SDF.py b/preprocessing/PubChemSTM/step_03_filter_out_SDF.py new file mode 100644 index 0000000..44aa79b --- /dev/null +++ b/preprocessing/PubChemSTM/step_03_filter_out_SDF.py @@ -0,0 +1,60 @@ +from tqdm import tqdm +import json +import gzip +import numpy as np + +from rdkit import Chem +from rdkit import RDLogger +RDLogger.DisableLog('rdApp.*') +import multiprocessing +from multiprocessing import Pool +import sys + + +if __name__ == "__main__": + datasets_home_folder = "../../../Datasets" + + PubChemSTM_datasets_description_home_folder = "{}/step_01_PubChemSTM_description".format(datasets_home_folder) + with open("{}/PubChemSTM_data/raw/CID2text.json".format(datasets_home_folder), "r") as f: + CID2text = json.load(f) + target_CID_list = set(CID2text.keys()) + + PubChemSTM_datasets_input_folder = "{}/step_02_PubChemSTM_SDF".format(datasets_home_folder) + PubChemSTM_datasets_output_folder = "{}/step_03_PubChemSTM_filtered".format(datasets_home_folder) + block_size = 500000 + + def extract_one_SDF_file(block_id): + valid_mol_count = 0 + + writer = Chem.SDWriter('{}/filtered_{}.sdf'.format(PubChemSTM_datasets_output_folder, block_id)) + start_id = block_id * block_size + 1 + end_id = (block_id + 1) * block_size + + compound_file_name = "Compound_{:09d}_{:09d}.sdf.gz".format(start_id, end_id) + gzip_loader = gzip.open("{}/{}".format(PubChemSTM_datasets_input_folder, compound_file_name)) + suppl = Chem.ForwardSDMolSupplier(gzip_loader) + + for mol in tqdm(suppl): + if mol is None: + continue + cid = mol.GetProp("PUBCHEM_COMPOUND_CID") + + if cid not in target_CID_list: + continue + + writer.write(mol) + valid_mol_count += 1 + + print("block id: {}\nfound {}\n\n".format(block_id, valid_mol_count)) + sys.stdout.flush() + return + + num_process = multiprocessing.cpu_count() + print("{} CPUs".format(num_process)) + num_process = 8 + p = Pool(num_process) + + total_block_num = 325 + block_id_list = np.arange(total_block_num+1) + with p: + p.map(extract_one_SDF_file, block_id_list) diff --git a/preprocessing/PubChemSTM/step_04_merge_SDF.py b/preprocessing/PubChemSTM/step_04_merge_SDF.py new file mode 100644 index 0000000..e443898 --- /dev/null +++ b/preprocessing/PubChemSTM/step_04_merge_SDF.py @@ -0,0 +1,40 @@ +from tqdm import tqdm +import json + +from rdkit import Chem +from rdkit import RDLogger +RDLogger.DisableLog('rdApp.*') + + +if __name__ == "__main__": + datasets_home_folder = "../../../Datasets" + + PubChemSTM_datasets_description_home_folder = "{}/step_01_PubChemSTM_description".format(datasets_home_folder) + with open("{}/PubChemSTM_data/raw/CID2text.json".format(datasets_home_folder), "r") as f: + CID2text = json.load(f) + target_CID_list = set(CID2text.keys()) + print('The length of target_CID_list: {}'.format(len(target_CID_list))) + + PubChemSTM_datasets_folder = "{}/step_03_PubChemSTM_filtered".format(datasets_home_folder) + writer = Chem.SDWriter('{}/PubChemSTM_data/raw/molecules.sdf'.format(datasets_home_folder)) + + total_block_num = 325 + found_CID_set = set() + for block_id in range(total_block_num+1): + compound_file_path = "{}/filtered_{}.sdf".format(PubChemSTM_datasets_folder, block_id) + try: + suppl = Chem.SDMolSupplier(compound_file_path) + + for mol in tqdm(suppl): + writer.write(mol) + cid = mol.GetProp("PUBCHEM_COMPOUND_CID") + found_CID_set.add(cid) + except: + print("block id: {} with 0 valid SDF file".format(block_id)) + continue + + for CID in target_CID_list: + if CID not in found_CID_set: + print("CID: {} not found.".format(CID)) + + print("In total: {} molecules".format(len(found_CID_set))) \ No newline at end of file diff --git a/scripts/downstream_01_retrieval_ATC.py b/scripts/downstream_01_retrieval_ATC.py new file mode 100644 index 0000000..f9984dd --- /dev/null +++ b/scripts/downstream_01_retrieval_ATC.py @@ -0,0 +1,309 @@ +import os +import time +import argparse +import numpy as np +import pandas as pd +from tqdm import tqdm +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +from torch.utils.data import DataLoader as torch_DataLoader +from torch_geometric.loader import DataLoader as pyg_DataLoader + +from transformers import AutoModel, AutoTokenizer +from MoleculeSTM.datasets import DrugBank_Datasets_SMILES_ATC, DrugBank_Datasets_Graph_ATC +from MoleculeSTM.models.mega_molbart.mega_mol_bart import MegaMolBART +from MoleculeSTM.models import GNN, GNN_graphpred +from MoleculeSTM.utils import prepare_text_tokens, get_molecule_repr_MoleculeSTM, freeze_network + + +def do_CL_eval(X, Y, neg_Y, args): + X = F.normalize(X, dim=-1) + X = X.unsqueeze(1) # B, 1, d + + Y = Y.unsqueeze(0) + Y = torch.cat([Y, neg_Y], dim=0) # T, B, d + Y = Y.transpose(0, 1) # B, T, d + Y = F.normalize(Y, dim=-1) + + logits = torch.bmm(X, Y.transpose(1, 2)).squeeze() # B*T + B = X.size()[0] + labels = torch.zeros(B).long().to(logits.device) # B*1 + + criterion = nn.CrossEntropyLoss() + + CL_loss = criterion(logits, labels) + pred = logits.argmax(dim=1, keepdim=False) + confidence = logits + CL_conf = confidence.max(dim=1)[0] + CL_conf = CL_conf.cpu().numpy() + + CL_acc = pred.eq(labels).sum().detach().cpu().item() * 1. / B + return CL_loss, CL_conf, CL_acc + + +def get_text_repr(text): + text_tokens_ids, text_masks = prepare_text_tokens( + device=device, description=text, tokenizer=text_tokenizer, max_seq_len=args.max_seq_len) + text_output = text_model(input_ids=text_tokens_ids, attention_mask=text_masks) + text_repr = text_output["pooler_output"] + text_repr = text2latent(text_repr) + return text_repr + + +@torch.no_grad() +def eval_epoch(dataloader): + text_model.eval() + molecule_model.eval() + text2latent.eval() + mol2latent.eval() + + accum_acc_list = [0 for _ in args.T_list] + if args.verbose: + L = tqdm(dataloader) + else: + L = dataloader + for batch in L: + text = batch[0] + molecule_data = batch[1] + neg_text = batch[2] + neg_molecule_data = batch[3] + + text_repr = get_text_repr(text) + if args.molecule_type == "SMILES": + SMILES_list = list(molecule_data) + molecule_repr = get_molecule_repr_MoleculeSTM( + SMILES_list, mol2latent=mol2latent, + molecule_type="SMILES", MegaMolBART_wrapper=MegaMolBART_wrapper) + else: + molecule_data = molecule_data.to(device) + molecule_repr = get_molecule_repr_MoleculeSTM( + molecule_data, mol2latent=mol2latent, + molecule_type="Graph", molecule_model=molecule_model) + + if test_mode == "given_text": + if args.molecule_type == "SMILES": + neg_molecule_repr = [ + get_molecule_repr_MoleculeSTM( + list(neg_molecule_data[idx]), mol2latent=mol2latent, + molecule_type="SMILES", MegaMolBART_wrapper=MegaMolBART_wrapper) for idx in range(T_max) + ] + neg_molecule_repr = torch.stack(neg_molecule_repr) + else: + neg_molecule_repr = [ + get_molecule_repr_MoleculeSTM( + neg_molecule_data[idx].to(device), mol2latent=mol2latent, + molecule_type="Graph", molecule_model=molecule_model) for idx in range(T_max) + ] + neg_molecule_repr = torch.stack(neg_molecule_repr) + for T_idx, T in enumerate(args.T_list): + _, _, acc = do_CL_eval(text_repr, molecule_repr, neg_molecule_repr[:T-1], args) + accum_acc_list[T_idx] += acc + elif test_mode == "given_molecule": + neg_text_repr = [get_text_repr(neg_text[idx]) for idx in range(T_max)] + neg_text_repr = torch.stack(neg_text_repr) + for T_idx, T in enumerate(args.T_list): + _, _, acc = do_CL_eval(molecule_repr, text_repr, neg_text_repr[:T-1], args) + accum_acc_list[T_idx] += acc + else: + raise Exception + + accum_acc_list = np.array(accum_acc_list) + accum_acc_list /= len(dataloader) + return accum_acc_list + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--device", type=int, default=0) + parser.add_argument("--SSL_emb_dim", type=int, default=256) + parser.add_argument("--text_type", type=str, default="SciBERT", choices=["SciBERT", "BioBERT"]) + parser.add_argument("--load_latent_projector", type=int, default=1) + parser.add_argument("--model_loading_mode", type=str, default="load_from_latest", choices=["load_from_latest", "load_mode_0", "load_mode_1"]) + parser.add_argument("--training_mode", type=str, default="zero_shot", choices=["zero_shot"]) + + ########## for dataset and split ########## + parser.add_argument("--dataspace_path", type=str, default="../data") + parser.add_argument("--task", type=str, default="molecule_ATC", + choices=["molecule_ATC", "molecule_ATC_overlap_PubChem"]) + parser.add_argument("--test_mode", type=str, default="given_text", + choices=["given_text", "given_molecule"]) + parser.add_argument("--ATC_level", type=int, default=5, choices=[1, 3, 4, 5]) + + ########## for optimization ########## + parser.add_argument("--T_list", type=int, nargs="+", default=[4, 10, 20]) + parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("--num_workers", type=int, default=8) + parser.add_argument("--epochs", type=int, default=1) + parser.add_argument("--text_lr", type=float, default=1e-5) + parser.add_argument("--mol_lr", type=float, default=1e-5) + parser.add_argument("--text_lr_scale", type=float, default=0.1) + parser.add_argument("--mol_lr_scale", type=float, default=0.1) + parser.add_argument("--decay", type=float, default=0) + + ########## for contrastive objective ########## + parser.add_argument("--SSL_loss", type=str, default="EBM_NCE", choices=["EBM_NCE", "InfoNCE"]) + parser.add_argument("--CL_neg_samples", type=int, default=1) + parser.add_argument("--T", type=float, default=0.1) + parser.add_argument('--normalize', dest='normalize', action='store_true') + parser.add_argument('--no_normalize', dest='normalize', action='store_false') + parser.set_defaults(normalize=True) + + ########## for BERT model ########## + parser.add_argument("--max_seq_len", type=int, default=512) + + ########## for molecule model ########## + parser.add_argument("--molecule_type", type=str, default="SMILES", choices=["SMILES", "Graph"]) + + ########## for MegaMolBART ########## + parser.add_argument("--vocab_path", type=str, default="../MoleculeSTM/bart_vocab.txt") + + ########## for 2D GNN ########## + parser.add_argument("--gnn_emb_dim", type=int, default=300) + parser.add_argument("--num_layer", type=int, default=5) + parser.add_argument('--JK', type=str, default='last') + parser.add_argument("--dropout_ratio", type=float, default=0.5) + parser.add_argument("--gnn_type", type=str, default="gin") + parser.add_argument('--graph_pooling', type=str, default='mean') + + ########## for saver ########## + parser.add_argument("--eval_interval", type=int, default=5) + parser.add_argument("--verbose", type=int, default=0) + + parser.add_argument("--input_model_dir", type=str, default=None) + parser.add_argument("--input_model_path", type=str, default=None) + parser.add_argument("--output_model_dir", type=str, default=None) + + args = parser.parse_args() + print("arguments\t", args) + torch.multiprocessing.set_sharing_strategy('file_system') + + np.random.seed(args.seed) + torch.random.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + device = torch.device("cuda:" + str(args.device)) \ + if torch.cuda.is_available() else torch.device("cpu") + + ##### prepare text model ##### + ##### by default, this is load_mode_1 ##### + if args.text_type == "SciBERT": + pretrained_SciBERT_folder = os.path.join(args.dataspace_path, 'pretrained_SciBERT') + text_tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased', cache_dir=pretrained_SciBERT_folder) + # TODO: check https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py#L1501 + text_model = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased', cache_dir=pretrained_SciBERT_folder).to(device) + text_dim = 768 + else: + raise Exception + + if args.model_loading_mode == "load_from_latest": + input_model_path = os.path.join(args.input_model_dir, "text_model.pth") + print("Loading from {}...".format(input_model_path)) + state_dict = torch.load(input_model_path, map_location='cpu') + text_model.load_state_dict(state_dict) + elif args.model_loading_mode == "load_mode_0": + text_model.init_weights() + print("Random init for BERT.") + + ##### prepare molecule model ##### + if args.molecule_type == "SMILES": + if args.model_loading_mode == "load_from_latest": + input_model_path = os.path.join(args.input_model_dir, "molecule_model.pth") + print("Loading from {}...".format(input_model_path)) + MegaMolBART_wrapper = MegaMolBART(vocab_path=args.vocab_path, input_dir=None, output_dir=None) + molecule_model = MegaMolBART_wrapper.model + state_dict = torch.load(input_model_path, map_location='cpu') + molecule_model.load_state_dict(state_dict) + elif args.model_loading_mode == "load_mode_0": + MegaMolBART_wrapper = MegaMolBART(vocab_path=args.vocab_path, input_dir=None, output_dir=None) + molecule_model = MegaMolBART_wrapper.model + print("Random init for MegaMolBART.") + elif args.model_loading_mode == "load_mode_1": + # This is loading from the pretarined_MegaMolBART + # --input_model_dir=../data/pretrained_MegaMolBART/checkpoints + MegaMolBART_wrapper = MegaMolBART(input_dir="../data/pretrained_MegaMolBART/checkpoints", output_dir=None) + molecule_model = MegaMolBART_wrapper.model + print("Loading from ../data/pretrained_MegaMolBART/checkpoint.") + molecule_dim = 256 + + else: + molecule_node_model = GNN( + num_layer=args.num_layer, emb_dim=args.gnn_emb_dim, + JK=args.JK, drop_ratio=args.dropout_ratio, + gnn_type=args.gnn_type) + molecule_model = GNN_graphpred( + num_layer=args.num_layer, emb_dim=args.gnn_emb_dim, JK=args.JK, graph_pooling=args.graph_pooling, + num_tasks=1, molecule_node_model=molecule_node_model) + molecule_dim = args.gnn_emb_dim + if args.model_loading_mode == "load_from_latest": + input_model_path = os.path.join(args.input_model_dir, "molecule_model.pth") + print("Loading from {}...".format(input_model_path)) + state_dict = torch.load(input_model_path, map_location='cpu') + molecule_model.load_state_dict(state_dict) + elif args.model_loading_mode == "load_mode_0": + print("Random init for GNN.") + elif args.model_loading_mode == "load_mode_1": + print("Loading from ../data/pretrained_GraphMVP/GraphMVP_G/model.pth") + molecule_model.from_pretrained("../data/pretrained_GraphMVP/GraphMVP_G/model.pth") + + # Rewrite the seed by MegaMolBART + np.random.seed(args.seed) + torch.random.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + + text2latent = nn.Linear(text_dim, args.SSL_emb_dim) + if args.model_loading_mode == "load_from_latest": + input_model_path = os.path.join(args.input_model_dir, "text2latent_model.pth") + print("Loading from {}...".format(input_model_path)) + state_dict = torch.load(input_model_path, map_location='cpu') + text2latent.load_state_dict(state_dict) + + mol2latent = nn.Linear(molecule_dim, args.SSL_emb_dim) + if args.model_loading_mode == "load_from_latest": + input_model_path = os.path.join(args.input_model_dir, "mol2latent_model.pth") + print("Loading from {}...".format(input_model_path)) + state_dict = torch.load(input_model_path, map_location='cpu') + mol2latent.load_state_dict(state_dict) + + text_model = text_model.to(device) + molecule_model = molecule_model.to(device) + text2latent = text2latent.to(device) + mol2latent = mol2latent.to(device) + + T_max = max(args.T_list) - 1 + + # TODO: will tune more prompt_templates + prompt_template = "This molecule is for {}." + + initial_test_acc_list, optimal_test_acc_list = [], [] + test_mode = args.test_mode + dataset_folder = os.path.join(args.dataspace_path, "DrugBank_data") + if args.molecule_type == "SMILES": + if args.task == "molecule_ATC": + full_file_name = "SMILES_ATC_{}_full.txt".format(args.ATC_level) + dataset_class = DrugBank_Datasets_SMILES_ATC + dataloader_class = torch_DataLoader + + full_dataset = dataset_class(dataset_folder, full_file_name, neg_sample_size=T_max, prompt_template=prompt_template) + + else: + if args.task == "molecule_ATC": + full_file_name = "SMILES_ATC_{}_full.txt".format(args.ATC_level) + full_processed_dir_prefix = "ATC_full_{}".format(args.ATC_level) + dataset_class = DrugBank_Datasets_Graph_ATC + dataloader_class = pyg_DataLoader + + full_dataset = dataset_class(dataset_folder, full_file_name, full_processed_dir_prefix, neg_sample_size=T_max, prompt_template=prompt_template) + + full_dataloader = dataloader_class(full_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) + + initial_test_acc_list = eval_epoch(full_dataloader) + print('Initial', initial_test_acc_list) + + row = ", ".join(["{:.4f}".format(x * 100) for x in initial_test_acc_list]) + print("initial results,", row) \ No newline at end of file diff --git a/scripts/downstream_01_retrieval_ATC_KV-PLM.py b/scripts/downstream_01_retrieval_ATC_KV-PLM.py new file mode 100644 index 0000000..023f13e --- /dev/null +++ b/scripts/downstream_01_retrieval_ATC_KV-PLM.py @@ -0,0 +1,225 @@ +from lib2to3.pgen2 import token +import os +import time +import argparse +import numpy as np +import pandas as pd +from tqdm import tqdm +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +from torch.utils.data import DataLoader as torch_DataLoader + +from MoleculeSTM.datasets import DrugBank_Datasets_SMILES_ATC +from MoleculeSTM.utils import prepare_text_tokens +from transformers import BertTokenizer, BertForPreTraining + + +class BigModel(nn.Module): + def __init__(self, main_model): + super(BigModel, self).__init__() + self.main_model = main_model + self.dropout = nn.Dropout(0.1) + + def forward(self, tok, att, cud=True): + typ = torch.zeros(tok.shape).long() + if cud: + typ = typ.cuda() + pooled_output = self.main_model(tok, token_type_ids=typ, attention_mask=att)['pooler_output'] + logits = self.dropout(pooled_output) + return logits + + +def do_CL_eval(X, Y, neg_Y, args): + X = F.normalize(X, dim=-1) + X = X.unsqueeze(1) # B, 1, d + + Y = Y.unsqueeze(0) + Y = torch.cat([Y, neg_Y], dim=0) # T, B, d + Y = Y.transpose(0, 1) # B, T, d + Y = F.normalize(Y, dim=-1) + + logits = torch.bmm(X, Y.transpose(1, 2)).squeeze() # B*T + B = X.size()[0] + labels = torch.zeros(B).long().to(logits.device) # B*1 + + criterion = nn.CrossEntropyLoss() + + CL_loss = criterion(logits, labels) + pred = logits.argmax(dim=1, keepdim=False) + confidence = logits + CL_conf = confidence.max(dim=1)[0] + CL_conf = CL_conf.cpu().numpy() + + CL_acc = pred.eq(labels).sum().detach().cpu().item() * 1. / B + return CL_loss, CL_conf, CL_acc + + +def save_model(save_best, epoch=None): + if args.output_model_dir is not None: + if save_best: + model_file = "model.pth" + + elif epoch is None: + model_file = "model_final.pth" + + else: + model_file = "model_{}.pth".format(epoch) + + saved_file_path = os.path.join(args.output_model_dir, model_file) + torch.save(model.state_dict(), saved_file_path) + return + + +def get_text_and_SMILES_repr(text): + text_tokens_ids, text_masks = prepare_text_tokens( + device=device, description=text, tokenizer=tokenizer, max_seq_len=args.max_seq_len) + text_repr = model(text_tokens_ids, text_masks) + return text_repr + + +@torch.no_grad() +def eval_epoch(dataloader): + model.eval() + + accum_acc_list = [0 for _ in args.T_list] + if args.verbose: + L = tqdm(dataloader) + else: + L = dataloader + for batch in L: + text = batch[0] + molecule_data = batch[1] + neg_text = batch[2] + neg_molecule_data = batch[3] + + text_repr = get_text_and_SMILES_repr(text) + molecule_repr = get_text_and_SMILES_repr(molecule_data) + + if test_mode == "given_text": + if args.molecule_type == "SMILES": + neg_molecule_repr = [get_text_and_SMILES_repr(neg_molecule_data[idx]) for idx in range(T_max)] + neg_molecule_repr = torch.stack(neg_molecule_repr) + else: + neg_molecule_repr = [get_text_and_SMILES_repr(neg_molecule_data[idx]) for idx in range(T_max)] + neg_molecule_repr = torch.stack(neg_molecule_repr) + for T_idx, T in enumerate(args.T_list): + _, _, acc = do_CL_eval(text_repr, molecule_repr, neg_molecule_repr[:T-1], args) + accum_acc_list[T_idx] += acc + elif test_mode == "given_molecule": + neg_text_repr = [get_text_and_SMILES_repr(neg_text[idx]) for idx in range(T_max)] + neg_text_repr = torch.stack(neg_text_repr) + for T_idx, T in enumerate(args.T_list): + _, _, acc = do_CL_eval(molecule_repr, text_repr, neg_text_repr[:T-1], args) + accum_acc_list[T_idx] += acc + else: + raise Exception + + accum_acc_list = np.array(accum_acc_list) + accum_acc_list /= len(dataloader) + return accum_acc_list + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--device", type=int, default=0) + parser.add_argument("--SSL_emb_dim", type=int, default=256) + parser.add_argument("--text_type", type=str, default="SciBERT", choices=["SciBERT", "BioBERT"]) + parser.add_argument("--load_latent_projector", type=int, default=1) + parser.add_argument("--model_loading_mode", type=str, default="load_from_latest", choices=["load_from_latest", "load_mode_0", "load_mode_1"]) + parser.add_argument("--training_mode", type=str, default="zero_shot", choices=["zero_shot"]) + + ########## for dataset and split ########## + parser.add_argument("--dataspace_path", type=str, default="../data") + parser.add_argument("--task", type=str, default="molecule_ATC", + choices=["molecule_ATC", "molecule_ATC_overlap_PubChem"]) + parser.add_argument("--test_mode", type=str, default="given_text", + choices=["given_text", "given_molecule"]) + parser.add_argument("--ATC_level", type=int, default=5, choices=[1, 3, 4, 5]) + + ########## for optimization ########## + parser.add_argument("--T_list", type=int, nargs="+", default=[4, 10, 20]) + parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("--num_workers", type=int, default=8) + parser.add_argument("--epochs", type=int, default=1) + parser.add_argument("--lr", type=float, default=1e-5) + parser.add_argument("--decay", type=float, default=0) + + ########## for contrastive objective ########## + parser.add_argument("--SSL_loss", type=str, default="EBM_NCE", choices=["EBM_NCE", "InfoNCE"]) + parser.add_argument("--CL_neg_samples", type=int, default=1) + parser.add_argument("--T", type=float, default=0.1) + parser.add_argument('--normalize', dest='normalize', action='store_true') + parser.add_argument('--no_normalize', dest='normalize', action='store_false') + parser.set_defaults(normalize=True) + + ########## for BERT model ########## + parser.add_argument("--max_seq_len", type=int, default=512) + + ########## for molecule model ########## + parser.add_argument("--molecule_type", type=str, default="SMILES", choices=["SMILES", "Graph"]) + ########## for 2D GNN ########## + parser.add_argument("--gnn_emb_dim", type=int, default=300) + parser.add_argument("--num_layer", type=int, default=5) + parser.add_argument('--JK', type=str, default='last') + parser.add_argument("--dropout_ratio", type=float, default=0.5) + parser.add_argument("--gnn_type", type=str, default="gin") + parser.add_argument('--graph_pooling', type=str, default='mean') + + ########## for saver ########## + parser.add_argument("--eval_interval", type=int, default=5) + parser.add_argument("--verbose", type=int, default=0) + + parser.add_argument("--input_model_dir", type=str, default=None) + parser.add_argument("--input_model_path", type=str, default=None) + parser.add_argument("--output_model_dir", type=str, default=None) + + args = parser.parse_args() + print("arguments\t", args) + torch.multiprocessing.set_sharing_strategy('file_system') + + np.random.seed(args.seed) + torch.random.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + device = torch.device("cuda:" + str(args.device)) \ + if torch.cuda.is_available() else torch.device("cpu") + + tokenizer = BertTokenizer.from_pretrained('allenai/scibert_scivocab_uncased') + + bert_model0 = BertForPreTraining.from_pretrained('allenai/scibert_scivocab_uncased') + model = BigModel(bert_model0.bert) + if torch.cuda.is_available(): + model.load_state_dict(torch.load('../data/pretrained_KV-PLM/ckpt_ret01.pt')) + model = model.cuda() + else: + model.load_state_dict(torch.load('../data/pretrained_KV-PLM/ckpt_ret01.pt', map_location=torch.device('cpu') )) + model.eval() + + T_max = max(args.T_list) - 1 + + # TODO: will tune more prompt_templates + prompt_template = "This molecule is for {}." + + initial_test_acc_list, optimal_test_acc_list = [], [] + test_mode = args.test_mode + dataset_folder = os.path.join(args.dataspace_path, "DrugBank_data") + + if args.task == "molecule_ATC": + full_file_name = "SMILES_ATC_{}_full.txt".format(args.ATC_level) + dataset_class = DrugBank_Datasets_SMILES_ATC + dataloader_class = torch_DataLoader + + full_dataset = dataset_class(dataset_folder, full_file_name, neg_sample_size=T_max, prompt_template=prompt_template) + + full_dataloader = dataloader_class(full_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) + + initial_test_acc_list = eval_epoch(full_dataloader) + print('Initial', initial_test_acc_list) + + row = ", ".join(["{:.4f}".format(x * 100) for x in initial_test_acc_list]) + print("initial results,", row) diff --git a/scripts/downstream_01_retrieval_ATC_Retrieval.py b/scripts/downstream_01_retrieval_ATC_Retrieval.py new file mode 100644 index 0000000..ef7bcc5 --- /dev/null +++ b/scripts/downstream_01_retrieval_ATC_Retrieval.py @@ -0,0 +1,392 @@ +import os +import argparse +import numpy as np +from tqdm import tqdm + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader as torch_DataLoader +from torch_geometric.loader import DataLoader as pyg_DataLoader + +from transformers import AutoModel, AutoTokenizer +from MoleculeSTM.datasets import PubChem_Datasets_SMILES, PubChem_Datasets_Graph, DrugBank_Datasets_SMILES_ATC, DrugBank_Datasets_Graph_ATC +from MoleculeSTM.models.mega_molbart.mega_mol_bart import MegaMolBART +from MoleculeSTM.models import GNN, GNN_graphpred +from MoleculeSTM.utils import prepare_text_tokens, get_molecule_repr_MoleculeSTM +from downstream_zero_shot_retrieval_DrugBank_Retrieval import RetrievalDataset +from torch.utils.data import DataLoader + + +def do_CL_eval(X, Y, neg_Y, args): + X = F.normalize(X, dim=-1) + X = X.unsqueeze(1) # B, 1, d + + Y = Y.unsqueeze(0) + Y = torch.cat([Y, neg_Y], dim=0) # T, B, d + Y = Y.transpose(0, 1) # B, T, d + Y = F.normalize(Y, dim=-1) + + logits = torch.bmm(X, Y.transpose(1, 2)).squeeze() # B*T + B = X.size()[0] + labels = torch.zeros(B).long().to(logits.device) # B*1 + + criterion = nn.CrossEntropyLoss() + + CL_loss = criterion(logits, labels) + pred = logits.argmax(dim=1, keepdim=False) + confidence = logits + CL_conf = confidence.max(dim=1)[0] + CL_conf = CL_conf.cpu().numpy() + + CL_acc = pred.eq(labels).sum().detach().cpu().item() * 1. / B + return CL_loss, CL_conf, CL_acc + + +def get_text_repr(text): + text_tokens_ids, text_masks = prepare_text_tokens( + device=device, description=text, tokenizer=text_tokenizer, max_seq_len=args.max_seq_len) + text_output = text_model(input_ids=text_tokens_ids, attention_mask=text_masks) + text_repr = text_output["pooler_output"] + return text_repr + + +@torch.no_grad() +def extract_retrieval_representation(retrieval_dataloader): + if args.verbose: + L = tqdm(retrieval_dataloader) + else: + L = retrieval_dataloader + + retrieval_molecule_repr_list, retrieval_description_representation_list = [], [] + for step, batch in enumerate(L): + description = batch[0] + molecule_data = batch[1] + + try: + description_tokens_ids, description_masks = prepare_text_tokens( + device=device, description=description, tokenizer=text_tokenizer, max_seq_len=args.max_seq_len) + description_output = text_model(input_ids=description_tokens_ids, attention_mask=description_masks) + description_repr = description_output["pooler_output"] + + if args.molecule_type == "SMILES": + molecule_data = list(molecule_data) # for SMILES_list + molecule_repr = get_molecule_repr_MoleculeSTM( + molecule_data, mol2latent=None, + molecule_type=args.molecule_type, MegaMolBART_wrapper=MegaMolBART_wrapper) + else: + molecule_data = molecule_data.to(device) + molecule_repr = get_molecule_repr_MoleculeSTM( + molecule_data, mol2latent=None, + molecule_type=args.molecule_type, molecule_model=molecule_model) + except: + continue + retrieval_description_representation_list.append(description_repr.detach().cpu().numpy()) + retrieval_molecule_repr_list.append(molecule_repr.detach().cpu().numpy()) + + retrieval_description_representation_array = np.concatenate(retrieval_description_representation_list) + retrieval_molecule_representation_array = np.concatenate(retrieval_molecule_repr_list) + + return retrieval_description_representation_array, retrieval_molecule_representation_array + + +def get_similarity_array(X, retrieval_loader): + sim_list = [] + if args.verbose: + L = tqdm(retrieval_loader) + else: + L = retrieval_loader + for batch in L: + batch = batch.to(device) + sim = torch.matmul(X, batch.transpose(1, 0)).detach().cpu().numpy() + sim_list.append(sim) + sim_array = np.concatenate(sim_list, axis=1) + return sim_array + + +@torch.no_grad() +def eval_epoch(dataloader): + text_model.eval() + molecule_model.eval() + + accum_acc_list = [0 for _ in args.T_list] + if args.verbose: + L = tqdm(dataloader) + else: + L = dataloader + for batch in L: + text = batch[0] + molecule_data = batch[1] + neg_text = batch[2] + neg_molecule_data = batch[3] + + text_repr = get_text_repr(text) + if args.molecule_type == "SMILES": + SMILES_list = list(molecule_data) + molecule_repr = get_molecule_repr_MoleculeSTM( + SMILES_list, mol2latent=None, + molecule_type="SMILES", MegaMolBART_wrapper=MegaMolBART_wrapper) + else: + molecule_data = molecule_data.to(device) + molecule_repr = get_molecule_repr_MoleculeSTM( + molecule_data, mol2latent=None, + molecule_type="Graph", molecule_model=molecule_model) + + if test_mode == "given_text": + if args.molecule_type == "SMILES": + neg_molecule_repr = [ + get_molecule_repr_MoleculeSTM( + list(neg_molecule_data[idx]), mol2latent=None, + molecule_type="SMILES", MegaMolBART_wrapper=MegaMolBART_wrapper) for idx in range(T_max) + ] + neg_molecule_repr = torch.stack(neg_molecule_repr) + else: + neg_molecule_repr = [ + get_molecule_repr_MoleculeSTM( + neg_molecule_data[idx].to(device), mol2latent=None, + molecule_type="Graph", molecule_model=molecule_model) for idx in range(T_max) + ] + neg_molecule_repr = torch.stack(neg_molecule_repr) + + # Next we will do the retrieval: + # text_repr -> retrieval_description_representation_array -> retrieval_molecule_representation_array + similarity_array = get_similarity_array(text_repr, retrieval_description_representation_dataloader) + batch_size = similarity_array.shape[0] + retrieved_text_repr_list = [] + for batch_i in range(batch_size): + temp_similarity_array = similarity_array[batch_i] + sorted_index = np.argsort(temp_similarity_array)[::-1] + optimal_index = sorted_index[0] + retrieved_text_repr_list.append(retrieval_molecule_representation_array[optimal_index]) + retrieved_text_repr_list = np.array(retrieved_text_repr_list) + retrieved_text_repr = torch.Tensor(retrieved_text_repr_list).to(device) + + for T_idx, T in enumerate(args.T_list): + _, _, acc = do_CL_eval(retrieved_text_repr, molecule_repr, neg_molecule_repr[:T-1], args) + accum_acc_list[T_idx] += acc + + elif test_mode == "given_molecule": + neg_text_repr = [get_text_repr(neg_text[idx]) for idx in range(T_max)] + neg_text_repr = torch.stack(neg_text_repr) + + # Next we will do the retrieval: + # molecule_repr -> retrieval_molecule_representation_array -> retrieval_description_representation_array + similarity_array = get_similarity_array(molecule_repr, retrieval_molecule_representation_dataloader) + batch_size = similarity_array.shape[0] + retrieved_mol_repr_list = [] + for batch_i in range(batch_size): + temp_similarity_array = similarity_array[batch_i] + sorted_index = np.argsort(temp_similarity_array)[::-1] + optimal_index = sorted_index[0] + retrieved_mol_repr_list.append(retrieval_description_representation_array[optimal_index]) + retrieved_mol_repr_list = np.array(retrieved_mol_repr_list) + retrieved_mol_repr = torch.Tensor(retrieved_mol_repr_list).to(device) + + for T_idx, T in enumerate(args.T_list): + _, _, acc = do_CL_eval(retrieved_mol_repr, text_repr, neg_text_repr[:T-1], args) + accum_acc_list[T_idx] += acc + else: + raise Exception + + accum_acc_list = np.array(accum_acc_list) + accum_acc_list /= len(dataloader) + return accum_acc_list + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--device", type=int, default=0) + parser.add_argument("--SSL_emb_dim", type=int, default=256) + parser.add_argument("--text_type", type=str, default="SciBERT", choices=["SciBERT", "BioBERT"]) + parser.add_argument("--load_latent_projector", type=int, default=1) + parser.add_argument("--model_loading_mode", type=str, default="load_from_latest", choices=["load_from_latest", "load_mode_0", "load_mode_1"]) + parser.add_argument("--training_mode", type=str, default="zero_shot", choices=["zero_shot"]) + parser.add_argument("--retrieval_folder", type=str, default="retrieval_similarity") + + ########## for dataset and split ########## + parser.add_argument("--dataspace_path", type=str, default="../data") + parser.add_argument("--task", type=str, default="molecule_ATC", + choices=["molecule_ATC", "molecule_ATC_overlap_PubChem"]) + parser.add_argument("--test_mode", type=str, default="given_text", + choices=["given_text", "given_molecule"]) + parser.add_argument("--ATC_level", type=int, default=5, choices=[1, 3, 4, 5]) + + ########## for optimization ########## + parser.add_argument("--T_list", type=int, nargs="+", default=[4, 10, 20]) + parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("--num_workers", type=int, default=8) + parser.add_argument("--epochs", type=int, default=1) + parser.add_argument("--text_lr", type=float, default=1e-5) + parser.add_argument("--mol_lr", type=float, default=1e-5) + parser.add_argument("--text_lr_scale", type=float, default=0.1) + parser.add_argument("--mol_lr_scale", type=float, default=0.1) + parser.add_argument("--decay", type=float, default=0) + + ########## for contrastive objective ########## + parser.add_argument("--SSL_loss", type=str, default="EBM_NCE", choices=["EBM_NCE", "InfoNCE"]) + parser.add_argument("--CL_neg_samples", type=int, default=1) + parser.add_argument("--T", type=float, default=0.1) + parser.add_argument('--normalize', dest='normalize', action='store_true') + parser.add_argument('--no_normalize', dest='normalize', action='store_false') + parser.set_defaults(normalize=True) + + ########## for BERT model ########## + parser.add_argument("--max_seq_len", type=int, default=512) + + ########## for molecule model ########## + parser.add_argument("--molecule_type", type=str, default="SMILES", choices=["SMILES", "Graph"]) + + ########## for MegaMolBART ########## + parser.add_argument("--vocab_path", type=str, default="../MoleculeSTM/bart_vocab.txt") + + ########## for 2D GNN ########## + parser.add_argument("--gnn_emb_dim", type=int, default=300) + parser.add_argument("--num_layer", type=int, default=5) + parser.add_argument('--JK', type=str, default='last') + parser.add_argument("--dropout_ratio", type=float, default=0.5) + parser.add_argument("--gnn_type", type=str, default="gin") + parser.add_argument('--graph_pooling', type=str, default='mean') + + ########## for saver ########## + parser.add_argument("--eval_interval", type=int, default=5) + parser.add_argument("--verbose", type=int, default=0) + + parser.add_argument("--input_model_dir", type=str, default=None) + parser.add_argument("--input_model_path", type=str, default=None) + parser.add_argument("--output_model_dir", type=str, default=None) + + args = parser.parse_args() + print("arguments\t", args) + torch.multiprocessing.set_sharing_strategy('file_system') + + np.random.seed(args.seed) + torch.random.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + device = torch.device("cuda:" + str(args.device)) \ + if torch.cuda.is_available() else torch.device("cpu") + + ##### prepare text model ##### + ##### by default, this is load_mode_1 ##### + if args.text_type == "SciBERT": + pretrained_SciBERT_folder = os.path.join(args.dataspace_path, 'pretrained_SciBERT') + text_tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased', cache_dir=pretrained_SciBERT_folder) + # TODO: check https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py#L1501 + text_model = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased', cache_dir=pretrained_SciBERT_folder).to(device) + text_dim = 768 + else: + raise Exception + + if args.model_loading_mode == "load_from_latest": + input_model_path = os.path.join(args.input_model_dir, "text_model.pth") + print("Loading from {}...".format(input_model_path)) + state_dict = torch.load(input_model_path, map_location='cpu') + text_model.load_state_dict(state_dict) + elif args.model_loading_mode == "load_mode_0": + text_model.init_weights() + print("Random init for BERT.") + + ##### prepare molecule model ##### + if args.molecule_type == "SMILES": + if args.model_loading_mode == "load_from_latest": + input_model_path = os.path.join(args.input_model_dir, "molecule_model.pth") + print("Loading from {}...".format(input_model_path)) + MegaMolBART_wrapper = MegaMolBART(vocab_path=args.vocab_path, input_dir=None, output_dir=None) + molecule_model = MegaMolBART_wrapper.model + state_dict = torch.load(input_model_path, map_location='cpu') + molecule_model.load_state_dict(state_dict) + elif args.model_loading_mode == "load_mode_0": + MegaMolBART_wrapper = MegaMolBART(vocab_path=args.vocab_path, input_dir=None, output_dir=None) + molecule_model = MegaMolBART_wrapper.model + print("Random init for MegaMolBART.") + elif args.model_loading_mode == "load_mode_1": + # This is loading from the pretarined_MegaMolBART + # --input_model_dir=../data/pretrained_MegaMolBART/checkpoints + MegaMolBART_wrapper = MegaMolBART(input_dir="../data/pretrained_MegaMolBART/checkpoints", output_dir=None) + molecule_model = MegaMolBART_wrapper.model + print("Loading from ../data/pretrained_MegaMolBART/checkpoint.") + molecule_dim = 256 + + else: + molecule_node_model = GNN( + num_layer=args.num_layer, emb_dim=args.gnn_emb_dim, + JK=args.JK, drop_ratio=args.dropout_ratio, + gnn_type=args.gnn_type) + molecule_model = GNN_graphpred( + num_layer=args.num_layer, emb_dim=args.gnn_emb_dim, JK=args.JK, graph_pooling=args.graph_pooling, + num_tasks=1, molecule_node_model=molecule_node_model) + molecule_dim = args.gnn_emb_dim + if args.model_loading_mode == "load_from_latest": + input_model_path = os.path.join(args.input_model_dir, "molecule_model.pth") + print("Loading from {}...".format(input_model_path)) + state_dict = torch.load(input_model_path, map_location='cpu') + molecule_model.load_state_dict(state_dict) + elif args.model_loading_mode == "load_mode_0": + print("Random init for GNN.") + elif args.model_loading_mode == "load_mode_1": + print("Loading from ../data/pretrained_GraphMVP/GraphMVP_G/model.pth") + molecule_model.from_pretrained("../data/pretrained_GraphMVP/GraphMVP_G/model.pth") + + # Rewrite the seed by MegaMolBART + np.random.seed(args.seed) + torch.random.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + + text_model = text_model.to(device) + molecule_model = molecule_model.to(device) + + T_max = max(args.T_list) - 1 + + # TODO: will tune more prompt_templates + prompt_template = "This molecule is for {}." + + initial_test_acc_list, optimal_test_acc_list = [], [] + test_mode = args.test_mode + dataset_folder = os.path.join(args.dataspace_path, "DrugBank_data") + if args.molecule_type == "SMILES": + if args.task == "molecule_ATC": + full_file_name = "SMILES_ATC_{}_full.txt".format(args.ATC_level) + dataset_class = DrugBank_Datasets_SMILES_ATC + dataloader_class = torch_DataLoader + + full_dataset = dataset_class(dataset_folder, full_file_name, neg_sample_size=T_max, prompt_template=prompt_template) + + dataset_root = os.path.join(args.dataspace_path, "PubChem_data") + retrieval_dataset = PubChem_Datasets_SMILES(dataset_root) + + else: + if args.task == "molecule_ATC": + full_file_name = "SMILES_ATC_{}_full.txt".format(args.ATC_level) + full_processed_dir_prefix = "ATC_full_{}".format(args.ATC_level) + dataset_class = DrugBank_Datasets_Graph_ATC + dataloader_class = pyg_DataLoader + + full_dataset = dataset_class(dataset_folder, full_file_name, full_processed_dir_prefix, neg_sample_size=T_max, prompt_template=prompt_template) + + dataset_root = os.path.join(args.dataspace_path, "PubChem_data") + retrieval_dataset = PubChem_Datasets_Graph(dataset_root) + + full_dataloader = dataloader_class(full_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) # The program will get blcoked with none-zero num_workers + retrieval_dataloader = dataloader_class(retrieval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers) + + os.makedirs(args.retrieval_folder, exist_ok=True) + retrieval_datapath = "{}/{}_{}_{}".format(args.retrieval_folder, args.molecule_type, args.task, args.ATC_level) + if os.path.exists(retrieval_datapath+".npz"): + data = np.load(retrieval_datapath+".npz") + retrieval_description_representation_array = data["retrieval_description_representation_array"] + retrieval_molecule_representation_array = data["retrieval_molecule_representation_array"] + else: + retrieval_description_representation_array, retrieval_molecule_representation_array = extract_retrieval_representation(retrieval_dataloader) + np.savez(retrieval_datapath, retrieval_description_representation_array=retrieval_description_representation_array, retrieval_molecule_representation_array=retrieval_molecule_representation_array) + retrieval_description_representation_dataset = RetrievalDataset(retrieval_description_representation_array) + retrieval_description_representation_dataloader = DataLoader(retrieval_description_representation_dataset, batch_size=512, shuffle=False, num_workers=args.num_workers) + retrieval_molecule_representation_dataset = RetrievalDataset(retrieval_molecule_representation_array) + retrieval_molecule_representation_dataloader = DataLoader(retrieval_molecule_representation_dataset, batch_size=512, shuffle=False, num_workers=args.num_workers) + + initial_test_acc_list = eval_epoch(full_dataloader) + print('Initial', initial_test_acc_list) + + row = ", ".join(["{:.4f}".format(x * 100) for x in initial_test_acc_list]) + print("initial results,", row) \ No newline at end of file diff --git a/scripts/downstream_01_retrieval_Description_Pharmacodynamics.py b/scripts/downstream_01_retrieval_Description_Pharmacodynamics.py new file mode 100644 index 0000000..08db3ad --- /dev/null +++ b/scripts/downstream_01_retrieval_Description_Pharmacodynamics.py @@ -0,0 +1,341 @@ +import os +import time +import argparse +import numpy as np +import pandas as pd +from tqdm import tqdm +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +from torch.utils.data import DataLoader as torch_DataLoader +from torch_geometric.loader import DataLoader as pyg_DataLoader + +from transformers import AutoModel, AutoTokenizer +from MoleculeSTM.datasets import DrugBank_Datasets_SMILES_retrieval, DrugBank_Datasets_Graph_retrieval +from MoleculeSTM.models.mega_molbart.mega_mol_bart import MegaMolBART +from MoleculeSTM.models import GNN, GNN_graphpred +from MoleculeSTM.utils import prepare_text_tokens, get_molecule_repr_MoleculeSTM, freeze_network + + +def do_CL_eval(X, Y, neg_Y, args): + X = F.normalize(X, dim=-1) + X = X.unsqueeze(1) # B, 1, d + + Y = Y.unsqueeze(0) + Y = torch.cat([Y, neg_Y], dim=0) # T, B, d + Y = Y.transpose(0, 1) # B, T, d + Y = F.normalize(Y, dim=-1) + + logits = torch.bmm(X, Y.transpose(1, 2)).squeeze() # B*T + B = X.size()[0] + labels = torch.zeros(B).long().to(logits.device) # B*1 + + criterion = nn.CrossEntropyLoss() + + CL_loss = criterion(logits, labels) + pred = logits.argmax(dim=1, keepdim=False) + confidence = logits + CL_conf = confidence.max(dim=1)[0] + CL_conf = CL_conf.cpu().numpy() + + CL_acc = pred.eq(labels).sum().detach().cpu().item() * 1. / B + return CL_loss, CL_conf, CL_acc + + +def get_text_repr(text): + text_tokens_ids, text_masks = prepare_text_tokens( + device=device, description=text, tokenizer=text_tokenizer, max_seq_len=args.max_seq_len) + text_output = text_model(input_ids=text_tokens_ids, attention_mask=text_masks) + text_repr = text_output["pooler_output"] + text_repr = text2latent(text_repr) + return text_repr + + +@torch.no_grad() +def eval_epoch(dataloader): + text_model.eval() + molecule_model.eval() + text2latent.eval() + mol2latent.eval() + + accum_acc_list = [0 for _ in args.T_list] + if args.verbose: + L = tqdm(dataloader) + else: + L = dataloader + for batch in L: + text = batch[0] + molecule_data = batch[1] + neg_text = batch[2] + neg_molecule_data = batch[3] + + text_repr = get_text_repr(text) + if args.molecule_type == "SMILES": + molecule_data = list(molecule_data) # for SMILES_list + molecule_repr = get_molecule_repr_MoleculeSTM( + molecule_data, mol2latent=mol2latent, + molecule_type="SMILES", MegaMolBART_wrapper=MegaMolBART_wrapper) + else: + molecule_data = molecule_data.to(device) + molecule_repr = get_molecule_repr_MoleculeSTM( + molecule_data, mol2latent=mol2latent, + molecule_type="Graph", molecule_model=molecule_model + ) + + if test_mode == "given_text": + if args.molecule_type == "SMILES": + neg_molecule_repr = [ + get_molecule_repr_MoleculeSTM( + list(neg_molecule_data[idx]), mol2latent=mol2latent, + molecule_type="SMILES", MegaMolBART_wrapper=MegaMolBART_wrapper) for idx in range(T_max) + ] + neg_molecule_repr = torch.stack(neg_molecule_repr) + else: + neg_molecule_repr = [ + get_molecule_repr_MoleculeSTM( + neg_molecule_data[idx].to(device), mol2latent=mol2latent, + molecule_type="Graph", molecule_model=molecule_model) for idx in range(T_max) + ] + neg_molecule_repr = torch.stack(neg_molecule_repr) + for T_idx, T in enumerate(args.T_list): + _, _, acc = do_CL_eval(text_repr, molecule_repr, neg_molecule_repr[:T-1], args) + accum_acc_list[T_idx] += acc + + elif test_mode == "given_molecule": + neg_text_repr = [get_text_repr(neg_text[idx]) for idx in range(T_max)] + neg_text_repr = torch.stack(neg_text_repr) + for T_idx, T in enumerate(args.T_list): + _, _, acc = do_CL_eval(molecule_repr, text_repr, neg_text_repr[:T-1], args) + accum_acc_list[T_idx] += acc + else: + raise Exception + + accum_acc_list = np.array(accum_acc_list) + accum_acc_list /= len(dataloader) + return accum_acc_list + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--device", type=int, default=0) + parser.add_argument("--SSL_emb_dim", type=int, default=256) + parser.add_argument("--text_type", type=str, default="SciBERT", choices=["SciBERT", "BioBERT"]) + parser.add_argument("--load_latent_projector", type=int, default=1) + parser.add_argument("--model_loading_mode", type=str, default="load_from_latest", choices=["load_from_latest", "load_mode_0", "load_mode_1"]) + parser.add_argument("--training_mode", type=str, default="zero_shot", choices=["zero_shot"]) + + ########## for dataset and split ########## + parser.add_argument("--dataspace_path", type=str, default="../data") + parser.add_argument("--task", type=str, default="molecule_description", + choices=[ + "molecule_description", "molecule_description_Raw", + "molecule_description_removed_PubChem", "molecule_description_removed_PubChem_Raw", + "molecule_pharmacodynamics", "molecule_pharmacodynamics_Raw", + "molecule_pharmacodynamics_removed_PubChem", "molecule_pharmacodynamics_removed_PubChem_Raw"]) + parser.add_argument("--test_mode", type=str, default="given_text", choices=["given_text", "given_molecule"]) + + ########## for optimization ########## + parser.add_argument("--T_list", type=int, nargs="+", default=[4, 10, 20]) + parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("--num_workers", type=int, default=8) + parser.add_argument("--epochs", type=int, default=1) + parser.add_argument("--text_lr", type=float, default=1e-5) + parser.add_argument("--mol_lr", type=float, default=1e-5) + parser.add_argument("--text_lr_scale", type=float, default=0.1) + parser.add_argument("--mol_lr_scale", type=float, default=0.1) + parser.add_argument("--decay", type=float, default=0) + + ########## for contrastive objective ########## + parser.add_argument("--SSL_loss", type=str, default="EBM_NCE", choices=["EBM_NCE", "InfoNCE"]) + parser.add_argument("--CL_neg_samples", type=int, default=1) + parser.add_argument("--T", type=float, default=0.1) + parser.add_argument('--normalize', dest='normalize', action='store_true') + parser.add_argument('--no_normalize', dest='normalize', action='store_false') + parser.set_defaults(normalize=True) + + ########## for BERT model ########## + parser.add_argument("--max_seq_len", type=int, default=512) + + ########## for molecule model ########## + parser.add_argument("--molecule_type", type=str, default="SMILES", choices=["SMILES", "Graph"]) + + ########## for MegaMolBART ########## + parser.add_argument("--vocab_path", type=str, default="../MoleculeSTM/bart_vocab.txt") + + ########## for 2D GNN ########## + parser.add_argument("--gnn_emb_dim", type=int, default=300) + parser.add_argument("--num_layer", type=int, default=5) + parser.add_argument('--JK', type=str, default='last') + parser.add_argument("--dropout_ratio", type=float, default=0.5) + parser.add_argument("--gnn_type", type=str, default="gin") + parser.add_argument('--graph_pooling', type=str, default='mean') + + ########## for saver ########## + parser.add_argument("--eval_train", type=int, default=0) + parser.add_argument("--verbose", type=int, default=0) + + parser.add_argument("--input_model_dir", type=str, default=None) + parser.add_argument("--input_model_path", type=str, default=None) + parser.add_argument("--output_model_dir", type=str, default=None) + + args = parser.parse_args() + print("arguments\t", args) + torch.multiprocessing.set_sharing_strategy('file_system') + + np.random.seed(args.seed) + torch.random.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + device = torch.device("cuda:" + str(args.device)) \ + if torch.cuda.is_available() else torch.device("cpu") + + ##### prepare text model ##### + ##### by default, this is load_mode_1 ##### + if args.text_type == "SciBERT": + pretrained_SciBERT_folder = os.path.join(args.dataspace_path, 'pretrained_SciBERT') + text_tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased', cache_dir=pretrained_SciBERT_folder) + # TODO: check https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py#L1501 + text_model = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased', cache_dir=pretrained_SciBERT_folder).to(device) + text_dim = 768 + else: + raise Exception + + if args.model_loading_mode == "load_from_latest": + input_model_path = os.path.join(args.input_model_dir, "text_model.pth") + print("Loading from {}...".format(input_model_path)) + state_dict = torch.load(input_model_path, map_location='cpu') + text_model.load_state_dict(state_dict) + elif args.model_loading_mode == "load_mode_0": + text_model.init_weights() + print("Random init for BERT.") + + ##### prepare molecule model ##### + if args.molecule_type == "SMILES": + if args.model_loading_mode == "load_from_latest": + input_model_path = os.path.join(args.input_model_dir, "molecule_model.pth") + print("Loading from {}...".format(input_model_path)) + MegaMolBART_wrapper = MegaMolBART(vocab_path=args.vocab_path, input_dir=None, output_dir=None) + molecule_model = MegaMolBART_wrapper.model + state_dict = torch.load(input_model_path, map_location='cpu') + molecule_model.load_state_dict(state_dict) + elif args.model_loading_mode == "load_mode_0": + MegaMolBART_wrapper = MegaMolBART(vocab_path=args.vocab_path, input_dir=None, output_dir=None) + molecule_model = MegaMolBART_wrapper.model + print("Random init for MegaMolBART.") + elif args.model_loading_mode == "load_mode_1": + # This is loading from the pretarined_MegaMolBART + # --input_model_dir=../data/pretrained_MegaMolBART/checkpoints + MegaMolBART_wrapper = MegaMolBART(input_dir="../data/pretrained_MegaMolBART/checkpoints", output_dir=None) + molecule_model = MegaMolBART_wrapper.model + print("Loading from ../data/pretrained_MegaMolBART/checkpoint.") + molecule_dim = 256 + + else: + molecule_node_model = GNN( + num_layer=args.num_layer, emb_dim=args.gnn_emb_dim, + JK=args.JK, drop_ratio=args.dropout_ratio, + gnn_type=args.gnn_type) + molecule_model = GNN_graphpred( + num_layer=args.num_layer, emb_dim=args.gnn_emb_dim, JK=args.JK, graph_pooling=args.graph_pooling, + num_tasks=1, molecule_node_model=molecule_node_model) + molecule_dim = args.gnn_emb_dim + if args.model_loading_mode == "load_from_latest": + input_model_path = os.path.join(args.input_model_dir, "molecule_model.pth") + print("Loading from {}...".format(input_model_path)) + state_dict = torch.load(input_model_path, map_location='cpu') + molecule_model.load_state_dict(state_dict) + elif args.model_loading_mode == "load_mode_0": + print("Random init for GNN.") + elif args.model_loading_mode == "load_mode_1": + print("Loading from ../data/pretrained_GraphMVP/GraphMVP_G/model.pth") + molecule_model.from_pretrained("../data/pretrained_GraphMVP/GraphMVP_G/model.pth") + + # Rewrite the seed by MegaMolBART + np.random.seed(args.seed) + torch.random.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + + text2latent = nn.Linear(text_dim, args.SSL_emb_dim) + if args.model_loading_mode == "load_from_latest": + input_model_path = os.path.join(args.input_model_dir, "text2latent_model.pth") + print("Loading from {}...".format(input_model_path)) + state_dict = torch.load(input_model_path, map_location='cpu') + text2latent.load_state_dict(state_dict) + + mol2latent = nn.Linear(molecule_dim, args.SSL_emb_dim) + if args.model_loading_mode == "load_from_latest": + input_model_path = os.path.join(args.input_model_dir, "mol2latent_model.pth") + print("Loading from {}...".format(input_model_path)) + state_dict = torch.load(input_model_path, map_location='cpu') + mol2latent.load_state_dict(state_dict) + + text_model = text_model.to(device) + molecule_model = molecule_model.to(device) + text2latent = text2latent.to(device) + mol2latent = mol2latent.to(device) + + T_max = max(args.T_list) - 1 + + initial_test_acc_list = [] + test_mode = args.test_mode + dataset_folder = os.path.join(args.dataspace_path, "DrugBank_data") + if args.molecule_type == "SMILES": + dataset_class = DrugBank_Datasets_SMILES_retrieval + dataloader_class = torch_DataLoader + + if args.task == "molecule_description": + template = "SMILES_description_{}.txt" + elif args.task == "molecule_description_removed_PubChem": + template = "SMILES_description_removed_from_PubChem_{}.txt" + elif args.task == "molecule_description_Raw": + template = "SMILES_description_{}_Raw.txt" + elif args.task == "molecule_description_removed_PubChem_Raw": + template = "SMILES_description_removed_from_PubChem_{}_Raw.txt" + elif args.task == "molecule_pharmacodynamics": + template = "SMILES_pharmacodynamics_{}.txt" + elif args.task == "molecule_pharmacodynamics_removed_PubChem": + template = "SMILES_pharmacodynamics_removed_from_PubChem_{}.txt" + elif args.task == "molecule_pharmacodynamics_Raw": + template = "SMILES_pharmacodynamics_{}_Raw.txt" + elif args.task == "molecule_pharmacodynamics_removed_PubChem_Raw": + template = "SMILES_pharmacodynamics_removed_from_PubChem_{}_Raw.txt" + + full_dataset = dataset_class(dataset_folder, 'full', neg_sample_size=T_max, template=template) + + else: + dataset_class = DrugBank_Datasets_Graph_retrieval + dataloader_class = pyg_DataLoader + processed_dir_prefix = args.task + + if args.task == "molecule_description": + template = "SMILES_description_{}.txt" + elif args.task == "molecule_description_removed_PubChem": + template = "SMILES_description_removed_from_PubChem_{}.txt" + elif args.task == "molecule_description_Raw": + template = "SMILES_description_{}_Raw.txt" + elif args.task == "molecule_description_removed_PubChem_Raw": + template = "SMILES_description_removed_from_PubChem_{}_Raw.txt" + elif args.task == "molecule_pharmacodynamics": + template = "SMILES_pharmacodynamics_{}.txt" + elif args.task == "molecule_pharmacodynamics_removed_PubChem": + template = "SMILES_pharmacodynamics_removed_from_PubChem_{}.txt" + elif args.task == "molecule_pharmacodynamics_Raw": + template = "SMILES_pharmacodynamics_{}_Raw.txt" + elif args.task == "molecule_pharmacodynamics_removed_PubChem_Raw": + template = "SMILES_pharmacodynamics_removed_from_PubChem_{}_Raw.txt" + + full_dataset = dataset_class(dataset_folder, 'full', neg_sample_size=T_max, processed_dir_prefix=processed_dir_prefix, template=template) + + full_dataloader = dataloader_class(full_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) # The program will get blcoked with none-zero num_workers + + initial_test_acc_list = eval_epoch(full_dataloader) + print('Initial', initial_test_acc_list) + + row = ", ".join(["{:.4f}".format(x * 100) for x in initial_test_acc_list]) + print("initial results,", row) + \ No newline at end of file diff --git a/scripts/downstream_01_retrieval_Description_Pharmacodynamics_KV-PLM.py b/scripts/downstream_01_retrieval_Description_Pharmacodynamics_KV-PLM.py new file mode 100644 index 0000000..3ada83e --- /dev/null +++ b/scripts/downstream_01_retrieval_Description_Pharmacodynamics_KV-PLM.py @@ -0,0 +1,227 @@ +import os +import time +import argparse +import numpy as np +import pandas as pd +from tqdm import tqdm +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +from torch.utils.data import DataLoader as torch_DataLoader +from torch_geometric.loader import DataLoader as pyg_DataLoader + +from transformers import AutoModel, AutoTokenizer +from MoleculeSTM.datasets import DrugBank_Datasets_SMILES_retrieval, DrugBank_Datasets_Graph_retrieval +from MoleculeSTM.models.mega_molbart.mega_mol_bart import MegaMolBART +from MoleculeSTM.models import GNN, GNN_graphpred +from MoleculeSTM.utils import prepare_text_tokens +from transformers import BertTokenizer, BertForPreTraining + + +class BigModel(nn.Module): + def __init__(self, main_model): + super(BigModel, self).__init__() + self.main_model = main_model + self.dropout = nn.Dropout(0.1) + + def forward(self, tok, att, cud=True): + typ = torch.zeros(tok.shape).long() + if cud: + typ = typ.cuda() + pooled_output = self.main_model(tok, token_type_ids=typ, attention_mask=att)['pooler_output'] + logits = self.dropout(pooled_output) + return logits + + +def do_CL_eval(X, Y, neg_Y, args): + X = F.normalize(X, dim=-1) + X = X.unsqueeze(1) # B, 1, d + + Y = Y.unsqueeze(0) + Y = torch.cat([Y, neg_Y], dim=0) # T, B, d + Y = Y.transpose(0, 1) # B, T, d + Y = F.normalize(Y, dim=-1) + + logits = torch.bmm(X, Y.transpose(1, 2)).squeeze() # B*T + B = X.size()[0] + labels = torch.zeros(B).long().to(logits.device) # B*1 + + criterion = nn.CrossEntropyLoss() + + CL_loss = criterion(logits, labels) + pred = logits.argmax(dim=1, keepdim=False) + confidence = logits + CL_conf = confidence.max(dim=1)[0] + CL_conf = CL_conf.cpu().numpy() + + CL_acc = pred.eq(labels).sum().detach().cpu().item() * 1. / B + return CL_loss, CL_conf, CL_acc + + +def get_text_and_SMILES_repr(text): + text_tokens_ids, text_masks = prepare_text_tokens( + device=device, description=text, tokenizer=tokenizer, max_seq_len=args.max_seq_len) + text_repr = model(text_tokens_ids, text_masks) + return text_repr + + +@torch.no_grad() +def eval_epoch(dataloader): + model.eval() + + accum_acc_list = [0 for _ in args.T_list] + if args.verbose: + L = tqdm(dataloader) + else: + L = dataloader + for batch in L: + text = batch[0] + molecule_data = batch[1] + neg_text = batch[2] + neg_molecule_data = batch[3] + + text_repr = get_text_and_SMILES_repr(text) + molecule_repr = get_text_and_SMILES_repr(molecule_data) + + if test_mode == "given_text": + if args.molecule_type == "SMILES": + neg_molecule_repr = [get_text_and_SMILES_repr(neg_molecule_data[idx]) for idx in range(T_max)] + neg_molecule_repr = torch.stack(neg_molecule_repr) + else: + neg_molecule_repr = [get_text_and_SMILES_repr(neg_molecule_data[idx]) for idx in range(T_max)] + neg_molecule_repr = torch.stack(neg_molecule_repr) + for T_idx, T in enumerate(args.T_list): + _, _, acc = do_CL_eval(text_repr, molecule_repr, neg_molecule_repr[:T-1], args) + accum_acc_list[T_idx] += acc + elif test_mode == "given_molecule": + neg_text_repr = [get_text_and_SMILES_repr(neg_text[idx]) for idx in range(T_max)] + neg_text_repr = torch.stack(neg_text_repr) + for T_idx, T in enumerate(args.T_list): + _, _, acc = do_CL_eval(molecule_repr, text_repr, neg_text_repr[:T-1], args) + accum_acc_list[T_idx] += acc + else: + raise Exception + + accum_acc_list = np.array(accum_acc_list) + accum_acc_list /= len(dataloader) + return accum_acc_list + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--device", type=int, default=0) + parser.add_argument("--SSL_emb_dim", type=int, default=256) + parser.add_argument("--text_type", type=str, default="SciBERT", choices=["SciBERT", "BioBERT"]) + parser.add_argument("--load_latent_projector", type=int, default=1) + parser.add_argument("--model_loading_mode", type=str, default="load_from_latest", choices=["load_from_latest", "load_mode_0", "load_mode_1"]) + parser.add_argument("--training_mode", type=str, default="zero_shot", choices=["zero_shot"]) + + ########## for dataset and split ########## + parser.add_argument("--dataspace_path", type=str, default="../data") + parser.add_argument("--task", type=str, default="molecule_description", + choices=[ + "molecule_description", "molecule_description_Raw", + "molecule_description_removed_PubChem", "molecule_description_removed_PubChem_Raw", + "molecule_pharmacodynamics", "molecule_pharmacodynamics_Raw", + "molecule_pharmacodynamics_removed_PubChem", "molecule_pharmacodynamics_removed_PubChem_Raw"]) + parser.add_argument("--test_mode", type=str, default="given_text", choices=["given_text", "given_molecule"]) + + ########## for optimization ########## + parser.add_argument("--T_list", type=int, nargs="+", default=[4, 10, 20]) + parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("--num_workers", type=int, default=8) + parser.add_argument("--epochs", type=int, default=1) + parser.add_argument("--lr", type=float, default=1e-5) + parser.add_argument("--decay", type=float, default=0) + + ########## for contrastive objective ########## + parser.add_argument("--SSL_loss", type=str, default="EBM_NCE", choices=["EBM_NCE", "InfoNCE"]) + parser.add_argument("--CL_neg_samples", type=int, default=1) + parser.add_argument("--T", type=float, default=0.1) + parser.add_argument('--normalize', dest='normalize', action='store_true') + parser.add_argument('--no_normalize', dest='normalize', action='store_false') + parser.set_defaults(normalize=True) + + ########## for BERT model ########## + parser.add_argument("--max_seq_len", type=int, default=512) + + ########## for molecule model ########## + parser.add_argument("--molecule_type", type=str, default="SMILES", choices=["SMILES", "Graph"]) + ########## for 2D GNN ########## + parser.add_argument("--gnn_emb_dim", type=int, default=300) + parser.add_argument("--num_layer", type=int, default=5) + parser.add_argument('--JK', type=str, default='last') + parser.add_argument("--dropout_ratio", type=float, default=0.5) + parser.add_argument("--gnn_type", type=str, default="gin") + parser.add_argument('--graph_pooling', type=str, default='mean') + + ########## for saver ########## + parser.add_argument("--eval_train", type=int, default=0) + parser.add_argument("--verbose", type=int, default=0) + + parser.add_argument("--input_model_dir", type=str, default=None) + parser.add_argument("--input_model_path", type=str, default=None) + parser.add_argument("--output_model_dir", type=str, default=None) + + args = parser.parse_args() + print("arguments\t", args) + torch.multiprocessing.set_sharing_strategy('file_system') + + np.random.seed(args.seed) + torch.random.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + device = torch.device("cuda:" + str(args.device)) \ + if torch.cuda.is_available() else torch.device("cpu") + + tokenizer = BertTokenizer.from_pretrained('allenai/scibert_scivocab_uncased') + + bert_model0 = BertForPreTraining.from_pretrained('allenai/scibert_scivocab_uncased') + model = BigModel(bert_model0.bert) + if torch.cuda.is_available(): + model.load_state_dict(torch.load('../data/pretrained_KV-PLM/ckpt_ret01.pt')) + model = model.cuda() + else: + model.load_state_dict(torch.load('../data/pretrained_KV-PLM/ckpt_ret01.pt', map_location=torch.device('cpu') )) + model.eval() + + T_max = max(args.T_list) - 1 + + initial_test_acc_list = [] + test_mode = args.test_mode + dataset_folder = os.path.join(args.dataspace_path, "DrugBank_data") + + dataset_class = DrugBank_Datasets_SMILES_retrieval + dataloader_class = torch_DataLoader + + if args.task == "molecule_description": + template = "SMILES_description_{}.txt" + elif args.task == "molecule_description_removed_PubChem": + template = "SMILES_description_removed_from_PubChem_{}.txt" + elif args.task == "molecule_description_Raw": + template = "SMILES_description_{}_Raw.txt" + elif args.task == "molecule_description_removed_PubChem_Raw": + template = "SMILES_description_removed_from_PubChem_{}_Raw.txt" + elif args.task == "molecule_pharmacodynamics": + template = "SMILES_pharmacodynamics_{}.txt" + elif args.task == "molecule_pharmacodynamics_removed_PubChem": + template = "SMILES_pharmacodynamics_removed_from_PubChem_{}.txt" + elif args.task == "molecule_pharmacodynamics_Raw": + template = "SMILES_pharmacodynamics_{}_Raw.txt" + elif args.task == "molecule_pharmacodynamics_removed_PubChem_Raw": + template = "SMILES_pharmacodynamics_removed_from_PubChem_{}_Raw.txt" + + full_dataset = dataset_class(dataset_folder, 'full', neg_sample_size=T_max, template=template) + + full_dataloader = dataloader_class(full_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) # The program will get blcoked with none-zero num_workers + + initial_test_acc_list = eval_epoch(full_dataloader) + print('Initial', initial_test_acc_list) + + row = ", ".join(["{:.4f}".format(x * 100) for x in initial_test_acc_list]) + print("initial results,", row) + \ No newline at end of file diff --git a/scripts/downstream_01_retrieval_Description_Pharmacodynamics_Retrieval.py b/scripts/downstream_01_retrieval_Description_Pharmacodynamics_Retrieval.py new file mode 100644 index 0000000..814c4f4 --- /dev/null +++ b/scripts/downstream_01_retrieval_Description_Pharmacodynamics_Retrieval.py @@ -0,0 +1,435 @@ +import os +import time +import argparse +import numpy as np +import pandas as pd +from tqdm import tqdm +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +from torch.utils.data import DataLoader as torch_DataLoader +from torch_geometric.loader import DataLoader as pyg_DataLoader + +from transformers import AutoModel, AutoTokenizer +from MoleculeSTM.datasets import PubChem_Datasets_SMILES, PubChem_Datasets_Graph, DrugBank_Datasets_SMILES_retrieval, DrugBank_Datasets_Graph_retrieval +from MoleculeSTM.models.mega_molbart.mega_mol_bart import MegaMolBART +from MoleculeSTM.models import GNN, GNN_graphpred +from MoleculeSTM.utils import prepare_text_tokens, get_molecule_repr_MoleculeSTM +from torch.utils.data import Dataset, DataLoader + + +class RetrievalDataset(Dataset): + def __init__(self, repr_array): + self.repr_array = repr_array + + def __len__(self): + return len(self.repr_array) + + def __getitem__(self, idx): + return torch.Tensor(self.repr_array[idx]) + + +def do_CL_eval(X, Y, neg_Y, args): + X = F.normalize(X, dim=-1) + X = X.unsqueeze(1) # B, 1, d + + Y = Y.unsqueeze(0) + Y = torch.cat([Y, neg_Y], dim=0) # T, B, d + Y = Y.transpose(0, 1) # B, T, d + Y = F.normalize(Y, dim=-1) + + logits = torch.bmm(X, Y.transpose(1, 2)).squeeze() # B*T + B = X.size()[0] + labels = torch.zeros(B).long().to(logits.device) # B*1 + + criterion = nn.CrossEntropyLoss() + + CL_loss = criterion(logits, labels) + pred = logits.argmax(dim=1, keepdim=False) + confidence = logits + CL_conf = confidence.max(dim=1)[0] + CL_conf = CL_conf.cpu().numpy() + + CL_acc = pred.eq(labels).sum().detach().cpu().item() * 1. / B + return CL_loss, CL_conf, CL_acc + + +def get_text_repr(text): + text_tokens_ids, text_masks = prepare_text_tokens( + device=device, description=text, tokenizer=text_tokenizer, max_seq_len=args.max_seq_len) + text_output = text_model(input_ids=text_tokens_ids, attention_mask=text_masks) + text_repr = text_output["pooler_output"] + return text_repr + + +@torch.no_grad() +def extract_retrieval_representation(retrieval_dataloader): + if args.verbose: + L = tqdm(retrieval_dataloader) + else: + L = retrieval_dataloader + + retrieval_molecule_repr_list, retrieval_description_representation_list = [], [] + for step, batch in enumerate(L): + description = batch[0] + molecule_data = batch[1] + + try: + description_tokens_ids, description_masks = prepare_text_tokens( + device=device, description=description, tokenizer=text_tokenizer, max_seq_len=args.max_seq_len) + description_output = text_model(input_ids=description_tokens_ids, attention_mask=description_masks) + description_repr = description_output["pooler_output"] + + if args.molecule_type == "SMILES": + molecule_data = list(molecule_data) # for SMILES_list + molecule_repr = get_molecule_repr_MoleculeSTM( + molecule_data, mol2latent=None, + molecule_type=args.molecule_type, MegaMolBART_wrapper=MegaMolBART_wrapper) + else: + molecule_data = molecule_data.to(device) + molecule_repr = get_molecule_repr_MoleculeSTM( + molecule_data, mol2latent=None, + molecule_type=args.molecule_type, molecule_model=molecule_model) + except: + continue + retrieval_description_representation_list.append(description_repr.detach().cpu().numpy()) + retrieval_molecule_repr_list.append(molecule_repr.detach().cpu().numpy()) + + retrieval_description_representation_array = np.concatenate(retrieval_description_representation_list) + retrieval_molecule_representation_array = np.concatenate(retrieval_molecule_repr_list) + + return retrieval_description_representation_array, retrieval_molecule_representation_array + + +def get_similarity_array(X, retrieval_loader): + sim_list = [] + if args.verbose: + L = tqdm(retrieval_loader) + else: + L = retrieval_loader + for batch in L: + batch = batch.to(device) + sim = torch.matmul(X, batch.transpose(1, 0)).detach().cpu().numpy() + sim_list.append(sim) + sim_array = np.concatenate(sim_list, axis=1) + return sim_array + + +@torch.no_grad() +def eval_epoch(dataloader): + text_model.eval() + molecule_model.eval() + + accum_acc_list = [0 for _ in args.T_list] + if args.verbose: + L = tqdm(dataloader) + else: + L = dataloader + for batch in L: + text = batch[0] + molecule_data = batch[1] + neg_text = batch[2] + neg_molecule_data = batch[3] + + text_repr = get_text_repr(text) + if args.molecule_type == "SMILES": + molecule_data = list(molecule_data) # for SMILES_list + molecule_repr = get_molecule_repr_MoleculeSTM( + molecule_data, mol2latent=None, + molecule_type="SMILES", MegaMolBART_wrapper=MegaMolBART_wrapper) + else: + molecule_data = molecule_data.to(device) + molecule_repr = get_molecule_repr_MoleculeSTM( + molecule_data, mol2latent=None, + molecule_type="Graph", molecule_model=molecule_model + ) + + if test_mode == "given_text": + if args.molecule_type == "SMILES": + neg_molecule_repr = [ + get_molecule_repr_MoleculeSTM( + list(neg_molecule_data[idx]), mol2latent=None, + molecule_type="SMILES", MegaMolBART_wrapper=MegaMolBART_wrapper) for idx in range(T_max) + ] + neg_molecule_repr = torch.stack(neg_molecule_repr) + else: + neg_molecule_repr = [ + get_molecule_repr_MoleculeSTM( + neg_molecule_data[idx].to(device), mol2latent=None, + molecule_type="Graph", molecule_model=molecule_model) for idx in range(T_max) + ] + neg_molecule_repr = torch.stack(neg_molecule_repr) + + # Next we will do the retrieval: + # text_repr -> retrieval_description_representation_array -> retrieval_molecule_representation_array + similarity_array = get_similarity_array(text_repr, retrieval_description_representation_dataloader) + batch_size = similarity_array.shape[0] + retrieved_text_repr_list = [] + for batch_i in range(batch_size): + temp_similarity_array = similarity_array[batch_i] + sorted_index = np.argsort(temp_similarity_array)[::-1] + optimal_index = sorted_index[0] + retrieved_text_repr_list.append(retrieval_molecule_representation_array[optimal_index]) + retrieved_text_repr_list = np.array(retrieved_text_repr_list) + retrieved_text_repr = torch.Tensor(retrieved_text_repr_list).to(device) + + for T_idx, T in enumerate(args.T_list): + _, _, acc = do_CL_eval(retrieved_text_repr, molecule_repr, neg_molecule_repr[:T-1], args) + accum_acc_list[T_idx] += acc + + elif test_mode == "given_molecule": + neg_text_repr = [get_text_repr(neg_text[idx]) for idx in range(T_max)] + neg_text_repr = torch.stack(neg_text_repr) + + # Next we will do the retrieval: + # molecule_repr -> retrieval_molecule_representation_array -> retrieval_description_representation_array + similarity_array = get_similarity_array(molecule_repr, retrieval_molecule_representation_dataloader) + batch_size = similarity_array.shape[0] + retrieved_mol_repr_list = [] + for batch_i in range(batch_size): + temp_similarity_array = similarity_array[batch_i] + sorted_index = np.argsort(temp_similarity_array)[::-1] + optimal_index = sorted_index[0] + retrieved_mol_repr_list.append(retrieval_description_representation_array[optimal_index]) + retrieved_mol_repr_list = np.array(retrieved_mol_repr_list) + retrieved_mol_repr = torch.Tensor(retrieved_mol_repr_list).to(device) + + for T_idx, T in enumerate(args.T_list): + _, _, acc = do_CL_eval(retrieved_mol_repr, text_repr, neg_text_repr[:T-1], args) + accum_acc_list[T_idx] += acc + else: + raise Exception + + accum_acc_list = np.array(accum_acc_list) + accum_acc_list /= len(dataloader) + return accum_acc_list + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--device", type=int, default=0) + parser.add_argument("--SSL_emb_dim", type=int, default=256) + parser.add_argument("--text_type", type=str, default="SciBERT", choices=["SciBERT", "BioBERT"]) + parser.add_argument("--load_latent_projector", type=int, default=1) + parser.add_argument("--model_loading_mode", type=str, default="load_from_latest", choices=["load_from_latest", "load_mode_0", "load_mode_1"]) + parser.add_argument("--training_mode", type=str, default="zero_shot", choices=["zero_shot"]) + parser.add_argument("--retrieval_folder", type=str, default="retrieval_similarity") + + ########## for dataset and split ########## + parser.add_argument("--dataspace_path", type=str, default="../data") + parser.add_argument("--task", type=str, default="molecule_description", + choices=[ + "molecule_description", "molecule_description_Raw", + "molecule_description_removed_PubChem", "molecule_description_removed_PubChem_Raw", + "molecule_pharmacodynamics", "molecule_pharmacodynamics_Raw", + "molecule_pharmacodynamics_removed_PubChem", "molecule_pharmacodynamics_removed_PubChem_Raw"]) + parser.add_argument("--test_mode", type=str, default="given_text", choices=["given_text", "given_molecule"]) + + ########## for optimization ########## + parser.add_argument("--T_list", type=int, nargs="+", default=[4, 10, 20]) + parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("--num_workers", type=int, default=8) + parser.add_argument("--epochs", type=int, default=1) + parser.add_argument("--text_lr", type=float, default=1e-5) + parser.add_argument("--mol_lr", type=float, default=1e-5) + parser.add_argument("--text_lr_scale", type=float, default=0.1) + parser.add_argument("--mol_lr_scale", type=float, default=0.1) + parser.add_argument("--decay", type=float, default=0) + + ########## for contrastive objective ########## + parser.add_argument("--SSL_loss", type=str, default="EBM_NCE", choices=["EBM_NCE", "InfoNCE"]) + parser.add_argument("--CL_neg_samples", type=int, default=1) + parser.add_argument("--T", type=float, default=0.1) + parser.add_argument('--normalize', dest='normalize', action='store_true') + parser.add_argument('--no_normalize', dest='normalize', action='store_false') + parser.set_defaults(normalize=True) + + ########## for BERT model ########## + parser.add_argument("--max_seq_len", type=int, default=512) + + ########## for molecule model ########## + parser.add_argument("--molecule_type", type=str, default="SMILES", choices=["SMILES", "Graph"]) + + ########## for MegaMolBART ########## + parser.add_argument("--vocab_path", type=str, default="../MoleculeSTM/bart_vocab.txt") + + ########## for 2D GNN ########## + parser.add_argument("--gnn_emb_dim", type=int, default=300) + parser.add_argument("--num_layer", type=int, default=5) + parser.add_argument('--JK', type=str, default='last') + parser.add_argument("--dropout_ratio", type=float, default=0.5) + parser.add_argument("--gnn_type", type=str, default="gin") + parser.add_argument('--graph_pooling', type=str, default='mean') + + ########## for saver ########## + parser.add_argument("--eval_train", type=int, default=0) + parser.add_argument("--verbose", type=int, default=0) + + parser.add_argument("--input_model_dir", type=str, default=None) + parser.add_argument("--input_model_path", type=str, default=None) + parser.add_argument("--output_model_dir", type=str, default=None) + + args = parser.parse_args() + print("arguments\t", args) + torch.multiprocessing.set_sharing_strategy('file_system') + + np.random.seed(args.seed) + torch.random.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + device = torch.device("cuda:" + str(args.device)) \ + if torch.cuda.is_available() else torch.device("cpu") + + ##### prepare text model ##### + ##### by default, this is load_mode_1 ##### + if args.text_type == "SciBERT": + pretrained_SciBERT_folder = os.path.join(args.dataspace_path, 'pretrained_SciBERT') + text_tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased', cache_dir=pretrained_SciBERT_folder) + # TODO: check https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py#L1501 + text_model = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased', cache_dir=pretrained_SciBERT_folder).to(device) + text_dim = 768 + else: + raise Exception + + if args.model_loading_mode == "load_from_latest": + input_model_path = os.path.join(args.input_model_dir, "text_model.pth") + print("Loading from {}...".format(input_model_path)) + state_dict = torch.load(input_model_path, map_location='cpu') + text_model.load_state_dict(state_dict) + elif args.model_loading_mode == "load_mode_0": + text_model.init_weights() + print("Random init for BERT.") + + ##### prepare molecule model ##### + if args.molecule_type == "SMILES": + if args.model_loading_mode == "load_from_latest": + input_model_path = os.path.join(args.input_model_dir, "molecule_model.pth") + print("Loading from {}...".format(input_model_path)) + MegaMolBART_wrapper = MegaMolBART(vocab_path=args.vocab_path, input_dir=None, output_dir=None) + molecule_model = MegaMolBART_wrapper.model + state_dict = torch.load(input_model_path, map_location='cpu') + molecule_model.load_state_dict(state_dict) + elif args.model_loading_mode == "load_mode_0": + MegaMolBART_wrapper = MegaMolBART(vocab_path=args.vocab_path, input_dir=None, output_dir=None) + molecule_model = MegaMolBART_wrapper.model + print("Random init for MegaMolBART.") + elif args.model_loading_mode == "load_mode_1": + # This is loading from the pretarined_MegaMolBART + # --input_model_dir=../data/pretrained_MegaMolBART/checkpoints + MegaMolBART_wrapper = MegaMolBART(input_dir="../data/pretrained_MegaMolBART/checkpoints", output_dir=None) + molecule_model = MegaMolBART_wrapper.model + print("Loading from ../data/pretrained_MegaMolBART/checkpoint.") + molecule_dim = 256 + + else: + molecule_node_model = GNN( + num_layer=args.num_layer, emb_dim=args.gnn_emb_dim, + JK=args.JK, drop_ratio=args.dropout_ratio, + gnn_type=args.gnn_type) + molecule_model = GNN_graphpred( + num_layer=args.num_layer, emb_dim=args.gnn_emb_dim, JK=args.JK, graph_pooling=args.graph_pooling, + num_tasks=1, molecule_node_model=molecule_node_model) + molecule_dim = args.gnn_emb_dim + if args.model_loading_mode == "load_from_latest": + input_model_path = os.path.join(args.input_model_dir, "molecule_model.pth") + print("Loading from {}...".format(input_model_path)) + state_dict = torch.load(input_model_path, map_location='cpu') + molecule_model.load_state_dict(state_dict) + elif args.model_loading_mode == "load_mode_0": + print("Random init for GNN.") + elif args.model_loading_mode == "load_mode_1": + print("Loading from ../data/pretrained_GraphMVP/GraphMVP_G/model.pth") + molecule_model.from_pretrained("../data/pretrained_GraphMVP/GraphMVP_G/model.pth") + + # Rewrite the seed by MegaMolBART + np.random.seed(args.seed) + torch.random.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + + text_model = text_model.to(device) + molecule_model = molecule_model.to(device) + + T_max = max(args.T_list) - 1 + + initial_test_acc_list = [] + test_mode = args.test_mode + dataset_folder = os.path.join(args.dataspace_path, "DrugBank_data") + if args.molecule_type == "SMILES": + dataset_class = DrugBank_Datasets_SMILES_retrieval + dataloader_class = torch_DataLoader + + if args.task == "molecule_description": + template = "SMILES_description_{}.txt" + elif args.task == "molecule_description_removed_PubChem": + template = "SMILES_description_removed_from_PubChem_{}.txt" + elif args.task == "molecule_description_Raw": + template = "SMILES_description_{}_Raw.txt" + elif args.task == "molecule_description_removed_PubChem_Raw": + template = "SMILES_description_removed_from_PubChem_{}_Raw.txt" + elif args.task == "molecule_pharmacodynamics": + template = "SMILES_pharmacodynamics_{}.txt" + elif args.task == "molecule_pharmacodynamics_removed_PubChem": + template = "SMILES_pharmacodynamics_removed_from_PubChem_{}.txt" + elif args.task == "molecule_pharmacodynamics_Raw": + template = "SMILES_pharmacodynamics_{}_Raw.txt" + elif args.task == "molecule_pharmacodynamics_removed_PubChem_Raw": + template = "SMILES_pharmacodynamics_removed_from_PubChem_{}_Raw.txt" + full_dataset = dataset_class(dataset_folder, 'full', neg_sample_size=T_max, template=template) + + dataset_root = os.path.join(args.dataspace_path, "PubChem_data") + retrieval_dataset = PubChem_Datasets_SMILES(dataset_root) + + else: + dataset_class = DrugBank_Datasets_Graph_retrieval + dataloader_class = pyg_DataLoader + processed_dir_prefix = args.task + + if args.task == "molecule_description": + template = "SMILES_description_{}.txt" + elif args.task == "molecule_description_removed_PubChem": + template = "SMILES_description_removed_from_PubChem_{}.txt" + elif args.task == "molecule_description_Raw": + template = "SMILES_description_{}_Raw.txt" + elif args.task == "molecule_description_removed_PubChem_Raw": + template = "SMILES_description_removed_from_PubChem_{}_Raw.txt" + elif args.task == "molecule_pharmacodynamics": + template = "SMILES_pharmacodynamics_{}.txt" + elif args.task == "molecule_pharmacodynamics_removed_PubChem": + template = "SMILES_pharmacodynamics_removed_from_PubChem_{}.txt" + elif args.task == "molecule_pharmacodynamics_Raw": + template = "SMILES_pharmacodynamics_{}_Raw.txt" + elif args.task == "molecule_pharmacodynamics_removed_PubChem_Raw": + template = "SMILES_pharmacodynamics_removed_from_PubChem_{}_Raw.txt" + full_dataset = dataset_class(dataset_folder, 'full', neg_sample_size=T_max, processed_dir_prefix=processed_dir_prefix, template=template) + + dataset_root = os.path.join(args.dataspace_path, "PubChem_data") + retrieval_dataset = PubChem_Datasets_Graph(dataset_root) + + full_dataloader = dataloader_class(full_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) # The program will get blcoked with none-zero num_workers + retrieval_dataloader = dataloader_class(retrieval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers) + + os.makedirs(args.retrieval_folder, exist_ok=True) + retrieval_datapath = "{}/{}_{}".format(args.retrieval_folder, args.molecule_type, args.task) + if os.path.exists(retrieval_datapath+".npz"): + data = np.load(retrieval_datapath+".npz") + retrieval_description_representation_array = data["retrieval_description_representation_array"] + retrieval_molecule_representation_array = data["retrieval_molecule_representation_array"] + else: + retrieval_description_representation_array, retrieval_molecule_representation_array = extract_retrieval_representation(retrieval_dataloader) + np.savez(retrieval_datapath, retrieval_description_representation_array=retrieval_description_representation_array, retrieval_molecule_representation_array=retrieval_molecule_representation_array) + retrieval_description_representation_dataset = RetrievalDataset(retrieval_description_representation_array) + retrieval_description_representation_dataloader = DataLoader(retrieval_description_representation_dataset, batch_size=512, shuffle=False, num_workers=args.num_workers) + retrieval_molecule_representation_dataset = RetrievalDataset(retrieval_molecule_representation_array) + retrieval_molecule_representation_dataloader = DataLoader(retrieval_molecule_representation_dataset, batch_size=512, shuffle=False, num_workers=args.num_workers) + + initial_test_acc_list = eval_epoch(full_dataloader) + print('Initial', initial_test_acc_list) + + row = ", ".join(["{:.4f}".format(x * 100) for x in initial_test_acc_list]) + print("initial results,", row) + \ No newline at end of file diff --git a/scripts/downstream_02_molecule_edit_step_01_MoleculeSTM_Space_Alignment.py b/scripts/downstream_02_molecule_edit_step_01_MoleculeSTM_Space_Alignment.py new file mode 100644 index 0000000..906be3c --- /dev/null +++ b/scripts/downstream_02_molecule_edit_step_01_MoleculeSTM_Space_Alignment.py @@ -0,0 +1,281 @@ +import argparse +import os +import numpy as np +from tqdm import tqdm +import time + +import torch +import torch.nn as nn +from torch import optim +import torch.nn.functional as F +from torch.utils.data import DataLoader as torch_DataLoader +from torch_geometric.loader import DataLoader as pyg_DataLoader + +from MoleculeSTM.utils import get_molecule_repr_MoleculeSTM +from MoleculeSTM.models import MLP +from MoleculeSTM.downstream_molecule_edit_utils import load_molecule_models +from MoleculeSTM.utils import freeze_network +from MoleculeSTM.datasets import ZINC250K_Dataset_SMILES, ZINC250K_Dataset_Graph + + +def cycle_index(num, shift): + arr = torch.arange(num) + shift + arr[-shift:] = torch.arange(shift) + return arr + + +def do_CL(X, Y, args): + if args.normalize: + X = F.normalize(X, dim=-1) + Y = F.normalize(Y, dim=-1) + + if args.SSL_loss == 'EBM_NCE': + criterion = nn.BCEWithLogitsLoss() + neg_Y = torch.cat([Y[cycle_index(len(Y), i + 1)] for i in range(args.CL_neg_samples)], dim=0) + neg_X = X.repeat((args.CL_neg_samples, 1)) + + pred_pos = torch.sum(X * Y, dim=1) / args.T + pred_neg = torch.sum(neg_X * neg_Y, dim=1) / args.T + + loss_pos = criterion(pred_pos, torch.ones(len(pred_pos)).to(pred_pos.device)) + loss_neg = criterion(pred_neg, torch.zeros(len(pred_neg)).to(pred_neg.device)) + SSL_loss = (loss_pos + args.CL_neg_samples * loss_neg) / (1 + args.CL_neg_samples) + + SSL_acc = (torch.sum(pred_pos > 0).float() + torch.sum(pred_neg < 0).float()) / \ + (len(pred_pos) + len(pred_neg)) + SSL_acc = SSL_acc.detach().cpu().item() + + elif args.SSL_loss == 'InfoNCE': + criterion = nn.CrossEntropyLoss() + B = X.size()[0] + logits = torch.mm(X, Y.transpose(1, 0)) # B*B + logits = torch.div(logits, args.T) + labels = torch.arange(B).long().to(logits.device) # B*1 + + SSL_loss = criterion(logits, labels) + pred = logits.argmax(dim=1, keepdim=False) + SSL_acc = pred.eq(labels).sum().detach().cpu().item() * 1. / B + + elif args.SSL_loss == 'RR': + criterion = nn.MSELoss() + SSL_loss = criterion(X, Y) + SSL_acc = 0 + + else: + raise Exception + + return SSL_loss, SSL_acc + +def mean_pooling(token_embeddings, attention_mask): + attention_mask = ~attention_mask + input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() # [pad, B, d] + sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 0) # [B, d] + sum_mask = torch.clamp(input_mask_expanded.sum(0), min=1e-9) # [B, d] + return sum_embeddings / sum_mask + + +def get_molecule_repr_generation(molecule_data, molecule_model, molecule_type="MegaMolBART", MegaMolBART_wrapper=None): + if molecule_type == "MegaMolBART": + embedding, pad_mask = MegaMolBART_wrapper.smileslist2embedding_model_given(molecule_model, molecule_data) # [pad, B, d], [pad, B] + molecule_repr = mean_pooling(embedding, pad_mask) + else: + molecule_repr, _ = molecule_model(molecule_data) + return molecule_repr + + +def save_model(save_best, epoch=None): + if args.output_model_dir is not None: + if save_best: + global optimal_loss + print("save model with loss: {:.5f}".format(optimal_loss)) + model_file = "model.pth" + + elif epoch is None: + model_file = "model_final.pth" + + else: + model_file = "model_{}.pth".format(epoch) + + saved_file_path = os.path.join(args.output_model_dir, "generation2MoleculeSTM_{}".format(model_file)) + torch.save(generation2MoleculeSTM.state_dict(), saved_file_path) + + saved_file_path = os.path.join(args.output_model_dir, "MoleculeSTM2generation_{}".format(model_file)) + torch.save(MoleculeSTM2generation.state_dict(), saved_file_path) + return + + +def train(epoch): + if args.verbose: + L = tqdm(dataloader) + else: + L = dataloader + + start_time = time.time() + accum_loss, accum_acc = 0, 0 + for batch in L: + if args.MoleculeSTM_molecule_type == "SMILES": + SMILES_list = batch + else: + SMILES_list, graph = batch + graph = graph.to(device) + + if args.MoleculeSTM_molecule_type == "SMILES": + molecule_repr_MoleculeSTM = get_molecule_repr_MoleculeSTM( + SMILES_list, molecule_model=molecule_model_MoleculeSTM, mol2latent=mol2latent_MoleculeSTM, + molecule_type=args.MoleculeSTM_molecule_type, MegaMolBART_wrapper=MegaMolBART_wrapper + ) + molecule_repr_MoleculeSTM2generation = MoleculeSTM2generation(molecule_repr_MoleculeSTM) + + else: + molecule_repr_MoleculeSTM = get_molecule_repr_MoleculeSTM( + graph, molecule_model=molecule_model_MoleculeSTM, mol2latent=mol2latent_MoleculeSTM, + molecule_type=args.MoleculeSTM_molecule_type, MegaMolBART_wrapper=None + ) + molecule_repr_MoleculeSTM2generation = MoleculeSTM2generation(molecule_repr_MoleculeSTM) + + if args.generation_model == "MegaMolBART": + molecule_repr_generation = get_molecule_repr_generation( + SMILES_list, molecule_model=molecule_model_generation, + molecule_type="MegaMolBART", MegaMolBART_wrapper=MegaMolBART_wrapper + ) + else: # for HierVAE + hiervae_data_list = MolGraph.tensorize(SMILES_list, vocab, avocab) + molecule_repr_generation = molecule_model_generation.forward_MoleculeSTM(hiervae_data_list) + molecule_repr_generation2MoleculeSTM = generation2MoleculeSTM(molecule_repr_generation) + + loss_01, acc_01 = do_CL(molecule_repr_generation, molecule_repr_MoleculeSTM2generation, args) + loss_02, acc_02 = do_CL(molecule_repr_MoleculeSTM, molecule_repr_generation2MoleculeSTM, args) + loss = (loss_01 + loss_02) / 2 + acc = (acc_01 + acc_02) / 2 + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + accum_loss += loss.item() + accum_acc += acc + + accum_loss /= len(L) + accum_acc /= len(L) + + global optimal_loss + temp_loss = accum_loss + if temp_loss < optimal_loss: + optimal_loss = temp_loss + save_model(save_best=True, epoch=epoch) + print("SSL Loss: {:.5f}\tSSL Acc: {:.5f}\tTime: {:.5f}".format(accum_loss, accum_acc, time.time() - start_time)) + return + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--device", type=int, default=0) + parser.add_argument("--verbose", type=int, default=1) + parser.add_argument("--dataspace_path", type=str, default="../data") + parser.add_argument("--dataset", type=str, default="ZINC250K") + parser.add_argument("--MoleculeSTM_molecule_type", type=str, default="SMILES", choices=["SMILES", "Graph"]) + parser.add_argument("--output_model_dir", type=str, default=None) + + ########## for MoleculeSTM ########## + parser.add_argument("--MoleculeSTM_model_dir", type=str, default="../../pretrained_model") + parser.add_argument("--SSL_emb_dim", type=int, default=256) + ########## for 2D GNN ########## + parser.add_argument("--gnn_emb_dim", type=int, default=300) + parser.add_argument("--num_layer", type=int, default=5) + parser.add_argument('--JK', type=str, default='last') + parser.add_argument("--dropout_ratio", type=float, default=0.5) + parser.add_argument("--gnn_type", type=str, default="gin") + parser.add_argument('--graph_pooling', type=str, default='mean') + + ########## for generation ########## + parser.add_argument('--generation_model', type=str, default="MegaMolBART", choices=["MegaMolBART"]) + + ######### for MegaMolBART ########## + parser.add_argument("--MegaMolBART_generation_model_dir", type=str, default="../data/pretrained_MegaMolBART/checkpoints") + parser.add_argument("--vocab_path", type=str, default="../MoleculeSTM/bart_vocab.txt") + + ########## for optimization ########## + parser.add_argument("--batch_size", type=int, default=256) + parser.add_argument("--num_workers", type=int, default=8) + parser.add_argument("--epochs", type=int, default=100) + parser.add_argument("--decay", type=float, default=0) + parser.add_argument("--generation_lr", type=float, default=1e-2) + parser.add_argument("--MoleculeSTM_lr", type=float, default=1e-2) + parser.add_argument("--T", type=float, default=0.1) + parser.add_argument("--SSL_loss", type=str, default="EBM_NCE", choices=["EBM_NCE", "InfoNCE", "RR"]) + parser.add_argument("--CL_neg_samples", type=int, default=1) + parser.add_argument('--use_normalize', dest='normalize', action='store_true') + parser.add_argument('--no_normalize', dest='normalize', action='store_false') + parser.set_defaults(normalize=True) + + args = parser.parse_args() + print(args) + + if args.generation_model == "MegaMolBART": + if args.MoleculeSTM_molecule_type == "SMILES": + if args.dataset == "ZINC250K": + dataset_root = os.path.join(args.dataspace_path, "ZINC250K_data") + dataset = ZINC250K_Dataset_SMILES(dataset_root) + elif args.dataset == "ZINC250K1K": + dataset_root = os.path.join(args.dataspace_path, "ZINC250K_data") + dataset = ZINC250K_Dataset_SMILES(dataset_root, 1000) + elif args.dataset == "ZINC250K10K": + dataset_root = os.path.join(args.dataspace_path, "ZINC250K_data") + dataset = ZINC250K_Dataset_SMILES(dataset_root, 10000) + else: + raise Exception + dataloader_class = torch_DataLoader + else: + if args.dataset == "ZINC250K": + dataset_root = os.path.join(args.dataspace_path, "ZINC250K_data") + dataset = ZINC250K_Dataset_Graph(dataset_root) + elif args.dataset == "ZINC250K1K": + dataset_root = os.path.join(args.dataspace_path, "ZINC250K_data") + dataset = ZINC250K_Dataset_Graph(dataset_root, 1000) + elif args.dataset == "ZINC250K10K": + dataset_root = os.path.join(args.dataspace_path, "ZINC250K_data") + dataset = ZINC250K_Dataset_Graph(dataset_root, 10000) + else: + raise Exception + dataloader_class = pyg_DataLoader + else: + raise NotImplementedError + + MegaMolBART_wrapper, molecule_model_generation, molecule_dim_generation, \ + molecule_model_MoleculeSTM, mol2latent_MoleculeSTM, molecule_dim_MoleculeSTM = load_molecule_models(args) + device = torch.device("cuda:" + str(args.device)) \ + if torch.cuda.is_available() else torch.device("cpu") + molecule_model_generation = molecule_model_generation.to(device) + molecule_model_MoleculeSTM = molecule_model_MoleculeSTM.to(device) + mol2latent_MoleculeSTM = mol2latent_MoleculeSTM.to(device) + + np.random.seed(args.seed) + torch.random.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + device = torch.device("cuda:" + str(args.device)) \ + if torch.cuda.is_available() else torch.device("cpu") + + freeze_network(molecule_model_generation) + freeze_network(mol2latent_MoleculeSTM) + freeze_network(molecule_model_MoleculeSTM) + molecule_model_generation.eval() + mol2latent_MoleculeSTM.eval() + molecule_model_MoleculeSTM.eval() + + dataloader = dataloader_class(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) + + generation2MoleculeSTM = MLP(molecule_dim_generation, [molecule_dim_MoleculeSTM, molecule_dim_MoleculeSTM]).to(device) + MoleculeSTM2generation = MLP(molecule_dim_MoleculeSTM, [molecule_dim_generation, molecule_dim_generation]).to(device) + + model_param_group = [ + {"params": generation2MoleculeSTM.parameters(), "lr": args.generation_lr}, + {"params": MoleculeSTM2generation.parameters(), "lr": args.MoleculeSTM_lr}, + ] + optimizer = optim.Adam(model_param_group, weight_decay=args.decay) + optimal_loss = 1e10 + + for e in range(1, args.epochs+1): + print("Epoch {}".format(e)) + train(e) diff --git a/scripts/downstream_02_molecule_edit_step_02_GA.py b/scripts/downstream_02_molecule_edit_step_02_GA.py new file mode 100644 index 0000000..8742a0b --- /dev/null +++ b/scripts/downstream_02_molecule_edit_step_02_GA.py @@ -0,0 +1,111 @@ +import argparse +from curses import tparm +import numpy as np +import random +import os + +import torch +from MoleculeSTM.downstream_molecule_edit_utils import get_SMILES_list, get_description_list, evaluate_SMILES_list +from rdkit import Chem +import MoleculeSTM.models.GA.mutate as mu + + +def check_edit(SMILES, text): + first_and_second_SMILES_list = [] + + first_and_second_SMILES_list.append(SMILES) + first_and_second_SMILES_list.append(SMILES) + + alpha_list = [1] + result_SMILES_list_one_pair, result_eval_list_one_pair = [], [] + mol = Chem.MolFromSmiles(SMILES) + + for alpha in alpha_list: + print("alpha: {}".format(alpha)) + current_SMILES_list = [first_and_second_SMILES_list[0]] + [first_and_second_SMILES_list[1]] + + mutated_mol = mol + for _ in range(args.mutation_step): + try: + while True: + mutated_mol = mu.mutate(mutated_mol, args.mutation_rate) + if mutated_mol is not None: + break + except: + mutated_mol = mol + + generated_SMILES = Chem.MolToSmiles(mutated_mol) + current_SMILES_list.append(generated_SMILES) + result_SMILES_list_one_pair.append([text] + current_SMILES_list + ['{}'.format(alpha)]) + + current_result_list = evaluate_SMILES_list(current_SMILES_list, text) + result_eval_list_one_pair.append(current_result_list) + print() + + result_eval_list_one_pair = np.array(result_eval_list_one_pair) + result_eval_list_one_pair = np.any(result_eval_list_one_pair, axis=0, keepdims=True) + print("result_eval_list_one_pair\n", result_eval_list_one_pair) + return result_SMILES_list_one_pair, result_eval_list_one_pair + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--device", type=int, default=0) + parser.add_argument("--verbose", type=int, default=1) + + ########## for editing ########## + parser.add_argument("--input_description", type=str, default=None) + parser.add_argument("--input_description_id", type=int, default=None) + parser.add_argument("--input_SMILES", type=str, default=None) + parser.add_argument("--input_SMILES_file", type=str, default="../data/Editing_data/single_multi_property_SMILES.txt") + parser.add_argument("--output_model_dir", type=str, default=None) + parser.add_argument("--variance", type=float, default=1) + parser.add_argument("--mutation_rate", type=float, default=1) + parser.add_argument("--mutation_step", type=int, default=1) + + args = parser.parse_args() + + print(args) + + device = torch.device("cuda:" + str(args.device)) \ + if torch.cuda.is_available() else torch.device("cpu") + + random.seed(args.seed) + np.random.seed(args.seed) + torch.random.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + device = torch.device("cuda:" + str(args.device)) \ + if torch.cuda.is_available() else torch.device("cpu") + + print("\n\n\nstart editing\n\n\n") + + source_SMILES_list = get_SMILES_list(args) + description_list = get_description_list(args) + + for description in description_list: + print("===== for description {} =====".format(description)) + result_SMILES_list, result_acc_list = [], [] + for SMILES in source_SMILES_list: + print("===== for SMILES {} =====".format(SMILES)) + result_SMILES_list_, result_acc_list_ = check_edit(SMILES, description) + result_SMILES_list.extend(result_SMILES_list_) + result_acc_list.append(result_acc_list_) + print("\n\n\n") + + result_acc_list = np.concatenate(result_acc_list, axis=0) + result_acc_list = np.sum(result_acc_list, axis=0) + result_acc_list = 100. * result_acc_list / len(source_SMILES_list) + result_acc_row = '\t'.join(['{}'.format(x) for x in result_acc_list]) + print("===== Accuracy =====\t{}".format(result_acc_row)) + + if args.output_model_dir is not None: + saver_file = os.path.join(args.output_model_dir, "edited_SMILES.tsv") + with open(saver_file, 'a') as f: + for row in result_SMILES_list: + row = "\t".join(row) + print(row, file=f) + + saver_file = os.path.join(args.output_model_dir, "accuracy") + np.savez(saver_file, result_acc_list) diff --git a/scripts/downstream_02_molecule_edit_step_02_High_Variance.py b/scripts/downstream_02_molecule_edit_step_02_High_Variance.py new file mode 100644 index 0000000..9b87d81 --- /dev/null +++ b/scripts/downstream_02_molecule_edit_step_02_High_Variance.py @@ -0,0 +1,148 @@ +import argparse +import math +import numpy as np +import os + +import torch +from MoleculeSTM.downstream_molecule_edit_utils import get_SMILES_list, get_description_list, evaluate_SMILES_list +from MoleculeSTM.models.mega_molbart.mega_mol_bart import MegaMolBART + + +def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05): + lr_ramp = min(1, (1 - t) / rampdown) + lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi) + lr_ramp = lr_ramp * min(1, t / rampup) + return initial_lr * lr_ramp + + +def mean_pooling(token_embeddings, attention_mask): + attention_mask = ~attention_mask + input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() # [pad, B, d] + sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 0) # [B, d] + sum_mask = torch.clamp(input_mask_expanded.sum(0), min=1e-9) # [B, d] + return sum_embeddings / sum_mask + + +def get_molecule_repr_generation(molecule_data, molecule_model, molecule_type="SMILES", MegaMolBART_wrapper=None): + if molecule_type == "SMILES": + embedding, pad_mask = MegaMolBART_wrapper.smileslist2embedding_model_given(molecule_model, molecule_data) # [pad, B, d], [pad, B] + molecule_repr = mean_pooling(embedding, pad_mask) + else: + molecule_repr, _ = molecule_model(molecule_data) + return molecule_repr + + +def check_edit(SMILES, text): + first_and_second_SMILES_list = [] + + SMILES_list = [SMILES] + latent_code_init, pad_mask_init = MegaMolBART_wrapper.smileslist2embedding(SMILES_list) # [pad, B, d], [pad, B] + first_and_second_SMILES_list.append(SMILES) + + generated_mols = MegaMolBART_wrapper.inverse_transform([latent_code_init], pad_mask_init.bool().cuda(), k=1, sanitize=True) + first_and_second_SMILES_list.append(generated_mols[0]) + + alpha_list = [ + 1.0, 1.5, 2.0, 2.5, 3.0 + ] + result_SMILES_list_one_pair, result_eval_list_one_pair = [], [] + + for alpha in alpha_list: + print("alpha: {}".format(alpha)) + current_SMILES_list = [first_and_second_SMILES_list[0]] + [first_and_second_SMILES_list[1]] + + latent = latent_code_init + alpha * direction / len(latent_code_init) + pad_mask = pad_mask_init + + generated_mols = MegaMolBART_wrapper.inverse_transform([latent], pad_mask.bool().cuda(), k=1, sanitize=True) + current_SMILES_list.append(generated_mols[0]) + result_SMILES_list_one_pair.append([text] + current_SMILES_list + ['{}'.format(alpha)]) + + current_result_list = evaluate_SMILES_list(current_SMILES_list, text) + result_eval_list_one_pair.append(current_result_list) + print() + + result_eval_list_one_pair = np.array(result_eval_list_one_pair) + result_eval_list_one_pair = np.any(result_eval_list_one_pair, axis=0, keepdims=True) + print("result_eval_list_one_pair\n", result_eval_list_one_pair) + return result_SMILES_list_one_pair, result_eval_list_one_pair + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--device", type=int, default=0) + parser.add_argument("--verbose", type=int, default=1) + + ########## for editing ########## + parser.add_argument("--input_description", type=str, default=None) + parser.add_argument("--input_description_id", type=int, default=None) + parser.add_argument("--input_SMILES", type=str, default=None) + parser.add_argument("--input_SMILES_file", type=str, default="../data/Editing_data/single_multi_property_SMILES.txt") + parser.add_argument("--output_model_dir", type=str, default=None) + + ########## for generation ########## + parser.add_argument("--generation_model_dir", type=str, default="../data/pretrained_MegaMolBART/checkpoints") + args = parser.parse_args() + + print(args) + MegaMolBART_wrapper = MegaMolBART(input_dir=args.generation_model_dir, output_dir=None) + molecule_model_generation = MegaMolBART_wrapper.model + print("Loading from pretrained MegaMolBART ({}).".format(args.generation_model_dir)) + molecule_dim_generation = 256 + + device = torch.device("cuda:" + str(args.device)) \ + if torch.cuda.is_available() else torch.device("cpu") + + np.random.seed(args.seed) + torch.random.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + device = torch.device("cuda:" + str(args.device)) \ + if torch.cuda.is_available() else torch.device("cpu") + + print("\n\n\nstart editing\n\n\n") + + source_SMILES_list = get_SMILES_list(args) + description_list = get_description_list(args) + + mol_repr = get_molecule_repr_generation( + source_SMILES_list, molecule_model=molecule_model_generation, + molecule_type="SMILES", MegaMolBART_wrapper=MegaMolBART_wrapper + ) + mol_repr = mol_repr.detach().cpu().numpy() + + # calculate variance + var_array = np.var(mol_repr, axis=0) + assert len(var_array) == mol_repr.shape[1] + var_index = np.argsort(var_array) + sampled_direction_idx = var_index[0] + direction = torch.zeros((mol_repr.shape[1])) + direction[sampled_direction_idx] = 1 + direction = direction.to(device) + + for description in description_list: + print("===== for description {} =====".format(description)) + result_SMILES_list, result_acc_list = [], [] + for SMILES in source_SMILES_list: + print("===== for SMILES {} =====".format(SMILES)) + result_SMILES_list_, result_acc_list_ = check_edit(SMILES, description) + result_SMILES_list.extend(result_SMILES_list_) + result_acc_list.append(result_acc_list_) + print("\n\n\n") + + result_acc_list = np.concatenate(result_acc_list, axis=0) + result_acc_list = np.sum(result_acc_list, axis=0) + result_acc_list = 100. * result_acc_list / len(source_SMILES_list) + result_acc_row = '\t'.join(['{}'.format(x) for x in result_acc_list]) + print("===== Accuracy =====\t{}".format(result_acc_row)) + + if args.output_model_dir is not None: + saver_file = os.path.join(args.output_model_dir, "edited_SMILES.tsv") + with open(saver_file, 'a') as f: + for row in result_SMILES_list: + row = "\t".join(row) + print(row, file=f) + + saver_file = os.path.join(args.output_model_dir, "accuracy") + np.savez(saver_file, result_acc_list) diff --git a/scripts/downstream_02_molecule_edit_step_02_MoleculeSTM_Latent_Optimization.py b/scripts/downstream_02_molecule_edit_step_02_MoleculeSTM_Latent_Optimization.py new file mode 100644 index 0000000..fa3a703 --- /dev/null +++ b/scripts/downstream_02_molecule_edit_step_02_MoleculeSTM_Latent_Optimization.py @@ -0,0 +1,201 @@ +import argparse +import math +import numpy as np +import os + +import torch +from torch import optim +import torch.nn.functional as F +from tqdm import tqdm +from MoleculeSTM.downstream_molecule_edit_utils import get_SMILES_list, get_description_list, load_language_molecule_and_edit_models, clip_loss_for_edit, evaluate_SMILES_list +from MoleculeSTM.utils import prepare_text_tokens + + +def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05): + lr_ramp = min(1, (1 - t) / rampdown) + lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi) + lr_ramp = lr_ramp * min(1, t / rampup) + return initial_lr * lr_ramp + + +def mean_pooling(token_embeddings, attention_mask): + attention_mask = ~attention_mask + input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() # [pad, B, d] + sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 0) # [B, d] + sum_mask = torch.clamp(input_mask_expanded.sum(0), min=1e-9) # [B, d] + return sum_embeddings / sum_mask + + +def check_edit(SMILES, text): + text_list = [text] + text_tokens_ids, text_masks = prepare_text_tokens( + device=device, description=text_list, tokenizer=text_tokenizer, max_seq_len=args.max_seq_len) + text_output = text_model(input_ids=text_tokens_ids, attention_mask=text_masks) + text_repr = text_output["pooler_output"] + text_repr = text2latent(text_repr) + + first_and_second_SMILES_list = [] + + latent_code_init, pad_mask_init = MegaMolBART_wrapper.smileslist2embedding([SMILES]) # [pad, B, d], [pad, B] + first_and_second_SMILES_list.append(SMILES) + + regenerated_mols = MegaMolBART_wrapper.inverse_transform([latent_code_init], pad_mask_init.bool().cuda(), k=1, sanitize=True) + first_and_second_SMILES_list.append(regenerated_mols[0]) + + l2_lambda_list = [ + 1e1, 1e0, 1e-1, 1e-2, 1e-3 + ] + result_SMILES_list_one_pair, result_eval_list_one_pair = [], [] + + if args.use_noise_for_init: + print("Use random noise for init") + random_noise = torch.randn(latent_code_init.size()).to(device) + + for l2_lambda in l2_lambda_list: + print("l2 lambda: {}".format(l2_lambda)) + current_SMILES_list = [first_and_second_SMILES_list[0]] + [first_and_second_SMILES_list[1]] + if args.use_noise_for_init: + print("Use random noise for init") + latent = latent_code_init.detach().clone() + random_noise + else: + print("No random noise for init") + latent = latent_code_init.detach().clone() + pad_mask = pad_mask_init.detach().clone() + latent.requires_grad = True + optimizer = optim.Adam([latent], lr=args.lr) + + if args.verbose: + L = tqdm(range(args.epochs)) + else: + L = range(args.epochs) + + for i in L: + t = i / args.epochs + lr = get_lr(t, args.lr) + optimizer.param_groups[0]["lr"] = lr + + molecule_repr_generation = mean_pooling(latent, pad_mask) # [B, d] + if args.normalize: + molecule_repr_generation = F.normalize(molecule_repr_generation, dim=-1) + molecule_repr_MoleculeSTM = generation2MoleculeSTM(molecule_repr_generation) + + clip_loss_ = clip_loss_for_edit(molecule_repr_MoleculeSTM, text_repr) + l2_loss_ = l2_lambda * ((latent_code_init - latent) ** 2).mean() + + loss = clip_loss_ + l2_loss_ + + optimizer.zero_grad() + loss.backward(retain_graph=True) + optimizer.step() + print("clip loss: {:.5f}\tL2 loss: {:.5f}".format(clip_loss_.item(), l2_loss_.item())) + + generated_mols = MegaMolBART_wrapper.inverse_transform([latent], pad_mask.bool().cuda(), k=1, sanitize=True) + current_SMILES_list.append(generated_mols[0]) + result_SMILES_list_one_pair.append([text] + current_SMILES_list + ['{}'.format(l2_lambda)]) + + current_result_list = evaluate_SMILES_list(current_SMILES_list, text) + result_eval_list_one_pair.append(current_result_list) + print() + + result_eval_list_one_pair = np.array(result_eval_list_one_pair) + result_eval_list_one_pair = np.any(result_eval_list_one_pair, axis=0, keepdims=True) + print("result_eval_list_one_pair\n", result_eval_list_one_pair) + return result_SMILES_list_one_pair, result_eval_list_one_pair + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--device", type=int, default=0) + parser.add_argument("--verbose", type=int, default=1) + + ########## for editing ########## + parser.add_argument("--input_description", type=str, default=None) + parser.add_argument("--input_description_id", type=int, default=None) + parser.add_argument("--input_SMILES", type=str, default=None) + parser.add_argument("--input_SMILES_file", type=str, default="../data/Editing_data/single_multi_property_SMILES.txt") + parser.add_argument("--output_model_dir", type=str, default=None) + parser.add_argument("--use_noise_for_init", dest="use_noise_for_init", action="store_true") + parser.add_argument("--no_noise_for_init", dest="use_noise_for_init", action="store_false") + parser.set_defaults(use_noise_for_init=False) + parser.add_argument('--normalize', dest='normalize', action='store_true') + parser.add_argument('--no_normalize', dest='normalize', action='store_false') + parser.set_defaults(normalize=True) + + parser.add_argument("--dataspace_path", type=str, default="../data") + parser.add_argument("--SSL_emb_dim", type=int, default=256) + parser.add_argument("--max_seq_len", type=int, default=512) + + ########## for MoleculeSTM ########## + parser.add_argument("--MoleculeSTM_model_dir", type=str, default="../../pretrained_model_Raw") + parser.add_argument("--MoleculeSTM_molecule_type", type=str, default="SMILES", choices=["SMILES", "Graph"]) + + ########## for MegaMolBART ########## + parser.add_argument("--MegaMolBART_generation_model_dir", type=str, default="../data/pretrained_MegaMolBART/checkpoints") + parser.add_argument("--vocab_path", type=str, default="../MoleculeSTM/bart_vocab.txt") + + ########## for MoleculeSTM and generation projection ########## + parser.add_argument("--language_edit_model_dir", type=str, default="edit_temp/EBM_NCE") + + ########## for editing ########## + parser.add_argument("--lr_rampup", type=float, default=0.05) + parser.add_argument("--lr", type=float, default=0.1) + parser.add_argument("--epochs", type=int, default=100) + args = parser.parse_args() + + print(args) + + text_model, text_tokenizer, text_dim, molecule_model, MegaMolBART_wrapper, molecule_dim, \ + text2latent, mol2latent, generation2MoleculeSTM, MoleculeSTM2generation = load_language_molecule_and_edit_models(args) + device = torch.device("cuda:" + str(args.device)) \ + if torch.cuda.is_available() else torch.device("cpu") + text_model = text_model.to(device) + molecule_model = molecule_model.to(device) + text2latent = text2latent.to(device) + mol2latent = mol2latent.to(device) + generation2MoleculeSTM.to(device) + MoleculeSTM2generation.to(device) + text_model.eval() + molecule_model.eval() + text2latent.eval() + mol2latent.eval() + generation2MoleculeSTM.eval() + MoleculeSTM2generation.eval() + + np.random.seed(args.seed) + torch.random.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + device = torch.device("cuda:" + str(args.device)) \ + if torch.cuda.is_available() else torch.device("cpu") + + print("\n\n\nstart editing\n\n\n") + + source_SMILES_list = get_SMILES_list(args) + description_list = get_description_list(args) + + for description in description_list: + print("===== for description {} =====".format(description)) + result_SMILES_list, result_acc_list = [], [] + for SMILES in source_SMILES_list: + print("===== for SMILES {} =====".format(SMILES)) + result_SMILES_list_, result_acc_list_ = check_edit(SMILES, description) + result_SMILES_list.extend(result_SMILES_list_) + result_acc_list.append(result_acc_list_) + print("\n\n\n") + + result_acc_list = np.concatenate(result_acc_list, axis=0) + result_acc_list = np.sum(result_acc_list, axis=0) + result_acc_list = 100. * result_acc_list / len(source_SMILES_list) + result_acc_row = '\t'.join(['{}'.format(x) for x in result_acc_list]) + print("===== Accuracy =====\t{}".format(result_acc_row)) + + if args.output_model_dir is not None: + saver_file = os.path.join(args.output_model_dir, "edited_SMILES.tsv") + with open(saver_file, 'a') as f: + for row in result_SMILES_list: + row = "\t".join(row) + print(row, file=f) + + saver_file = os.path.join(args.output_model_dir, "accuracy") + np.savez(saver_file, result_acc_list) diff --git a/scripts/downstream_02_molecule_edit_step_02_PCA.py b/scripts/downstream_02_molecule_edit_step_02_PCA.py new file mode 100644 index 0000000..ee769fc --- /dev/null +++ b/scripts/downstream_02_molecule_edit_step_02_PCA.py @@ -0,0 +1,148 @@ +import argparse +import math +import numpy as np +import os + +import torch +import torch.nn.functional as F +from MoleculeSTM.downstream_molecule_edit_utils import get_SMILES_list, get_description_list, evaluate_SMILES_list +from MoleculeSTM.models.mega_molbart.mega_mol_bart import MegaMolBART + +from sklearn.decomposition import PCA + + +def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05): + lr_ramp = min(1, (1 - t) / rampdown) + lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi) + lr_ramp = lr_ramp * min(1, t / rampup) + return initial_lr * lr_ramp + + +def mean_pooling(token_embeddings, attention_mask): + attention_mask = ~attention_mask + input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() # [pad, B, d] + sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 0) # [B, d] + sum_mask = torch.clamp(input_mask_expanded.sum(0), min=1e-9) # [B, d] + return sum_embeddings / sum_mask + + +def get_molecule_repr_generation(molecule_data, molecule_model, molecule_type="SMILES", MegaMolBART_wrapper=None): + if molecule_type == "SMILES": + embedding, pad_mask = MegaMolBART_wrapper.smileslist2embedding_model_given(molecule_model, molecule_data) # [pad, B, d], [pad, B] + molecule_repr = mean_pooling(embedding, pad_mask) + else: + molecule_repr, _ = molecule_model(molecule_data) + return molecule_repr + + +def check_edit(SMILES, text): + first_and_second_SMILES_list = [] + + SMILES_list = [SMILES] + latent_code_init, pad_mask_init = MegaMolBART_wrapper.smileslist2embedding(SMILES_list) # [pad, B, d], [pad, B] + first_and_second_SMILES_list.append(SMILES) + + generated_mols = MegaMolBART_wrapper.inverse_transform([latent_code_init], pad_mask_init.bool().cuda(), k=1, sanitize=True) + first_and_second_SMILES_list.append(generated_mols[0]) + + alpha_list = [ + 1.0, 1.5, 2.0, 2.5, 3.0 + ] + result_SMILES_list_one_pair, result_eval_list_one_pair = [], [] + + for alpha in alpha_list: + print("alpha: {}".format(alpha)) + current_SMILES_list = [first_and_second_SMILES_list[0]] + [first_and_second_SMILES_list[1]] + + latent = latent_code_init + alpha * latent_repr_pca / len(latent_code_init) + pad_mask = pad_mask_init + + generated_mols = MegaMolBART_wrapper.inverse_transform([latent], pad_mask.bool().cuda(), k=1, sanitize=True) + current_SMILES_list.append(generated_mols[0]) + result_SMILES_list_one_pair.append([text] + current_SMILES_list + ['{}'.format(alpha)]) + + print("current_SMILES_list", current_SMILES_list) + print("result_SMILES_list_one_pair", result_SMILES_list_one_pair) + + current_result_list = evaluate_SMILES_list(current_SMILES_list, text) + result_eval_list_one_pair.append(current_result_list) + print() + + result_eval_list_one_pair = np.array(result_eval_list_one_pair) + result_eval_list_one_pair = np.any(result_eval_list_one_pair, axis=0, keepdims=True) + print("result_eval_list_one_pair\n", result_eval_list_one_pair) + return result_SMILES_list_one_pair, result_eval_list_one_pair + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--device", type=int, default=0) + parser.add_argument("--verbose", type=int, default=1) + + ########## for editing ########## + parser.add_argument("--input_description", type=str, default=None) + parser.add_argument("--input_description_id", type=int, default=None) + parser.add_argument("--input_SMILES", type=str, default=None) + parser.add_argument("--input_SMILES_file", type=str, default="../data/Editing_data/single_multi_property_SMILES.txt") + parser.add_argument("--output_model_dir", type=str, default=None) + + ########## for generation ########## + parser.add_argument("--generation_model_dir", type=str, default="../data/pretrained_MegaMolBART/checkpoints") + args = parser.parse_args() + + print(args) + MegaMolBART_wrapper = MegaMolBART(input_dir=args.generation_model_dir, output_dir=None) + molecule_model_generation = MegaMolBART_wrapper.model + print("Loading from pretrained MegaMolBART ({}).".format(args.generation_model_dir)) + molecule_dim_generation = 256 + + device = torch.device("cuda:" + str(args.device)) \ + if torch.cuda.is_available() else torch.device("cpu") + + np.random.seed(args.seed) + torch.random.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + device = torch.device("cuda:" + str(args.device)) \ + if torch.cuda.is_available() else torch.device("cpu") + + print("\n\n\nstart editing\n\n\n") + + source_SMILES_list = get_SMILES_list(args) + description_list = get_description_list(args) + + mol_repr = get_molecule_repr_generation( + source_SMILES_list, molecule_model=molecule_model_generation, + molecule_type="SMILES", MegaMolBART_wrapper=MegaMolBART_wrapper + ) + mol_repr = mol_repr.detach().cpu().numpy() + latent_repr_pca = PCA(n_components=1).fit_transform(mol_repr.transpose())[:, 0] + latent_repr_pca = torch.tensor(latent_repr_pca).to(device) + latent_repr_pca = F.normalize(latent_repr_pca, dim=-1) + + for description in description_list: + print("===== for description {} =====".format(description)) + result_SMILES_list, result_acc_list = [], [] + for SMILES in source_SMILES_list: + print("===== for SMILES {} =====".format(SMILES)) + result_SMILES_list_, result_acc_list_ = check_edit(SMILES, description) + result_SMILES_list.extend(result_SMILES_list_) + result_acc_list.append(result_acc_list_) + print("\n\n\n") + + result_acc_list = np.concatenate(result_acc_list, axis=0) + result_acc_list = np.sum(result_acc_list, axis=0) + result_acc_list = 100. * result_acc_list / len(source_SMILES_list) + result_acc_row = '\t'.join(['{}'.format(x) for x in result_acc_list]) + print("===== Accuracy =====\t{}".format(result_acc_row)) + + if args.output_model_dir is not None: + saver_file = os.path.join(args.output_model_dir, "edited_SMILES.tsv") + with open(saver_file, 'a') as f: + for row in result_SMILES_list: + row = "\t".join(row) + print(row, file=f) + + saver_file = os.path.join(args.output_model_dir, "accuracy") + np.savez(saver_file, result_acc_list) diff --git a/scripts/downstream_02_molecule_edit_step_02_Random_Perturbation.py b/scripts/downstream_02_molecule_edit_step_02_Random_Perturbation.py new file mode 100644 index 0000000..916bc58 --- /dev/null +++ b/scripts/downstream_02_molecule_edit_step_02_Random_Perturbation.py @@ -0,0 +1,131 @@ +import argparse +import math +import numpy as np +import os + +import torch +import torch.nn.functional as F +from MoleculeSTM.downstream_molecule_edit_utils import get_SMILES_list, get_description_list, evaluate_SMILES_list +from MoleculeSTM.models.mega_molbart.mega_mol_bart import MegaMolBART + + +def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05): + lr_ramp = min(1, (1 - t) / rampdown) + lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi) + lr_ramp = lr_ramp * min(1, t / rampup) + return initial_lr * lr_ramp + + +def mean_pooling(token_embeddings, attention_mask): + attention_mask = ~attention_mask + input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() # [pad, B, d] + sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 0) # [B, d] + sum_mask = torch.clamp(input_mask_expanded.sum(0), min=1e-9) # [B, d] + return sum_embeddings / sum_mask + + +def check_edit(SMILES, text): + first_and_second_SMILES_list = [] + + SMILES_list = [SMILES] + latent_code_init, pad_mask_init = MegaMolBART_wrapper.smileslist2embedding(SMILES_list) # [pad, B, d], [pad, B] + first_and_second_SMILES_list.append(SMILES) + + generated_mols = MegaMolBART_wrapper.inverse_transform([latent_code_init], pad_mask_init.bool().cuda(), k=1, sanitize=True) + first_and_second_SMILES_list.append(generated_mols[0]) + + alpha_list = [ + 1.0, 1.5, 2.0, 2.5, 3.0 + ] + result_SMILES_list_one_pair, result_eval_list_one_pair = [], [] + + print("Use random noise for init") + random_noise = args.variance * torch.randn(latent_code_init.size()).to(device) + random_noise = F.normalize(random_noise, dim=-1) + + for alpha in alpha_list: + print("alpha: {}".format(alpha)) + current_SMILES_list = [first_and_second_SMILES_list[0]] + [first_and_second_SMILES_list[1]] + + latent = latent_code_init + alpha * random_noise / len(latent_code_init) + pad_mask = pad_mask_init + + generated_mols = MegaMolBART_wrapper.inverse_transform([latent], pad_mask.bool().cuda(), k=1, sanitize=True) + current_SMILES_list.append(generated_mols[0]) + result_SMILES_list_one_pair.append([text] + current_SMILES_list + ['{}'.format(alpha)]) + + current_result_list = evaluate_SMILES_list(current_SMILES_list, text) + result_eval_list_one_pair.append(current_result_list) + print() + + result_eval_list_one_pair = np.array(result_eval_list_one_pair) + result_eval_list_one_pair = np.any(result_eval_list_one_pair, axis=0, keepdims=True) + print("result_eval_list_one_pair\n", result_eval_list_one_pair) + return result_SMILES_list_one_pair, result_eval_list_one_pair + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--device", type=int, default=0) + parser.add_argument("--verbose", type=int, default=1) + + ########## for editing ########## + parser.add_argument("--input_description", type=str, default=None) + parser.add_argument("--input_description_id", type=int, default=None) + parser.add_argument("--input_SMILES", type=str, default=None) + parser.add_argument("--input_SMILES_file", type=str, default="../data/Editing_data/single_multi_property_SMILES.txt") + parser.add_argument("--output_model_dir", type=str, default=None) + parser.add_argument("--variance", type=float, default=1) + + ########## for generation ########## + parser.add_argument("--generation_model_dir", type=str, default="../data/pretrained_MegaMolBART/checkpoints") + args = parser.parse_args() + + print(args) + MegaMolBART_wrapper = MegaMolBART(input_dir=args.generation_model_dir, output_dir=None) + molecule_model_generation = MegaMolBART_wrapper.model + print("Loading from pretrained MegaMolBART ({}).".format(args.generation_model_dir)) + molecule_dim_generation = 256 + + device = torch.device("cuda:" + str(args.device)) \ + if torch.cuda.is_available() else torch.device("cpu") + + np.random.seed(args.seed) + torch.random.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + device = torch.device("cuda:" + str(args.device)) \ + if torch.cuda.is_available() else torch.device("cpu") + + print("\n\n\nstart editing\n\n\n") + + source_SMILES_list = get_SMILES_list(args) + description_list = get_description_list(args) + + for description in description_list: + print("===== for description {} =====".format(description)) + result_SMILES_list, result_acc_list = [], [] + for SMILES in source_SMILES_list: + print("===== for SMILES {} =====".format(SMILES)) + result_SMILES_list_, result_acc_list_ = check_edit(SMILES, description) + result_SMILES_list.extend(result_SMILES_list_) + result_acc_list.append(result_acc_list_) + print("\n\n\n") + + result_acc_list = np.concatenate(result_acc_list, axis=0) + result_acc_list = np.sum(result_acc_list, axis=0, keepdims=True) + result_acc_list = 100. * result_acc_list / len(source_SMILES_list) + print(description, result_acc_list) + result_acc_row = '\t'.join(['{}'.format(x) for x in result_acc_list]) + print("===== Accuracy =====\t{}".format(result_acc_row)) + + if args.output_model_dir is not None: + saver_file = os.path.join(args.output_model_dir, "edited_SMILES.tsv") + with open(saver_file, 'a') as f: + for row in result_SMILES_list: + row = "\t".join(row) + print(row, file=f) + + saver_file = os.path.join(args.output_model_dir, "accuracy") + np.savez(saver_file, result_acc_list) diff --git a/scripts/downstream_03_property_prediction.py b/scripts/downstream_03_property_prediction.py new file mode 100644 index 0000000..56c1687 --- /dev/null +++ b/scripts/downstream_03_property_prediction.py @@ -0,0 +1,432 @@ +import os +import argparse +import numpy as np +import pandas as pd +from tqdm import tqdm +from sklearn.metrics import accuracy_score, average_precision_score, roc_auc_score, mean_absolute_error, mean_squared_error + +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader as torch_DataLoader +from torch_geometric.loader import DataLoader as pyg_DataLoader + +from MoleculeSTM.datasets import MoleculeNetSMILESDataset, MoleculeNetGraphDataset +from MoleculeSTM.splitters import scaffold_split +from MoleculeSTM.utils import get_num_task_and_type, get_molecule_repr_MoleculeSTM +from MoleculeSTM.models.mega_molbart.mega_mol_bart import MegaMolBART +from MoleculeSTM.models import GNN, GNN_graphpred + + +def train_classification(model, device, loader, optimizer): + if args.training_mode == "fine_tuning": + model.train() + else: + model.eval() + linear_model.train() + total_loss = 0 + + if args.verbose: + L = tqdm(loader) + else: + L = loader + for step, batch in enumerate(L): + if args.molecule_type == "SMILES": + SMILES_list, y = batch + SMILES_list = list(SMILES_list) + molecule_repr = get_molecule_repr_MoleculeSTM( + SMILES_list, mol2latent=None, + molecule_type="SMILES", MegaMolBART_wrapper=MegaMolBART_wrapper) + pred = linear_model(molecule_repr) + pred = pred.float() + y = y.to(device).float() + else: + batch = batch.to(device) + molecule_repr = get_molecule_repr_MoleculeSTM( + batch, mol2latent=None, + molecule_type="Graph", molecule_model=model) + pred = linear_model(molecule_repr) + pred = pred.float() + y = batch.y.view(pred.shape).to(device).float() + + is_valid = y ** 2 > 0 + loss_mat = criterion(pred, (y + 1) / 2) + loss_mat = torch.where( + is_valid, loss_mat, + torch.zeros(loss_mat.shape).to(device).to(loss_mat.dtype)) + + optimizer.zero_grad() + loss = torch.sum(loss_mat) / torch.sum(is_valid) + loss.backward() + optimizer.step() + total_loss += loss.detach().item() + + return total_loss / len(loader) + + +@torch.no_grad() +def eval_classification(model, device, loader): + model.eval() + linear_model.eval() + y_true, y_scores = [], [] + + if args.verbose: + L = tqdm(loader) + else: + L = loader + for step, batch in enumerate(L): + if args.molecule_type == "SMILES": + SMILES_list, y = batch + SMILES_list = list(SMILES_list) + molecule_repr = get_molecule_repr_MoleculeSTM( + SMILES_list, mol2latent=None, + molecule_type="SMILES", MegaMolBART_wrapper=MegaMolBART_wrapper) + pred = linear_model(molecule_repr) + pred = pred.float() + y = y.to(device).float() + else: + batch = batch.to(device) + molecule_repr = get_molecule_repr_MoleculeSTM( + batch, mol2latent=None, + molecule_type="Graph", molecule_model=model) + pred = linear_model(molecule_repr) + pred = pred.float() + y = batch.y.view(pred.shape).to(device).float() + + y_true.append(y) + y_scores.append(pred) + + y_true = torch.cat(y_true, dim=0).cpu().numpy() + y_scores = torch.cat(y_scores, dim=0).cpu().numpy() + + roc_list = [] + for i in range(y_true.shape[1]): + # AUC is only defined when there is at least one positive data. + if np.sum(y_true[:, i] == 1) > 0 and np.sum(y_true[:, i] == -1) > 0: + is_valid = y_true[:, i] ** 2 > 0 + roc_list.append(roc_auc_score((y_true[is_valid, i] + 1) / 2, y_scores[is_valid, i])) + else: + print("{} is invalid".format(i)) + + if len(roc_list) < y_true.shape[1]: + print(len(roc_list)) + print("Some target is missing!") + print("Missing ratio: %f" %(1 - float(len(roc_list)) / y_true.shape[1])) + + return sum(roc_list) / len(roc_list), 0, y_true, y_scores + + +def train_regression(model, device, loader, optimizer): + if args.training_mode == "fine_tuning": + model.train() + else: + model.eval() + linear_model.train() + total_loss = 0 + + if args.verbose: + L = tqdm(loader) + else: + L = loader + for step, batch in enumerate(L): + if args.molecule_type == "SMILES": + SMILES_list, y = batch + SMILES_list = list(SMILES_list) + molecule_repr = get_molecule_repr_MoleculeSTM( + SMILES_list, mol2latent=None, + molecule_type="SMILES", MegaMolBART_wrapper=MegaMolBART_wrapper) + pred = linear_model(molecule_repr) + pred = pred.float() + y = y.to(device).float() + else: + batch = batch.to(device) + molecule_repr = get_molecule_repr_MoleculeSTM( + batch, mol2latent=None, + molecule_type="Graph", molecule_model=model) + pred = linear_model(molecule_repr) + pred = pred.float() + y = batch.y.view(pred.shape).to(device).float() + + loss = criterion(pred, y) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + total_loss += loss.detach().item() + + return total_loss / len(loader) + + +@torch.no_grad() +def eval_regression(model, device, loader): + model.eval() + y_true, y_pred = [], [] + + if args.verbose: + L = tqdm(loader) + else: + L = loader + for step, batch in enumerate(L): + if args.molecule_type == "SMILES": + SMILES_list, y = batch + SMILES_list = list(SMILES_list) + molecule_repr = get_molecule_repr_MoleculeSTM( + SMILES_list, mol2latent=None, + molecule_type="SMILES", MegaMolBART_wrapper=MegaMolBART_wrapper) + pred = linear_model(molecule_repr) + pred = pred.float() + y = y.to(device).float() + else: + batch = batch.to(device) + molecule_repr = get_molecule_repr_MoleculeSTM( + batch, mol2latent=None, + molecule_type="Graph", molecule_model=model) + pred = linear_model(molecule_repr) + pred = pred.float() + y = batch.y.view(pred.shape).to(device).float() + + y_true.append(y) + y_pred.append(pred) + + y_true = torch.cat(y_true, dim=0).cpu().numpy() + y_pred = torch.cat(y_pred, dim=0).cpu().numpy() + rmse = mean_squared_error(y_true, y_pred, squared=False) + mae = mean_absolute_error(y_true, y_pred) + return {'RMSE': rmse, 'MAE': mae}, y_true, y_pred + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--device", type=int, default=0) + parser.add_argument("--training_mode", type=str, default="fine_tuning", choices=["fine_tuning", "linear_probing"]) + parser.add_argument("--molecule_type", type=str, default="SMILES", choices=["SMILES", "Graph"]) + + ########## for dataset and split ########## + parser.add_argument("--dataspace_path", type=str, default="../data") + parser.add_argument("--dataset", type=str, default="bace") + parser.add_argument("--split", type=str, default="scaffold") + + ########## for optimization ########## + parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("--lr", type=float, default=1e-4) + parser.add_argument("--lr_scale", type=float, default=1) + parser.add_argument("--num_workers", type=int, default=1) + parser.add_argument("--epochs", type=int, default=100) + parser.add_argument("--weight_decay", type=float, default=0) + parser.add_argument("--schedule", type=str, default="cycle") + parser.add_argument("--warm_up_steps", type=int, default=10) + + ########## for MegaMolBART ########## + parser.add_argument("--megamolbart_input_dir", type=str, default="../data/pretrained_MegaMolBART/checkpoints") + parser.add_argument("--vocab_path", type=str, default="../MoleculeSTM/bart_vocab.txt") + + ########## for 2D GNN ########## + parser.add_argument("--gnn_emb_dim", type=int, default=300) + parser.add_argument("--num_layer", type=int, default=5) + parser.add_argument('--JK', type=str, default='last') + parser.add_argument("--dropout_ratio", type=float, default=0.5) + parser.add_argument("--gnn_type", type=str, default="gin") + parser.add_argument('--graph_pooling', type=str, default='mean') + + ########## for saver ########## + parser.add_argument("--eval_train", type=int, default=0) + parser.add_argument("--verbose", type=int, default=0) + + parser.add_argument("--input_model_path", type=str, default=None) + parser.add_argument("--output_model_dir", type=str, default=None) + + args = parser.parse_args() + print("arguments\t", args) + + torch.manual_seed(args.seed) + np.random.seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + device = torch.device("cuda:" + str(args.device)) \ + if torch.cuda.is_available() else torch.device("cpu") + + num_tasks, task_mode = get_num_task_and_type(args.dataset) + dataset_folder = os.path.join(args.dataspace_path, "MoleculeNet_data", args.dataset) + + if args.molecule_type == "SMILES": + dataset = MoleculeNetSMILESDataset(dataset_folder) + dataloader_class = torch_DataLoader + use_pyg_dataset = False + else: + dataset = MoleculeNetGraphDataset(dataset_folder, args.dataset) + dataloader_class = pyg_DataLoader + use_pyg_dataset = True + + assert args.split == "scaffold" + print("split via scaffold") + smiles_list = pd.read_csv( + dataset_folder + "/processed/smiles.csv", header=None)[0].tolist() + train_dataset, valid_dataset, test_dataset = scaffold_split( + dataset, smiles_list, null_value=0, frac_train=0.8, + frac_valid=0.1, frac_test=0.1, pyg_dataset=use_pyg_dataset) + + train_loader = dataloader_class(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) + val_loader = dataloader_class(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers) + test_loader = dataloader_class(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers) + + if args.molecule_type == "SMILES": + if args.megamolbart_input_dir is not None: + # This is loading from the pretarined_MegaMolBART + # --megamolbart_input_dir=../../Datasets/pretrained_MegaMolBART/checkpoints + MegaMolBART_wrapper = MegaMolBART(vocab_path=args.vocab_path, input_dir=args.megamolbart_input_dir, output_dir=None) + print("Start from pretrained MegaMolBART using MLM.") + else: + # This is starting from scratch + MegaMolBART_wrapper = MegaMolBART(vocab_path=args.vocab_path, input_dir=None, output_dir=None) + print("Start from randomly initialized MegaMolBART.") + model = MegaMolBART_wrapper.model + if args.input_model_path is not None: + print("Update MegaMolBART with pretrained MoleculeSTM. Loading from {}...".format(args.input_model_path)) + state_dict = torch.load(args.input_model_path, map_location='cpu') + model.load_state_dict(state_dict) + molecule_dim = 256 + else: + molecule_node_model = GNN( + num_layer=args.num_layer, emb_dim=args.gnn_emb_dim, + JK=args.JK, drop_ratio=args.dropout_ratio, + gnn_type=args.gnn_type) + model = GNN_graphpred( + num_layer=args.num_layer, emb_dim=args.gnn_emb_dim, JK=args.JK, graph_pooling=args.graph_pooling, + num_tasks=1, molecule_node_model=molecule_node_model) + molecule_dim = args.gnn_emb_dim + if args.input_model_path is not None: + if "GraphMVP" in args.input_model_path: + print("Start from pretrained model (GraphMVP) in {}.".format(args.input_model_path)) + model.from_pretrained(args.input_model_path) + else: + print("Start from pretrained model (MoleculeSTM) in {}.".format(args.input_model_path)) + state_dict = torch.load(args.input_model_path, map_location='cpu') + model.load_state_dict(state_dict) + else: + print("Start from randomly initialized GNN.") + + # Rewrite the seed by MegaMolBART + torch.manual_seed(args.seed) + np.random.seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + + model = model.to(device) + linear_model = nn.Linear(molecule_dim, num_tasks).to(device) + + # set up optimizer + if args.training_mode == "fine_tuning": + model_param_group = [ + {"params": model.parameters()}, + {"params": linear_model.parameters(), 'lr': args.lr * args.lr_scale} + ] + else: + model_param_group = [ + {"params": linear_model.parameters(), 'lr': args.lr * args.lr_scale} + ] + optimizer = optim.Adam(model_param_group, lr=args.lr, weight_decay=args.weight_decay) + + if task_mode == "classification": + train_func = train_classification + eval_func = eval_classification + + train_roc_list, val_roc_list, test_roc_list = [], [], [] + train_acc_list, val_acc_list, test_acc_list = [], [], [] + best_val_roc, best_val_idx = -1, 0 + criterion = nn.BCEWithLogitsLoss(reduction="none") + + for epoch in range(1, args.epochs + 1): + loss_acc = train_func(model, device, train_loader, optimizer) + print("Epoch: {}\nLoss: {}".format(epoch, loss_acc)) + + if args.eval_train: + train_roc, train_acc, train_target, train_pred = eval_func(model, device, train_loader) + else: + train_roc = train_acc = 0 + val_roc, val_acc, val_target, val_pred = eval_func(model, device, val_loader) + test_roc, test_acc, test_target, test_pred = eval_func(model, device, test_loader) + + train_roc_list.append(train_roc) + train_acc_list.append(train_acc) + val_roc_list.append(val_roc) + val_acc_list.append(val_acc) + test_roc_list.append(test_roc) + test_acc_list.append(test_acc) + print("train: {:.6f}\tval: {:.6f}\ttest: {:.6f}".format(train_roc, val_roc, test_roc)) + print() + + if val_roc > best_val_roc: + best_val_roc = val_roc + best_val_idx = epoch - 1 + if args.output_model_dir is not None: + ##### save best model ##### + output_model_path = os.path.join(args.output_model_dir, "{}_model_best.pth".format(args.dataset)) + saved_model_dict = { + "model": model.state_dict() + } + torch.save(saved_model_dict, output_model_path) + + filename = os.path.join(args.output_model_dir, "{}_evaluation_best.pth".format(args.dataset)) + np.savez( + filename, val_target=val_target, val_pred=val_pred, + test_target=test_target, test_pred=test_pred) + + print("best train: {:.6f}\tval: {:.6f}\ttest: {:.6f}".format(train_roc_list[best_val_idx], val_roc_list[best_val_idx], test_roc_list[best_val_idx])) + + else: + train_func = train_regression + eval_func = eval_regression + criterion = torch.nn.MSELoss() + + train_result_list, val_result_list, test_result_list = [], [], [] + metric_list = ['RMSE', 'MAE'] + best_val_rmse, best_val_idx = 1e10, 0 + + for epoch in range(1, args.epochs + 1): + loss_acc = train_func(model, device, train_loader, optimizer) + print('Epoch: {}\nLoss: {}'.format(epoch, loss_acc)) + + if args.eval_train: + train_result, train_target, train_pred = eval_func(model, device, train_loader) + else: + train_result = {'RMSE': 0, 'MAE': 0, 'R2': 0} + val_result, val_target, val_pred = eval_func(model, device, val_loader) + test_result, test_target, test_pred = eval_func(model, device, test_loader) + + train_result_list.append(train_result) + val_result_list.append(val_result) + test_result_list.append(test_result) + + for metric in metric_list: + print('{} train: {:.6f}\tval: {:.6f}\ttest: {:.6f}'.format(metric, train_result[metric], val_result[metric], test_result[metric])) + print() + + if val_result['RMSE'] < best_val_rmse: + best_val_rmse = val_result['RMSE'] + best_val_idx = epoch - 1 + if args.output_model_dir is not None: + ##### save best model ##### + output_model_path = os.path.join(args.output_model_dir, "{}_model_best.pth".format(args.dataset)) + saved_model_dict = { + 'model': model.state_dict() + } + torch.save(saved_model_dict, output_model_path) + + filename = os.path.join(args.output_model_dir, "{}_evaluation_best.pth".format(args.dataset)) + np.savez( + filename, val_target=val_target, val_pred=val_pred, + test_target=test_target, test_pred=test_pred) + + for metric in metric_list: + print('Best (RMSE), {} train: {:.6f}\tval: {:.6f}\ttest: {:.6f}'.format( + metric, train_result_list[best_val_idx][metric], val_result_list[best_val_idx][metric], test_result_list[best_val_idx][metric])) + + ##### save final model ##### + if args.output_model_dir is not None: + output_model_path = os.path.join(args.output_model_dir, '{}_model_final.pth'.format(args.dataset)) + saved_model_dict = { + 'model': model.state_dict() + } + torch.save(saved_model_dict, output_model_path) + \ No newline at end of file diff --git a/scripts/downstream_03_property_prediction_KV-PLM.py b/scripts/downstream_03_property_prediction_KV-PLM.py new file mode 100644 index 0000000..4c8bd6c --- /dev/null +++ b/scripts/downstream_03_property_prediction_KV-PLM.py @@ -0,0 +1,373 @@ +import os +import argparse +import numpy as np +import pandas as pd +from tqdm import tqdm +from sklearn.metrics import roc_auc_score, mean_absolute_error, mean_squared_error + +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader as torch_DataLoader + +from MoleculeSTM.datasets import MoleculeNetSMILESDataset +from MoleculeSTM.splitters import scaffold_split +from MoleculeSTM.utils import get_num_task_and_type +from MoleculeSTM.utils import prepare_text_tokens +from transformers import BertTokenizer, BertForPreTraining + + +class BigModel(nn.Module): + def __init__(self, main_model): + super(BigModel, self).__init__() + self.main_model = main_model + self.dropout = nn.Dropout(0.1) + + def forward(self, tok, att, cud=True): + typ = torch.zeros(tok.shape).long() + if cud: + typ = typ.cuda() + pooled_output = self.main_model(tok, token_type_ids=typ, attention_mask=att)['pooler_output'] + logits = self.dropout(pooled_output) + return logits + + +def get_text_and_SMILES_repr(text): + text_tokens_ids, text_masks = prepare_text_tokens( + device=device, description=text, tokenizer=tokenizer, max_seq_len=args.max_seq_len) + text_repr = model(text_tokens_ids, text_masks) + return text_repr + + +def train_classification(model, device, loader, optimizer): + if args.training_mode == "fine_tuning": + model.train() + else: + model.eval() + linear_model.train() + total_loss = 0 + + if args.verbose: + L = tqdm(loader) + else: + L = loader + for step, batch in enumerate(L): + SMILES_list, y = batch + SMILES_list = list(SMILES_list) + + molecule_repr = get_text_and_SMILES_repr(SMILES_list) + + pred = linear_model(molecule_repr) + pred = pred.float() + y = y.to(device).float() + + is_valid = y ** 2 > 0 + loss_mat = criterion(pred, (y + 1) / 2) + loss_mat = torch.where( + is_valid, loss_mat, + torch.zeros(loss_mat.shape).to(device).to(loss_mat.dtype)) + + optimizer.zero_grad() + loss = torch.sum(loss_mat) / torch.sum(is_valid) + loss.backward() + optimizer.step() + total_loss += loss.detach().item() + + return total_loss / len(loader) + + +@torch.no_grad() +def eval_classification(model, device, loader): + model.eval() + linear_model.eval() + y_true, y_scores = [], [] + + if args.verbose: + L = tqdm(loader) + else: + L = loader + for step, batch in enumerate(L): + SMILES_list, y = batch + SMILES_list = list(SMILES_list) + molecule_repr = get_text_and_SMILES_repr(SMILES_list) + pred = linear_model(molecule_repr) + pred = pred.float() + y = y.to(device).float() + + y_true.append(y) + y_scores.append(pred) + + y_true = torch.cat(y_true, dim=0).cpu().numpy() + y_scores = torch.cat(y_scores, dim=0).cpu().numpy() + + roc_list = [] + for i in range(y_true.shape[1]): + # AUC is only defined when there is at least one positive data. + if np.sum(y_true[:, i] == 1) > 0 and np.sum(y_true[:, i] == -1) > 0: + is_valid = y_true[:, i] ** 2 > 0 + roc_list.append(roc_auc_score((y_true[is_valid, i] + 1) / 2, y_scores[is_valid, i])) + else: + print("{} is invalid".format(i)) + + if len(roc_list) < y_true.shape[1]: + print(len(roc_list)) + print("Some target is missing!") + print("Missing ratio: %f" %(1 - float(len(roc_list)) / y_true.shape[1])) + + return sum(roc_list) / len(roc_list), 0, y_true, y_scores + + +def train_regression(model, device, loader, optimizer): + if args.training_mode == "fine_tuning": + model.train() + else: + model.eval() + linear_model.train() + total_loss = 0 + + if args.verbose: + L = tqdm(loader) + else: + L = loader + for step, batch in enumerate(L): + SMILES_list, y = batch + SMILES_list = list(SMILES_list) + molecule_repr = get_text_and_SMILES_repr(SMILES_list) + pred = linear_model(molecule_repr) + pred = pred.float() + y = y.to(device).float() + + loss = criterion(pred, y) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + total_loss += loss.detach().item() + + return total_loss / len(loader) + + +@torch.no_grad() +def eval_regression(model, device, loader): + model.eval() + y_true, y_pred = [], [] + + if args.verbose: + L = tqdm(loader) + else: + L = loader + for step, batch in enumerate(L): + SMILES_list, y = batch + SMILES_list = list(SMILES_list) + molecule_repr = get_text_and_SMILES_repr(SMILES_list) + pred = linear_model(molecule_repr) + pred = pred.float() + y = y.to(device).float() + + y_true.append(y) + y_pred.append(pred) + + y_true = torch.cat(y_true, dim=0).cpu().numpy() + y_pred = torch.cat(y_pred, dim=0).cpu().numpy() + rmse = mean_squared_error(y_true, y_pred, squared=False) + mae = mean_absolute_error(y_true, y_pred) + return {'RMSE': rmse, 'MAE': mae}, y_true, y_pred + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--device", type=int, default=0) + parser.add_argument("--training_mode", type=str, default="fine_tuning", choices=["fine_tuning", "linear_probing"]) + parser.add_argument("--molecule_type", type=str, default="SMILES", choices=["SMILES", "Graph"]) + + ########## for dataset and split ########## + parser.add_argument("--dataspace_path", type=str, default="../data") + parser.add_argument("--dataset", type=str, default="bace") + parser.add_argument("--split", type=str, default="scaffold") + + ########## for optimization ########## + parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("--lr", type=float, default=1e-4) + parser.add_argument("--lr_scale", type=float, default=1) + parser.add_argument("--num_workers", type=int, default=1) + parser.add_argument("--epochs", type=int, default=100) + parser.add_argument("--weight_decay", type=int, default=0) + parser.add_argument("--schedule", type=str, default="cycle") + parser.add_argument("--warm_up_steps", type=int, default=10) + + ########## for BERT model ########## + parser.add_argument("--max_seq_len", type=int, default=512) + + ########## for saver ########## + parser.add_argument("--eval_train", type=int, default=0) + parser.add_argument("--verbose", type=int, default=0) + + parser.add_argument("--input_model_dir", type=str, default=None, help="This is only for MegaMolBART.") + parser.add_argument("--input_model_path", type=str, default=None) + parser.add_argument("--output_model_dir", type=str, default=None) + + args = parser.parse_args() + print("arguments\t", args) + + torch.manual_seed(args.seed) + np.random.seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + print("using device", device) + + num_tasks, task_mode = get_num_task_and_type(args.dataset) + dataset_folder = os.path.join(args.dataspace_path, "MoleculeNet_data", args.dataset) + + dataset = MoleculeNetSMILESDataset(dataset_folder) + dataloader_class = torch_DataLoader + use_pyg_dataset = False + + assert args.split == "scaffold" + print("split via scaffold") + smiles_list = pd.read_csv( + dataset_folder + "/processed/smiles.csv", header=None)[0].tolist() + train_dataset, valid_dataset, test_dataset = scaffold_split( + dataset, smiles_list, null_value=0, frac_train=0.8, + frac_valid=0.1, frac_test=0.1, pyg_dataset=use_pyg_dataset) + + train_loader = dataloader_class(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) + val_loader = dataloader_class(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers) + test_loader = dataloader_class(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers) + + tokenizer = BertTokenizer.from_pretrained('allenai/scibert_scivocab_uncased') + + bert_model0 = BertForPreTraining.from_pretrained('allenai/scibert_scivocab_uncased') + model = BigModel(bert_model0.bert) + if torch.cuda.is_available(): + model.load_state_dict(torch.load('../data/pretrained_KV-PLM/ckpt_ret01.pt')) + model = model.cuda() + else: + model.load_state_dict(torch.load('../data/pretrained_KV-PLM/ckpt_ret01.pt', map_location=torch.device('cpu') )) + model.eval() + molecule_dim = 768 + + torch.manual_seed(args.seed) + np.random.seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + + model = model.to(device) + linear_model = nn.Linear(molecule_dim, num_tasks).to(device) + + # set up optimizer + if args.training_mode == "fine_tuning": + model_param_group = [ + {"params": model.parameters()}, + {"params": linear_model.parameters()} + ] + else: + model_param_group = [ + {"params": linear_model.parameters()} + ] + optimizer = optim.Adam(model_param_group, lr=args.lr, weight_decay=args.weight_decay) + + if task_mode == "classification": + train_func = train_classification + eval_func = eval_classification + + train_roc_list, val_roc_list, test_roc_list = [], [], [] + train_acc_list, val_acc_list, test_acc_list = [], [], [] + best_val_roc, best_val_idx = -1, 0 + criterion = nn.BCEWithLogitsLoss(reduction="none") + + for epoch in range(1, args.epochs + 1): + loss_acc = train_func(model, device, train_loader, optimizer) + print("Epoch: {}\nLoss: {}".format(epoch, loss_acc)) + + if args.eval_train: + train_roc, train_acc, train_target, train_pred = eval_func(model, device, train_loader) + else: + train_roc = train_acc = 0 + val_roc, val_acc, val_target, val_pred = eval_func(model, device, val_loader) + test_roc, test_acc, test_target, test_pred = eval_func(model, device, test_loader) + + train_roc_list.append(train_roc) + train_acc_list.append(train_acc) + val_roc_list.append(val_roc) + val_acc_list.append(val_acc) + test_roc_list.append(test_roc) + test_acc_list.append(test_acc) + print("train: {:.6f}\tval: {:.6f}\ttest: {:.6f}".format(train_roc, val_roc, test_roc)) + print() + + if val_roc > best_val_roc: + best_val_roc = val_roc + best_val_idx = epoch - 1 + if args.output_model_dir is not None: + ##### save best model ##### + output_model_path = os.path.join(args.output_model_dir, "{}_model_best.pth".format(args.dataset)) + saved_model_dict = { + "model": model.state_dict() + } + torch.save(saved_model_dict, output_model_path) + + filename = os.path.join(args.output_model_dir, "{}_evaluation_best.pth".format(args.dataset)) + np.savez( + filename, val_target=val_target, val_pred=val_pred, + test_target=test_target, test_pred=test_pred) + + print("best train: {:.6f}\tval: {:.6f}\ttest: {:.6f}".format(train_roc_list[best_val_idx], val_roc_list[best_val_idx], test_roc_list[best_val_idx])) + + else: + train_func = train_regression + eval_func = eval_regression + criterion = torch.nn.MSELoss() + + train_result_list, val_result_list, test_result_list = [], [], [] + metric_list = ['RMSE', 'MAE'] + best_val_rmse, best_val_idx = 1e10, 0 + + for epoch in range(1, args.epochs + 1): + loss_acc = train_func(model, device, train_loader, optimizer) + print('Epoch: {}\nLoss: {}'.format(epoch, loss_acc)) + + if args.eval_train: + train_result, train_target, train_pred = eval_func(model, device, train_loader) + else: + train_result = {'RMSE': 0, 'MAE': 0, 'R2': 0} + val_result, val_target, val_pred = eval_func(model, device, val_loader) + test_result, test_target, test_pred = eval_func(model, device, test_loader) + + train_result_list.append(train_result) + val_result_list.append(val_result) + test_result_list.append(test_result) + + for metric in metric_list: + print('{} train: {:.6f}\tval: {:.6f}\ttest: {:.6f}'.format(metric, train_result[metric], val_result[metric], test_result[metric])) + print() + + if val_result['RMSE'] < best_val_rmse: + best_val_rmse = val_result['RMSE'] + best_val_idx = epoch - 1 + if args.output_model_dir is not None: + ##### save best model ##### + output_model_path = os.path.join(args.output_model_dir, "{}_model_best.pth".format(args.dataset)) + saved_model_dict = { + 'model': model.state_dict() + } + torch.save(saved_model_dict, output_model_path) + + filename = os.path.join(args.output_model_dir, "{}_evaluation_best.pth".format(args.dataset)) + np.savez( + filename, val_target=val_target, val_pred=val_pred, + test_target=test_target, test_pred=test_pred) + + for metric in metric_list: + print('Best (RMSE), {} train: {:.6f}\tval: {:.6f}\ttest: {:.6f}'.format( + metric, train_result_list[best_val_idx][metric], val_result_list[best_val_idx][metric], test_result_list[best_val_idx][metric])) + + ##### save final model ##### + if args.output_model_dir is not None: + output_model_path = os.path.join(args.output_model_dir, '{}_model_final.pth'.format(args.dataset)) + saved_model_dict = { + 'model': model.state_dict() + } + torch.save(saved_model_dict, output_model_path) + \ No newline at end of file diff --git a/scripts/pretrain.py b/scripts/pretrain.py new file mode 100644 index 0000000..9cc93fe --- /dev/null +++ b/scripts/pretrain.py @@ -0,0 +1,333 @@ +import os +import time +import numpy as np +from tqdm import tqdm +import argparse + +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +from torch.utils.data import DataLoader as torch_DataLoader + +from torch_geometric.loader import DataLoader as pyg_DataLoader +from transformers import AutoModel, AutoTokenizer + +from MoleculeSTM.datasets import ( + PubChemSTM_Datasets_SMILES, PubChemSTM_SubDatasets_SMILES, + PubChemSTM_Datasets_Graph, PubChemSTM_SubDatasets_Graph, + PubChemSTM_Datasets_Raw_SMILES, PubChemSTM_SubDatasets_Raw_SMILES, + PubChemSTM_Datasets_Raw_Graph, PubChemSTM_SubDatasets_Raw_Graph +) +from MoleculeSTM.models import GNN, GNN_graphpred +from MoleculeSTM.utils import prepare_text_tokens, get_molecule_repr_MoleculeSTM, freeze_network +from MoleculeSTM.models.mega_molbart.mega_mol_bart import MegaMolBART + + +def cycle_index(num, shift): + arr = torch.arange(num) + shift + arr[-shift:] = torch.arange(shift) + return arr + + +def do_CL(X, Y, args): + if args.normalize: + X = F.normalize(X, dim=-1) + Y = F.normalize(Y, dim=-1) + + if args.SSL_loss == 'EBM_NCE': + criterion = nn.BCEWithLogitsLoss() + neg_Y = torch.cat([Y[cycle_index(len(Y), i + 1)] for i in range(args.CL_neg_samples)], dim=0) + neg_X = X.repeat((args.CL_neg_samples, 1)) + + pred_pos = torch.sum(X * Y, dim=1) / args.T + pred_neg = torch.sum(neg_X * neg_Y, dim=1) / args.T + + loss_pos = criterion(pred_pos, torch.ones(len(pred_pos)).to(pred_pos.device)) + loss_neg = criterion(pred_neg, torch.zeros(len(pred_neg)).to(pred_neg.device)) + CL_loss = (loss_pos + args.CL_neg_samples * loss_neg) / (1 + args.CL_neg_samples) + + CL_acc = (torch.sum(pred_pos > 0).float() + torch.sum(pred_neg < 0).float()) / \ + (len(pred_pos) + len(pred_neg)) + CL_acc = CL_acc.detach().cpu().item() + + elif args.SSL_loss == 'InfoNCE': + criterion = nn.CrossEntropyLoss() + B = X.size()[0] + logits = torch.mm(X, Y.transpose(1, 0)) # B*B + logits = torch.div(logits, args.T) + labels = torch.arange(B).long().to(logits.device) # B*1 + + CL_loss = criterion(logits, labels) + pred = logits.argmax(dim=1, keepdim=False) + CL_acc = pred.eq(labels).sum().detach().cpu().item() * 1. / B + + else: + raise Exception + + return CL_loss, CL_acc + + +def save_model(save_best, epoch=None): + if args.output_model_dir is not None: + if save_best: + global optimal_loss + print("save model with loss: {:.5f}".format(optimal_loss)) + model_file = "model.pth" + + elif epoch is None: + model_file = "model_final.pth" + + else: + model_file = "model_{}.pth".format(epoch) + + saved_file_path = os.path.join(args.output_model_dir, "text_{}".format(model_file)) + torch.save(text_model.state_dict(), saved_file_path) + + saved_file_path = os.path.join(args.output_model_dir, "molecule_{}".format(model_file)) + torch.save(molecule_model.state_dict(), saved_file_path) + + saved_file_path = os.path.join(args.output_model_dir, "text2latent_{}".format(model_file)) + torch.save(text2latent.state_dict(), saved_file_path) + + saved_file_path = os.path.join(args.output_model_dir, "mol2latent_{}".format(model_file)) + torch.save(mol2latent.state_dict(), saved_file_path) + return + + +def train( + epoch, + dataloader, + text_model, text_tokenizer, + molecule_model, MegaMolBART_wrapper=None): + + if args.representation_frozen: + text_model.eval() + molecule_model.eval() + else: + text_model.train() + molecule_model.train() + text2latent.train() + mol2latent.train() + + if args.verbose: + L = tqdm(dataloader) + else: + L = dataloader + + start_time = time.time() + accum_loss, accum_acc = 0, 0 + for step, batch in enumerate(L): + description = batch[0] + molecule_data = batch[1] + + description_tokens_ids, description_masks = prepare_text_tokens( + device=device, description=description, tokenizer=text_tokenizer, max_seq_len=args.max_seq_len) + description_output = text_model(input_ids=description_tokens_ids, attention_mask=description_masks) + description_repr = description_output["pooler_output"] + description_repr = text2latent(description_repr) + + if molecule_type == "SMILES": + molecule_data = list(molecule_data) # for SMILES_list + molecule_repr = get_molecule_repr_MoleculeSTM( + molecule_data, mol2latent=mol2latent, + molecule_type=molecule_type, MegaMolBART_wrapper=MegaMolBART_wrapper) + else: + molecule_data = molecule_data.to(device) + molecule_repr = get_molecule_repr_MoleculeSTM( + molecule_data, mol2latent=mol2latent, + molecule_type=molecule_type, molecule_model=molecule_model) + + loss_01, acc_01 = do_CL(description_repr, molecule_repr, args) + loss_02, acc_02 = do_CL(molecule_repr, description_repr, args) + loss = (loss_01 + loss_02) / 2 + acc = (acc_01 + acc_02) / 2 + optimizer.zero_grad() + loss.backward() + optimizer.step() + + accum_loss += loss.item() + accum_acc += acc + + accum_loss /= len(L) + accum_acc /= len(L) + + global optimal_loss + temp_loss = accum_loss + if temp_loss < optimal_loss: + optimal_loss = temp_loss + save_model(save_best=True, epoch=epoch) + print("CL Loss: {:.5f}\tCL Acc: {:.5f}\tTime: {:.5f}".format(accum_loss, accum_acc, time.time() - start_time)) + return + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--device", type=int, default=0) + + parser.add_argument("--dataspace_path", type=str, default="../data") + parser.add_argument("--dataset", type=str, default="PubChemSTM") + parser.add_argument("--text_type", type=str, default="SciBERT", choices=["SciBERT"]) + parser.add_argument("--molecule_type", type=str, default="SMILES", choices=["SMILES", "Graph"]) + parser.add_argument("--representation_frozen", dest='representation_frozen', action='store_true') + parser.add_argument('--no_representation_frozen', dest='representation_frozen', action='store_false') + parser.set_defaults(representation_frozen=False) + + parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("--text_lr", type=float, default=1e-4) + parser.add_argument("--mol_lr", type=float, default=1e-4) + parser.add_argument("--text_lr_scale", type=float, default=0.1) + parser.add_argument("--mol_lr_scale", type=float, default=0.1) + parser.add_argument("--num_workers", type=int, default=8) + parser.add_argument("--epochs", type=int, default=100) + parser.add_argument("--decay", type=float, default=0) + parser.add_argument('--verbose', dest='verbose', action='store_true') + parser.set_defaults(verbose=False) + parser.add_argument("--output_model_dir", type=str, default=None) + + ########## for SciBERT ########## + parser.add_argument("--max_seq_len", type=int, default=512) + + ########## for MegaMolBART ########## + parser.add_argument("--megamolbart_input_dir", type=str, default="../data/pretrained_MegaMolBART/checkpoints") + parser.add_argument("--vocab_path", type=str, default="../MoleculeSTM/bart_vocab.txt") + + ########## for 2D GNN ########## + parser.add_argument("--pretrain_gnn_mode", type=str, default="GraphMVP_G", choices=["GraphMVP_G"]) + parser.add_argument("--gnn_emb_dim", type=int, default=300) + parser.add_argument("--num_layer", type=int, default=5) + parser.add_argument('--JK', type=str, default='last') + parser.add_argument("--dropout_ratio", type=float, default=0.5) + parser.add_argument("--gnn_type", type=str, default="gin") + parser.add_argument('--graph_pooling', type=str, default='mean') + + ########## for contrastive SSL ########## + parser.add_argument("--SSL_loss", type=str, default="EBM_NCE", choices=["EBM_NCE", "InfoNCE"]) + parser.add_argument("--SSL_emb_dim", type=int, default=256) + parser.add_argument("--CL_neg_samples", type=int, default=1) + parser.add_argument("--T", type=float, default=0.1) + parser.add_argument('--normalize', dest='normalize', action='store_true') + parser.add_argument('--no_normalize', dest='normalize', action='store_false') + parser.set_defaults(normalize=True) + + args = parser.parse_args() + print("arguments\t", args) + + torch.manual_seed(args.seed) + np.random.seed(args.seed) + device = torch.device("cuda:" + str(args.device)) \ + if torch.cuda.is_available() else torch.device("cpu") + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + + if "PubChemSTM" in args.dataset: + dataset_root = os.path.join(args.dataspace_path, "PubChemSTM_data") + else: + raise Exception + + kwargs = {} + + # ##### prepare text model ##### + if args.text_type == "SciBERT": + pretrained_SciBERT_folder = os.path.join(args.dataspace_path, 'pretrained_SciBERT') + text_tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased', cache_dir=pretrained_SciBERT_folder) + text_model = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased', cache_dir=pretrained_SciBERT_folder).to(device) + kwargs["text_tokenizer"] = text_tokenizer + kwargs["text_model"] = text_model + text_dim = 768 + else: + raise Exception + + ##### prepare molecule model ##### + molecule_type = args.molecule_type + if molecule_type == "SMILES": + if args.dataset == "PubChemSTM": + dataset = PubChemSTM_Datasets_SMILES(dataset_root) + elif args.dataset == "PubChemSTM1K": + dataset = PubChemSTM_SubDatasets_SMILES(dataset_root, size=1000) + elif args.dataset == "PubChemSTM10K": + dataset = PubChemSTM_SubDatasets_SMILES(dataset_root, size=10000) + elif args.dataset == "PubChemSTM_Raw": + dataset = PubChemSTM_Datasets_Raw_SMILES(dataset_root) + elif args.dataset == "PubChemSTM1K_Raw": + dataset = PubChemSTM_SubDatasets_Raw_SMILES(dataset_root, size=1000) + elif args.dataset == "PubChemSTM10K_Raw": + dataset = PubChemSTM_SubDatasets_Raw_SMILES(dataset_root, size=10000) + else: + raise Exception + dataloader_class = torch_DataLoader + + if args.output_model_dir is not None: + MegaMolBART_dir = os.path.join(args.output_model_dir, "MegaMolBART") + else: + MegaMolBART_dir = None + MegaMolBART_wrapper = MegaMolBART( + vocab_path=args.vocab_path, + input_dir=args.megamolbart_input_dir, + output_dir=MegaMolBART_dir) + molecule_model = MegaMolBART_wrapper.model + kwargs["MegaMolBART_wrapper"] = MegaMolBART_wrapper + kwargs["molecule_model"] = molecule_model + molecule_dim = 256 + + elif molecule_type == "Graph": + if args.dataset == "PubChemSTM": + dataset = PubChemSTM_Datasets_Graph(dataset_root) + elif args.dataset == "PubChemSTM1K": + dataset = PubChemSTM_SubDatasets_Graph(dataset_root, size=1000) + elif args.dataset == "PubChemSTM10K": + dataset = PubChemSTM_SubDatasets_Graph(dataset_root, size=10000) + elif args.dataset == "PubChemSTM_Raw": + dataset = PubChemSTM_Datasets_Raw_Graph(dataset_root) + elif args.dataset == "PubChemSTM1K_Raw": + dataset = PubChemSTM_SubDatasets_Raw_Graph(dataset_root, size=1000) + elif args.dataset == "PubChemSTM10K_Raw": + dataset = PubChemSTM_SubDatasets_Raw_Graph(dataset_root, size=10000) + dataloader_class = pyg_DataLoader + molecule_node_model = GNN( + num_layer=args.num_layer, emb_dim=args.gnn_emb_dim, + JK=args.JK, drop_ratio=args.dropout_ratio, + gnn_type=args.gnn_type) + molecule_model = GNN_graphpred( + num_layer=args.num_layer, emb_dim=args.gnn_emb_dim, JK=args.JK, graph_pooling=args.graph_pooling, + num_tasks=1, molecule_node_model=molecule_node_model) + pretrained_model_path = os.path.join(args.dataspace_path, "pretrained_GraphMVP", args.pretrain_gnn_mode, "model.pth") + molecule_model.from_pretrained(pretrained_model_path) + + molecule_model = molecule_model.to(device) + + kwargs["molecule_model"] = molecule_model + molecule_dim = args.gnn_emb_dim + + else: + raise Exception + dataloader = dataloader_class(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) + + text2latent = nn.Linear(text_dim, args.SSL_emb_dim).to(device) + mol2latent = nn.Linear(molecule_dim, args.SSL_emb_dim).to(device) + + if args.representation_frozen: + print("Representation is fronzen during pretraining.") + freeze_network(text_model) + freeze_network(molecule_model) + model_param_group = [ + {"params": text2latent.parameters(), "lr": args.text_lr * args.text_lr_scale}, + {"params": mol2latent.parameters(), "lr": args.mol_lr * args.mol_lr_scale}, + ] + else: + model_param_group = [ + {"params": text_model.parameters(), "lr": args.text_lr}, + {"params": molecule_model.parameters(), "lr": args.mol_lr}, + {"params": text2latent.parameters(), "lr": args.text_lr * args.text_lr_scale}, + {"params": mol2latent.parameters(), "lr": args.mol_lr * args.mol_lr_scale}, + ] + optimizer = optim.Adam(model_param_group, weight_decay=args.decay) + optimal_loss = 1e10 + + for e in range(1, args.epochs+1): + print("Epoch {}".format(e)) + train(e, dataloader, **kwargs) + + save_model(save_best=False) diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..b7d9f85 --- /dev/null +++ b/setup.py @@ -0,0 +1,8 @@ +from setuptools import setup, find_packages + +setup(name='MoleculeSTM', + description='', + author='Shengchao Liu', + author_email='liusheng@mila.quebec', + license='MIT', + packages=find_packages()) \ No newline at end of file