From eab92560caede046ec44b0cedc5fe1e5dfeabdcb Mon Sep 17 00:00:00 2001 From: ulmer Date: Mon, 8 Aug 2022 10:56:32 -0400 Subject: [PATCH 001/302] delete build strings from `environment.yml` --- environment.yml | 188 ++++++++++++++++++++++++------------------------ 1 file changed, 94 insertions(+), 94 deletions(-) diff --git a/environment.yml b/environment.yml index d4e5fcc7..19ac7f65 100644 --- a/environment.yml +++ b/environment.yml @@ -5,100 +5,100 @@ channels: - conda-forge - defaults dependencies: - - _libgcc_mutex=0.1=conda_forge - - _openmp_mutex=4.5=1_gnu - - blas=1.0=mkl - - boost=1.74.0=py39h5472131_3 - - boost-cpp=1.74.0=h312852a_4 - - bzip2=1.0.8=h7f98852_4 - - ca-certificates=2021.5.30=ha878542_0 - - cairo=1.16.0=h6cf1ce9_1008 - - certifi=2021.5.30=py39hf3d152e_0 - - cudatoolkit=11.1.74=h6bb024c_0 - - cycler=0.10.0=py_2 - - fontconfig=2.13.1=hba837de_1005 - - freetype=2.10.4=h0708190_1 - - gettext=0.19.8.1=h0b5b191_1005 - - greenlet=1.1.0=py39he80948d_0 - - icu=68.1=h58526e2_0 - - intel-openmp=2021.3.0=h06a4308_3350 - - jbig=2.1=h7f98852_2003 - - jpeg=9d=h36c2ea0_0 - - kiwisolver=1.3.1=py39h1a9c180_1 - - lcms2=2.12=hddcbb42_0 - - ld_impl_linux-64=2.36.1=hea4e1c9_1 - - lerc=2.2.1=h9c3ff4c_0 - - libdeflate=1.7=h7f98852_5 - - libffi=3.3=h58526e2_2 - - libgcc-ng=9.3.0=h2828fa1_19 - - libgfortran-ng=9.3.0=hff62375_19 - - libgfortran5=9.3.0=hff62375_19 - - libglib=2.68.3=h3e27bee_0 - - libgomp=9.3.0=h2828fa1_19 - - libiconv=1.16=h516909a_0 - - libopenblas=0.3.15=pthreads_h8fe5266_1 - - libpng=1.6.37=h21135ba_2 - - libstdcxx-ng=9.3.0=h6de172a_19 - - libtiff=4.3.0=hf544144_1 - - libuuid=2.32.1=h7f98852_1000 - - libuv=1.40.0=h7b6447c_0 - - libwebp-base=1.2.0=h7f98852_2 - - libxcb=1.13=h7f98852_1003 - - libxml2=2.9.12=h72842e0_0 - - lz4-c=1.9.3=h9c3ff4c_0 - - matplotlib-base=3.4.2=py39h2fa2bec_0 - - mkl=2021.3.0=h06a4308_520 - - mkl-service=2.4.0=py39h7f8727e_0 - - mkl_fft=1.3.0=py39h42c9631_2 - - mkl_random=1.2.2=py39h51133e4_0 - - ncurses=6.2=h58526e2_4 - - ninja=1.10.2=hff7bd54_1 - - numpy=1.20.3=py39hf144106_0 - - numpy-base=1.20.3=py39h74d4b33_0 - - olefile=0.46=pyh9f0ad1d_1 - - openjpeg=2.4.0=hb52868f_1 - - openssl=1.1.1k=h7f98852_0 - - pandas=1.3.0=py39hde0f152_0 - - pcre=8.45=h9c3ff4c_0 - - pillow=8.3.1=py39ha612740_0 - - pip=21.1.3=pyhd8ed1ab_0 - - pixman=0.40.0=h36c2ea0_0 - - pthread-stubs=0.4=h36c2ea0_1001 - - pycairo=1.20.1=py39hedcb9fc_0 - - pyparsing=2.4.7=pyh9f0ad1d_0 - - python=3.9.6=h49503c6_1_cpython - - python-dateutil=2.8.2=pyhd8ed1ab_0 - - python_abi=3.9=2_cp39 - - pytorch=1.9.0=py3.9_cuda11.1_cudnn8.0.5_0 - - pytz=2021.1=pyhd8ed1ab_0 - - rdkit=2021.03.4=py39hccf6a74_0 - - readline=8.1=h46c0cb4_0 - - reportlab=3.5.68=py39he59360d_0 - - setuptools=49.6.0=py39hf3d152e_3 - - six=1.16.0=pyh6c4a22f_0 - - sqlalchemy=1.4.21=py39h3811e60_0 - - sqlite=3.36.0=h9cd32fc_0 - - tk=8.6.10=h21135ba_1 - - torchaudio=0.9.0=py39 - - torchvision=0.2.2=py_3 - - tornado=6.1=py39h3811e60_1 - - typing_extensions=3.10.0.0=pyh06a4308_0 - - tzdata=2021a=he74cb21_1 - - wheel=0.36.2=pyhd3deb0d_0 - - xorg-kbproto=1.0.7=h7f98852_1002 - - xorg-libice=1.0.10=h7f98852_0 - - xorg-libsm=1.2.3=hd9c2040_1000 - - xorg-libx11=1.7.2=h7f98852_0 - - xorg-libxau=1.0.9=h7f98852_0 - - xorg-libxdmcp=1.1.3=h7f98852_0 - - xorg-libxext=1.3.4=h7f98852_1 - - xorg-libxrender=0.9.10=h7f98852_1003 - - xorg-renderproto=0.11.1=h7f98852_1002 - - xorg-xextproto=7.3.0=h7f98852_1002 - - xorg-xproto=7.0.31=h7f98852_1007 - - xz=5.2.5=h516909a_1 - - zlib=1.2.11=h516909a_1010 - - zstd=1.5.0=ha95c52a_0 + - _libgcc_mutex=0.1 + - _openmp_mutex=4.5 + - blas=1.0 + - boost=1.74.0 + - boost-cpp=1.74.0 + - bzip2=1.0.8 + - ca-certificates=2021.5.30 + - cairo=1.16.0 + - certifi=2021.5.30 + - cudatoolkit=11.1.74 + - cycler=0.10.0 + - fontconfig=2.13.1 + - freetype=2.10.4 + - gettext=0.19.8.1 + - greenlet=1.1.0 + - icu=68.1 + - intel-openmp + - jbig=2.1 + - jpeg=9d + - kiwisolver=1.3.1 + - lcms2=2.12 + - ld_impl_linux-64=2.36.1 + - lerc=2.2.1 + - libdeflate=1.7 + - libffi=3.3 + - libgcc-ng=9.3.0 + - libgfortran-ng=9.3.0 + - libgfortran5=9.3.0 + - libglib=2.68.3 + - libgomp=9.3.0 + - libiconv=1.16 + - libopenblas=0.3.15 + - libpng=1.6.37 + - libstdcxx-ng=9.3.0 + - libtiff=4.3.0 + - libuuid=2.32.1 + - libuv=1.40.0 + - libwebp-base=1.2.0 + - libxcb=1.13 + - libxml2=2.9.12 + - lz4-c=1.9.3 + - matplotlib-base=3.4.2 + - mkl=2021.3.0 + - mkl-service=2.4.0 + - mkl_fft=1.3.0 + - mkl_random=1.2.2 + - ncurses=6.2 + - ninja=1.10.2 + - numpy=1.20.3 + - numpy-base=1.20.3 + - olefile=0.46 + - openjpeg=2.4.0 + - openssl=1.1.1k + - pandas=1.3.0 + - pcre=8.45 + - pillow=8.3.1 + - pip=21.1.3 + - pixman=0.40.0 + - pthread-stubs=0.4 + - pycairo=1.20.1 + - pyparsing=2.4.7 + - python=3.9.6 + - python-dateutil=2.8.2 + - python_abi=3.9 + - pytorch=1.9.0 + - pytz=2021.1 + - rdkit=2021.03.4 + - readline=8.1 + - reportlab=3.5.68 + - setuptools=49.6.0 + - six=1.16.0 + - sqlalchemy=1.4.21 + - sqlite=3.36.0 + - tk=8.6.10 + - torchaudio=0.9.0 + - torchvision=0.2.2 + - tornado=6.1 + - typing_extensions=3.10.0.0 + - tzdata=2021a + - wheel=0.36.2 + - xorg-kbproto=1.0.7 + - xorg-libice=1.0.10 + - xorg-libsm=1.2.3 + - xorg-libx11=1.7.2 + - xorg-libxau=1.0.9 + - xorg-libxdmcp=1.1.3 + - xorg-libxext=1.3.4 + - xorg-libxrender=0.9.10 + - xorg-renderproto=0.11.1 + - xorg-xextproto=7.3.0 + - xorg-xproto=7.0.31 + - xz=5.2.5 + - zlib=1.2.11 + - zstd=1.5.0 - pip: - absl-py==0.13.0 - aiohttp==3.7.4.post0 From 21ce6e022f0b70da1e6f45db7151e9b62681c6c8 Mon Sep 17 00:00:00 2001 From: ulmer Date: Mon, 8 Aug 2022 10:59:00 -0400 Subject: [PATCH 002/302] uncomment packages that prevent building conda env --- environment.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/environment.yml b/environment.yml index 19ac7f65..7ad42d28 100644 --- a/environment.yml +++ b/environment.yml @@ -41,7 +41,7 @@ dependencies: - libstdcxx-ng=9.3.0 - libtiff=4.3.0 - libuuid=2.32.1 - - libuv=1.40.0 + # - libuv=1.40.0 - libwebp-base=1.2.0 - libxcb=1.13 - libxml2=2.9.12 @@ -49,12 +49,12 @@ dependencies: - matplotlib-base=3.4.2 - mkl=2021.3.0 - mkl-service=2.4.0 - - mkl_fft=1.3.0 - - mkl_random=1.2.2 + # - mkl_fft=1.3.0 + # - mkl_random=1.2.2 - ncurses=6.2 - ninja=1.10.2 - numpy=1.20.3 - - numpy-base=1.20.3 + # - numpy-base=1.20.3 - olefile=0.46 - openjpeg=2.4.0 - openssl=1.1.1k From c18d30c5364d6efb7685377af2f1f0629a77719c Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Fri, 12 Aug 2022 16:09:53 -0400 Subject: [PATCH 003/302] move static test files to `tests/assets/` --- .../{data => assets}/building_blocks_matched.csv.gz | Bin tests/{data => assets}/rxn_set_hb_test.txt | 0 tests/test_DataPreparation.py | 8 ++++---- tests/test_Predict.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) rename tests/{data => assets}/building_blocks_matched.csv.gz (100%) rename tests/{data => assets}/rxn_set_hb_test.txt (100%) diff --git a/tests/data/building_blocks_matched.csv.gz b/tests/assets/building_blocks_matched.csv.gz similarity index 100% rename from tests/data/building_blocks_matched.csv.gz rename to tests/assets/building_blocks_matched.csv.gz diff --git a/tests/data/rxn_set_hb_test.txt b/tests/assets/rxn_set_hb_test.txt similarity index 100% rename from tests/data/rxn_set_hb_test.txt rename to tests/assets/rxn_set_hb_test.txt diff --git a/tests/test_DataPreparation.py b/tests/test_DataPreparation.py index be9af8d9..f3e40993 100644 --- a/tests/test_DataPreparation.py +++ b/tests/test_DataPreparation.py @@ -31,10 +31,10 @@ def test_process_rxn_templates(self): """ # the following file contains the three templates at the top of # 'SynNet/data/rxn_set_hb.txt' - path_to_rxn_templates = f"{TEST_DIR}/data/rxn_set_hb_test.txt" + path_to_rxn_templates = f"{TEST_DIR}/assets/rxn_set_hb_test.txt" # load the reference building blocks (100 here) - path_to_building_blocks = f"{TEST_DIR}/data/building_blocks_matched.csv.gz" + path_to_building_blocks = f"{TEST_DIR}/assets/building_blocks_matched.csv.gz" building_blocks = pd.read_csv(path_to_building_blocks, compression="gzip")[ "SMILES" ].tolist() @@ -76,7 +76,7 @@ def test_synthetic_tree_prep(self): rxns = r_ref.rxns # load the reference building blocks (100 here) - path_to_building_blocks = f"{TEST_DIR}/data/building_blocks_matched.csv.gz" + path_to_building_blocks = f"{TEST_DIR}/assets/building_blocks_matched.csv.gz" building_blocks = pd.read_csv(path_to_building_blocks, compression="gzip")[ "SMILES" ].tolist() @@ -255,7 +255,7 @@ def test_bb_emb(self): model.eval() # load the building blocks - path_to_building_blocks = f"{TEST_DIR}/data/building_blocks_matched.csv.gz" + path_to_building_blocks = f"{TEST_DIR}/assets/building_blocks_matched.csv.gz" building_blocks = pd.read_csv(path_to_building_blocks, compression="gzip")[ "SMILES" ].tolist() diff --git a/tests/test_Predict.py b/tests/test_Predict.py index a2b5454a..6ca8856a 100644 --- a/tests/test_Predict.py +++ b/tests/test_Predict.py @@ -42,7 +42,7 @@ def test_predict(self): # define path to the reaction templates and purchasable building blocks path_to_reaction_file = f"{ref_dir}rxns_hb.json.gz" - path_to_building_blocks = f"{TEST_DIR}/data/building_blocks_matched.csv.gz" + path_to_building_blocks = f"{TEST_DIR}/assets/building_blocks_matched.csv.gz" # define paths to pretrained modules path_to_act = f"{ref_dir}act.ckpt" From 9d7f144cbbd2a485f887a8a6b3ae98374e11e0f5 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Fri, 12 Aug 2022 16:14:04 -0400 Subject: [PATCH 004/302] adds test for # of reactions in reaction template file --- tests/test_Training.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/tests/test_Training.py b/tests/test_Training.py index 75ff83fd..0f9e245c 100644 --- a/tests/test_Training.py +++ b/tests/test_Training.py @@ -15,6 +15,16 @@ TEST_DIR = Path(__file__).parent +REACTION_TEMPLATES_FILE = f"{TEST_DIR}/assets/rxn_set_hb_test.txt" + +class TestReactionTemplateFile(unittest.TestCase): + + def test_number_of_reaction_templates(self): + """ Count number of lines in file, i.e. the number of reaction templates.""" + with open(REACTION_TEMPLATES_FILE,"r") as f: + nReactionTemplates = sum(1 for _ in f) + self.assertEqual(nReactionTemplates,3) + class TestTraining(unittest.TestCase): """ @@ -144,7 +154,7 @@ def test_reaction_network(self): batch_size = 10 epochs = 2 ncpu = 2 - n_templates = 3 # num templates in 'data/rxn_set_hb_test.txt' + n_templates = 3 # num templates in `REACTION_TEMPLATES_FILE` validation_option = "accuracy" ref_dir = f"{TEST_DIR}/data/ref/" @@ -203,7 +213,7 @@ def test_reactant2_network(self): batch_size = 10 epochs = 2 ncpu = 2 - n_templates = 3 # num templates in 'data/rxn_set_hb_test.txt' + n_templates = 3 # num templates in `REACTION_TEMPLATES_FILE` validation_option = "nn_accuracy" ref_dir = f"{TEST_DIR}/data/ref/" From a60db1b4c1d420def07595536897b4c880a74106 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Fri, 12 Aug 2022 16:15:24 -0400 Subject: [PATCH 005/302] add `assert`statements for the shape of `X,y` data --- tests/test_Training.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/test_Training.py b/tests/test_Training.py index 0f9e245c..74a5a34e 100644 --- a/tests/test_Training.py +++ b/tests/test_Training.py @@ -46,7 +46,9 @@ def test_action_network(self): ref_dir = f"{TEST_DIR}/data/ref/" X = sparse.load_npz(ref_dir + "X_act_train.npz") + assert X.shape==(4,3*nbits) # (4,12288) y = sparse.load_npz(ref_dir + "y_act_train.npz") + assert y.shape==(4,1) # (4,1) X = torch.Tensor(X.A) y = torch.LongTensor( y.A.reshape( @@ -105,8 +107,10 @@ def test_reactant1_network(self): # load the reaction data X = sparse.load_npz(ref_dir + "X_rt1_train.npz") + assert X.shape==(2,3*nbits) # (4,12288) X = torch.Tensor(X.A) y = sparse.load_npz(ref_dir + "y_rt1_train.npz") + assert y.shape==(2,300) # (2,300) y = torch.Tensor(y.A) train_data_iter = load_array((X, y), batch_size, ncpu=ncpu, is_train=True) @@ -159,7 +163,9 @@ def test_reaction_network(self): ref_dir = f"{TEST_DIR}/data/ref/" X = sparse.load_npz(ref_dir + "X_rxn_train.npz") + assert X.shape==(2,4*nbits) # (2,16384) y = sparse.load_npz(ref_dir + "y_rxn_train.npz") + assert y.shape==(2, 1) # (2, 1) X = torch.Tensor(X.A) y = torch.LongTensor( y.A.reshape( @@ -218,7 +224,9 @@ def test_reactant2_network(self): ref_dir = f"{TEST_DIR}/data/ref/" X = sparse.load_npz(ref_dir + "X_rt2_train.npz") + assert X.shape==(2,4*nbits+n_templates) # (2,16387) y = sparse.load_npz(ref_dir + "y_rt2_train.npz") + assert y.shape==(2,300) # (2,300) X = torch.Tensor(X.A) y = torch.Tensor(y.A) train_data_iter = load_array((X, y), batch_size, ncpu=ncpu, is_train=True) From 0e48da45ee45ae6839f0b13bd000f299b0204019 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Fri, 12 Aug 2022 16:20:51 -0400 Subject: [PATCH 006/302] update test values and relax tests to `AlmostEqual. Note: The unittests fail back on 0c27ced4as well. I do not know how to check the (old) test values, so hoping the new ones are sane for now. --- tests/test_Training.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/tests/test_Training.py b/tests/test_Training.py index 74a5a34e..f2e3b220 100644 --- a/tests/test_Training.py +++ b/tests/test_Training.py @@ -45,7 +45,7 @@ def test_action_network(self): validation_option = "accuracy" ref_dir = f"{TEST_DIR}/data/ref/" - X = sparse.load_npz(ref_dir + "X_act_train.npz") + X = sparse.load_npz(ref_dir + "X_act_train.npz") assert X.shape==(4,3*nbits) # (4,12288) y = sparse.load_npz(ref_dir + "y_act_train.npz") assert y.shape==(4,1) # (4,1) @@ -86,10 +86,10 @@ def test_action_network(self): trainer.fit(mlp, train_data_iter, valid_data_iter) train_loss = float(trainer.callback_metrics["train_loss"]) - train_loss_ref = 1.4203987121582031 + train_loss_ref = 1.2967982292175293 shutil.rmtree(f"act_{embedding}_{radius}_{nbits}_logs/") - self.assertEqual(train_loss, train_loss_ref) + self.assertAlmostEqual(train_loss, train_loss_ref) def test_reactant1_network(self): """ @@ -98,14 +98,14 @@ def test_reactant1_network(self): embedding = "fp" radius = 2 nbits = 4096 - out_dim = 300 + out_dim = 300 # Note: out_dim 300 = gin embedding batch_size = 10 epochs = 2 ncpu = 2 - validation_option = "nn_accuracy" + validation_option = "nn_accuracy_gin" ref_dir = f"{TEST_DIR}/data/ref/" - # load the reaction data + # load the reaction data X = sparse.load_npz(ref_dir + "X_rt1_train.npz") assert X.shape==(2,3*nbits) # (4,12288) X = torch.Tensor(X.A) @@ -143,10 +143,10 @@ def test_reactant1_network(self): trainer.fit(mlp, train_data_iter, valid_data_iter) train_loss = float(trainer.callback_metrics["train_loss"]) - train_loss_ref = 0.35571354627609253 + train_loss_ref = 0.33368119597435 shutil.rmtree(f"rt1_{embedding}_{radius}_{nbits}_logs/") - self.assertEqual(train_loss, train_loss_ref) + self.assertAlmostEqual(train_loss, train_loss_ref) def test_reaction_network(self): """ @@ -206,7 +206,7 @@ def test_reaction_network(self): train_loss_ref = 1.1214743852615356 shutil.rmtree(f"rxn_{embedding}_{radius}_{nbits}_logs/") - self.assertEqual(train_loss, train_loss_ref) + self.assertAlmostEqual(train_loss, train_loss_ref,places=-6) def test_reactant2_network(self): """ @@ -215,12 +215,12 @@ def test_reactant2_network(self): embedding = "fp" radius = 2 nbits = 4096 - out_dim = 300 + out_dim = 300 # Note: out_dim 300 = gin embedding batch_size = 10 epochs = 2 ncpu = 2 - n_templates = 3 # num templates in `REACTION_TEMPLATES_FILE` - validation_option = "nn_accuracy" + n_templates = 3 # num templates in 'data/rxn_set_hb_test.txt' + validation_option = "nn_accuracy_gin" ref_dir = f"{TEST_DIR}/data/ref/" X = sparse.load_npz(ref_dir + "X_rt2_train.npz") @@ -260,7 +260,7 @@ def test_reactant2_network(self): trainer.fit(mlp, train_data_iter, valid_data_iter) train_loss = float(trainer.callback_metrics["train_loss"]) - train_loss_ref = 0.41246509552001953 + train_loss_ref = 0.3026905953884125 shutil.rmtree(f"rt2_{embedding}_{radius}_{nbits}_logs/") - self.assertEqual(train_loss, train_loss_ref) + self.assertAlmostEqual(train_loss, train_loss_ref) From 2be9daa5beb29f54179d8b73361d38bdb981057d Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 15 Aug 2022 10:16:31 -0400 Subject: [PATCH 007/302] fixes markdownlint violations --- README.md | 32 +++++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 325f1b90..5f7c9f0a 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,13 @@ # SynNet + This repo contains the code and analysis scripts for our amortized approach to synthetic tree generation using neural networks. Our model can serve as both a synthesis planning tool and as a tool for synthesizable molecular design. The method is described in detail in the publication "Amortized tree generation for bottom-up synthesis planning and synthesizable molecular design" available on the [arXiv](https://arxiv.org/abs/2110.06389) and summarized below. ## Summary + ### Overview + We model synthetic pathways as tree structures called *synthetic trees*. A valid synthetic tree has one root node (the final product molecule) linked to purchasable building blocks (encoded as SMILES strings) via feasible reactions according to a list of discrete reaction templates (examples of templates encoded as SMARTS strings in [data/rxn_set_hb.txt](./data/rxn_set_hb.txt)). At a high level, each synthetic tree is constructed one reaction step at a time in a bottom-up manner, starting from purchasable building blocks. The model consists of four modules, each containing a multi-layer perceptron (MLP): @@ -19,6 +22,7 @@ The model consists of four modules, each containing a multi-layer perceptron (ML These four modules predict the probability distributions of actions to be taken within a single reaction step, and determine the nodes to be added to the synthetic tree under construction. All of these networks are conditioned on the target molecule embedding. ### Synthesis planning + This task is to infer the synthetic pathway to a given target molecule. We formulate this problem as generating a synthetic tree such that the product molecule it produces (i.e., the molecule at the root node) matches the desired target molecule. For this task, we can take a molecular embedding for the desired product, and use it as input to our model to produce a synthetic tree. If the desired product is successfully recovered, then the final root molecule will match the desired molecule used to create the input embedding. If the desired product is not successully recovered, it is possible the final root molecule may still be *similar* to the desired molecule used to create the input embedding, and thus our tool can also be used for *synthesizable analog recommendation*. @@ -26,6 +30,7 @@ For this task, we can take a molecular embedding for the desired product, and us ![the generation process](./figures/generation_process.png "generation process") ### Synthesizable molecular design + This task is to optimize a molecular structure with respect to an oracle function (e.g. bioactivity), while ensuring the synthetic accessibility of the molecules. We formulate this problem as optimizing the structure of a synthetic tree with respect to the desired properties of the product molecule it produces. To do this, we optimize the molecular embedding of the molecule using a genetic algorithm and the desired oracle function. The optimized molecule embedding can then be used as input to our model to produce a synthetic tree, where the final root molecule corresponds to the optimized molecule. @@ -33,6 +38,7 @@ To do this, we optimize the molecular embedding of the molecule using a genetic ## Setup instructions ### Setting up the environment + You can use conda to create an environment containing the necessary packages and dependencies for running SynNet by using the provided YAML file: ``` @@ -53,6 +59,7 @@ export PYTHONPATH=`pwd`:$PYTHONPATH ``` ### Unit tests + To check that everything has been set-up correctly, you can run the unit tests from within the [tests/](./tests/). If starting in the main SynNet/ directory, you can run the unit tests as follows: ``` @@ -67,11 +74,15 @@ You should get no errors if everything ran correctly. ### Data #### Templates + The Hartenfeller-Button templates are available in the [./data/](./data/) directory. + #### Building blocks -The Enamine data can be freely downloaded from https://enamine.net/building-blocks/building-blocks-catalog for academic purposes. After downloading the Enamine building blocks, you will need to replace the paths to the Enamine building blocks in the code. This can be done by searching for the string "enamine". + +The Enamine data can be freely downloaded from for academic purposes. After downloading the Enamine building blocks, you will need to replace the paths to the Enamine building blocks in the code. This can be done by searching for the string "enamine". ## Code Structure + The code is structured as follows: ``` @@ -140,6 +151,7 @@ SynNet/ The model implementations can be found in [syn_net/models/](syn_net/models/), with processing and analysis scripts located in [scripts/](./scripts/). ## Instructions + Before running anything, you need to add the root directory to the Python path. One option for doing this is to run the following command in the root `SynNet` directory: ``` @@ -147,37 +159,49 @@ export PYTHONPATH=`pwd`:$PYTHONPATH ``` ## Using pre-trained models + We have made available a set of pre-trained models at the following [link](https://figshare.com/articles/software/Trained_model_parameters_for_SynNet/16799413). The pretrained models correspond to the Action, Reactant 1, Reaction, and Reactant 2 networks, trained on the Hartenfeller-Button dataset using radius 2, length 4096 Morgan fingerprints for the molecular node embeddings, and length 256 fingerprints for the k-NN search. For further details, please see the publication. The models can be uncompressed with: + ``` tar -zxvf hb_fp_2_4096_256.tar.gz ``` ### Synthesis Planning + To perform synthesis planning described in the main text: + ``` python predict_multireactant_mp.py -n -1 --ncpu 36 --data test ``` + This script will feed a list of molecules from the test data and save the decoded results (predicted synthesis trees) to [./results/](./results/). One can use --help to see the instruction of each argument. Note: this file reads parameters from a directory, please specify the path to parameters previously. ### Synthesizable Molecular Design + To perform synthesizable molecular design, under [./scripts/](./scripts/), run: + ``` optimize_ga.py -i path/to/zinc.csv --radius 2 --nbits 4096 --num_population 128 --num_offspring 512 --num_gen 200 --ncpu 32 --objective gsk ``` + This script uses a genetic algorithm to optimize molecular embeddings and returns the predicted synthetic trees for the optimized molecular embedding. One can use --help to see the instruction of each argument. If user wants to start from a checkpoint of previous run, run: + ``` optimize_ga.py -i path/to/population.npy --radius 2 --nbits 4096 --num_population 128 --num_offspring 512 --num_gen 200 --ncpu 32 --objective gsk --restart ``` + Note: the input file indicated by -i contains the seed molecules in CSV format for an initial run, and as a pre-saved numpy array of the population for restarting the run. ### Train the model from scratch + Before training any models, you will first need to preprocess the set of reaction templates which you would like to use. You can use either a new set of reaction templates, or the provided Hartenfeller-Button (HB) set of reaction templates (see [data/rxn_set_hb.txt](data/rxn_set_hb.txt)). To preprocess a new dataset, you will need to: + 1. Preprocess the data to identify applicable reactants for each reaction template 2. Generate the synthetic trees by random selection 3. Split the synthetic trees into training, testing, and validation splits @@ -211,6 +235,7 @@ python filter_unmatch.py This will filter out buyable building blocks which didn't match a single template. ### Generating the synthetic path data by random selection + Under [./syn_net/data_generation/](./syn_net/data_generation/), run: ``` @@ -226,6 +251,7 @@ python sample_from_original.py This will filter out the samples where the root node QED is less than 0.5, or randomly with a probability less than 1 - QED/0.5. ### Splitting data into training, validation, and testing sets, and removing duplicates + Under [./scripts/](./scripts/), run: ``` @@ -235,6 +261,7 @@ python st_split.py The default split ratio is 6:2:2 for training, validation, and testing sets. ### Featurizing data + Under [./scripts/](./scripts/), run: ``` @@ -244,6 +271,7 @@ python st2steps.py -r 2 -b 4096 -d train This will featurize the synthetic tree data into step-by-step data which can be used for training. The flag *-r* indicates the fingerprint radius, *-b* indicates the number of bits to use for the fingerprints, and *-d* indicates which dataset split to featurize. ### Preparing training data for each network + Under [./syn_net/models/](./syn_net/models/), run: ``` @@ -261,6 +289,7 @@ python act.py --radius 2 --nbits 4096 This will train the network and save the model parameters at the state with the best validation loss in a logging directory, e.g., **`act_hb_fp_2_4096_logs`**. One can use tensorboard to monitor the training and validation loss. ### Sketching synthetic trees + To visualize the synthetic trees, run: ``` @@ -270,6 +299,7 @@ python scripts/sketch-synthetic-trees.py --file /path/to/st_hb/st_train.json.gz This will sketch 5 synthetic trees with 3 or more actions to the current ("./") directory (you can play around with these variables or just also leave them out to use the defaults). ### Testing the mean reciprocal rank (MRR) of reactant 1 + Under [./scripts/](./scripts/), run: ``` From b627600e89ef5f96126df1a119cb1cda905ca70e Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 15 Aug 2022 10:38:28 -0400 Subject: [PATCH 008/302] package-ize with `setuptools` --- README.md | 4 ++-- setup.py | 3 +++ 2 files changed, 5 insertions(+), 2 deletions(-) create mode 100644 setup.py diff --git a/README.md b/README.md index 5f7c9f0a..f86f6c8f 100644 --- a/README.md +++ b/README.md @@ -51,11 +51,11 @@ If you update the environment and would like to save the updated environment as conda env export > path/to/env.yml ``` -Before running any SynNet code, activate the environment and update the Python path so that the scripts can find the right files. You can do this by typing: +Before running any SynNet code, activate the environment and install the package in development mode. This ensures the scripts can find the right files. You can do this by typing: ``` source activate synthenv -export PYTHONPATH=`pwd`:$PYTHONPATH +python setup.py install ``` ### Unit tests diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..c1057cf5 --- /dev/null +++ b/setup.py @@ -0,0 +1,3 @@ +import setuptools + +setuptools.setup() \ No newline at end of file From 5c2815e6895fe10c24785f37b490962154fd8e27 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 15 Aug 2022 10:45:05 -0400 Subject: [PATCH 009/302] fix: tests are discoverable from root add: note in `README.md` on failing test --- README.md | 9 +++------ tests/__init__.py | 0 2 files changed, 3 insertions(+), 6 deletions(-) create mode 100644 tests/__init__.py diff --git a/README.md b/README.md index f86f6c8f..1f1feead 100644 --- a/README.md +++ b/README.md @@ -60,16 +60,13 @@ python setup.py install ### Unit tests -To check that everything has been set-up correctly, you can run the unit tests from within the [tests/](./tests/). If starting in the main SynNet/ directory, you can run the unit tests as follows: +To check that everything has been set-up correctly, you can run the unit tests. If starting in the main directory, you can run the unit tests as follows: -``` -source activate synthenv -export PYTHONPATH=`pwd`:$PYTHONPATH -cd tests/ +```python python -m unittest ``` -You should get no errors if everything ran correctly. +Except for `tests/test_Training.py`, all tests should succedd. The `test_Training.py` still relies on the embedding of the building blocks, which is tracked in this repostory. ### Data diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b From 10649d5535816ea744e0940a008819f7b7149dab Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 15 Aug 2022 14:52:26 -0400 Subject: [PATCH 010/302] restructure package: move source files to `src/` --- README.md | 11 ++++------- setup.py | 19 ++++++++++++++++++- {syn_net => src/syn_net}/__init__.py | 0 .../syn_net}/data_generation/__init__.py | 0 .../syn_net}/data_generation/_mp_make.py | 0 .../syn_net}/data_generation/_mp_process.py | 0 .../data_generation/check_all_template.py | 0 .../data_generation/filter_unmatch.py | 0 .../syn_net}/data_generation/make_dataset.py | 0 .../data_generation/make_dataset_mp.py | 0 .../data_generation/process_rxn_mp.py | 0 {syn_net => src/syn_net}/models/act.py | 0 {syn_net => src/syn_net}/models/mlp.py | 0 .../syn_net}/models/prepare_data.py | 0 {syn_net => src/syn_net}/models/rt1.py | 0 {syn_net => src/syn_net}/models/rt2.py | 0 {syn_net => src/syn_net}/models/rxn.py | 0 {syn_net => src/syn_net}/utils/__init__.py | 0 {syn_net => src/syn_net}/utils/data_utils.py | 0 {syn_net => src/syn_net}/utils/ga_utils.py | 0 .../syn_net}/utils/predict_beam_utils.py | 0 .../syn_net}/utils/predict_utils.py | 0 {syn_net => src/syn_net}/utils/prep_utils.py | 0 23 files changed, 22 insertions(+), 8 deletions(-) rename {syn_net => src/syn_net}/__init__.py (100%) rename {syn_net => src/syn_net}/data_generation/__init__.py (100%) rename {syn_net => src/syn_net}/data_generation/_mp_make.py (100%) rename {syn_net => src/syn_net}/data_generation/_mp_process.py (100%) rename {syn_net => src/syn_net}/data_generation/check_all_template.py (100%) rename {syn_net => src/syn_net}/data_generation/filter_unmatch.py (100%) rename {syn_net => src/syn_net}/data_generation/make_dataset.py (100%) rename {syn_net => src/syn_net}/data_generation/make_dataset_mp.py (100%) rename {syn_net => src/syn_net}/data_generation/process_rxn_mp.py (100%) rename {syn_net => src/syn_net}/models/act.py (100%) rename {syn_net => src/syn_net}/models/mlp.py (100%) rename {syn_net => src/syn_net}/models/prepare_data.py (100%) rename {syn_net => src/syn_net}/models/rt1.py (100%) rename {syn_net => src/syn_net}/models/rt2.py (100%) rename {syn_net => src/syn_net}/models/rxn.py (100%) rename {syn_net => src/syn_net}/utils/__init__.py (100%) rename {syn_net => src/syn_net}/utils/data_utils.py (100%) rename {syn_net => src/syn_net}/utils/ga_utils.py (100%) rename {syn_net => src/syn_net}/utils/predict_beam_utils.py (100%) rename {syn_net => src/syn_net}/utils/predict_utils.py (100%) rename {syn_net => src/syn_net}/utils/prep_utils.py (100%) diff --git a/README.md b/README.md index 1f1feead..bf0cb1ee 100644 --- a/README.md +++ b/README.md @@ -51,11 +51,12 @@ If you update the environment and would like to save the updated environment as conda env export > path/to/env.yml ``` +pip install -e . Before running any SynNet code, activate the environment and install the package in development mode. This ensures the scripts can find the right files. You can do this by typing: -``` +```shell source activate synthenv -python setup.py install +pip install -e . ``` ### Unit tests @@ -149,11 +150,7 @@ The model implementations can be found in [syn_net/models/](syn_net/models/), wi ## Instructions -Before running anything, you need to add the root directory to the Python path. One option for doing this is to run the following command in the root `SynNet` directory: - -``` -export PYTHONPATH=`pwd`:$PYTHONPATH -``` +Before running anything, set up the environment as decribed above. ## Using pre-trained models diff --git a/setup.py b/setup.py index c1057cf5..958cf08d 100644 --- a/setup.py +++ b/setup.py @@ -1,3 +1,20 @@ import setuptools -setuptools.setup() \ No newline at end of file +with open("README.md", "r", encoding="utf-8") as fh: + long_description = fh.read() + +setuptools.setup( + name="syn_net", + version="0.1.0", + description="Synthetic tree generation using neural networks.", + long_description=long_description, + long_description_content_type="text/markdown", + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + ], + package_dir={"": "src"}, + packages=setuptools.find_packages(where="src"), + python_requires=">=3.9", +) \ No newline at end of file diff --git a/syn_net/__init__.py b/src/syn_net/__init__.py similarity index 100% rename from syn_net/__init__.py rename to src/syn_net/__init__.py diff --git a/syn_net/data_generation/__init__.py b/src/syn_net/data_generation/__init__.py similarity index 100% rename from syn_net/data_generation/__init__.py rename to src/syn_net/data_generation/__init__.py diff --git a/syn_net/data_generation/_mp_make.py b/src/syn_net/data_generation/_mp_make.py similarity index 100% rename from syn_net/data_generation/_mp_make.py rename to src/syn_net/data_generation/_mp_make.py diff --git a/syn_net/data_generation/_mp_process.py b/src/syn_net/data_generation/_mp_process.py similarity index 100% rename from syn_net/data_generation/_mp_process.py rename to src/syn_net/data_generation/_mp_process.py diff --git a/syn_net/data_generation/check_all_template.py b/src/syn_net/data_generation/check_all_template.py similarity index 100% rename from syn_net/data_generation/check_all_template.py rename to src/syn_net/data_generation/check_all_template.py diff --git a/syn_net/data_generation/filter_unmatch.py b/src/syn_net/data_generation/filter_unmatch.py similarity index 100% rename from syn_net/data_generation/filter_unmatch.py rename to src/syn_net/data_generation/filter_unmatch.py diff --git a/syn_net/data_generation/make_dataset.py b/src/syn_net/data_generation/make_dataset.py similarity index 100% rename from syn_net/data_generation/make_dataset.py rename to src/syn_net/data_generation/make_dataset.py diff --git a/syn_net/data_generation/make_dataset_mp.py b/src/syn_net/data_generation/make_dataset_mp.py similarity index 100% rename from syn_net/data_generation/make_dataset_mp.py rename to src/syn_net/data_generation/make_dataset_mp.py diff --git a/syn_net/data_generation/process_rxn_mp.py b/src/syn_net/data_generation/process_rxn_mp.py similarity index 100% rename from syn_net/data_generation/process_rxn_mp.py rename to src/syn_net/data_generation/process_rxn_mp.py diff --git a/syn_net/models/act.py b/src/syn_net/models/act.py similarity index 100% rename from syn_net/models/act.py rename to src/syn_net/models/act.py diff --git a/syn_net/models/mlp.py b/src/syn_net/models/mlp.py similarity index 100% rename from syn_net/models/mlp.py rename to src/syn_net/models/mlp.py diff --git a/syn_net/models/prepare_data.py b/src/syn_net/models/prepare_data.py similarity index 100% rename from syn_net/models/prepare_data.py rename to src/syn_net/models/prepare_data.py diff --git a/syn_net/models/rt1.py b/src/syn_net/models/rt1.py similarity index 100% rename from syn_net/models/rt1.py rename to src/syn_net/models/rt1.py diff --git a/syn_net/models/rt2.py b/src/syn_net/models/rt2.py similarity index 100% rename from syn_net/models/rt2.py rename to src/syn_net/models/rt2.py diff --git a/syn_net/models/rxn.py b/src/syn_net/models/rxn.py similarity index 100% rename from syn_net/models/rxn.py rename to src/syn_net/models/rxn.py diff --git a/syn_net/utils/__init__.py b/src/syn_net/utils/__init__.py similarity index 100% rename from syn_net/utils/__init__.py rename to src/syn_net/utils/__init__.py diff --git a/syn_net/utils/data_utils.py b/src/syn_net/utils/data_utils.py similarity index 100% rename from syn_net/utils/data_utils.py rename to src/syn_net/utils/data_utils.py diff --git a/syn_net/utils/ga_utils.py b/src/syn_net/utils/ga_utils.py similarity index 100% rename from syn_net/utils/ga_utils.py rename to src/syn_net/utils/ga_utils.py diff --git a/syn_net/utils/predict_beam_utils.py b/src/syn_net/utils/predict_beam_utils.py similarity index 100% rename from syn_net/utils/predict_beam_utils.py rename to src/syn_net/utils/predict_beam_utils.py diff --git a/syn_net/utils/predict_utils.py b/src/syn_net/utils/predict_utils.py similarity index 100% rename from syn_net/utils/predict_utils.py rename to src/syn_net/utils/predict_utils.py diff --git a/syn_net/utils/prep_utils.py b/src/syn_net/utils/prep_utils.py similarity index 100% rename from syn_net/utils/prep_utils.py rename to src/syn_net/utils/prep_utils.py From a287ba96f9aeb46f8ac057364d1349fc81b3e504 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Tue, 16 Aug 2022 12:46:23 -0400 Subject: [PATCH 011/302] silence rdlogger --- src/syn_net/data_generation/process_rxn_mp.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/syn_net/data_generation/process_rxn_mp.py b/src/syn_net/data_generation/process_rxn_mp.py index a57faf5f..90f738e2 100644 --- a/src/syn_net/data_generation/process_rxn_mp.py +++ b/src/syn_net/data_generation/process_rxn_mp.py @@ -10,8 +10,9 @@ from syn_net.utils.data_utils import Reaction, ReactionSet import syn_net.data_generation._mp_process as process -import shutup -shutup.please() +# Silence RDKit loggers (https://github.com/rdkit/rdkit/issues/2683) +from rdkit import RDLogger +RDLogger.DisableLog('rdApp.*') if __name__ == '__main__': From 00870132aa92ea671ccc851c1f2b9e9eefe7e318 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Tue, 16 Aug 2022 13:11:53 -0400 Subject: [PATCH 012/302] move paths to `config.py` --- src/syn_net/config.py | 8 ++++++++ src/syn_net/data_generation/process_rxn_mp.py | 19 ++++++++++++++----- 2 files changed, 22 insertions(+), 5 deletions(-) create mode 100644 src/syn_net/config.py diff --git a/src/syn_net/config.py b/src/syn_net/config.py new file mode 100644 index 00000000..da48b0e7 --- /dev/null +++ b/src/syn_net/config.py @@ -0,0 +1,8 @@ +"""Central place for all configuration, paths, and parameter.""" + +DATA_DIR = "database" +ASSETS_DIR = "database/assets" + +# +BUILDING_BLOCKS_RAW_DIR = f"{ASSETS_DIR}/building-blocks" +REACTION_TEMPLATE_DIR = f"{ASSETS_DIR}/reaction-templates" diff --git a/src/syn_net/data_generation/process_rxn_mp.py b/src/syn_net/data_generation/process_rxn_mp.py index 90f738e2..02d9f4a3 100644 --- a/src/syn_net/data_generation/process_rxn_mp.py +++ b/src/syn_net/data_generation/process_rxn_mp.py @@ -10,24 +10,33 @@ from syn_net.utils.data_utils import Reaction, ReactionSet import syn_net.data_generation._mp_process as process +from pathlib import Path # Silence RDKit loggers (https://github.com/rdkit/rdkit/issues/2683) from rdkit import RDLogger RDLogger.DisableLog('rdApp.*') +from syn_net.config import REACTION_TEMPLATE_DIR, DATA_DIR if __name__ == '__main__': - name = 'pis' - path_to_rxn_templates = '/home/whgao/scGen/synth_net/data/rxn_set_' + name + '.txt' + name = 'hb' # "pis" or "hb" + + # Load reaction templates and parse + path_to_rxn_templates = f'{REACTION_TEMPLATE_DIR}/{name}.txt' rxn_templates = [] for line in open(path_to_rxn_templates, 'rt'): - rxn = Reaction(line.split('|')[1].strip()) + template = line.split("|")[1].strip() # reaction templates are prefix with "|" + rxn = Reaction(template) rxn_templates.append(rxn) + # Filter building blocks on each reaction pool = mp.Pool(processes=64) - t = time() rxns = pool.map(process.func, rxn_templates) print('Time: ', time() - t, 's') + # Save data to local disk r = ReactionSet(rxns) - r.save('/pool001/whgao/data/synth_net/st_pis/reactions_' + name + '.json.gz') + out_dir = Path(DATA_DIR) / f"pre-process/" + out_dir.mkdir(exist_ok=True, parents=True) + out_file = out_dir / f"st_{name}.json.gz" + r.save(out_file) From d804bac0a3c3b354410f7068d0311de2817a4f74 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Tue, 16 Aug 2022 13:12:48 -0400 Subject: [PATCH 013/302] remove "|" prefix in the reaction templates asset --- src/syn_net/data_generation/process_rxn_mp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/syn_net/data_generation/process_rxn_mp.py b/src/syn_net/data_generation/process_rxn_mp.py index 02d9f4a3..835e74d5 100644 --- a/src/syn_net/data_generation/process_rxn_mp.py +++ b/src/syn_net/data_generation/process_rxn_mp.py @@ -24,7 +24,7 @@ path_to_rxn_templates = f'{REACTION_TEMPLATE_DIR}/{name}.txt' rxn_templates = [] for line in open(path_to_rxn_templates, 'rt'): - template = line.split("|")[1].strip() # reaction templates are prefix with "|" + template = line.strip() rxn = Reaction(template) rxn_templates.append(rxn) From 7701b24f5ba28c88743ba8d6d0df7bf04edc9ecb Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Tue, 16 Aug 2022 14:52:18 -0400 Subject: [PATCH 014/302] refactor: consolidate reaction pre-process code - move into a single function - remove hard-coded paths --- src/syn_net/config.py | 3 + src/syn_net/data_generation/_mp_process.py | 13 ---- src/syn_net/data_generation/process_rxn_mp.py | 60 +++++++++++++------ 3 files changed, 44 insertions(+), 32 deletions(-) delete mode 100644 src/syn_net/data_generation/_mp_process.py diff --git a/src/syn_net/config.py b/src/syn_net/config.py index da48b0e7..740829b9 100644 --- a/src/syn_net/config.py +++ b/src/syn_net/config.py @@ -6,3 +6,6 @@ # BUILDING_BLOCKS_RAW_DIR = f"{ASSETS_DIR}/building-blocks" REACTION_TEMPLATE_DIR = f"{ASSETS_DIR}/reaction-templates" + +# Pre-processed data +DATA_PREPROCESS_DIR = "database/pre-process" diff --git a/src/syn_net/data_generation/_mp_process.py b/src/syn_net/data_generation/_mp_process.py deleted file mode 100644 index 21947daf..00000000 --- a/src/syn_net/data_generation/_mp_process.py +++ /dev/null @@ -1,13 +0,0 @@ -""" -This file contains a function for search available building blocks -for a matching reaction template. Prepared for multiprocessing. -""" -import pandas as pd - -path_to_building_blocks = '/home/whgao/scGen/synth_net/data/enamine_us.csv.gz' -building_blocks = pd.read_csv(path_to_building_blocks, compression='gzip')['SMILES'].tolist() -print('Finish reading the building blocks list!') - -def func(rxn_): - rxn_.set_available_reactants(building_blocks) - return rxn_ diff --git a/src/syn_net/data_generation/process_rxn_mp.py b/src/syn_net/data_generation/process_rxn_mp.py index 835e74d5..0c190f8e 100644 --- a/src/syn_net/data_generation/process_rxn_mp.py +++ b/src/syn_net/data_generation/process_rxn_mp.py @@ -3,40 +3,62 @@ reactants from a list of purchasable building blocks. Usage: - python process_rxn_mp.py + python process__rxnmp.py """ import multiprocessing as mp +from functools import partial +from pathlib import Path from time import time -from syn_net.utils.data_utils import Reaction, ReactionSet -import syn_net.data_generation._mp_process as process -from pathlib import Path # Silence RDKit loggers (https://github.com/rdkit/rdkit/issues/2683) -from rdkit import RDLogger -RDLogger.DisableLog('rdApp.*') +from rdkit import RDLogger + +from syn_net.utils.data_utils import Reaction, ReactionSet + +RDLogger.DisableLog("rdApp.*") + + +import pandas as pd + + +def _load_building_blocks(file: Path) -> list[str]: + return pd.read_csv(file)["SMILES"].to_list() + + +def _match_building_blocks_to_rxn(building_blocks: list[str], _rxn: Reaction): + _rxn.set_available_reactants(building_blocks) + return _rxn + + +from syn_net.config import (BUILDING_BLOCKS_RAW_DIR, DATA_PREPROCESS_DIR, + REACTION_TEMPLATE_DIR) -from syn_net.config import REACTION_TEMPLATE_DIR, DATA_DIR +if __name__ == "__main__": + reaction_template_id = "hb" # "pis" or "hb" + building_blocks_id = "enamine_us-2021-smiles" -if __name__ == '__main__': - name = 'hb' # "pis" or "hb" + # Load building blocks + building_blocks_file = Path(BUILDING_BLOCKS_RAW_DIR) / f"{building_blocks_id}.csv.gz" + building_blocks = _load_building_blocks(building_blocks_file) # Load reaction templates and parse - path_to_rxn_templates = f'{REACTION_TEMPLATE_DIR}/{name}.txt' - rxn_templates = [] - for line in open(path_to_rxn_templates, 'rt'): - template = line.strip() + path_to__rxntemplates = Path(REACTION_TEMPLATE_DIR) / f"{reaction_template_id}.txt" + _rxntemplates = [] + for line in open(path_to__rxntemplates, "rt"): + template = line.strip() rxn = Reaction(template) - rxn_templates.append(rxn) + _rxntemplates.append(rxn) # Filter building blocks on each reaction - pool = mp.Pool(processes=64) t = time() - rxns = pool.map(process.func, rxn_templates) - print('Time: ', time() - t, 's') + func = partial(_match_building_blocks_to_rxn, building_blocks) + with mp.Pool(processes=64) as pool: + rxns = pool.map(func, _rxntemplates) + print("Time: ", time() - t, "s") # Save data to local disk r = ReactionSet(rxns) - out_dir = Path(DATA_DIR) / f"pre-process/" + out_dir = Path(DATA_PREPROCESS_DIR) out_dir.mkdir(exist_ok=True, parents=True) - out_file = out_dir / f"st_{name}.json.gz" + out_file = out_dir / f"reaction-sets_{reaction_template_id}_{building_blocks_id}.json.gz" r.save(out_file) From ed35aee8e1cae18588e753add44e1d6cf4864703 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Tue, 16 Aug 2022 15:29:40 -0400 Subject: [PATCH 015/302] refactor: remove hard-coded paths pre-process code --- src/syn_net/data_generation/filter_unmatch.py | 41 ++++++++++++++----- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/src/syn_net/data_generation/filter_unmatch.py b/src/syn_net/data_generation/filter_unmatch.py index a81db3a6..3dd2b939 100644 --- a/src/syn_net/data_generation/filter_unmatch.py +++ b/src/syn_net/data_generation/filter_unmatch.py @@ -4,22 +4,41 @@ from syn_net.utils.data_utils import * import pandas as pd from tqdm import tqdm +from pathlib import Path +from syn_net.data_generation.process_rxn_mp import _load_building_blocks # TODO: refactor +from syn_net.config import BUILDING_BLOCKS_RAW_DIR, DATA_PREPROCESS_DIR +import logging +logger = logging.getLogger(__name__) if __name__ == '__main__': - r_path = '/pool001/whgao/data/synth_net/st_pis/reactions_pis.json.gz' - bb_path = '/home/whgao/scGen/synth_net/data/enamine_us.csv.gz' + reaction_template_id = "hb" # "pis" or "hb" + building_blocks_id = "enamine_us-2021-smiles" + + # Load building blocks + building_blocks_file = Path(BUILDING_BLOCKS_RAW_DIR) / f"{building_blocks_id}.csv.gz" + building_blocks = _load_building_blocks(building_blocks_file) + + + # Load genearted reactions (matched reactions <=> building blocks) + reactions_dir = Path(DATA_PREPROCESS_DIR) + reactions_file = f"reaction-sets_{reaction_template_id}_{building_blocks_id}.json.gz" r_set = ReactionSet() - r_set.load(r_path) - matched_mols = set() + r_set.load(reactions_dir / reactions_file) + + # Identify all used building blocks (via union of sets) + matched_bblocks = set() for r in tqdm(r_set.rxns): - for a_list in r.available_reactants: - matched_mols = matched_mols | set(a_list) + for reactants in r.available_reactants: + matched_bblocks = matched_bblocks.union(set(reactants)) - original_mols = pd.read_csv(bb_path, compression='gzip')['SMILES'].tolist() - print('Total building blocks number:', len(original_mols)) - print('Matched building blocks number:', len(matched_mols)) + logger.info(f'Total number of building blocks {len(building_blocks):d}') + logger.info(f'Matched number of building blocks {len(matched_bblocks):d}') + logger.info(f"{len(matched_bblocks)/len(building_blocks):.2%} of building blocks are applicable for the reaction template set '{reaction_template_id}'.") - df = pd.DataFrame({'SMILES': list(matched_mols)}) - df.to_csv('/pool001/whgao/data/synth_net/st_pis/enamine_us_matched.csv.gz', compression='gzip') + # Save to local disk + df = pd.DataFrame({'SMILES': list(matched_bblocks)}) + outfile = f"{reaction_template_id}-{building_blocks_id}-matched.csv.gz" + file = Path(DATA_PREPROCESS_DIR) / outfile + df.to_csv(file, compression='gzip') From d73b03e9c273595255b8b851f4fb150a05069fde Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Tue, 16 Aug 2022 18:02:23 -0400 Subject: [PATCH 016/302] refactor: consolidate data generation pre process code - move into a single function - remove hard-coded paths - use logging --- src/syn_net/data_generation/_mp_make.py | 29 ---------- .../data_generation/make_dataset_mp.py | 54 ++++++++++++++----- 2 files changed, 40 insertions(+), 43 deletions(-) delete mode 100644 src/syn_net/data_generation/_mp_make.py diff --git a/src/syn_net/data_generation/_mp_make.py b/src/syn_net/data_generation/_mp_make.py deleted file mode 100644 index badd519c..00000000 --- a/src/syn_net/data_generation/_mp_make.py +++ /dev/null @@ -1,29 +0,0 @@ -""" -This file contains a function to generate a single synthetic tree, prepared for -multiprocessing. -""" -import pandas as pd -import numpy as np -# import dill as pickle -# import gzip - -from syn_net.data_generation.make_dataset import synthetic_tree_generator -from syn_net.utils.data_utils import ReactionSet - - -path_reaction_file = '/pool001/whgao/data/synth_net/st_pis/reactions_pis.json.gz' -path_to_building_blocks = '/pool001/whgao/data/synth_net/st_pis/enamine_us_matched.csv.gz' - -building_blocks = pd.read_csv(path_to_building_blocks, compression='gzip')['SMILES'].tolist() -r_set = ReactionSet() -r_set.load(path_reaction_file) -rxns = r_set.rxns -# with gzip.open(path_reaction_file, 'rb') as f: -# rxns = pickle.load(f) - -print('Finish reading the templates and building blocks list!') - -def func(_): - np.random.seed(_) - tree, action = synthetic_tree_generator(building_blocks, rxns, max_step=15) - return tree, action diff --git a/src/syn_net/data_generation/make_dataset_mp.py b/src/syn_net/data_generation/make_dataset_mp.py index b3d596b4..3c84737d 100644 --- a/src/syn_net/data_generation/make_dataset_mp.py +++ b/src/syn_net/data_generation/make_dataset_mp.py @@ -4,24 +4,47 @@ Usage: python make_dataset_mp.py """ -import numpy as np import multiprocessing as mp -from time import time -from syn_net.utils.data_utils import SyntheticTreeSet -import syn_net.data_generation._mp_make as make +import numpy as np +from pathlib import Path +from syn_net.data_generation.make_dataset import synthetic_tree_generator +from syn_net.utils.data_utils import ReactionSet, SyntheticTreeSet +from syn_net.data_generation.process_rxn_mp import _load_building_blocks # TODO: refactor +from syn_net.config import BUILDING_BLOCKS_RAW_DIR, DATA_PREPROCESS_DIR +import logging + +logger = logging.getLogger(__name__) + + +def func(_x): + np.random.seed(_x) # dummy input to generate "unique" seed + tree, action = synthetic_tree_generator(building_blocks, rxns) + return tree, action if __name__ == '__main__': - pool = mp.Pool(processes=100) + reaction_template_id = "hb" # "pis" or "hb" + building_blocks_id = "enamine_us-2021-smiles" + NUM_TREES = 600_000 + + # Load building blocks + building_blocks_file = Path(BUILDING_BLOCKS_RAW_DIR) / f"{building_blocks_id}.csv.gz" + building_blocks = _load_building_blocks(building_blocks_file) - NUM_TREES = 600000 + # Load genearted reactions (matched reactions <=> building blocks) + reactions_dir = Path(DATA_PREPROCESS_DIR) + reactions_file = f"reaction-sets_{reaction_template_id}_{building_blocks_id}.json.gz" + r_set = ReactionSet() + r_set.load(reactions_dir / reactions_file) + rxns = r_set.rxns - t = time() - results = pool.map(make.func, np.arange(NUM_TREES).tolist()) - print('Time: ', time() - t, 's') + # Generate synthetic trees + with mp.Pool(processes=64) as pool: + results = pool.map(func, np.arange(NUM_TREES).tolist()) + # Filter out trees that were completed with action="end" trees = [r[0] for r in results if r[1] == 3] actions = [r[1] for r in results] @@ -29,10 +52,13 @@ num_error = actions.count(-1) num_unfinish = NUM_TREES - num_finish - num_error - print('Total trial: ', NUM_TREES) - print('num of finished trees: ', num_finish) - print('num of unfinished tree: ', num_unfinish) - print('num of error processes: ', num_error) + logging.info(f"Total trial {NUM_TREES}") + logging.info(f"Number of finished trees: {num_finish}") + logging.info(f"Number of of unfinished tree: {num_unfinish}") + logging.info(f"Number of error processes: {num_error}") + # Save to local disk tree_set = SyntheticTreeSet(trees) - tree_set.save('/pool001/whgao/data/synth_net/st_pis/st_data.json.gz') + outfile = f"synthetic-trees_{reaction_template_id}-{building_blocks_id}.json.gz" + file = Path(DATA_PREPROCESS_DIR) / outfile + tree_set.save(file) From 84ac835d4fefb2da04620446f7798a14420df6b3 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Tue, 16 Aug 2022 19:58:18 -0400 Subject: [PATCH 017/302] refactor: remove hard-coded paths & cosmetic changes --- scripts/sample_from_original.py | 95 +++++++++++++++++---------------- 1 file changed, 49 insertions(+), 46 deletions(-) diff --git a/scripts/sample_from_original.py b/scripts/sample_from_original.py index 01f79e78..9d32e79e 100644 --- a/scripts/sample_from_original.py +++ b/scripts/sample_from_original.py @@ -1,72 +1,75 @@ """ Filters the synthetic trees by the QEDs of the root molecules. """ -from tdc import Oracle -qed = Oracle(name='qed') +from pathlib import Path + import numpy as np import pandas as pd -from syn_net.utils.data_utils import * +from rdkit import Chem +from tdc import Oracle +from tqdm import tqdm -def is_valid(smi): - """ - Checks if a SMILES string is valid. +from syn_net.config import DATA_PREPROCESS_DIR +from syn_net.utils.data_utils import SyntheticTree, SyntheticTreeSet - Args: - smi (str): Molecular SMILES string. +DATA_DIR = "pool001/whgao/data/synth_net" +SYNTHETIC_TREES_FILE = "abc-st_data.json.gz" - Returns: - False or str: False if the SMILES is not valid, or the reconverted - SMILES string. - """ - mol = Chem.MolFromSmiles(smi) - if mol is None: - return False - else: - return Chem.MolToSmiles(mol, isomericSmiles=False) +def _is_valid_mol(mol: Chem.rdchem.Mol): + return mol is not None if __name__ == '__main__': + reaction_template_id = "hb" # "pis" or "hb" + building_blocks_id = "enamine_us-2021-smiles" + qed = Oracle(name='qed') - data_path = '/pool001/whgao/data/synth_net/st_pis/st_data.json.gz' + # Load generated synthetic trees + file = Path(DATA_PREPROCESS_DIR) / f"synthetic-trees_{reaction_template_id}-{building_blocks_id}.json.gz" st_set = SyntheticTreeSet() - st_set.load(data_path) - data = st_set.sts - print(f'Finish reading, in total {len(data)} synthetic trees.') + st_set.load(file) + synthetic_trees = st_set.sts + print(f'Finish reading, in total {len(synthetic_trees)} synthetic trees.') - filtered_data = [] - original_qed = [] - qeds = [] - generated_smiles = [] + # Filter synthetic trees + # .. based on validity of root molecule + # .. based on drug-like quality + filtered_data: list[SyntheticTree] = [] + original_qed: list[float] = [] + qeds: list[float] = [] + generated_smiles: list[str] = [] threshold = 0.5 - for t in tqdm(data): + for t in tqdm(synthetic_trees): try: - valid_smiles = is_valid(t.root.smiles) - if valid_smiles: - if valid_smiles in generated_smiles: - pass - else: - qed_value = qed(valid_smiles) - original_qed.append(qed_value) + smiles = t.root.smiles + mol = Chem.MolFromSmiles(smiles) + if not _is_valid_mol(mol): + continue + if smiles in generated_smiles: + continue + + qed_value = qed(smiles) + original_qed.append(qed_value) - # filter the trees based on their QEDs - if qed_value > threshold or np.random.random() < (qed_value/threshold): - generated_smiles.append(valid_smiles) - filtered_data.append(t) - qeds.append(qed_value) - else: - pass - else: - pass - except: - pass + # filter the trees based on their QEDs + if qed_value > threshold or np.random.random() < (qed_value/threshold): + generated_smiles.append(smiles) + filtered_data.append(t) + qeds.append(qed_value) + + except Exception as e: + print(e) print(f'Finish sampling, remaining {len(filtered_data)} synthetic trees.') + # Save to local disk st_set = SyntheticTreeSet(filtered_data) - st_set.save('/pool001/whgao/data/synth_net/st_pis/st_data_filtered.json.gz') + file = Path(DATA_PREPROCESS_DIR) / f"synthetic-trees_{reaction_template_id}-{building_blocks_id}-filtered.json.gz" + st_set.save(file) df = pd.DataFrame({'SMILES': generated_smiles, 'qed': qeds}) - df.to_csv('/pool001/whgao/data/synth_net/st_pis/filtered_smiles.csv.gz', compression='gzip', index=False) + file = Path(DATA_PREPROCESS_DIR) / f"filtered-smiles_{reaction_template_id}-{building_blocks_id}-filtered.csv.gz" + df.to_csv(file, compression='gzip', index=False) print('Finish!') From b98a6fb804be717d466a83f8fab3df9120f958fb Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Thu, 18 Aug 2022 10:32:37 -0400 Subject: [PATCH 018/302] refactor: remove hard-coded paths & cosmetic changes --- scripts/st_split.py | 37 ++++++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/scripts/st_split.py b/scripts/st_split.py index 60497304..5a4f2d75 100644 --- a/scripts/st_split.py +++ b/scripts/st_split.py @@ -1,40 +1,47 @@ """ Reads synthetic tree data and splits it into training, validation and testing sets. """ -from syn_net.utils.data_utils import * - +from syn_net.utils.data_utils import SyntheticTreeSet +from pathlib import Path +from syn_net.config import DATA_PREPROCESS_DIR, DATA_PREPARED_DIR if __name__ == "__main__": + reaction_template_id = "hb" # "pis" or "hb" + building_blocks_id = "enamine_us-2021-smiles" + # Load filtered synthetic trees st_set = SyntheticTreeSet() - path_to_data = '/pool001/whgao/data/synth_net/st_pis/st_data_filtered.json.gz' - print('Reading data from ', path_to_data) - st_set.load(path_to_data) + file = Path(DATA_PREPROCESS_DIR) / f"synthetic-trees_{reaction_template_id}-{building_blocks_id}-filtered.json.gz" + print(f'Reading data from {file}') + st_set.load(file) data = st_set.sts del st_set num_total = len(data) - print("In total we have: ", num_total, "paths.") + print(f"There are {len(data)} synthetic trees.") - split_ratio = [0.6, 0.2, 0.2] + # Split data + SPLIT_RATIO = [0.6, 0.2, 0.2] - num_train = int(split_ratio[0] * num_total) - num_valid = int(split_ratio[1] * num_total) + num_train = int(SPLIT_RATIO[0] * num_total) + num_valid = int(SPLIT_RATIO[1] * num_total) num_test = num_total - num_train - num_valid data_train = data[:num_train] data_valid = data[num_train: num_train + num_valid] data_test = data[num_train + num_valid: ] + # Save to local disk + print("Saving training dataset: ", len(data_train)) - tree_set = SyntheticTreeSet(data_train) - tree_set.save('/pool001/whgao/data/synth_net/st_pis/st_train.json.gz') + trees = SyntheticTreeSet(data_train) + trees.save(f'{DATA_PREPARED_DIR}/synthetic-trees-train.json.gz') print("Saving validation dataset: ", len(data_valid)) - tree_set = SyntheticTreeSet(data_valid) - tree_set.save('/pool001/whgao/data/synth_net/st_pis/st_valid.json.gz') + trees = SyntheticTreeSet(data_valid) + trees.save(f'{DATA_PREPARED_DIR}/synthetic-trees-valid.json.gz') print("Saving testing dataset: ", len(data_test)) - tree_set = SyntheticTreeSet(data_test) - tree_set.save('/pool001/whgao/data/synth_net/st_pis/st_test.json.gz') + trees = SyntheticTreeSet(data_test) + trees.save(f'{DATA_PREPARED_DIR}/synthetic-trees-test.json.gz') print("Finish!") From 2ff4c5ee135320b1e1091ab5cedc56306231d4de Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Thu, 18 Aug 2022 11:16:46 -0400 Subject: [PATCH 019/302] refactor: remove hard-coded paths & cosmetic changes --- scripts/st2steps.py | 52 +++++++++++++++++++++++++------------------ src/syn_net/config.py | 6 +++++ 2 files changed, 36 insertions(+), 22 deletions(-) diff --git a/scripts/st2steps.py b/scripts/st2steps.py index 892ae30e..ad395624 100644 --- a/scripts/st2steps.py +++ b/scripts/st2steps.py @@ -1,13 +1,13 @@ """ Splits a synthetic tree into states and steps. """ -import os +from pathlib import Path from tqdm import tqdm from scipy import sparse -from syn_net.utils.data_utils import * +from syn_net.utils.data_utils import SyntheticTreeSet from syn_net.utils.prep_utils import organize - +from syn_net.config import DATA_PREPARED_DIR, DATA_FEATURIZED_DIR if __name__ == '__main__': @@ -15,34 +15,41 @@ parser = argparse.ArgumentParser() parser.add_argument("-n", "--numbersave", type=int, default=999999999999, help="Save number") - parser.add_argument("-v", "--verbose", action="store_true", default=False, - help="Increase output verbosity") parser.add_argument("-e", "--targetembedding", type=str, default='fp', help="Choose from ['fp', 'gin']") - parser.add_argument("-o", "--outputembedding", type=str, default='gin', + parser.add_argument("-o", "--outputembedding", type=str, default='fp_256', help="Choose from ['fp_4096', 'fp_256', 'gin', 'rdkit2d']") parser.add_argument("-r", "--radius", type=int, default=2, help="Radius for Morgan Fingerprint") parser.add_argument("-b", "--nbits", type=int, default=4096, help="Number of Bits for Morgan Fingerprint") - parser.add_argument("-d", "--datasettype", type=str, default='train', + parser.add_argument("-d", "--datasettype", type=str, choices=["train","valid","test"], help="Choose from ['train', 'valid', 'test']") - parser.add_argument("-r", "--rxn_template", type=str, default='hb', + parser.add_argument("-rxn", "--rxn_template", type=str, default='hb', choices=["hb","pis"], help="Choose from ['hb', 'pis']") args = parser.parse_args() + # Parse & set inputs + reaction_template_id = args.rxn_template + building_blocks_id = "enamine_us-2021-smiles" dataset_type = args.datasettype embedding = args.targetembedding - path_st = '/pool001/whgao/data/synth_net/st_hb/st_' + dataset_type + '.json.gz' - save_dir = '/pool001/whgao/data/synth_net/hb_' + embedding + '_' + str(args.radius) + '_' + str(args.nbits) + '_' + str(args.outputembedding) + '/' + assert dataset_type is not None, "Must specify which dataset to use." + # Load synthetic trees subset {train,valid,test} + file = f'{DATA_PREPARED_DIR}/synthetic-trees-{dataset_type}.json.gz' st_set = SyntheticTreeSet() - st_set.load(path_st) + st_set.load(file) print('Original length: ', len(st_set.sts)) - data = st_set.sts + data: list = st_set.sts del st_set print('Working length: ', len(data)) + + # Set output directory + save_dir = Path(DATA_FEATURIZED_DIR) / f'{reaction_template_id}_{embedding}_{args.radius}_{args.nbits}_{args.outputembedding}/' + Path(save_dir).mkdir(parents=1,exist_ok=1) + # Start splitting synthetic trees in states and steps states = [] steps = [] @@ -51,7 +58,10 @@ save_idx = 0 for st in tqdm(data): try: - state, step = organize(st, target_embedding=embedding, radius=args.radius, nBits=args.nbits, output_embedding=args.outputembedding) + state, step = organize(st, target_embedding=embedding, + radius=args.radius, + nBits=args.nbits, + output_embedding=args.outputembedding) except Exception as e: print(e) continue @@ -62,8 +72,8 @@ print('Saving......') states = sparse.vstack(states) steps = sparse.vstack(steps) - sparse.save_npz(save_dir + 'states_' + str(save_idx) + '_' + dataset_type + '.npz', states) - sparse.save_npz(save_dir + 'steps_' + str(save_idx) + '_' + dataset_type + '.npz', steps) + sparse.save_npz(save_dir / f"states_{save_idx}_{dataset_type}.npz", states) + sparse.save_npz(save_dir / f"steps_{save_idx}_{dataset_type}.npz", steps) save_idx += 1 del states del steps @@ -72,15 +82,13 @@ del data + # Finally, save again. (Potentially overwrite existing files) if len(steps) != 0: + print('Saving......') states = sparse.vstack(states) steps = sparse.vstack(steps) - - print('Saving......') - if not os.path.exists(save_dir): - os.makedirs(save_dir) - - sparse.save_npz(save_dir + 'states_' + str(save_idx) + '_' + dataset_type + '.npz', states) - sparse.save_npz(save_dir + 'steps_' + str(save_idx) + '_' + dataset_type + '.npz', steps) + sparse.save_npz(save_dir / f"states_{save_idx}_{dataset_type}.npz", states) + sparse.save_npz(save_dir / f"steps_{save_idx}_{dataset_type}.npz", steps) print('Finish!') + diff --git a/src/syn_net/config.py b/src/syn_net/config.py index 740829b9..1438a155 100644 --- a/src/syn_net/config.py +++ b/src/syn_net/config.py @@ -9,3 +9,9 @@ # Pre-processed data DATA_PREPROCESS_DIR = "database/pre-process" + +# Prepared data +DATA_PREPARED_DIR = "database/prepared" + +# Prepared data +DATA_FEATURIZED_DIR = "database/featurized" \ No newline at end of file From 7254d7dbb794ec5c3afd289205881ac43eaeca68 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Thu, 18 Aug 2022 11:18:37 -0400 Subject: [PATCH 020/302] remove code that allows saving at an interval --- scripts/st2steps.py | 30 ++++++------------------------ 1 file changed, 6 insertions(+), 24 deletions(-) diff --git a/scripts/st2steps.py b/scripts/st2steps.py index ad395624..9db766b5 100644 --- a/scripts/st2steps.py +++ b/scripts/st2steps.py @@ -13,8 +13,6 @@ import argparse parser = argparse.ArgumentParser() - parser.add_argument("-n", "--numbersave", type=int, default=999999999999, - help="Save number") parser.add_argument("-e", "--targetembedding", type=str, default='fp', help="Choose from ['fp', 'gin']") parser.add_argument("-o", "--outputembedding", type=str, default='fp_256', @@ -54,8 +52,6 @@ steps = [] num_save = args.numbersave - idx = 0 - save_idx = 0 for st in tqdm(data): try: state, step = organize(st, target_embedding=embedding, @@ -67,28 +63,14 @@ continue states.append(state) steps.append(step) - idx += 1 - if idx % num_save == 0: - print('Saving......') - states = sparse.vstack(states) - steps = sparse.vstack(steps) - sparse.save_npz(save_dir / f"states_{save_idx}_{dataset_type}.npz", states) - sparse.save_npz(save_dir / f"steps_{save_idx}_{dataset_type}.npz", steps) - save_idx += 1 - del states - del steps - states = [] - steps = [] - del data - # Finally, save again. (Potentially overwrite existing files) - if len(steps) != 0: - print('Saving......') - states = sparse.vstack(states) - steps = sparse.vstack(steps) - sparse.save_npz(save_dir / f"states_{save_idx}_{dataset_type}.npz", states) - sparse.save_npz(save_dir / f"steps_{save_idx}_{dataset_type}.npz", steps) + # Finally, save. + print('Saving......') + states = sparse.vstack(states) + steps = sparse.vstack(steps) + sparse.save_npz(save_dir / f"states_{dataset_type}.npz", states) + sparse.save_npz(save_dir / f"steps_{dataset_type}.npz", steps) print('Finish!') From b492cddf70a848eaaa94dcd5ab6a5e090f9682d6 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Thu, 18 Aug 2022 11:34:39 -0400 Subject: [PATCH 021/302] use logging instead of print --- scripts/st2steps.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/scripts/st2steps.py b/scripts/st2steps.py index 9db766b5..eb07ee62 100644 --- a/scripts/st2steps.py +++ b/scripts/st2steps.py @@ -7,6 +7,9 @@ from syn_net.utils.data_utils import SyntheticTreeSet from syn_net.utils.prep_utils import organize +import logging +logger = logging.getLogger(__file__) + from syn_net.config import DATA_PREPARED_DIR, DATA_FEATURIZED_DIR if __name__ == '__main__': @@ -26,6 +29,7 @@ parser.add_argument("-rxn", "--rxn_template", type=str, default='hb', choices=["hb","pis"], help="Choose from ['hb', 'pis']") args = parser.parse_args() + logger.info(vars(args)) # Parse & set inputs reaction_template_id = args.rxn_template @@ -38,10 +42,9 @@ file = f'{DATA_PREPARED_DIR}/synthetic-trees-{dataset_type}.json.gz' st_set = SyntheticTreeSet() st_set.load(file) - print('Original length: ', len(st_set.sts)) + logger.info("Number of synthetic trees: {len(st_set.sts}") data: list = st_set.sts del st_set - print('Working length: ', len(data)) # Set output directory save_dir = Path(DATA_FEATURIZED_DIR) / f'{reaction_template_id}_{embedding}_{args.radius}_{args.nbits}_{args.outputembedding}/' @@ -51,7 +54,6 @@ states = [] steps = [] - num_save = args.numbersave for st in tqdm(data): try: state, step = organize(st, target_embedding=embedding, @@ -59,18 +61,18 @@ nBits=args.nbits, output_embedding=args.outputembedding) except Exception as e: - print(e) + logger.exception(exc_info=e) continue states.append(state) steps.append(step) # Finally, save. - print('Saving......') + logger.info(f"Saving to {save_dir}") states = sparse.vstack(states) steps = sparse.vstack(steps) sparse.save_npz(save_dir / f"states_{dataset_type}.npz", states) sparse.save_npz(save_dir / f"steps_{dataset_type}.npz", steps) - print('Finish!') + logger.info("Save successful.") From 87ade74d2c5f368ad2b188a817dad4b47181551a Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Thu, 18 Aug 2022 11:48:52 -0400 Subject: [PATCH 022/302] change default: use `fp_256` as output embedding --- src/syn_net/models/prepare_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/syn_net/models/prepare_data.py b/src/syn_net/models/prepare_data.py index 04400c2d..2a1d2bf7 100644 --- a/src/syn_net/models/prepare_data.py +++ b/src/syn_net/models/prepare_data.py @@ -18,7 +18,7 @@ help="Radius for Morgan fingerprint.") parser.add_argument("--nbits", type=int, default=4096, help="Number of Bits for Morgan fingerprint.") - parser.add_argument("--outputembedding", type=str, default='gin', + parser.add_argument("--outputembedding", type=str, default='fp_256', help="Choose from ['fp_4096', 'fp_256', 'gin', 'rdkit2d']") args = parser.parse_args() rxn_template = args.rxn_template From 39b41f4e6085a3ccca0338b9b8ba25d578c53990 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Thu, 18 Aug 2022 12:05:29 -0400 Subject: [PATCH 023/302] remove code that allows saving at an interval (see 7254d7d) --- src/syn_net/utils/prep_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/syn_net/utils/prep_utils.py b/src/syn_net/utils/prep_utils.py index 9839bfe3..a689bc2c 100644 --- a/src/syn_net/utils/prep_utils.py +++ b/src/syn_net/utils/prep_utils.py @@ -251,9 +251,9 @@ def prep_data(main_dir, num_rxn, out_dim): print('Reading ' + dataset + ' data ......') states_list = [] steps_list = [] - for i in range(1): - states_list.append(sparse.load_npz(f'{main_dir}states_{i}_{dataset}.npz')) - steps_list.append(sparse.load_npz(f'{main_dir}steps_{i}_{dataset}.npz')) + + states_list.append(sparse.load_npz(f'{main_dir}states_{dataset}.npz')) + steps_list.append(sparse.load_npz(f'{main_dir}steps_{dataset}.npz')) states = sparse.csc_matrix(sparse.vstack(states_list)) steps = sparse.csc_matrix(sparse.vstack(steps_list)) From 66e6227763d3080a1afa2514852aaa3543480bb6 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Thu, 18 Aug 2022 13:50:25 -0400 Subject: [PATCH 024/302] refactor: remove hard-coded paths & cosmetic changes --- src/syn_net/models/prepare_data.py | 55 ++++++++++++++++-------------- src/syn_net/utils/prep_utils.py | 30 +++++++++------- 2 files changed, 47 insertions(+), 38 deletions(-) diff --git a/src/syn_net/models/prepare_data.py b/src/syn_net/models/prepare_data.py index 2a1d2bf7..0e0644a6 100644 --- a/src/syn_net/models/prepare_data.py +++ b/src/syn_net/models/prepare_data.py @@ -4,43 +4,48 @@ Action, Reactant 1, Reactant 2, and Reaction files. """ from syn_net.utils.prep_utils import prep_data - +from syn_net.config import DATA_FEATURIZED_DIR +from pathlib import Path +import logging +logger = logging.getLogger(__file__) if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() - parser.add_argument("-f", "--featurize", type=str, default='fp', + parser.add_argument("-e", "--targetembedding", type=str, default='fp', help="Choose from ['fp', 'gin']") - parser.add_argument("-r", "--rxn_template", type=str, default='hb', - help="Choose from ['hb', 'pis']") - parser.add_argument("--radius", type=int, default=2, - help="Radius for Morgan fingerprint.") - parser.add_argument("--nbits", type=int, default=4096, - help="Number of Bits for Morgan fingerprint.") - parser.add_argument("--outputembedding", type=str, default='fp_256', + parser.add_argument("-o", "--outputembedding", type=str, default='fp_256', help="Choose from ['fp_4096', 'fp_256', 'gin', 'rdkit2d']") + parser.add_argument("-r", "--radius", type=int, default=2, + help="Radius for Morgan Fingerprint") + parser.add_argument("-b", "--nbits", type=int, default=4096, + help="Number of Bits for Morgan Fingerprint") + parser.add_argument("-rxn", "--rxn_template", type=str, default='hb', choices=["hb","pis"], + help="Choose from ['hb', 'pis']") + args = parser.parse_args() - rxn_template = args.rxn_template - featurize = args.featurize + reaction_template_id = args.rxn_template + embedding = args.targetembedding output_emb = args.outputembedding - main_dir = '/pool001/whgao/data/synth_net/' + rxn_template + '_' + featurize + '_' + str(args.radius) + '_' + str(args.nbits) + '_' + str(args.outputembedding) + '/' - if rxn_template == 'hb': + main_dir = Path(DATA_FEATURIZED_DIR) / f'{reaction_template_id}_{embedding}_{args.radius}_{args.nbits}_{args.outputembedding}/' # must match with dir in `st2steps.py` + if reaction_template_id == 'hb': num_rxn = 91 - elif rxn_template == 'pis': + elif reaction_template_id == 'pis': num_rxn = 4700 - if output_emb == 'gin': - out_dim = 300 - elif output_emb == 'rdkit2d': - out_dim = 200 - elif output_emb == 'fp_4096': - out_dim = 4096 - elif output_emb == 'fp_256': - out_dim = 256 - - prep_data(main_dir=main_dir, out_dim=out_dim) + # Get dimension of output embedding + OUTPUT_EMBEDDINGS = { + "gin": 300, + "fp_4096": 4096, + "fp_256": 256, + "rdkit2d": 200, + } + out_dim = OUTPUT_EMBEDDINGS[output_emb] + logger.info("Start splitting data.") + # Split datasets for each MLP + prep_data(main_dir, num_rxn, out_dim) - print('Finish!') + logger.info("Successfully splitted data.") diff --git a/src/syn_net/utils/prep_utils.py b/src/syn_net/utils/prep_utils.py index a689bc2c..fe4fde82 100644 --- a/src/syn_net/utils/prep_utils.py +++ b/src/syn_net/utils/prep_utils.py @@ -10,7 +10,7 @@ from syn_net.utils.predict_utils import (can_react, get_action_mask, get_reaction_mask, mol_fp, get_mol_embedding) - +from pathlib import Path def rdkit2d_embedding(smi): """ @@ -245,15 +245,15 @@ def prep_data(main_dir, num_rxn, out_dim): num_rxn (int): Number of reactions in the dataset. out_dim (int): Size of the output feature vectors. """ - + main_dir = Path(main_dir) for dataset in ['train', 'valid', 'test']: - print('Reading ' + dataset + ' data ......') + print(f'Reading {dataset} data ...') states_list = [] steps_list = [] - states_list.append(sparse.load_npz(f'{main_dir}states_{dataset}.npz')) - steps_list.append(sparse.load_npz(f'{main_dir}steps_{dataset}.npz')) + states_list.append(sparse.load_npz(main_dir / f'states_{dataset}.npz')) + steps_list.append(sparse.load_npz(main_dir / f'steps_{dataset}.npz')) states = sparse.csc_matrix(sparse.vstack(states_list)) steps = sparse.csc_matrix(sparse.vstack(steps_list)) @@ -261,17 +261,19 @@ def prep_data(main_dir, num_rxn, out_dim): # extract Action data X = states y = steps[:, 0] - sparse.save_npz(f'{main_dir}X_act_{dataset}.npz', X) - sparse.save_npz(f'{main_dir}y_act_{dataset}.npz', y) + sparse.save_npz(main_dir / f'X_act_{dataset}.npz', X) + sparse.save_npz(main_dir / f'y_act_{dataset}.npz', y) states = sparse.csc_matrix(states.A[(steps[:, 0].A != 3).reshape(-1, )]) steps = sparse.csc_matrix(steps.A[(steps[:, 0].A != 3).reshape(-1, )]) + print(f' saved data for "Action"') # extract Reaction data X = sparse.hstack([states, steps[:, (2 * out_dim + 2):]]) y = steps[:, out_dim + 1] - sparse.save_npz(f'{main_dir}X_rxn_{dataset}.npz', X) - sparse.save_npz(f'{main_dir}y_rxn_{dataset}.npz', y) + sparse.save_npz(main_dir / f'X_rxn_{dataset}.npz', X) + sparse.save_npz(main_dir / f'y_rxn_{dataset}.npz', y) + print(f' saved data for "Reaction"') states = sparse.csc_matrix(states.A[(steps[:, 0].A != 2).reshape(-1, )]) steps = sparse.csc_matrix(steps.A[(steps[:, 0].A != 2).reshape(-1, )]) @@ -287,8 +289,9 @@ def prep_data(main_dir, num_rxn, out_dim): sparse.csc_matrix(enc.transform(steps[:, out_dim+1].A.reshape((-1, 1))).toarray())] ) y = steps[:, (out_dim+2): (2 * out_dim + 2)] - sparse.save_npz(f'{main_dir}X_rt2_{dataset}.npz', X) - sparse.save_npz(f'{main_dir}y_rt2_{dataset}.npz', y) + sparse.save_npz(main_dir / f'X_rt2_{dataset}.npz', X) + sparse.save_npz(main_dir / f'y_rt2_{dataset}.npz', y) + print(f' saved data for "Reactant 2"') states = sparse.csc_matrix(states.A[(steps[:, 0].A != 1).reshape(-1, )]) steps = sparse.csc_matrix(steps.A[(steps[:, 0].A != 1).reshape(-1, )]) @@ -296,7 +299,8 @@ def prep_data(main_dir, num_rxn, out_dim): # extract Reactant 1 data X = states y = steps[:, 1: (out_dim+1)] - sparse.save_npz(f'{main_dir}X_rt1_{dataset}.npz', X) - sparse.save_npz(f'{main_dir}y_rt1_{dataset}.npz', y) + sparse.save_npz(main_dir / f'X_rt1_{dataset}.npz', X) + sparse.save_npz(main_dir / f'y_rt1_{dataset}.npz', y) + print(f' saved data for "Reactant 1"') return None From ed0b0b809a5bd8bc2c0cd0d2f38eb6b6f8d024d1 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Thu, 18 Aug 2022 14:12:19 -0400 Subject: [PATCH 025/302] refactor: remove hard-coded paths & cosmetic changes --- scripts/compute_embedding_mp.py | 99 ++++++++++++++++++--------------- src/syn_net/config.py | 1 + 2 files changed, 54 insertions(+), 46 deletions(-) diff --git a/scripts/compute_embedding_mp.py b/scripts/compute_embedding_mp.py index 421503ec..d8a4ef7e 100644 --- a/scripts/compute_embedding_mp.py +++ b/scripts/compute_embedding_mp.py @@ -1,57 +1,64 @@ """ Computes the molecular embeddings of the purchasable building blocks. + +The embeddings are also referred to as "output embedding". +In the embedding space, a kNN-search will identify the 1st or 2nd reactant. """ +import logging import multiprocessing as mp -from scripts.compute_embedding import * -from rdkit import RDLogger -from syn_net.utils.predict_utils import mol_embedding, fp_4096, fp_2048, fp_1024, fp_512, fp_256, rdkit2d_embedding -RDLogger.DisableLog('*') +from pathlib import Path + +import numpy as np +import pandas as pd + +from syn_net.config import DATA_EMBEDDINGS_DIR, DATA_PREPROCESS_DIR +from syn_net.utils.predict_utils import fp_256, fp_512, fp_1024, fp_2048, fp_4096, mol_embedding, rdkit2d_embedding + +logger = logging.getLogger(__file__) -if __name__ == '__main__': +FUNCTIONS = { + "gin": mol_embedding, + "fp_4096": fp_4096, + "fp_2048": fp_2048, + "fp_1024": fp_1024, + "fp_512": fp_512, + "fp_256": fp_256, + "rdkit2d": rdkit2d_embedding, +} + + +if __name__ == "__main__": import argparse + parser = argparse.ArgumentParser() - parser.add_argument("--feature", type=str, default="gin", - help="Objective function to optimize") - parser.add_argument("--ncpu", type=int, default=16, - help="Number of cpus") + parser.add_argument("--feature", type=str, default="fp_256", choices=FUNCTIONS.keys(), help="Objective function to optimize") + parser.add_argument("--ncpu", type=int, default=64, help="Number of cpus") + parser.add_argument("-rxn", "--rxn_template", type=str, default="hb", choices=["hb", "pis"], help="Choose from ['hb', 'pis']") + parser.add_argument("--input", type=str, help="Input file with SMILES strings (One per line).") args = parser.parse_args() - # define the path to which data will be saved - path = '/pool001/whgao/data/synth_net/st_hb/' - ## path = './tests/data/' ## for debugging - - # load the building blocks - data = pd.read_csv(path + 'enamine_us_matched.csv.gz', compression='gzip')['SMILES'].tolist() - ## data = pd.read_csv(path + 'building_blocks_matched.csv.gz', compression='gzip')['SMILES'].tolist() ## for debugging - print('Total data: ', len(data)) - - if args.feature == 'gin': - with mp.Pool(processes=args.ncpu) as pool: - embeddings = pool.map(mol_embedding, data) - elif args.feature == 'fp_4096': - with mp.Pool(processes=args.ncpu) as pool: - embeddings = pool.map(fp_4096, data) - elif args.feature == 'fp_2048': - with mp.Pool(processes=args.ncpu) as pool: - embeddings = pool.map(fp_2048, data) - elif args.feature == 'fp_1024': - with mp.Pool(processes=args.ncpu) as pool: - embeddings = pool.map(fp_1024, data) - elif args.feature == 'fp_512': - with mp.Pool(processes=args.ncpu) as pool: - embeddings = pool.map(fp_512, data) - elif args.feature == 'fp_256': - with mp.Pool(processes=args.ncpu) as pool: - embeddings = pool.map(fp_256, data) - elif args.feature == 'rdkit2d': - with mp.Pool(processes=args.ncpu) as pool: - embeddings = pool.map(rdkit2d_embedding, data) - - embedding = np.array(embeddings) - - # import ipdb; ipdb.set_trace(context=9) - np.save(path + 'enamine_us_emb_' + args.feature + '.npy', embeddings) - - print('Finish!') + reaction_template_id = args.rxn_template + building_blocks_id = "enamine_us-2021-smiles" + + # Load building blocks + file = Path(DATA_PREPROCESS_DIR) / f"{reaction_template_id}-{building_blocks_id}-matched.csv.gz" + + data = pd.read_csv(file)["SMILES"].tolist() + logger.info(f"Successfully read {file}.") + logger.info(f"Total number of building blocks: {len(data)}.") + + func = FUNCTIONS[args.feature] + with mp.Pool(processes=args.ncpu) as pool: + embeddings = pool.map(func, data) + + # Save embeddings + embeddings = np.array(embeddings) + + path = Path(DATA_EMBEDDINGS_DIR) + path.mkdir(exist_ok=1, parents=1) + outfile = path / f"{reaction_template_id}-{building_blocks_id}-embeddings.npy" + + np.save(outfile, embeddings) + logger.info(f"Successfully saved to {outfile}.") diff --git a/src/syn_net/config.py b/src/syn_net/config.py index 1438a155..2273b0af 100644 --- a/src/syn_net/config.py +++ b/src/syn_net/config.py @@ -9,6 +9,7 @@ # Pre-processed data DATA_PREPROCESS_DIR = "database/pre-process" +DATA_EMBEDDINGS_DIR = "database/pre-process/embeddings" # Prepared data DATA_PREPARED_DIR = "database/prepared" From d7af05f743d335ae67573731a8672af21668fb98 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Fri, 19 Aug 2022 10:59:26 -0400 Subject: [PATCH 026/302] refactor functions from main --- scripts/predict_multireactant_mp.py | 50 ++++++++++++++++++----------- 1 file changed, 31 insertions(+), 19 deletions(-) diff --git a/scripts/predict_multireactant_mp.py b/scripts/predict_multireactant_mp.py index f8ee0da9..ceeeef12 100644 --- a/scripts/predict_multireactant_mp.py +++ b/scripts/predict_multireactant_mp.py @@ -6,7 +6,27 @@ import pandas as pd import _mp_predict_multireactant as predict from syn_net.utils.data_utils import SyntheticTreeSet +from pathlib import Path +from syn_net.config import DATA_PREPARED_DIR, DATA_RESULT_DIR +Path(DATA_RESULT_DIR).mkdir(exist_ok=True) + +def _fetch_data_chembl(name: str) -> list[str]: + raise NotImplementedError + df = pd.read_csv(f'{DATA_DIR}/chembl_20k.csv') + smis_query = df.smiles.to_list() + return smis_query + +def _fetch_data(name: str) -> list[str]: + if args.data in ["train", "valid", "test"]: + file = Path(DATA_PREPARED_DIR) / f"synthetic-trees-{args.data}.json.gz" + print(f'Reading data from {file}') + sts = SyntheticTreeSet() + sts.load(file) + smis_query = [st.root.smiles for st in sts.sts] + else: + smis_query = _fetch_data_chembl(name) + return smis_query if __name__ == '__main__': @@ -25,23 +45,15 @@ args = parser.parse_args() # load the query molecules (i.e. molecules to decode) - if args.data != 'chembl': - path_to_data = f'/pool001/whgao/data/synth_net/st_{args.rxn_template}/st_{args.data}.json.gz' - print('Reading data from ', path_to_data) - sts = SyntheticTreeSet() - sts.load(path_to_data) - smis_query = [st.root.smiles for st in sts.sts] - if args.num == -1: - pass - else: - smis_query = smis_query[:args.num] - else: - df = pd.read_csv('/home/whgao/synth_net/chembl_20k.csv') - smis_query = df.smiles.to_list() + smiles_queries = _fetch_data(args.data) + + # Select only n queries + if args.num > 0: + smiles_queries = smiles_queries[:args.num] - print('Start to decode!') + print(f'Start to decode {len(smiles_queries)} target molecules.') with mp.Pool(processes=args.ncpu) as pool: - results = pool.map(predict.func, smis_query) + results = pool.map(predict.func, smiles_queries) smis_decoded = [r[0] for r in results] similarities = [r[1] for r in results] @@ -52,15 +64,15 @@ print(f'Average similarity {args.data}: {np.mean(np.array(similarities))}') print('Saving ......') - save_path = '../results/' - df = pd.DataFrame({'query SMILES' : smis_query, + save_path = DATA_RESULT_DIR + df = pd.DataFrame({'query SMILES' : smiles_queries, 'decode SMILES': smis_decoded, 'similarity' : similarities}) - df.to_csv(f'{save_path}decode_result_{args.data}.csv.gz', + df.to_csv(f'{save_path}/decode_result_{args.data}.csv.gz', compression='gzip', index=False) synthetic_tree_set = SyntheticTreeSet(sts=trees) - synthetic_tree_set.save(f'{save_path}decoded_st_{args.data}.json.gz') + synthetic_tree_set.save(f'{save_path}/decoded_st_{args.data}.json.gz') print('Finish!') From 2c0cc201632e3b4caa9ac80465281e45e8519323 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Fri, 19 Aug 2022 15:11:47 -0400 Subject: [PATCH 027/302] wip: clean up code & comments - use `np.atleast_2d` instead of try/excepts - use dict-lookup instead of if/elseif --- src/syn_net/utils/predict_utils.py | 80 +++++++++++++++--------------- 1 file changed, 40 insertions(+), 40 deletions(-) diff --git a/src/syn_net/utils/predict_utils.py b/src/syn_net/utils/predict_utils.py index 36ab443b..c6541fcf 100644 --- a/src/syn_net/utils/predict_utils.py +++ b/src/syn_net/utils/predict_utils.py @@ -839,61 +839,59 @@ def synthetic_tree_decoder_rt1(z_target, # Initialization tree = SyntheticTree() mol_recent = None - kdtree = BallTree(bb_emb, metric=cosine_distance) - + kdtree = BallTree(bb_emb, metric=cosine_distance) # TODO: cache this or use class + z_target = np.atleast_2d(z_target) # Start iteration for i in range(max_step): # Encode current state - state = tree.get_state() # a set - try: - z_state = set_embedding(z_target, state, nbits=n_bits, _mol_embedding=mol_fp) - except: - z_target = np.expand_dims(z_target, axis=0) + state = tree.get_state() # a list z_state = set_embedding(z_target, state, nbits=n_bits, _mol_embedding=mol_fp) # Predict action type, masked selection # Action: (Add: 0, Expand: 1, Merge: 2, End: 3) - action_proba = action_net(torch.Tensor(z_state)) + action_proba = action_net(torch.Tensor(z_state)) # (1,4) action_proba = action_proba.squeeze().detach().numpy() + 1e-10 action_mask = get_action_mask(tree.get_state(), reaction_templates) act = np.argmax(action_proba * action_mask) + # Continue growing tree? + if act == 3: # End + break + z_mol1 = reactant1_net(torch.Tensor(z_state)) - z_mol1 = z_mol1.detach().numpy() + z_mol1 = z_mol1.detach().numpy() # (1,dimension_output_embedding), default: (1,256) + # Select first molecule - if act == 3: - # End - break - elif act == 0: - # Add + if act == 0: # Add if mol_recent is not None: dist, ind = nn_search(z_mol1) mol1 = building_blocks[ind] - else: - dist, ind = nn_search_rt1(z_mol1, _tree=kdtree, _k=rt1_index+1) + else: # no recent mol + dist, ind = nn_search_rt1(z_mol1, _tree=kdtree, _k=rt1_index+1) # TODO: why is there an option to select the k-th? rt1_index (???) mol1 = building_blocks[ind[rt1_index]] - else: + elif act==1 or act==2: # Expand or Merge mol1 = mol_recent + else: + raise ValueError(f"Unexpected action {act}.") - # z_mol1 = get_mol_embedding(mol1, mol_embedder) - z_mol1 = mol_fp(mol1) + z_mol1 = mol_fp(mol1) # (dimension_input_embedding=d), default (4096,) + z_mol1 = np.atleast_2d(z_mol1) # (1,4096) # Select reaction - try: - reaction_proba = rxn_net(torch.Tensor(np.concatenate([z_state, z_mol1], axis=1))) - except: - z_mol1 = np.expand_dims(z_mol1, axis=0) - reaction_proba = rxn_net(torch.Tensor(np.concatenate([z_state, z_mol1], axis=1))) - reaction_proba = reaction_proba.squeeze().detach().numpy() + 1e-10 + z = np.concatenate([z_state, z_mol1], axis=1) # (1,4d) + reaction_proba = rxn_net(torch.Tensor(z)) + reaction_proba = reaction_proba.squeeze().detach().numpy() + 1e-10 # (nReactionTemplate) - if act != 2: + if act != 2: # add or expand reaction_mask, available_list = get_reaction_mask(mol1, reaction_templates) - else: + else: # merge _, reaction_mask = can_react(tree.get_state(), reaction_templates) - available_list = [[] for rxn in reaction_templates] + available_list = [[] for rxn in reaction_templates] # TODO: if act=merge, this is not used at all + # If we ended up in a state where no reaction is possible, + # end this iteration. if reaction_mask is None: if len(state) == 1: act = 3 @@ -901,28 +899,30 @@ def synthetic_tree_decoder_rt1(z_target, else: break + # Select reaction template rxn_id = np.argmax(reaction_proba * reaction_mask) rxn = reaction_templates[rxn_id] + NUMBER_OF_REACTION_TEMPLATES = { + "hb": 91, + "pis": 4700, + "unittest": 3, + } # TODO: Refactor / use class + + # Select 2nd reactant if rxn.num_reactant == 2: - # Select second molecule - if act == 2: - # Merge + if act == 2: # Merge temp = set(state) - set([mol1]) mol2 = temp.pop() - else: - # Add or Expand - if rxn_template == 'hb': - z_mol2 = reactant2_net(torch.Tensor(np.concatenate([z_state, z_mol1, one_hot_encoder(rxn_id, 91)], axis=1))) - elif rxn_template == 'pis': - z_mol2 = reactant2_net(torch.Tensor(np.concatenate([z_state, z_mol1, one_hot_encoder(rxn_id, 4700)], axis=1))) - elif rxn_template == 'unittest': - z_mol2 = reactant2_net(torch.Tensor(np.concatenate([z_state, z_mol1, one_hot_encoder(rxn_id, 3)], axis=1))) + else: # Add or Expand + x_rxn = one_hot_encoder(rxn_id,NUMBER_OF_REACTION_TEMPLATES[rxn_template]) + x_rct2 = np.concatenate([z_state,z_mol1, x_rxn],axis=1) + z_mol2 = reactant2_net(torch.Tensor(x_rct2)) z_mol2 = z_mol2.detach().numpy() available = available_list[rxn_id] available = [bb_dict[available[i]] for i in range(len(available))] temp_emb = bb_emb[available] - available_tree = BallTree(temp_emb, metric=cosine_distance) + available_tree = BallTree(temp_emb, metric=cosine_distance) # TODO: evaluate if distance matrix is faster/feasible as this BallTree is discarded immediately. dist, ind = nn_search(z_mol2, _tree=available_tree) mol2 = building_blocks[available[ind]] else: From 57c19e8c0f399850d657a3e4e6a82e59cbbad69c Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Fri, 19 Aug 2022 15:25:53 -0400 Subject: [PATCH 028/302] delete redundant code --- src/syn_net/models/mlp.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/src/syn_net/models/mlp.py b/src/syn_net/models/mlp.py index 4db3574f..600b9c4c 100644 --- a/src/syn_net/models/mlp.py +++ b/src/syn_net/models/mlp.py @@ -130,16 +130,7 @@ def nn_search(_e, _tree, _k=1): return ind[0][0] def nn_search_list(y, out_feat, kdtree): - if out_feat == 'gin': return np.array([nn_search(emb.reshape(1, -1), _tree=kdtree) for emb in y]) - elif out_feat == 'fp_4096': - return np.array([nn_search(emb.reshape(1, -1), _tree=kdtree) for emb in y]) - elif out_feat == 'fp_256': - return np.array([nn_search(emb.reshape(1, -1), _tree=kdtree) for emb in y]) - elif out_feat == 'rdkit2d': - return np.array([nn_search(emb.reshape(1, -1), _tree=kdtree) for emb in y]) - else: - raise ValueError if __name__ == '__main__': From 142051b765d966a00aa2a1e3fbec8c0ea91b2e06 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Fri, 19 Aug 2022 15:42:22 -0400 Subject: [PATCH 029/302] use `lru_cache` to fetch pretrained gin --- src/syn_net/utils/predict_utils.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/syn_net/utils/predict_utils.py b/src/syn_net/utils/predict_utils.py index c6541fcf..50371d39 100644 --- a/src/syn_net/utils/predict_utils.py +++ b/src/syn_net/utils/predict_utils.py @@ -17,16 +17,20 @@ from tdc.chem_utils import MolConvert from syn_net.models.mlp import MLP from syn_net.utils.data_utils import SyntheticTree - +import functools # create a random seed for NumPy np.random.seed(6) -# get a GIN pretrained model to use for creating molecular embeddings -model_type = 'gin_supervised_contextpred' + +@functools.lru_cache(1) +def _fetch_gin_pretrained_model(name: str): + """Get a GIN pretrained model to use for creating molecular embeddings""" + # name = 'gin_supervised_contextpred' device = 'cpu' -gin_pretrained_model = load_pretrained(model_type).to(device) # used to learn embedding + gin_pretrained_model = load_pretrained(name).to(device) # used to learn embedding gin_pretrained_model.eval() + return gin_pretrained_model # general functions @@ -184,6 +188,8 @@ def mol_embedding(smi, device='cpu', readout=AvgPooling()): Returns: np.ndarray: Either a zeros array or the graph embedding. """ + name = 'gin_supervised_contextpred' + gin_pretrained_model = _fetch_gin_pretrained_model(name) # get the embedding if smi is None: From 7dd09b8cb5e819b7c903aa62eb009b6993fe9810 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Fri, 19 Aug 2022 15:50:27 -0400 Subject: [PATCH 030/302] wip: align fct signature + body (duplicate code; needs refactor) --- src/syn_net/utils/predict_utils.py | 9 ++++----- src/syn_net/utils/prep_utils.py | 14 ++++++++++---- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/src/syn_net/utils/predict_utils.py b/src/syn_net/utils/predict_utils.py index 50371d39..99b5b194 100644 --- a/src/syn_net/utils/predict_utils.py +++ b/src/syn_net/utils/predict_utils.py @@ -24,13 +24,12 @@ @functools.lru_cache(1) -def _fetch_gin_pretrained_model(name: str): +def _fetch_gin_pretrained_model(model_name: str): """Get a GIN pretrained model to use for creating molecular embeddings""" - # name = 'gin_supervised_contextpred' device = 'cpu' - gin_pretrained_model = load_pretrained(name).to(device) # used to learn embedding -gin_pretrained_model.eval() - return gin_pretrained_model + model = load_pretrained(model_name).to(device) # used to learn embedding + model.eval() + return model # general functions diff --git a/src/syn_net/utils/prep_utils.py b/src/syn_net/utils/prep_utils.py index fe4fde82..f7e6bca3 100644 --- a/src/syn_net/utils/prep_utils.py +++ b/src/syn_net/utils/prep_utils.py @@ -29,6 +29,15 @@ def rdkit2d_embedding(smi): rdkit2d = MolConvert(src = 'SMILES', dst = 'RDKit2D') return rdkit2d(smi).reshape(-1, ) +import functools +@functools.lru_cache(maxsize=1) +def _fetch_gin_pretrained_model(model_name: str): + """Get a GIN pretrained model to use for creating molecular embeddings""" + device = 'cpu' + model = load_pretrained(model_name).to(device) + model.eval() + return model + def organize(st, d_mol=300, target_embedding='fp', radius=2, nBits=4096, output_embedding='gin'): @@ -56,10 +65,7 @@ def organize(st, d_mol=300, target_embedding='fp', radius=2, nBits=4096, sparse.csc_matrix: Actions pulled from the tree. """ # define model to use for molecular embedding - model_type = 'gin_supervised_contextpred' - device = 'cpu' - model = load_pretrained(model_type).to(device) - model.eval() + model = _fetch_gin_pretrained_model("gin_supervised_contextpred") states = [] steps = [] From 07a84635dd07f55f72e085c190c72c8fff6f2985 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Fri, 19 Aug 2022 15:52:10 -0400 Subject: [PATCH 031/302] wip: avoid hard-coded paths, slight code clean up --- scripts/_mp_predict_multireactant.py | 33 ++++++++++++++-------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/scripts/_mp_predict_multireactant.py b/scripts/_mp_predict_multireactant.py index e3ab8f13..9b3c32cd 100644 --- a/scripts/_mp_predict_multireactant.py +++ b/scripts/_mp_predict_multireactant.py @@ -5,37 +5,38 @@ import numpy as np from syn_net.utils.data_utils import ReactionSet from syn_net.utils.predict_utils import synthetic_tree_decoder_multireactant, load_modules_from_checkpoint, mol_fp - +from pathlib import Path +from syn_net.config import DATA_PREPROCESS_DIR, DATA_EMBEDDINGS_DIR, CHECKPOINTS_DIR # define some constants (here, for the Hartenfeller-Button test set) nbits = 4096 -out_dim = 256 +out_dim = 256 # <=> morgan fingerprint with 256 bits rxn_template = 'hb' +building_blocks_id = "enamine_us-2021-smiles" featurize = 'fp' param_dir = 'hb_fp_2_4096_256' ncpu = 1 # load the purchasable building block embeddings -bb_emb = np.load('/pool001/whgao/data/synth_net/st_hb/enamine_us_emb_fp_256.npy') +file = Path(DATA_EMBEDDINGS_DIR) / f"{rxn_template}-{building_blocks_id}-embeddings.npy" +bb_emb = np.load(file) -# define path to the reaction templates and purchasable building blocks -path_to_reaction_file = f'/pool001/whgao/data/synth_net/st_{rxn_template}/reactions_{rxn_template}.json.gz' -path_to_building_blocks = f'/pool001/whgao/data/synth_net/st_{rxn_template}/enamine_us_matched.csv.gz' # define paths to pretrained modules -param_path = f'/home/whgao/synth_net/synth_net/params/{param_dir}/' -path_to_act = f'{param_path}act.ckpt' -path_to_rt1 = f'{param_path}rt1.ckpt' -path_to_rxn = f'{param_path}rxn.ckpt' -path_to_rt2 = f'{param_path}rt2.ckpt' +path_to_act = Path(CHECKPOINTS_DIR) / f"{param_dir}/act.ckpt" +path_to_rt1 = Path(CHECKPOINTS_DIR) / f"{param_dir}/rt1.ckpt" +path_to_rxn = Path(CHECKPOINTS_DIR) / f"{param_dir}/rxn.ckpt" +path_to_rt2 = Path(CHECKPOINTS_DIR) / f"{param_dir}/rt2.ckpt" -# load the purchasable building block SMILES to a dictionary -building_blocks = pd.read_csv(path_to_building_blocks, compression='gzip')['SMILES'].tolist() -bb_dict = {building_blocks[i]: i for i in range(len(building_blocks))} +# Load building blocks +building_blocks_file = Path(DATA_PREPROCESS_DIR) / f"{rxn_template}-{building_blocks_id}-matched.csv.gz" +building_blocks = pd.read_csv(building_blocks_file, compression='gzip')['SMILES'].tolist() +bb_dict = {block: i for i,block in enumerate(building_blocks)} # dict is useful as lookup table for 2nd reactant during inference -# load the reaction templates as a ReactionSet object +# Load reaction templates +reaction_file = Path(DATA_PREPROCESS_DIR) / f"reaction-sets_{rxn_template}_{building_blocks_id}.json.gz" rxn_set = ReactionSet() -rxn_set.load(path_to_reaction_file) +rxn_set.load(reaction_file) rxns = rxn_set.rxns # load the pre-trained modules From 3f44da4df39aaa6fd2b1c23e8272814702886642 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 22 Aug 2022 11:05:58 -0400 Subject: [PATCH 032/302] refactor: consolidate `predict_multireactant_mp.py` - move into a single function - remove hard-coded paths - move code into functions --- scripts/_mp_predict_multireactant.py | 92 --------------- scripts/predict_multireactant_mp.py | 161 +++++++++++++++++++++++---- src/syn_net/utils/predict_utils.py | 10 +- 3 files changed, 146 insertions(+), 117 deletions(-) diff --git a/scripts/_mp_predict_multireactant.py b/scripts/_mp_predict_multireactant.py index 9b3c32cd..e69de29b 100644 --- a/scripts/_mp_predict_multireactant.py +++ b/scripts/_mp_predict_multireactant.py @@ -1,92 +0,0 @@ -""" -This file contains a function to decode a single synthetic tree. -""" -import pandas as pd -import numpy as np -from syn_net.utils.data_utils import ReactionSet -from syn_net.utils.predict_utils import synthetic_tree_decoder_multireactant, load_modules_from_checkpoint, mol_fp -from pathlib import Path -from syn_net.config import DATA_PREPROCESS_DIR, DATA_EMBEDDINGS_DIR, CHECKPOINTS_DIR - -# define some constants (here, for the Hartenfeller-Button test set) -nbits = 4096 -out_dim = 256 # <=> morgan fingerprint with 256 bits -rxn_template = 'hb' -building_blocks_id = "enamine_us-2021-smiles" -featurize = 'fp' -param_dir = 'hb_fp_2_4096_256' -ncpu = 1 - -# load the purchasable building block embeddings -file = Path(DATA_EMBEDDINGS_DIR) / f"{rxn_template}-{building_blocks_id}-embeddings.npy" -bb_emb = np.load(file) - - -# define paths to pretrained modules -path_to_act = Path(CHECKPOINTS_DIR) / f"{param_dir}/act.ckpt" -path_to_rt1 = Path(CHECKPOINTS_DIR) / f"{param_dir}/rt1.ckpt" -path_to_rxn = Path(CHECKPOINTS_DIR) / f"{param_dir}/rxn.ckpt" -path_to_rt2 = Path(CHECKPOINTS_DIR) / f"{param_dir}/rt2.ckpt" - -# Load building blocks -building_blocks_file = Path(DATA_PREPROCESS_DIR) / f"{rxn_template}-{building_blocks_id}-matched.csv.gz" -building_blocks = pd.read_csv(building_blocks_file, compression='gzip')['SMILES'].tolist() -bb_dict = {block: i for i,block in enumerate(building_blocks)} # dict is useful as lookup table for 2nd reactant during inference - -# Load reaction templates -reaction_file = Path(DATA_PREPROCESS_DIR) / f"reaction-sets_{rxn_template}_{building_blocks_id}.json.gz" -rxn_set = ReactionSet() -rxn_set.load(reaction_file) -rxns = rxn_set.rxns - -# load the pre-trained modules -act_net, rt1_net, rxn_net, rt2_net = load_modules_from_checkpoint( - path_to_act=path_to_act, - path_to_rt1=path_to_rt1, - path_to_rxn=path_to_rxn, - path_to_rt2=path_to_rt2, - featurize=featurize, - rxn_template=rxn_template, - out_dim=out_dim, - nbits=nbits, - ncpu=ncpu, -) - -def func(smi): - """ - Generates the synthetic tree for the input molecular embedding. - - Args: - smi (str): SMILES string corresponding to the molecule to decode. - - Returns: - smi (str): SMILES for the final chemical node in the tree. - similarity (float): Similarity measure between the final chemical node - and the input molecule. - tree (SyntheticTree): The generated synthetic tree. - """ - emb = mol_fp(smi) - try: - smi, similarity, tree, action = synthetic_tree_decoder_multireactant( - z_target=emb, - building_blocks=building_blocks, - bb_dict=bb_dict, - reaction_templates=rxns, - mol_embedder=mol_fp, - action_net=act_net, - reactant1_net=rt1_net, - rxn_net=rxn_net, - reactant2_net=rt2_net, - bb_emb=bb_emb, - rxn_template=rxn_template, - n_bits=nbits, - beam_width=3, - max_step=15) - except Exception as e: - print(e) - action = -1 - - if action != 3: - return None, 0, None - else: - return smi, similarity, tree diff --git a/scripts/predict_multireactant_mp.py b/scripts/predict_multireactant_mp.py index ceeeef12..1df15d2a 100644 --- a/scripts/predict_multireactant_mp.py +++ b/scripts/predict_multireactant_mp.py @@ -2,12 +2,17 @@ Generate synthetic trees for a set of specified query molecules. Multiprocessing. """ import multiprocessing as mp +from pathlib import Path + import numpy as np import pandas as pd -import _mp_predict_multireactant as predict -from syn_net.utils.data_utils import SyntheticTreeSet -from pathlib import Path -from syn_net.config import DATA_PREPARED_DIR, DATA_RESULT_DIR + +from syn_net.config import (CHECKPOINTS_DIR, DATA_EMBEDDINGS_DIR, + DATA_PREPARED_DIR, DATA_PREPROCESS_DIR, + DATA_RESULT_DIR) +from syn_net.utils.data_utils import ReactionSet, SyntheticTreeSet +from syn_net.utils.predict_utils import (load_modules_from_checkpoint, mol_fp, + synthetic_tree_decoder_multireactant) Path(DATA_RESULT_DIR).mkdir(exist_ok=True) @@ -28,51 +33,167 @@ def _fetch_data(name: str) -> list[str]: smis_query = _fetch_data_chembl(name) return smis_query +def _fetch_reaction_templates(file: str): + # Load reaction templates + rxn_set = ReactionSet() + rxn_set.load(file) + return rxn_set.rxns + +def _fetch_building_blocks_embeddings(file: str): + """Load the purchasable building block embeddings.""" + return np.load(file) + +def _fetch_building_blocks(file: str): + """Load the building blocks""" + return pd.read_csv(file, compression='gzip')['SMILES'].tolist() + +def _load_pretrained_model(path_to_checkpoints: str): + """Wrapper to load modules from checkpoint.""" + # Define paths to pretrained models. + path_to_act = Path(path_to_checkpoints) / "act.ckpt" + path_to_rt1 = Path(path_to_checkpoints) / "rt1.ckpt" + path_to_rxn = Path(path_to_checkpoints) / "rxn.ckpt" + path_to_rt2 = Path(path_to_checkpoints) / "rt2.ckpt" + + # Load the pre-trained models. + act_net, rt1_net, rxn_net, rt2_net = load_modules_from_checkpoint( + path_to_act=path_to_act, + path_to_rt1=path_to_rt1, + path_to_rxn=path_to_rxn, + path_to_rt2=path_to_rt2, + featurize=featurize, + rxn_template=rxn_template, + out_dim=out_dim, + nbits=nbits, + ncpu=ncpu, + ) + return act_net, rt1_net, rxn_net, rt2_net + +def func(smiles: str): + """ + Generates the synthetic tree for the input molecular embedding. + + Args: + smi (str): SMILES string corresponding to the molecule to decode. + + Returns: + smi (str): SMILES for the final chemical node in the tree. + similarity (float): Similarity measure between the final chemical node + and the input molecule. + tree (SyntheticTree): The generated synthetic tree. + """ + emb = mol_fp(smiles) + try: + smi, similarity, tree, action = synthetic_tree_decoder_multireactant( + z_target=emb, + building_blocks=building_blocks, + bb_dict=building_blocks_dict, + reaction_templates=rxns, + mol_embedder=mol_fp, + action_net=act_net, + reactant1_net=rt1_net, + rxn_net=rxn_net, + reactant2_net=rt2_net, + bb_emb=bb_emb, + rxn_template=rxn_template, + n_bits=nbits, + beam_width=3, + max_step=15) + except Exception as e: + print(e) + action = -1 + + if action != 3: # aka tree has not been properly ended + smi = None + similarity = 0 + tree = None + + return smi, similarity, tree + + if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() parser.add_argument("-f", "--featurize", type=str, default='fp', help="Choose from ['fp', 'gin']") + parser.add_argument("--radius", type=int, default=2, + help="Radius for Morgan Fingerprint") + parser.add_argument("-b", "--nbits", type=int, default=4096, + help="Number of Bits for Morgan Fingerprint") parser.add_argument("-r", "--rxn_template", type=str, default='hb', help="Choose from ['hb', 'pis']") - parser.add_argument("--ncpu", type=int, default=16, + parser.add_argument("--ncpu", type=int, default=1, help="Number of cpus") - parser.add_argument("-n", "--num", type=int, default=-1, + parser.add_argument("-n", "--num", type=int, default=1, help="Number of molecules to predict.") parser.add_argument("-d", "--data", type=str, default='test', help="Choose from ['train', 'valid', 'test', 'chembl']") + parser.add_argument("-o", "--outputembedding", type=str, default='fp_256', + help="Choose from ['fp_4096', 'fp_256', 'gin', 'rdkit2d']") args = parser.parse_args() - # load the query molecules (i.e. molecules to decode) - smiles_queries = _fetch_data(args.data) + nbits = args.nbits + out_dim = args.outputembedding.split("_")[-1] # <=> morgan fingerprint with 256 bits + rxn_template = args.rxn_template + building_blocks_id = "enamine_us-2021-smiles" + featurize = args.featurize + radius = args.radius + ncpu = args.ncpu + param_dir = f"{rxn_template}_{featurize}_{radius}_{nbits}_{out_dim}" - # Select only n queries - if args.num > 0: + # Load data ... + # ... query molecules (i.e. molecules to decode) + smiles_queries = _fetch_data(args.data) + if args.num > 0: # Select only n queries smiles_queries = smiles_queries[:args.num] + # ... building blocks + file = Path(DATA_PREPROCESS_DIR) / f"{rxn_template}-{building_blocks_id}-matched.csv.gz" + building_blocks = _fetch_building_blocks(file) + building_blocks_dict = {block: i for i,block in enumerate(building_blocks)} # dict is used as lookup table for 2nd reactant during inference + + # ... reaction templates + file = Path(DATA_PREPROCESS_DIR) / f"reaction-sets_{rxn_template}_{building_blocks_id}.json.gz" + rxns = _fetch_reaction_templates(file) + + # ... building blocks + file = Path(DATA_EMBEDDINGS_DIR) / f"{rxn_template}-{building_blocks_id}-embeddings.npy" + bb_emb = _fetch_building_blocks_embeddings(file) + + # ... models + path = Path(CHECKPOINTS_DIR) / f"{param_dir}" + act_net, rt1_net, rxn_net, rt2_net = _load_pretrained_model(path) + + + # Decode queries, i.e. the target molecules. print(f'Start to decode {len(smiles_queries)} target molecules.') with mp.Pool(processes=args.ncpu) as pool: - results = pool.map(predict.func, smiles_queries) + results = pool.map(func, smiles_queries) + print('Finished decoding.') + + # Print some results from the prediction smis_decoded = [r[0] for r in results] - similarities = [r[1] for r in results] + similarities = [r[1] for r in results] trees = [r[2] for r in results] - print('Finish decoding') - print(f'Recovery rate {args.data}: {np.sum(np.array(similarities) == 1.0) / len(similarities)}') - print(f'Average similarity {args.data}: {np.mean(np.array(similarities))}') + recovery_rate = (np.asfarray(similarities)==1.0).sum()/len(similarities) + avg_similarity = np.mean(similarities) + print(f"For {args.data}:") + print(f" {recovery_rate=}") + print(f" {avg_similarity=}") - print('Saving ......') - save_path = DATA_RESULT_DIR + # Save to local dir + print('Saving results to {DATA_RESULT_DIR} ...') df = pd.DataFrame({'query SMILES' : smiles_queries, 'decode SMILES': smis_decoded, 'similarity' : similarities}) - df.to_csv(f'{save_path}/decode_result_{args.data}.csv.gz', + df.to_csv(f'{DATA_RESULT_DIR}/decode_result_{args.data}.csv.gz', compression='gzip', - index=False) + index=False,) synthetic_tree_set = SyntheticTreeSet(sts=trees) - synthetic_tree_set.save(f'{save_path}/decoded_st_{args.data}.json.gz') + synthetic_tree_set.save(f'{DATA_RESULT_DIR}/decoded_st_{args.data}.json.gz') print('Finish!') diff --git a/src/syn_net/utils/predict_utils.py b/src/syn_net/utils/predict_utils.py index 99b5b194..dbf3c937 100644 --- a/src/syn_net/utils/predict_utils.py +++ b/src/syn_net/utils/predict_utils.py @@ -26,7 +26,7 @@ @functools.lru_cache(1) def _fetch_gin_pretrained_model(model_name: str): """Get a GIN pretrained model to use for creating molecular embeddings""" -device = 'cpu' + device = 'cpu' model = load_pretrained(model_name).to(device) # used to learn embedding model.eval() return model @@ -594,7 +594,7 @@ def load_modules_from_checkpoint(path_to_act, path_to_rt1, path_to_rxn, path_to_ rt1_net = MLP.load_from_checkpoint(path_to_rt1, input_dim=int(3 * nbits), - output_dim=out_dim, + output_dim=int(out_dim), hidden_dim=1200, num_layers=5, dropout=0.5, @@ -624,7 +624,7 @@ def load_modules_from_checkpoint(path_to_act, path_to_rt1, path_to_rxn, path_to_ rt2_net = MLP.load_from_checkpoint(path_to_rt2, input_dim=int(4 * nbits + 91), - output_dim=out_dim, + output_dim=int(out_dim), hidden_dim=3000, num_layers=5, dropout=0.5, @@ -835,7 +835,7 @@ def synthetic_tree_decoder_rt1(z_target, synthetic tree rt1_index (int, optional): Index for molecule in the building blocks corresponding to reactant 1. - + Returns: tree (SyntheticTree): The final synthetic tree act (int): The final action (to know if the tree was "properly" @@ -850,7 +850,7 @@ def synthetic_tree_decoder_rt1(z_target, for i in range(max_step): # Encode current state state = tree.get_state() # a list - z_state = set_embedding(z_target, state, nbits=n_bits, _mol_embedding=mol_fp) + z_state = set_embedding(z_target, state, nbits=n_bits, _mol_embedding=mol_fp) # Predict action type, masked selection # Action: (Add: 0, Expand: 1, Merge: 2, End: 3) From 4efc530491cf587bd6fb3ff1d3c68ee8712bf374 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 22 Aug 2022 13:30:00 -0400 Subject: [PATCH 033/302] return `self` upon loading data in `SyntheticTreeSet` --- scripts/predict-beam-fullTree.py | 3 +-- scripts/predict.py | 3 +-- scripts/predict_mp.py | 3 +-- src/syn_net/utils/data_utils.py | 4 ++++ tests/test_DataPreparation.py | 6 ++---- tests/test_Predict.py | 3 +-- 6 files changed, 10 insertions(+), 12 deletions(-) diff --git a/scripts/predict-beam-fullTree.py b/scripts/predict-beam-fullTree.py index 7bf001c8..b25d974d 100644 --- a/scripts/predict-beam-fullTree.py +++ b/scripts/predict-beam-fullTree.py @@ -124,8 +124,7 @@ def decode_one_molecule(query_smi): # load the purchasable building blocks to decode path_to_data = f'/pool001/whgao/data/synth_net/st_{args.rxn_template}/st_{args.data}.json.gz' print('Reading data from ', path_to_data) - sts = SyntheticTreeSet() - sts.load(path_to_data) + sts = SyntheticTreeSet().load(path_to_data) query_smis = [st.root.smiles for st in sts.sts] if args.num == -1: pass diff --git a/scripts/predict.py b/scripts/predict.py index 2c40c83e..8cefe777 100644 --- a/scripts/predict.py +++ b/scripts/predict.py @@ -122,8 +122,7 @@ def decode_one_molecule(query_smi): path_to_data = '/pool001/whgao/data/synth_net/st_' + args.rxn_template + '/st_' + args.data +'.json.gz' print('Reading data from ', path_to_data) - sts = SyntheticTreeSet() - sts.load(path_to_data) + sts = SyntheticTreeSet().load(path_to_data) query_smis = [st.root.smiles for st in sts.sts] if args.num == -1: pass diff --git a/scripts/predict_mp.py b/scripts/predict_mp.py index fe521d48..17e3284c 100644 --- a/scripts/predict_mp.py +++ b/scripts/predict_mp.py @@ -27,8 +27,7 @@ # load the query molecules (i.e. molecules to decode) path_to_data = '/pool001/whgao/data/synth_net/st_' + args.rxn_template + '/st_' + args.data +'.json.gz' print('Reading data from ', path_to_data) - sts = SyntheticTreeSet() - sts.load(path_to_data) + sts = SyntheticTreeSet().load(path_to_data) smis_query = [st.root.smiles for st in sts.sts] if args.num == -1: pass diff --git a/src/syn_net/utils/data_utils.py b/src/syn_net/utils/data_utils.py index 7d055c34..2bc4a51d 100644 --- a/src/syn_net/utils/data_utils.py +++ b/src/syn_net/utils/data_utils.py @@ -764,6 +764,9 @@ def __init__(self, sts=None): else: self.sts = sts + def __len__(self): + return len(self.sts) + def load(self, json_file): """ A function that loads a JSON-formatted synthetic tree file. @@ -780,6 +783,7 @@ def load(self, json_file): else: st = SyntheticTree(st_dict) self.sts.append(st) + return self def save(self, json_file): """ diff --git a/tests/test_DataPreparation.py b/tests/test_DataPreparation.py index f3e40993..c22a1535 100644 --- a/tests/test_DataPreparation.py +++ b/tests/test_DataPreparation.py @@ -107,8 +107,7 @@ def test_synthetic_tree_prep(self): # check here that the synthetic trees were correctly saved by # comparing to a provided reference file in 'SynNet/tests/data/ref/' - sts_ref = SyntheticTreeSet() - sts_ref.load(f"{TEST_DIR}/data/ref/st_data.json.gz") + sts_ref = SyntheticTreeSet().load(f"{TEST_DIR}/data/ref/st_data.json.gz") for st_idx, st in enumerate(sts_ref.sts): st = st.__dict__ ref_st = sts_ref.sts[st_idx].__dict__ @@ -128,8 +127,7 @@ def test_featurization(self): save_dir = f"{TEST_DIR}/data/" reference_data_dir = f"{TEST_DIR}/data/ref/" - st_set = SyntheticTreeSet() - st_set.load(path_st) + st_set = SyntheticTreeSet().load(path_st) data = st_set.sts del st_set diff --git a/tests/test_Predict.py b/tests/test_Predict.py index 6ca8856a..367386a8 100644 --- a/tests/test_Predict.py +++ b/tests/test_Predict.py @@ -76,8 +76,7 @@ def test_predict(self): # load the query molecules (i.e. molecules to decode) path_to_data = f"{ref_dir}st_data.json.gz" - sts = SyntheticTreeSet() - sts.load(path_to_data) + sts = SyntheticTreeSet().load(path_to_data) smis_query = [st.root.smiles for st in sts.sts] # start to decode the query molecules (no multiprocessing for the unit tests here) From 729cd80c62ff9bee56a51bbaeeed7b12efe4947b Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 22 Aug 2022 13:31:57 -0400 Subject: [PATCH 034/302] return `self` upon loading data in `ReactionSet` --- scripts/_mp_decode.py | 3 +-- scripts/_mp_predict.py | 3 +-- scripts/_mp_predict_beam.py | 3 +-- scripts/predict-beam-fullTree.py | 3 +-- scripts/predict-beam-reactantOnly.py | 3 +-- scripts/predict.py | 3 +-- scripts/predict_multireactant_mp.py | 3 +-- src/syn_net/data_generation/filter_unmatch.py | 3 +-- src/syn_net/data_generation/make_dataset_mp.py | 3 +-- src/syn_net/utils/data_utils.py | 1 + tests/filter_unmatch_tests.py | 3 +-- tests/test_DataPreparation.py | 3 +-- tests/test_Predict.py | 3 +-- 13 files changed, 13 insertions(+), 24 deletions(-) diff --git a/scripts/_mp_decode.py b/scripts/_mp_decode.py index 2bb1029f..e5e833c2 100644 --- a/scripts/_mp_decode.py +++ b/scripts/_mp_decode.py @@ -41,8 +41,7 @@ bb_dict = {building_blocks[i]: i for i in range(len(building_blocks))} # load the reaction templates as a ReactionSet object -rxn_set = ReactionSet() -rxn_set.load(path_to_reaction_file) +rxn_set = ReactionSet().load(path_to_reaction_file) rxns = rxn_set.rxns # load the pre-trained modules diff --git a/scripts/_mp_predict.py b/scripts/_mp_predict.py index 3ef1a7f1..10884b6c 100644 --- a/scripts/_mp_predict.py +++ b/scripts/_mp_predict.py @@ -40,8 +40,7 @@ bb_dict = {building_blocks[i]: i for i in range(len(building_blocks))} # load the reaction templates as a ReactionSet object -rxn_set = ReactionSet() -rxn_set.load(path_to_reaction_file) +rxn_set = ReactionSet().load(path_to_reaction_file) rxns = rxn_set.rxns # load the pre-trained modules diff --git a/scripts/_mp_predict_beam.py b/scripts/_mp_predict_beam.py index e415079f..4beaedfe 100644 --- a/scripts/_mp_predict_beam.py +++ b/scripts/_mp_predict_beam.py @@ -42,8 +42,7 @@ bb_dict = {building_blocks[i]: i for i in range(len(building_blocks))} # load the reaction templates as a ReactionSet object -rxn_set = ReactionSet() -rxn_set.load(path_to_reaction_file) +rxn_set = ReactionSet().load(path_to_reaction_file) rxns = rxn_set.rxns # load the pre-trained modules diff --git a/scripts/predict-beam-fullTree.py b/scripts/predict-beam-fullTree.py index b25d974d..c2356d2c 100644 --- a/scripts/predict-beam-fullTree.py +++ b/scripts/predict-beam-fullTree.py @@ -73,8 +73,7 @@ bb_dict = {building_blocks[i]: i for i in range(len(building_blocks))} # load the reaction templates as a ReactionSet object - rxn_set = ReactionSet() - rxn_set.load(path_to_reaction_file) + rxn_set = ReactionSet().load(path_to_reaction_file) rxns = rxn_set.rxns # load the pre-trained modules diff --git a/scripts/predict-beam-reactantOnly.py b/scripts/predict-beam-reactantOnly.py index daf9b3c0..a3dfa78c 100644 --- a/scripts/predict-beam-reactantOnly.py +++ b/scripts/predict-beam-reactantOnly.py @@ -75,8 +75,7 @@ bb_dict = {building_blocks[i]: i for i in range(len(building_blocks))} # load the reaction templates as a ReactionSet object - rxn_set = ReactionSet() - rxn_set.load(path_to_reaction_file) + rxn_set = ReactionSet().load(path_to_reaction_file) rxns = rxn_set.rxns # load the pre-trained modules diff --git a/scripts/predict.py b/scripts/predict.py index 8cefe777..e2885f97 100644 --- a/scripts/predict.py +++ b/scripts/predict.py @@ -72,8 +72,7 @@ bb_dict = {building_blocks[i]: i for i in range(len(building_blocks))} # load the reaction templates as a ReactionSet object - rxn_set = ReactionSet() - rxn_set.load(path_to_reaction_file) + rxn_set = ReactionSet().load(path_to_reaction_file) rxns = rxn_set.rxns # load the pre-trained modules diff --git a/scripts/predict_multireactant_mp.py b/scripts/predict_multireactant_mp.py index 1df15d2a..b724cf1b 100644 --- a/scripts/predict_multireactant_mp.py +++ b/scripts/predict_multireactant_mp.py @@ -35,8 +35,7 @@ def _fetch_data(name: str) -> list[str]: def _fetch_reaction_templates(file: str): # Load reaction templates - rxn_set = ReactionSet() - rxn_set.load(file) + rxn_set = ReactionSet().load(file) return rxn_set.rxns def _fetch_building_blocks_embeddings(file: str): diff --git a/src/syn_net/data_generation/filter_unmatch.py b/src/syn_net/data_generation/filter_unmatch.py index 3dd2b939..c8272589 100644 --- a/src/syn_net/data_generation/filter_unmatch.py +++ b/src/syn_net/data_generation/filter_unmatch.py @@ -23,8 +23,7 @@ # Load genearted reactions (matched reactions <=> building blocks) reactions_dir = Path(DATA_PREPROCESS_DIR) reactions_file = f"reaction-sets_{reaction_template_id}_{building_blocks_id}.json.gz" - r_set = ReactionSet() - r_set.load(reactions_dir / reactions_file) + r_set = ReactionSet().load(reactions_dir / reactions_file) # Identify all used building blocks (via union of sets) matched_bblocks = set() diff --git a/src/syn_net/data_generation/make_dataset_mp.py b/src/syn_net/data_generation/make_dataset_mp.py index 3c84737d..92cf3571 100644 --- a/src/syn_net/data_generation/make_dataset_mp.py +++ b/src/syn_net/data_generation/make_dataset_mp.py @@ -36,8 +36,7 @@ def func(_x): # Load genearted reactions (matched reactions <=> building blocks) reactions_dir = Path(DATA_PREPROCESS_DIR) reactions_file = f"reaction-sets_{reaction_template_id}_{building_blocks_id}.json.gz" - r_set = ReactionSet() - r_set.load(reactions_dir / reactions_file) + r_set = ReactionSet().load(reactions_dir / reactions_file) rxns = r_set.rxns # Generate synthetic trees diff --git a/src/syn_net/utils/data_utils.py b/src/syn_net/utils/data_utils.py index 2bc4a51d..8467eb55 100644 --- a/src/syn_net/utils/data_utils.py +++ b/src/syn_net/utils/data_utils.py @@ -396,6 +396,7 @@ def load(self, json_file): r = Reaction() r.load(**r_dict) self.rxns.append(r) + return self def save(self, json_file): """ diff --git a/tests/filter_unmatch_tests.py b/tests/filter_unmatch_tests.py index d4c6fe67..e77f18db 100644 --- a/tests/filter_unmatch_tests.py +++ b/tests/filter_unmatch_tests.py @@ -10,8 +10,7 @@ if __name__ == '__main__': r_path = './data/ref/rxns_hb.json.gz' bb_path = '/home/whgao/scGen/synth_net/data/enamine_us.csv.gz' - r_set = ReactionSet() - r_set.load(r_path) + r_set = ReactionSet().load(r_path) matched_mols = set() for r in tqdm(r_set.rxns): for a_list in r.available_reactants: diff --git a/tests/test_DataPreparation.py b/tests/test_DataPreparation.py index c22a1535..bdfd92b9 100644 --- a/tests/test_DataPreparation.py +++ b/tests/test_DataPreparation.py @@ -53,8 +53,7 @@ def test_process_rxn_templates(self): # load the reference reaction templates path_to_ref_rxn_templates = f"{TEST_DIR}/data/ref/rxns_hb.json.gz" - r_ref = ReactionSet() - r_ref.load(path_to_ref_rxn_templates) + r_ref = ReactionSet().load(path_to_ref_rxn_templates) # check here that the templates were correctly saved as a ReactionSet by # comparing to a provided reference file in 'SynNet/tests/data/ref/' diff --git a/tests/test_Predict.py b/tests/test_Predict.py index 367386a8..ada21154 100644 --- a/tests/test_Predict.py +++ b/tests/test_Predict.py @@ -57,8 +57,7 @@ def test_predict(self): bb_dict = {building_blocks[i]: i for i in range(len(building_blocks))} # load the reaction templates as a ReactionSet object - rxn_set = ReactionSet() - rxn_set.load(path_to_reaction_file) + rxn_set = ReactionSet().load(path_to_reaction_file) rxns = rxn_set.rxns # load the pre-trained modules From b0a77ef6cb575b2c3a34a100de22a3454e18bb93 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 22 Aug 2022 13:37:40 -0400 Subject: [PATCH 035/302] fixes unit test: - change hard coded file name to reflect change in 7254d7d - fix indentation that would prematurely end for loop and skip valid + test data --- src/syn_net/utils/prep_utils.py | 7 +++++-- .../ref/{states_0_train.npz => states_train.npz} | Bin .../data/ref/{steps_0_train.npz => steps_train.npz} | Bin tests/test_DataPreparation.py | 6 +++--- 4 files changed, 8 insertions(+), 5 deletions(-) rename tests/data/ref/{states_0_train.npz => states_train.npz} (100%) rename tests/data/ref/{steps_0_train.npz => steps_train.npz} (100%) diff --git a/src/syn_net/utils/prep_utils.py b/src/syn_net/utils/prep_utils.py index f7e6bca3..978d3f2e 100644 --- a/src/syn_net/utils/prep_utils.py +++ b/src/syn_net/utils/prep_utils.py @@ -240,7 +240,7 @@ def synthetic_tree_generator(building_blocks, reaction_templates, max_step=15): return tree, action -def prep_data(main_dir, num_rxn, out_dim): +def prep_data(main_dir, num_rxn, out_dim, datasets=None): """ Loads the states and steps from preprocessed *.npz files and saves data specific to the Action, Reactant 1, Reaction, and Reactant 2 networks in @@ -251,8 +251,11 @@ def prep_data(main_dir, num_rxn, out_dim): num_rxn (int): Number of reactions in the dataset. out_dim (int): Size of the output feature vectors. """ + if datasets is None: + datasets = ['train', 'valid', 'test'] main_dir = Path(main_dir) - for dataset in ['train', 'valid', 'test']: + + for dataset in datasets: print(f'Reading {dataset} data ...') states_list = [] diff --git a/tests/data/ref/states_0_train.npz b/tests/data/ref/states_train.npz similarity index 100% rename from tests/data/ref/states_0_train.npz rename to tests/data/ref/states_train.npz diff --git a/tests/data/ref/steps_0_train.npz b/tests/data/ref/steps_train.npz similarity index 100% rename from tests/data/ref/steps_0_train.npz rename to tests/data/ref/steps_train.npz diff --git a/tests/test_DataPreparation.py b/tests/test_DataPreparation.py index bdfd92b9..f58eeafb 100644 --- a/tests/test_DataPreparation.py +++ b/tests/test_DataPreparation.py @@ -183,8 +183,8 @@ def test_dataprep(self): main_dir = f"{TEST_DIR}/data/" ref_dir = f"{TEST_DIR}/data/ref/" # copy data from the reference directory to use for this particular test - copyfile(f"{ref_dir}states_0_train.npz", f"{main_dir}states_0_train.npz") - copyfile(f"{ref_dir}steps_0_train.npz", f"{main_dir}steps_0_train.npz") + copyfile(f"{ref_dir}states_train.npz", f"{main_dir}states_train.npz") + copyfile(f"{ref_dir}steps_train.npz", f"{main_dir}steps_train.npz") # the lines below will save Action-, Reactant 1-, Reaction-, and Reactant 2- # specific files directly to the 'SynNet/tests/data/' directory (e.g. @@ -192,7 +192,7 @@ def test_dataprep(self): # 'X_rt1_{train/test/valid}.npz' and 'y_rt1_{train/test/valid}.npz' # 'X_rxn_{train/test/valid}.npz' and 'y_rxn_{train/test/valid}.npz' # 'X_rt2_{train/test/valid}.npz' and 'y_rt2_{train/test/valid}.npz' - prep_data(main_dir=main_dir, num_rxn=3, out_dim=300) + prep_data(main_dir=main_dir, num_rxn=3, out_dim=300,datasets=["train"]) # check that the saved files match the reference files in # 'SynNet/tests/data/ref': From 115ebbe9998fc7f2fea0351abdfe8bf8c3694881 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 22 Aug 2022 14:37:52 -0400 Subject: [PATCH 036/302] fix unittest `test_reactant1_network` & `test_reactant2_network` - loads pre-computed building-block embeddings during validation step --- src/syn_net/models/mlp.py | 44 ++++++++++++++++++++++++++------------- tests/test_Training.py | 4 ++-- 2 files changed, 32 insertions(+), 16 deletions(-) diff --git a/src/syn_net/models/mlp.py b/src/syn_net/models/mlp.py index 600b9c4c..2ad4d52e 100644 --- a/src/syn_net/models/mlp.py +++ b/src/syn_net/models/mlp.py @@ -72,21 +72,35 @@ def training_step(self, batch, batch_idx): self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) return loss + def _load_building_blocks_kdtree(self, out_feat: str) -> np.ndarray: + """Helper function to load the pre-computed building block embeddings + as a BallTree. + + TODO: Remove hard-coded paths. + """ + if out_feat == 'gin': + bb_emb_gin = np.load('/pool001/whgao/data/synth_net/st_hb/enamine_us_emb_gin.npy') + kdtree = BallTree(bb_emb_gin, metric='euclidean') + elif out_feat == 'fp_4096': + bb_emb_fp_4096 = np.load('/pool001/whgao/data/synth_net/st_hb/enamine_us_emb_fp_4096.npy') + kdtree = BallTree(bb_emb_fp_4096, metric='euclidean') + elif out_feat == 'fp_256': + bb_emb_fp_256 = np.load('/pool001/whgao/data/synth_net/st_hb/enamine_us_emb_fp_256.npy') + kdtree = BallTree(bb_emb_fp_256, metric=cosine_distance) + elif out_feat == 'rdkit2d': + bb_emb_rdkit2d = np.load('/pool001/whgao/data/synth_net/st_hb/enamine_us_emb_rdkit2d.npy') + kdtree = BallTree(bb_emb_rdkit2d, metric='euclidean') + elif out_feat == "gin_unittest": + # The embeddings are pre-computed based on the building blocks + # under 'tests/assets/building_blocks_matched.csv.gz'. + emb = np.load("tests/data/building_blocks_emb.npy") + kdtree = BallTree(emb,metric="euclidean") + else: + raise ValueError + return kdtree + def validation_step(self, batch, batch_idx): if self.trainer.current_epoch % self.val_freq == 0: - out_feat = self.valid_loss[12:] - if out_feat == 'gin': - bb_emb_gin = np.load('/pool001/whgao/data/synth_net/st_hb/enamine_us_emb_gin.npy') - kdtree = BallTree(bb_emb_gin, metric='euclidean') - elif out_feat == 'fp_4096': - bb_emb_fp_4096 = np.load('/pool001/whgao/data/synth_net/st_hb/enamine_us_emb_fp_4096.npy') - kdtree = BallTree(bb_emb_fp_4096, metric='euclidean') - elif out_feat == 'fp_256': - bb_emb_fp_256 = np.load('/pool001/whgao/data/synth_net/st_hb/enamine_us_emb_fp_256.npy') - kdtree = BallTree(bb_emb_fp_256, metric=cosine_distance) - elif out_feat == 'rdkit2d': - bb_emb_rdkit2d = np.load('/pool001/whgao/data/synth_net/st_hb/enamine_us_emb_rdkit2d.npy') - kdtree = BallTree(bb_emb_rdkit2d, metric='euclidean') x, y = batch y_hat = self.layers(x) if self.valid_loss == 'cross_entropy': @@ -95,6 +109,8 @@ def validation_step(self, batch, batch_idx): y_hat = torch.argmax(y_hat, axis=1) loss = 1 - (sum(y_hat == y) / len(y)) elif self.valid_loss[:11] == 'nn_accuracy': + out_feat = self.valid_loss[12:] + kdtree = self._load_building_blocks_kdtree(out_feat) y = nn_search_list(y.detach().cpu().numpy(), out_feat=out_feat, kdtree=kdtree) y_hat = nn_search_list(y_hat.detach().cpu().numpy(), out_feat=out_feat, kdtree=kdtree) loss = 1 - (sum(y_hat == y) / len(y)) @@ -130,7 +146,7 @@ def nn_search(_e, _tree, _k=1): return ind[0][0] def nn_search_list(y, out_feat, kdtree): - return np.array([nn_search(emb.reshape(1, -1), _tree=kdtree) for emb in y]) + return np.array([nn_search(emb.reshape(1, -1), _tree=kdtree) for emb in y]) if __name__ == '__main__': diff --git a/tests/test_Training.py b/tests/test_Training.py index f2e3b220..2db0b90b 100644 --- a/tests/test_Training.py +++ b/tests/test_Training.py @@ -102,7 +102,7 @@ def test_reactant1_network(self): batch_size = 10 epochs = 2 ncpu = 2 - validation_option = "nn_accuracy_gin" + validation_option = "nn_accuracy_gin_unittest" ref_dir = f"{TEST_DIR}/data/ref/" # load the reaction data @@ -220,7 +220,7 @@ def test_reactant2_network(self): epochs = 2 ncpu = 2 n_templates = 3 # num templates in 'data/rxn_set_hb_test.txt' - validation_option = "nn_accuracy_gin" + validation_option = "nn_accuracy_gin_unittest" ref_dir = f"{TEST_DIR}/data/ref/" X = sparse.load_npz(ref_dir + "X_rt2_train.npz") From 09535c5f6c1607930ca1614f860c6caa28c33c1c Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 22 Aug 2022 14:38:11 -0400 Subject: [PATCH 037/302] fix: change hardcoded paths in `test_featurization` due to change in 7254d7d --- tests/test_DataPreparation.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/test_DataPreparation.py b/tests/test_DataPreparation.py index f58eeafb..64ac7461 100644 --- a/tests/test_DataPreparation.py +++ b/tests/test_DataPreparation.py @@ -133,7 +133,6 @@ def test_featurization(self): states = [] steps = [] - save_idx = 0 for st in tqdm(data): try: state, step = organize( @@ -155,15 +154,15 @@ def test_featurization(self): if not os.path.exists(save_dir): os.makedirs(save_dir) - sparse.save_npz(f"{save_dir}states_{save_idx}_{dataset_type}.npz", states) - sparse.save_npz(f"{save_dir}steps_{save_idx}_{dataset_type}.npz", steps) + sparse.save_npz(f"{save_dir}states_{dataset_type}.npz", states) + sparse.save_npz(f"{save_dir}steps_{dataset_type}.npz", steps) # load the reference data, which we will compare against states_ref = sparse.load_npz( - f"{reference_data_dir}states_{save_idx}_{dataset_type}.npz" + f"{reference_data_dir}states_{dataset_type}.npz" ) steps_ref = sparse.load_npz( - f"{reference_data_dir}steps_{save_idx}_{dataset_type}.npz" + f"{reference_data_dir}steps_{dataset_type}.npz" ) # check here that states and steps were correctly saved (need to convert the From 7e09572379f2f02237140525eff26c1bc2c0024b Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 22 Aug 2022 14:51:01 -0400 Subject: [PATCH 038/302] delete empty file (see 3f44da4) --- scripts/_mp_predict_multireactant.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 scripts/_mp_predict_multireactant.py diff --git a/scripts/_mp_predict_multireactant.py b/scripts/_mp_predict_multireactant.py deleted file mode 100644 index e69de29b..00000000 From 7eeab167546b60288d027c85879c30f43a33f1fa Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 22 Aug 2022 14:52:50 -0400 Subject: [PATCH 039/302] delete unused code --- scripts/_mp_sum.py | 8 -------- 1 file changed, 8 deletions(-) delete mode 100644 scripts/_mp_sum.py diff --git a/scripts/_mp_sum.py b/scripts/_mp_sum.py deleted file mode 100644 index 9cb0591f..00000000 --- a/scripts/_mp_sum.py +++ /dev/null @@ -1,8 +0,0 @@ -""" -Computes the sum of a single molecular embedding. -""" -import numpy as np - - -def func(emb): - return np.sum(emb) From cb8411bc487f9f5e460dada81fa5b5f82771ff9a Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 24 Aug 2022 10:51:13 -0400 Subject: [PATCH 040/302] rewrite `README.md` and add `INSTRUCTIONS.md` - the detailed step-by-step instructions to train from scratch are move to `INSTRUCTIONS.md` --- INSTRUCTIONS.md | 113 ++++++++++++++++++++ README.md | 270 +++++++++++------------------------------------- 2 files changed, 173 insertions(+), 210 deletions(-) create mode 100644 INSTRUCTIONS.md diff --git a/INSTRUCTIONS.md b/INSTRUCTIONS.md new file mode 100644 index 00000000..45d9a4c6 --- /dev/null +++ b/INSTRUCTIONS.md @@ -0,0 +1,113 @@ +# Instructions + +This documents outlines the process to train SynNet from scratch step-by-step. + +> :warning: It is still a WIP to match the filenames of the scripts to the instructions here and to simplify the dependency on parameters/filenames. + +You can use any set of reaction templates and building blocks, but we will illustrate the process with the *Hartenfeller-Button* reaction templates and *Enamine building blocks*. + +*Note*: This project depends on a lot of exact filenames. +For example, one script will save to file, the next will read that file for further processing. +It is not a perfect approach - we are open to feedback - and advise to revise the parameters defined in each script. + +Let's start. + +## Step-by-Step + +0. Prepare reaction templates and building blocks. + + Extract SMILES from the `.sdf` file from enamine.net. + + ```shell + python scripts/00-extract-smiles-from-sdf.py --file="data/assets/building-blocks/enamine-us.sdf" + ``` + +1. Filter building blocks. + + We proprocess the building blocks to identify applicable reactants for each reaction template. + In other words, filter out all building blocks that do not match any reaction template. + There is no need to keep them, as they cannot act as reactant. + In a first step, we match all building blocks with each reaction template. + In a second step, we save a set of all matched building blocks. + + ```bash + # Match + python scripts/01-process_rxn.py + # Filter + python scripts/02-filter-unmatched.py + ``` + + > :bulb: All following steps use this matched building blocks <-> reaction template data. As of now, you still have to specify these parameters again for every script to that it can load the right data. + +2. Generate *synthetic trees* + + Herein we generate the data used for training the networks. + The data is generated by randomly selecting building blocks, reaction templates and directives to grow a synthetic tree. + In a second step, we filter out some synthetic trees to make the data pharmaceutically more interesting. + That is, we filter out trees, whose root node molecule has a QED < 0.5, or randomly with a probability less than 1 - QED/0.5. + + ```bash + # Generate synthetic trees + python scripts/03-make_dataset_mp.py + # Filter + python scripts/04-sample_from_original.py + ``` + + Each *synthetic tree* is serializable and so we save all trees in a compressed `.json` file. + +3. Split *synthetic trees* into train,valid,test-data + + We load the `.json`-file with all *synthetic trees* and + straightforward split it into three files: `{train,test,valid}.json`. + The default split ratio is 6:2:2. + + ```bash + python scripts/05-st_split.py + ``` + +4. Featurization and + + > :bulb: All following steps depend on the representations for the data. Hence, you have to specify the parameters for the reprensations as input argument for most of the scripts so that it can operate on the right data. + + We organize each *synthetic tree* into states and actions. + That is, we break down each tree to the action at each iteration ("Add", "Expand", "Extend", "End") and a corresponding "super state" vector. + We call it "super state" here, as it contains all states for all networks. + However, recall that the input that the state vector is different for each network. + + ```bash + python scripts/06-st2steps.py + ``` + +5. Split features + + Up to this point, we worked with a (featurized) *synthetic tree* as a whole, + now we split it up to into "consumable" input data for each of the four networks. + This includes picking the right state(s) from the "super state" vector from the previous step. + + ```bash + python scripts/07-prepare_data.py + ``` + +6. Train the networks + + Finally, we can train each of the four networks in `src/syn_net/models/` separately: + + ```bash + python src/syn_net/models/act.py + ``` + +After training a new model, you can then use the trained model to make predictions and construct synthetic trees for a list given set of molecules. + +You can also perform molecular optimization using a genetic algorithm. + +Please refer to the [README.md](./README.md) for inference instructions. + +## Auxiallary Scripts + +### Visualizing trees + +To be added. + +### Mean reciprocal rank + +To be added. diff --git a/README.md b/README.md index bf0cb1ee..604385fa 100644 --- a/README.md +++ b/README.md @@ -6,15 +6,26 @@ The method is described in detail in the publication "Amortized tree generation ## Summary -### Overview +We model synthetic pathways as tree structures called *synthetic trees*. +A synthetic tree has a single root node and one or more child nodes. +Every node is chemical molecule: + +- The root node is the final product molecule +- The leaf nodes consist of purchasable building blocks. +- All other inner nodes are constrained to be a product of allowed chemical reactions. + +At a high level, each synthetic tree is constructed one reaction step at a time in a bottom-up manner, that is starting from purchasable building blocks. -We model synthetic pathways as tree structures called *synthetic trees*. A valid synthetic tree has one root node (the final product molecule) linked to purchasable building blocks (encoded as SMILES strings) via feasible reactions according to a list of discrete reaction templates (examples of templates encoded as SMARTS strings in [data/rxn_set_hb.txt](./data/rxn_set_hb.txt)). At a high level, each synthetic tree is constructed one reaction step at a time in a bottom-up manner, starting from purchasable building blocks. +### Overview The model consists of four modules, each containing a multi-layer perceptron (MLP): -1. An *Action Type* selection function that classifies action types among the four possible actions (“Add”, “Expand”, “Merge”, and “End”) in building the synthetic tree. -2. A *First Reactant* selection function that predicts an embedding for the first reactant. A candidate molecule is identified for the first reactant through a k-nearest neighbors (k-NN) search from the list of potential building blocks. -3. A *Reaction* selection function whose output is a probability distribution over available reaction templates, from which inapplicable reactions are masked (based on reactant 1) and a suitable template is then sampled using a greedy search. +1. An *Action Type* selection function that classifies action types among the four possible actions (“Add”, “Expand”, “Merge”, and “End”) in building the synthetic tree. Each action increases the depth of the synthetic tree by one. + +2. A *First Reactant* selection function that selects the first reactant. A MLP predicts a molecular embedding and a first reactant is identified from the pool of building blocks through a k-nearest neighbors (k-NN) search. + +3. A *Reaction* selection function that select reaction. The whose output is a probability distribution over available reaction templates, from which inapplicable reactions are masked (based on reactant 1) and a suitable template is then sampled using a greedy search. + 4. A *Second Reactant* selection function that identifies the second reactant if the sampled template is bi-molecular. The model predicts an embedding for the second reactant, and a candidate is then sampled via a k-NN search from the masked set of building blocks. ![the model](./figures/network.png "model scheme") @@ -39,20 +50,14 @@ To do this, we optimize the molecular embedding of the molecule using a genetic ### Setting up the environment -You can use conda to create an environment containing the necessary packages and dependencies for running SynNet by using the provided YAML file: +Conda is used to create the environment for running SynNet. -``` +```shell +# Install environment from file conda env create -f environment.yml ``` -If you update the environment and would like to save the updated environment as a new YAML file using conda, use: - -``` -conda env export > path/to/env.yml -``` - -pip install -e . -Before running any SynNet code, activate the environment and install the package in development mode. This ensures the scripts can find the right files. You can do this by typing: +Before running any SynNet code, activate the environment and install this package in development mode. This ensures the scripts can find the right files. You can do this by typing: ```shell source activate synthenv @@ -67,235 +72,80 @@ To check that everything has been set-up correctly, you can run the unit tests. python -m unittest ``` -Except for `tests/test_Training.py`, all tests should succedd. The `test_Training.py` still relies on the embedding of the building blocks, which is tracked in this repostory. - ### Data -#### Templates +SyNNet relies on two datasources: -The Hartenfeller-Button templates are available in the [./data/](./data/) directory. +1. reaction templates and +2. building blocks. -#### Building blocks +The data used for the publication are 1) the *Hartenfeller-Button* reaction templates, which are available under [data/assets/reaction-templates/hb.txt](data/assets/reaction-templates/hb.txt) and 2) *Enamine building blocks*. +The building blocks are not freely available. -The Enamine data can be freely downloaded from for academic purposes. After downloading the Enamine building blocks, you will need to replace the paths to the Enamine building blocks in the code. This can be done by searching for the string "enamine". +To obtain the data, go to [https://enamine.net/building-blocks/building-blocks-catalog](https://enamine.net/building-blocks/building-blocks-catalog). +We used the "Building Blocks, US Stock" data. You need to first register and then request access to download the dataset. The people from enamine.net manually approve you, so please be nice and patient. ## Code Structure -The code is structured as follows: - -``` -SynNet/ -├── data -│ └── rxn_set_hb.txt -├── environment.yml -├── LICENSE -├── README.md -├── scripts -│ ├── compute_embedding_mp.py -│ ├── compute_embedding.py -│ ├── generation_fp.py -│ ├── generation.py -│ ├── gin_supervised_contextpred_pre_trained.pth -│ ├── _mp_decode.py -│ ├── _mp_predict_beam.py -│ ├── _mp_predict_multireactant.py -│ ├── _mp_predict.py -│ ├── _mp_search_similar.py -│ ├── _mp_sum.py -│ ├── mrr.py -│ ├── optimize_ga.py -│ ├── predict-beam-fullTree.py -│ ├── predict_beam_mp.py -│ ├── predict-beam-reactantOnly.py -│ ├── predict_mp.py -│ ├── predict_multireactant_mp.py -│ ├── predict.py -│ ├── read_st_data.py -│ ├── sample_from_original.py -│ ├── search_similar.py -│ ├── sketch-synthetic-trees.py -│ ├── st2steps.py -│ ├── st_split.py -│ └── temp.py -├── syn_net -│ ├── data_generation -│ │ ├── check_all_template.py -│ │ ├── filter_unmatch.py -│ │ ├── __init__.py -│ │ ├── make_dataset_mp.py -│ │ ├── make_dataset.py -│ │ ├── _mp_make.py -│ │ ├── _mp_process.py -│ │ └── process_rxn_mp.py -│ ├── __init__.py -│ ├── models -│ │ ├── act.py -│ │ ├── mlp.py -│ │ ├── prepare_data.py -│ │ ├── rt1.py -│ │ ├── rt2.py -│ │ └── rxn.py -│ └── utils -│ ├── data_utils.py -│ ├── ga_utils.py -│ ├── predict_beam_utils.py -│ ├── predict_utils.py -│ └── __init__.py -└── tests - ├── create-unittest-data.py - └── test_DataPreparation.py -``` +The model implementations can be found in [src/syn_net/models/](src/syn_net/models/). +The pre-processing and analysis scripts are in [scripts/](scripts/). -The model implementations can be found in [syn_net/models/](syn_net/models/), with processing and analysis scripts located in [scripts/](./scripts/). - -## Instructions +## Reproducing results Before running anything, set up the environment as decribed above. -## Using pre-trained models +### Using pre-trained models -We have made available a set of pre-trained models at the following [link](https://figshare.com/articles/software/Trained_model_parameters_for_SynNet/16799413). The pretrained models correspond to the Action, Reactant 1, Reaction, and Reactant 2 networks, trained on the Hartenfeller-Button dataset using radius 2, length 4096 Morgan fingerprints for the molecular node embeddings, and length 256 fingerprints for the k-NN search. For further details, please see the publication. +We have made available a set of pre-trained models at the following [link](https://figshare.com/articles/software/Trained_model_parameters_for_SynNet/16799413). +The pretrained models correspond to the Action, Reactant 1, Reaction, and Reactant 2 networks, trained on the *Hartenfeller-Button* dataset and *Enamine* building blocks using radius 2, length 4096 Morgan fingerprints for the molecular node embeddings, and length 256 fingerprints for the k-NN search. +For further details, please see the publication. -The models can be uncompressed with: +To download the pre-trained model to `./pre-trained-model`: -``` -tar -zxvf hb_fp_2_4096_256.tar.gz +```shell +mkdir pre-trained-model && cd pre-trained-model +# Download +wget -O hb_fp_2_4096_256.tar.gz https://figshare.com/ndownloader/files/31067692 +# Extract +tar -vxf hb_fp_2_4096_256.tar.gz ``` -### Synthesis Planning +The following scripts are run from the command line. +Use `python some_script.py --help` or check the source code to see the instructions of each argument. + +#### Synthesis Planning To perform synthesis planning described in the main text: +```shell +python scripts/predict_multireactant_mp.py -n -1 --ncpu 10 --data "data/assets/molecules/sample-targets.txt" ``` -python predict_multireactant_mp.py -n -1 --ncpu 36 --data test -``` -This script will feed a list of molecules from the test data and save the decoded results (predicted synthesis trees) to [./results/](./results/). -One can use --help to see the instruction of each argument. -Note: this file reads parameters from a directory, please specify the path to parameters previously. +This script will feed a list of ten randomly selected molecules from the validation to SynNet. +The decoded results, i.e. the predicted synthesis trees, are saved to `DATA_RESULT_DIR`. +(Paths are defined in [src/syn_net/config.py](src/syn_net/config.py).) + +*Note*: To do synthesis planning, you will need a list of target molecules, building blocks and compute their embedding. As mentioned, we cannot share the building blocks from enamine.net and you will have to request access yourselfs. -### Synthesizable Molecular Design +#### Synthesizable Molecular Design To perform synthesizable molecular design, under [./scripts/](./scripts/), run: -``` -optimize_ga.py -i path/to/zinc.csv --radius 2 --nbits 4096 --num_population 128 --num_offspring 512 --num_gen 200 --ncpu 32 --objective gsk +```shell +python scripts/optimize_ga.py -i path/to/zinc.csv --radius 2 --nbits 4096 --num_population 128 --num_offspring 512 --num_gen 200 --ncpu 32 --objective gsk ``` This script uses a genetic algorithm to optimize molecular embeddings and returns the predicted synthetic trees for the optimized molecular embedding. -One can use --help to see the instruction of each argument. + If user wants to start from a checkpoint of previous run, run: -``` -optimize_ga.py -i path/to/population.npy --radius 2 --nbits 4096 --num_population 128 --num_offspring 512 --num_gen 200 --ncpu 32 --objective gsk --restart +```shell +python scripts/optimize_ga.py -i path/to/population.npy --radius 2 --nbits 4096 --num_population 128 --num_offspring 512 --num_gen 200 --ncpu 32 --objective gsk --restart ``` Note: the input file indicated by -i contains the seed molecules in CSV format for an initial run, and as a pre-saved numpy array of the population for restarting the run. ### Train the model from scratch -Before training any models, you will first need to preprocess the set of reaction templates which you would like to use. You can use either a new set of reaction templates, or the provided Hartenfeller-Button (HB) set of reaction templates (see [data/rxn_set_hb.txt](data/rxn_set_hb.txt)). To preprocess a new dataset, you will need to: - -1. Preprocess the data to identify applicable reactants for each reaction template -2. Generate the synthetic trees by random selection -3. Split the synthetic trees into training, testing, and validation splits -4. Featurize the nodes in the synthetic trees using molecular fingerprints -5. Prepare the training data for each of the four networks - -Once you have preprocessed a training set, you can begin to train a model by training each of the four networks separately (the *Action*, *First Reactant*, *Reaction*, and *Second Reactant* networks). - -After training a new model, you can then use the trained model to make predictions and construct synthetic trees for a list given set of molecules. - -You can also perform molecular optimization using a genetic algorithm. - -Instructions for all of the aforementioned steps are described in detail below. - -In addition to the aforementioned types of jobs, we have also provide below instructions for (1) sketching synthetic trees and (2) calculating the mean reciprocal rank of reactant 1. - -### Processing the data: reaction templates and applicable reactants - -Given a set of reaction templates and a list of buyable building blocks, we first need to assign applicable reactants for each template. Under [./syn_net/data_generation/](./syn_net/data_generation/), run: - -``` -python process_rxn_mp.py -``` - -This will save the reaction templates and their corresponding building blocks in a JSON file. Then, run: - -``` -python filter_unmatch.py -``` - -This will filter out buyable building blocks which didn't match a single template. - -### Generating the synthetic path data by random selection - -Under [./syn_net/data_generation/](./syn_net/data_generation/), run: - -``` -python make_dataset_mp.py -``` - -This will generate synthetic path data saved in a JSON file. Then, to make the dataset more pharmaceutically revelant, we can change to [./scripts/](./scripts/) and run: - -``` -python sample_from_original.py -``` - -This will filter out the samples where the root node QED is less than 0.5, or randomly with a probability less than 1 - QED/0.5. - -### Splitting data into training, validation, and testing sets, and removing duplicates - -Under [./scripts/](./scripts/), run: - -``` -python st_split.py -``` - -The default split ratio is 6:2:2 for training, validation, and testing sets. - -### Featurizing data - -Under [./scripts/](./scripts/), run: - -``` -python st2steps.py -r 2 -b 4096 -d train -``` - -This will featurize the synthetic tree data into step-by-step data which can be used for training. The flag *-r* indicates the fingerprint radius, *-b* indicates the number of bits to use for the fingerprints, and *-d* indicates which dataset split to featurize. - -### Preparing training data for each network - -Under [./syn_net/models/](./syn_net/models/), run: - -``` -python prepare_data.py --radius 2 --nbits 4096 -``` - -This will prepare the training data for the networks. - -Each is a training script and can be used as follows (using the action network as an example): - -``` -python act.py --radius 2 --nbits 4096 -``` - -This will train the network and save the model parameters at the state with the best validation loss in a logging directory, e.g., **`act_hb_fp_2_4096_logs`**. One can use tensorboard to monitor the training and validation loss. - -### Sketching synthetic trees - -To visualize the synthetic trees, run: - -``` -python scripts/sketch-synthetic-trees.py --file /path/to/st_hb/st_train.json.gz --saveto ./ --nsketches 5 --actions 3 -``` - -This will sketch 5 synthetic trees with 3 or more actions to the current ("./") directory (you can play around with these variables or just also leave them out to use the defaults). - -### Testing the mean reciprocal rank (MRR) of reactant 1 - -Under [./scripts/](./scripts/), run: - -``` -python mrr.py --distance cosine -``` +Before training any models, you will first need to some data preprocessing. +Please see [INSTRUCTIONS.md](INSTRUCTIONS.md) for a complete guide. From fd294e3d2b658880ec797d3c2e9690156952b2c5 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 24 Aug 2022 10:52:43 -0400 Subject: [PATCH 041/302] format `README.md` --- README.md | 50 ++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 36 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 604385fa..b0563a21 100644 --- a/README.md +++ b/README.md @@ -52,14 +52,14 @@ To do this, we optimize the molecular embedding of the molecule using a genetic Conda is used to create the environment for running SynNet. -```shell +```bash # Install environment from file conda env create -f environment.yml ``` Before running any SynNet code, activate the environment and install this package in development mode. This ensures the scripts can find the right files. You can do this by typing: -```shell +```bash source activate synthenv pip install -e . ``` @@ -74,7 +74,7 @@ python -m unittest ### Data -SyNNet relies on two datasources: +SynNet relies on two datasources: 1. reaction templates and 2. building blocks. @@ -100,10 +100,10 @@ We have made available a set of pre-trained models at the following [link](https The pretrained models correspond to the Action, Reactant 1, Reaction, and Reactant 2 networks, trained on the *Hartenfeller-Button* dataset and *Enamine* building blocks using radius 2, length 4096 Morgan fingerprints for the molecular node embeddings, and length 256 fingerprints for the k-NN search. For further details, please see the publication. -To download the pre-trained model to `./pre-trained-model`: +To download the pre-trained model to `./checkpoints`: -```shell -mkdir pre-trained-model && cd pre-trained-model +```bash +mkdir -p checkpoints && cd checkpoints # Download wget -O hb_fp_2_4096_256.tar.gz https://figshare.com/ndownloader/files/31067692 # Extract @@ -113,12 +113,26 @@ tar -vxf hb_fp_2_4096_256.tar.gz The following scripts are run from the command line. Use `python some_script.py --help` or check the source code to see the instructions of each argument. +### Prerequisites + +In addition to the necessary data, see [Data](#data), we pre-compute an embedding of the building blocks. Please double-check the filename of your building blocks. + +```bash +python scripts/compute_embedding_mp.py \ + --feature "fp_256" \ + --rxn-template "hb" \ + --ncpu 10 +``` + #### Synthesis Planning To perform synthesis planning described in the main text: -```shell -python scripts/predict_multireactant_mp.py -n -1 --ncpu 10 --data "data/assets/molecules/sample-targets.txt" +```bash +python scripts/predict_multireactant_mp.py \ + -n -1 \ + --data "data/assets/molecules/sample-targets.txt" \ + --ncpu 10 ``` This script will feed a list of ten randomly selected molecules from the validation to SynNet. @@ -129,21 +143,29 @@ The decoded results, i.e. the predicted synthesis trees, are saved to `DATA_RESU #### Synthesizable Molecular Design -To perform synthesizable molecular design, under [./scripts/](./scripts/), run: +To perform synthesizable molecular design, run: -```shell -python scripts/optimize_ga.py -i path/to/zinc.csv --radius 2 --nbits 4096 --num_population 128 --num_offspring 512 --num_gen 200 --ncpu 32 --objective gsk +```bash +python scripts/optimize_ga.py \ + -i path/to/zinc.csv \ + --radius 2 --nbits 4096 \ + --num_population 128 --num_offspring 512 --num_gen 200 --objective gsk \ + --ncpu 32 ``` This script uses a genetic algorithm to optimize molecular embeddings and returns the predicted synthetic trees for the optimized molecular embedding. If user wants to start from a checkpoint of previous run, run: -```shell -python scripts/optimize_ga.py -i path/to/population.npy --radius 2 --nbits 4096 --num_population 128 --num_offspring 512 --num_gen 200 --ncpu 32 --objective gsk --restart +```bash +python scripts/optimize_ga.py \ + -i path/to/population.npy \ + --radius 2 --nbits 4096 \ + --num_population 128 --num_offspring 512 --num_gen 200 --objective gsk --restart \ + --ncpu 32 ``` -Note: the input file indicated by -i contains the seed molecules in CSV format for an initial run, and as a pre-saved numpy array of the population for restarting the run. +Note: the input file indicated by `-i` contains the seed molecules in CSV format for an initial run, and as a pre-saved numpy array of the population for restarting the run. ### Train the model from scratch From c02185e014bed1c6c4775adc9088cd9a54924312 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 24 Aug 2022 11:08:14 -0400 Subject: [PATCH 042/302] add helper to go from sdf->smiles --- scripts/00-extract-smiles-from-sdf.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 scripts/00-extract-smiles-from-sdf.py diff --git a/scripts/00-extract-smiles-from-sdf.py b/scripts/00-extract-smiles-from-sdf.py new file mode 100644 index 00000000..5c37729a --- /dev/null +++ b/scripts/00-extract-smiles-from-sdf.py @@ -0,0 +1,24 @@ +from syn_net.utils.prep_utils import Sdf2SmilesExtractor +from pathlib import Path +import logging + +logger = logging.getLogger(__name__) + + +def main(file): + file = Path(file) + if not file.exists(): + raise FileNotFoundError(file) + outfile = file.with_name(f"{file.name}-smiles").with_suffix(".csv.gz") + Sdf2SmilesExtractor().from_sdf(file).to_file(outfile) + +if __name__=="__main__": + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("-f", "--file", type=str, help="An *.sdf file") + args = parser.parse_args() + logger.info(f"Arguments: {vars(args)}") + file = args.file + main(file) + logger.info(f"Success.") + From 77dec043ac6c680cfe958bde76473d1e007d9546 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 24 Aug 2022 11:13:41 -0400 Subject: [PATCH 043/302] feature: predict from file (smiles only, no need for syntrees) --- scripts/predict_multireactant_mp.py | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/scripts/predict_multireactant_mp.py b/scripts/predict_multireactant_mp.py index b724cf1b..3e70cc28 100644 --- a/scripts/predict_multireactant_mp.py +++ b/scripts/predict_multireactant_mp.py @@ -22,6 +22,11 @@ def _fetch_data_chembl(name: str) -> list[str]: smis_query = df.smiles.to_list() return smis_query +def _fetch_data_from_file(name: str) -> list[str]: + with open(name,"rt") as f: + smis_query = [line.strip() for line in f] + return smis_query + def _fetch_data(name: str) -> list[str]: if args.data in ["train", "valid", "test"]: file = Path(DATA_PREPARED_DIR) / f"synthetic-trees-{args.data}.json.gz" @@ -29,8 +34,10 @@ def _fetch_data(name: str) -> list[str]: sts = SyntheticTreeSet() sts.load(file) smis_query = [st.root.smiles for st in sts.sts] - else: + elif args.data in ["chembl"]: smis_query = _fetch_data_chembl(name) + else: # Hopefully got a filename instead + smis_query = _fetch_data_from_file(name) return smis_query def _fetch_reaction_templates(file: str): @@ -127,9 +134,11 @@ def func(smiles: str): parser.add_argument("-n", "--num", type=int, default=1, help="Number of molecules to predict.") parser.add_argument("-d", "--data", type=str, default='test', - help="Choose from ['train', 'valid', 'test', 'chembl']") + help="Choose from ['train', 'valid', 'test', 'chembl'] or provide a file with one SMILES per line.") parser.add_argument("-o", "--outputembedding", type=str, default='fp_256', help="Choose from ['fp_4096', 'fp_256', 'gin', 'rdkit2d']") + parser.add_argument("--output-dir", type=str, default=None, + help="Directory to save output.") args = parser.parse_args() nbits = args.nbits @@ -184,15 +193,16 @@ def func(smiles: str): print(f" {avg_similarity=}") # Save to local dir - print('Saving results to {DATA_RESULT_DIR} ...') + output_dir = DATA_RESULT_DIR if args.output_dir is None else args.output_dir + print('Saving results to {output_dir} ...') df = pd.DataFrame({'query SMILES' : smiles_queries, - 'decode SMILES': smis_decoded, - 'similarity' : similarities}) - df.to_csv(f'{DATA_RESULT_DIR}/decode_result_{args.data}.csv.gz', - compression='gzip', - index=False,) + 'decode SMILES': smis_decoded, + 'similarity' : similarities}) + df.to_csv(f'{output_dir}/decode_result_{args.data}.csv.gz', + compression='gzip', + index=False,) synthetic_tree_set = SyntheticTreeSet(sts=trees) - synthetic_tree_set.save(f'{DATA_RESULT_DIR}/decoded_st_{args.data}.json.gz') + synthetic_tree_set.save(f'{output_dir}/decoded_st_{args.data}.json.gz') print('Finish!') From ac83368eb3d667eddecdc03f7db71ff887921278 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 24 Aug 2022 11:15:31 -0400 Subject: [PATCH 044/302] simplify list comp & comments --- src/syn_net/utils/predict_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/syn_net/utils/predict_utils.py b/src/syn_net/utils/predict_utils.py index dbf3c937..a6f27bd2 100644 --- a/src/syn_net/utils/predict_utils.py +++ b/src/syn_net/utils/predict_utils.py @@ -924,8 +924,8 @@ def synthetic_tree_decoder_rt1(z_target, x_rct2 = np.concatenate([z_state,z_mol1, x_rxn],axis=1) z_mol2 = reactant2_net(torch.Tensor(x_rct2)) z_mol2 = z_mol2.detach().numpy() - available = available_list[rxn_id] - available = [bb_dict[available[i]] for i in range(len(available))] + available = available_list[rxn_id] # list[str], list of reactants for this rxn + available = [bb_dict[smiles] for smiles in available] # list[int] temp_emb = bb_emb[available] available_tree = BallTree(temp_emb, metric=cosine_distance) # TODO: evaluate if distance matrix is faster/feasible as this BallTree is discarded immediately. dist, ind = nn_search(z_mol2, _tree=available_tree) From fd479209d8ad3fea0363b75b9f3e4919bbc7b138 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 24 Aug 2022 11:16:51 -0400 Subject: [PATCH 045/302] refactor: move code into fcts --- scripts/compute_embedding_mp.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/scripts/compute_embedding_mp.py b/scripts/compute_embedding_mp.py index d8a4ef7e..59a35eca 100644 --- a/scripts/compute_embedding_mp.py +++ b/scripts/compute_embedding_mp.py @@ -27,6 +27,14 @@ "rdkit2d": rdkit2d_embedding, } +def _load_building_blocks(file: Path) -> list[str]: + return pd.read_csv(file)["SMILES"].to_list() + +def _save_embedding(file: str, embeddings: list[list[float]]): + embeddings = np.array(embeddings) + + np.save(file, embeddings) + logger.info(f"Successfully saved to {file}.") if __name__ == "__main__": @@ -35,7 +43,7 @@ parser = argparse.ArgumentParser() parser.add_argument("--feature", type=str, default="fp_256", choices=FUNCTIONS.keys(), help="Objective function to optimize") parser.add_argument("--ncpu", type=int, default=64, help="Number of cpus") - parser.add_argument("-rxn", "--rxn_template", type=str, default="hb", choices=["hb", "pis"], help="Choose from ['hb', 'pis']") + parser.add_argument("-rxn", "--rxn-template", type=str, default="hb", choices=["hb", "pis"], help="Choose from ['hb', 'pis']") parser.add_argument("--input", type=str, help="Input file with SMILES strings (One per line).") args = parser.parse_args() @@ -44,8 +52,7 @@ # Load building blocks file = Path(DATA_PREPROCESS_DIR) / f"{reaction_template_id}-{building_blocks_id}-matched.csv.gz" - - data = pd.read_csv(file)["SMILES"].tolist() + data = _load_building_blocks(file) logger.info(f"Successfully read {file}.") logger.info(f"Total number of building blocks: {len(data)}.") @@ -53,12 +60,7 @@ with mp.Pool(processes=args.ncpu) as pool: embeddings = pool.map(func, data) - # Save embeddings - embeddings = np.array(embeddings) - path = Path(DATA_EMBEDDINGS_DIR) path.mkdir(exist_ok=1, parents=1) outfile = path / f"{reaction_template_id}-{building_blocks_id}-embeddings.npy" - - np.save(outfile, embeddings) - logger.info(f"Successfully saved to {outfile}.") + _save_embedding(file,embeddings) From 9761acf640da108ef3f5c26ca578640388b46541 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 24 Aug 2022 11:17:28 -0400 Subject: [PATCH 046/302] refactor: dict instead of if..elif --- src/syn_net/utils/prep_utils.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/syn_net/utils/prep_utils.py b/src/syn_net/utils/prep_utils.py index 978d3f2e..ce9710f9 100644 --- a/src/syn_net/utils/prep_utils.py +++ b/src/syn_net/utils/prep_utils.py @@ -70,14 +70,14 @@ def organize(st, d_mol=300, target_embedding='fp', radius=2, nBits=4096, states = [] steps = [] - if output_embedding == 'gin': - d_mol = 300 - elif output_embedding == 'fp_4096': - d_mol = 4096 - elif output_embedding == 'fp_256': - d_mol = 256 - elif output_embedding == 'rdkit2d': - d_mol = 200 + OUTPUT_EMBEDDINGS_DIMS = { + "gin": 300, + "fp_4096": 4096, + "fp_256": 256, + "rdkit2d": 200, + } + + d_mol = OUTPUT_EMBEDDINGS_DIMS[output_embedding] if target_embedding == 'fp': target = mol_fp(st.root.smiles, radius, nBits).tolist() From 1bf76fb20968e03fe1af9ad8417f36cee4285040 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 24 Aug 2022 11:21:26 -0400 Subject: [PATCH 047/302] forgot to add code used in c02185e --- src/syn_net/utils/prep_utils.py | 41 +++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/src/syn_net/utils/prep_utils.py b/src/syn_net/utils/prep_utils.py index ce9710f9..94d203e8 100644 --- a/src/syn_net/utils/prep_utils.py +++ b/src/syn_net/utils/prep_utils.py @@ -11,6 +11,8 @@ get_reaction_mask, mol_fp, get_mol_embedding) from pathlib import Path +import logging +logger = logging.getLogger(__name__) def rdkit2d_embedding(smi): """ @@ -313,3 +315,42 @@ def prep_data(main_dir, num_rxn, out_dim, datasets=None): print(f' saved data for "Reactant 1"') return None + +class Sdf2SmilesExtractor: + """Helper class for data generation.""" + + def __init__(self) -> None: + self.smiles: Iterator[str] + + def from_sdf(self, file: Union[str, Path]): + """Extract chemicals as SMILES from `*.sdf` file. + + See also: + https://www.rdkit.org/docs/GettingStartedInPython.html#reading-sets-of-molecules + """ + file = str(Path(file).resolve()) + suppl = Chem.SDMolSupplier(file) + self.smiles = (Chem.MolToSmiles(mol, canonical=True, isomericSmiles=False) for mol in suppl) + logger.info(f"Read data from {file}") + + return self + + def _to_csv_gz(self, file: Path) -> None: + import gzip + + with gzip.open(file, "wt") as f: + f.writelines("SMILES\n") + f.writelines((s + "\n" for s in self.smiles)) + + def _to_csv_gz(self, file: Path) -> None: + with open(file, "wt") as f: + f.writelines("SMILES\n") + f.writelines((s + "\n" for s in self.smiles)) + + def to_file(self, file: Union[str, Path]) -> None: + + if Path(file).suffixes == [".csv", ".gz"]: + self._to_csv_gz(file) + else: + self._to_txt(file) + logger.info(f"Saved data to {file}") \ No newline at end of file From 86ec38590862396bd10ef73f9b96253db7a137f2 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 24 Aug 2022 11:22:55 -0400 Subject: [PATCH 048/302] add type hints, comments --- src/syn_net/utils/prep_utils.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/src/syn_net/utils/prep_utils.py b/src/syn_net/utils/prep_utils.py index 94d203e8..04a2b0aa 100644 --- a/src/syn_net/utils/prep_utils.py +++ b/src/syn_net/utils/prep_utils.py @@ -1,16 +1,18 @@ """ This file contains various utils for data preparation and preprocessing. """ +from typing import Iterator, Union import numpy as np from scipy import sparse from dgllife.model import load_pretrained from tdc.chem_utils import MolConvert from sklearn.preprocessing import OneHotEncoder -from syn_net.utils.data_utils import SyntheticTree +from syn_net.utils.data_utils import Reaction, SyntheticTree from syn_net.utils.predict_utils import (can_react, get_action_mask, get_reaction_mask, mol_fp, get_mol_embedding) from pathlib import Path +from rdkit import Chem import logging logger = logging.getLogger(__name__) @@ -81,6 +83,7 @@ def organize(st, d_mol=300, target_embedding='fp', radius=2, nBits=4096, d_mol = OUTPUT_EMBEDDINGS_DIMS[output_embedding] + # Compute embedding of target molecule, i.e. the root of the synthetic tree if target_embedding == 'fp': target = mol_fp(st.root.smiles, radius, nBits).tolist() elif target_embedding == 'gin': @@ -94,7 +97,7 @@ def organize(st, d_mol=300, target_embedding='fp', radius=2, nBits=4096, most_recent_mol_embedding = mol_fp(most_recent_mol, radius, nBits).tolist() other_root_mol_embedding = mol_fp(other_root_mol, radius, nBits).tolist() - state = most_recent_mol_embedding + other_root_mol_embedding + target + state = most_recent_mol_embedding + other_root_mol_embedding + target # (3d,1) if action == 3: step = [3] + [0]*d_mol + [-1] + [0]*d_mol + [0]*nBits @@ -148,7 +151,9 @@ def organize(st, d_mol=300, target_embedding='fp', radius=2, nBits=4096, return sparse.csc_matrix(np.array(states)), sparse.csc_matrix(np.array(steps)) -def synthetic_tree_generator(building_blocks, reaction_templates, max_step=15): +def synthetic_tree_generator( + building_blocks: list[str], reaction_templates: list[Reaction], max_step: int = 15 +) -> tuple[SyntheticTree, int]: """ Generates a synthetic tree from the available building blocks and reaction templates. Used in preparing the training/validation/testing data. @@ -186,7 +191,7 @@ def synthetic_tree_generator(building_blocks, reaction_templates, max_step=15): break elif action == 0: # Add - mol1 = np.random.choice(building_blocks) + mol1 = np.random.choice(building_blocks) # TODO: convert to nparray to avoid costly conversion upon each function call else: # Expand or Merge mol1 = mol_recent @@ -194,10 +199,10 @@ def synthetic_tree_generator(building_blocks, reaction_templates, max_step=15): # Select reaction reaction_proba = np.random.rand(len(reaction_templates)) - if action != 2: - rxn_mask, available = get_reaction_mask(smi=mol1, + if action != 2: # = action == 0 or action == 1 + rxn_mask, available = get_reaction_mask(smi=mol1, rxns=reaction_templates) - else: + else: # merge tree _, rxn_mask = can_react(tree.get_state(), reaction_templates) available = [[] for rxn in reaction_templates] @@ -278,7 +283,7 @@ def prep_data(main_dir, num_rxn, out_dim, datasets=None): states = sparse.csc_matrix(states.A[(steps[:, 0].A != 3).reshape(-1, )]) steps = sparse.csc_matrix(steps.A[(steps[:, 0].A != 3).reshape(-1, )]) print(f' saved data for "Action"') - + # extract Reaction data X = sparse.hstack([states, steps[:, (2 * out_dim + 2):]]) y = steps[:, out_dim + 1] @@ -314,7 +319,7 @@ def prep_data(main_dir, num_rxn, out_dim, datasets=None): sparse.save_npz(main_dir / f'y_rt1_{dataset}.npz', y) print(f' saved data for "Reactant 1"') - return None + return None class Sdf2SmilesExtractor: """Helper class for data generation.""" From 772491f43e9d5c26335954dfb50889cb0a7b4723 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 24 Aug 2022 11:27:07 -0400 Subject: [PATCH 049/302] refactor: loop over networks --- tests/test_DataPreparation.py | 48 ++++++++--------------------------- 1 file changed, 11 insertions(+), 37 deletions(-) diff --git a/tests/test_DataPreparation.py b/tests/test_DataPreparation.py index 64ac7461..95debcc9 100644 --- a/tests/test_DataPreparation.py +++ b/tests/test_DataPreparation.py @@ -195,46 +195,20 @@ def test_dataprep(self): # check that the saved files match the reference files in # 'SynNet/tests/data/ref': + def _compare_to_reference(network_type: str): + X = sparse.load_npz(f"{main_dir}X_{network_type}_train.npz") + y = sparse.load_npz(f"{main_dir}y_{network_type}_train.npz") + + Xref = sparse.load_npz(f"{ref_dir}X_{network_type}_train.npz") + yref = sparse.load_npz(f"{ref_dir}y_{network_type}_train.npz") - # Action network data - X_act = sparse.load_npz(f"{main_dir}X_act_train.npz") - y_act = sparse.load_npz(f"{main_dir}y_act_train.npz") + self.assertEqual(X.toarray().all(), Xref.toarray().all(),msg=f"{network_type=}") + self.assertEqual(y.toarray().all(), yref.toarray().all(),msg=f"{network_type=}") - X_act_ref = sparse.load_npz(f"{ref_dir}X_act_train.npz") - y_act_ref = sparse.load_npz(f"{ref_dir}y_act_train.npz") + for network in ["act", "rt1", "rxn", "rt2"]: + _compare_to_reference(network) - self.assertEqual(X_act.toarray().all(), X_act_ref.toarray().all()) - self.assertEqual(y_act.toarray().all(), y_act_ref.toarray().all()) - - # Reactant 1 network data - X_rt1 = sparse.load_npz(f"{main_dir}X_rt1_train.npz") - y_rt1 = sparse.load_npz(f"{main_dir}y_rt1_train.npz") - - X_rt1_ref = sparse.load_npz(f"{ref_dir}X_rt1_train.npz") - y_rt1_ref = sparse.load_npz(f"{ref_dir}y_rt1_train.npz") - - self.assertEqual(X_rt1.toarray().all(), X_rt1_ref.toarray().all()) - self.assertEqual(y_rt1.toarray().all(), y_rt1_ref.toarray().all()) - - # Reaction network data - X_rxn = sparse.load_npz(f"{main_dir}X_rxn_train.npz") - y_rxn = sparse.load_npz(f"{main_dir}y_rxn_train.npz") - - X_rxn_ref = sparse.load_npz(f"{ref_dir}X_rxn_train.npz") - y_rxn_ref = sparse.load_npz(f"{ref_dir}y_rxn_train.npz") - - self.assertEqual(X_rxn.toarray().all(), X_rxn_ref.toarray().all()) - self.assertEqual(y_rxn.toarray().all(), y_rxn_ref.toarray().all()) - - # Reactant 2 network data - X_rt2 = sparse.load_npz(f"{main_dir}X_rt2_train.npz") - y_rt2 = sparse.load_npz(f"{main_dir}y_rt2_train.npz") - - X_rt2_ref = sparse.load_npz(f"{ref_dir}X_rt2_train.npz") - y_rt2_ref = sparse.load_npz(f"{ref_dir}y_rt2_train.npz") - - self.assertEqual(X_rt2.toarray().all(), X_rt2_ref.toarray().all()) - self.assertEqual(y_rt2.toarray().all(), y_rt2_ref.toarray().all()) + def test_bb_emb(self): """ From 05bc5646d23ba716969373f326440548d271e96c Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 24 Aug 2022 11:32:33 -0400 Subject: [PATCH 050/302] move hb reactions templates file --- data/assets/reaction-templates/hb.txt | 91 +++++++++++++++++++++++++++ data/rxn_set_hb.txt | 91 --------------------------- 2 files changed, 91 insertions(+), 91 deletions(-) create mode 100644 data/assets/reaction-templates/hb.txt delete mode 100644 data/rxn_set_hb.txt diff --git a/data/assets/reaction-templates/hb.txt b/data/assets/reaction-templates/hb.txt new file mode 100644 index 00000000..ff4b4727 --- /dev/null +++ b/data/assets/reaction-templates/hb.txt @@ -0,0 +1,91 @@ +[cH1:1]1:[c:2](-[CH2:7]-[CH2:8]-[NH2:9]):[c:3]:[c:4]:[c:5]:[c:6]:1.[#6:11]-[CH1;R0:10]=[OD1]>>[c:1]12:[c:2](-[CH2:7]-[CH2:8]-[NH1:9]-[C:10]-2(-[#6:11])):[c:3]:[c:4]:[c:5]:[c:6]:1 +[c;r6:1](-[NH1;$(N-[#6]):2]):[c;r6:3](-[NH2:4]).[#6:6]-[C;R0:5](=[OD1])-[#8;H1,$(O-[CH3])]>>[c:3]2:[c:1]:[n:2]:[c:5](-[#6:6]):[n:4]2 +[c;r6:1](-[NH1;$(N-[#6]):2]):[c;r6:3](-[NH2:4]).[#6:6]-[CH1;R0:5](=[OD1])>>[c:3]2:[c:1]:[n:2]:[c:5](-[#6:6]):[n:4]2 +[c;r6:1](-[SH1:2]):[c;r6:3](-[NH2:4]).[#6:6]-[CH1;R0:5](=[OD1])>>[c:3]2:[c:1]:[s:2]:[c:5](-[#6:6]):[n:4]2 +[c:1](-[OH1;$(Oc1ccccc1):2]):[c;r6:3](-[NH2:4]).[c:6]-[CH1;R0:5](=[OD1])>>[c:3]2:[c:1]:[o:2]:[c:5](-[c:6]):[n:4]2 +[c;r6:1](-[OH1:2]):[c;r6:3](-[NH2:4]).[#6:6]-[C;R0:5](=[OD1])-[OH1]>>[c:3]2:[c:1]:[o:2]:[c:5](-[#6:6]):[n:4]2 +[#6:6]-[C;R0:1](=[OD1])-[CH1;R0:5](-[#6:7])-[*;#17,#35,#53].[NH2:2]-[C:3]=[SD1:4]>>[c:1]2(-[#6:6]):[n:2]:[c:3]:[s:4][c:5]([#6:7]):2 +[c:1](-[C;$(C-c1ccccc1):2](=[OD1:3])-[OH1]):[c:4](-[NH2:5]).[N;!H0;!$(N-N);!$(N-C=N);!$(N(-C=O)-C=O):6]-[C;H1,$(C-[#6]):7]=[OD1]>>[c:4]2:[c:1]-[C:2](=[O:3])-[N:6]-[C:7]=[N:5]-2 +[CH0;$(C-[#6]):1]#[NH0:2]>>[C:1]1=[N:2]-N-N=N-1 +[CH0;$(C-[#6]):1]#[NH0:2].[C;A;!$(C=O):3]-[*;#17,#35,#53]>>[C:1]1=[N:2]-N(-[C:3])-N=N-1 +[CH0;$(C-[#6]):1]#[NH0:2].[C;A;!$(C=O):3]-[*;#17,#35,#53]>>[C:1]1=[N:2]-N=N-N-1(-[C:3]) +[CH0;$(C-[#6]):1]#[CH1:2].[C;H1,H2;A;!$(C=O):3]-[*;#17,#35,#53,OH1]>>[C:1]1=[C:2]-N(-[C:3])-N=N-1 +[CH0;$(C-[#6]):1]#[CH1:2].[C;H1,H2;A;!$(C=O):3]-[*;#17,#35,#53,OH1]>>[C:1]1=[C:2]-N=NN(-[C:3])-1 +[CH0;$(C-[#6]):1]#[CH0;$(C-[#6]):2].[C;H1,H2;A;!$(C=O):3]-[*;#17,#35,#53,OH1]>>[C:1]1=[C:2]-N=NN(-[C:3])-1 +[CH0;$(C-[#6]):1]#[NH0:2].[NH2:3]-[NH1:4]-[CH0;$(C-[#6]);R0:5]=[OD1]>>[N:2]1-[C:1]=[N:3]-[N:4]-[C:5]=1 +[CH0;$(C-[#6]):1]#[NH0:2].[CH0;$(C-[#6]);R0:5](=[OD1])-[#8;H1,$(O-[CH3]),$(O-[CH2]-[CH3])]>>[N:2]1-[C:1]=N-N-[C:5]=1 +[c:1](-[C;$(C-c1ccccc1):2](=[OD1:3])-[CH3:4]):[c:5](-[OH1:6]).[C;$(C1-[CH2]-[CH2]-[N,C]-[CH2]-[CH2]-1):7](=[OD1])>>[O:6]1-[c:5]:[c:1]-[C:2](=[OD1:3])-[C:4]-[C:7]-1 +[c;r6:1](-[C;$(C=O):6]-[OH1]):[c;r6:2]-[C;H1,$(C-C):3]=[OD1].[NH2:4]-[NH1;$(N-[#6]);!$(NC=[O,S,N]):5]>>[c:1]1:[c:2]-[C:3]=[N:4]-[N:5]-[C:6]-1 +[C;$(C-c1ccccc1):1](=[OD1])-[C;D3;$(C-c1ccccc1):2]~[O;D1,H1].[CH1;$(C-c):3]=[OD1]>>[C:1]1-N=[C:3]-[NH1]-[C:2]=1 +[NH1;$(N-c1ccccc1):1](-[NH2])-[c:5]:[cH1:4].[C;$(C([#6])[#6]):2](=[OD1])-[CH2;$(C([#6])[#6]);!$(C(C=O)C=O):3]>>[C:5]1-[N:1]-[C:2]=[C:3]-[C:4]:1 +[NH2;$(N-c1ccccc1):1]-[c:2]:[c:3]-[CH1:4]=[OD1].[C;$(C([#6])[#6]):6](=[OD1])-[CH2;$(C([#6])[#6]);!$(C(C=O)C=O):5]>>[N:1]1-[c:2]:[c:3]-[C:4]=[C:5]-[C:6]:1 +[*;Br,I;$(*c1ccccc1)]-[c:1]:[c:2]-[OH1:3].[CH1:5]#[C;$(C-[#6]):4]>>[c:1]1:[c:2]-[O:3]-[C:4]=[C:5]-1 +[*;Br,I;$(*c1ccccc1)]-[c:1]:[c:2]-[SD2:3]-[CH3].[CH1:5]#[C;$(C-[#6]):4]>>[c:1]1:[c:2]-[S:3]-[C:4]=[C:5]-1 +[*;Br,I;$(*c1ccccc1)]-[c:1]:[c:2]-[NH2:3].[CH1:5]#[C;$(C-[#6]):4]>>[c:1]1:[c:2]-[N:3]-[C:4]=[C:5]-1 +[#6:6][C:5]#[#7;D1:4].[#6:1][C:2](=[OD1:3])[OH1]>>[#6:6][c:5]1[n:4][o:3][c:2]([#6:1])n1 +[#6;$([#6]~[#6]);!$([#6]=O):2][#8;H1:3].[Cl,Br,I][#6;H2;$([#6]~[#6]):4]>>[CH2:4][O:3][#6:2] +[#6;H0;D3;$([#6](~[#6])~[#6]):1]B(O)O.[#6;H0;D3;$([#6](~[#6])~[#6]):2][Cl,Br,I]>>[#6:2][#6:1] +[c;H1:3]1:[c:4]:[c:5]:[c;H1:6]:[c:7]2:[nH:8]:[c:9]:[c;H1:1]:[c:2]:1:2.O=[C:10]1[#6;H2:11][#6;H2:12][N:13][#6;H2:14][#6;H2:15]1>>[#6;H2:12]3[#6;H1:11]=[C:10]([c:1]1:[c:9]:[n:8]:[c:7]2:[c:6]:[c:5]:[c:4]:[c:3]:[c:2]:1:2)[#6;H2:15][#6;H2:14][N:13]3 +[C;H1&$(C([#6])[#6]),H2&$(C[#6]):1][OH1].[NH1;$(N(C=O)C=O):2]>>[C:1][N:2] +[C;H1&$(C([#6])[#6]),H2&$(C[#6]):1][OH1].[OH1;$(Oc1ccccc1):2]>>[C:1][O:2] +[C;H1&$(C([#6])[#6]),H2&$(C[#6]):1][OH1].[NH1;$(N([#6])S(=O)=O):2]>>[C:1][N:2] +[C;H1&$(C([#6])[#6]),H2&$(C[#6]):1][OH1].[#7H1:2]1~[#7:3]~[#7:4]~[#7:5]~[#6:6]~1>>[C:1][#7:2]1:[#7:3]:[#7:4]:[#7:5]:[#6:6]:1 +[C;H1&$(C([#6])[#6]),H2&$(C[#6]):1][OH1].[#7H1:2]1~[#7:3]~[#7:4]~[#7:5]~[#6:6]~1>>[#7H0:2]1:[#7:3]:[#7H0:4]([C:1]):[#7:5]:[#6:6]:1 +[C;H1&$(C([#6])[#6]),H2&$(C[#6]):1][OH1].[#7:2]1~[#7:3]~[#7H1:4]~[#7:5]~[#6:6]~1>>[C:1][#7H0:2]1:[#7:3]:[#7H0:4]:[#7:5]:[#6:6]:1 +[C;H1&$(C([#6])[#6]),H2&$(C[#6]):1][OH1].[#7:2]1~[#7:3]~[#7H1:4]~[#7:5]~[#6:6]~1>>[#7:2]1:[#7:3]:[#7:4]([C:1]):[#7:5]:[#6:6]:1 +[#6;$(C=C-[#6]),$(c:c):1][Br,I].[Cl,Br,I][c:2]>>[c:2][#6:1] +[#6:1][C:2]#[#7;D1].[Cl,Br,I][#6;$([#6]~[#6]);!$([#6]([Cl,Br,I])[Cl,Br,I]);!$([#6]=O):3]>>[#6:1][C:2](=O)[#6:3] +[#6:1][C;H1,$([C]([#6])[#6]):2]=[OD1:3].[Cl,Br,I][#6;$([#6]~[#6]);!$([#6]([Cl,Br,I])[Cl,Br,I]);!$([#6]=O):4]>>[C:1][#6:2]([OH1:3])[#6:4] +[S;$(S(=O)(=O)[C,N]):1][Cl].[N;$(NC);!$(N=*);!$([N-]);!$(N#*);!$([ND3]);!$([ND4]);!$(N[c,O]);!$(N[C,S]=[S,O,N]):2]>>[S:1][N+0:2] +[c:1]B(O)O.[nH1;+0;r5;!$(n[#6]=[O,S,N]);!$(n~n~n);!$(n~n~c~n);!$(n~c~n~n):2]>>[c:1][n:2] +[#6:3]-[C;H1,$([CH0](-[#6])[#6]);!$(CC=O):1]=[OD1].[Cl,Br,I][C;H2;$(C-[#6]);!$(CC[I,Br]);!$(CCO[CH3]):2]>>[C:3][C:1]=[C:2] +[Cl,Br,I][c;$(c1:[c,n]:[c,n]:[c,n]:[c,n]:[c,n]:1):1].[N;$(NC)&!$(N=*)&!$([N-])&!$(N#*)&!$([ND3])&!$([ND4])&!$(N[c,O])&!$(N[C,S]=[S,O,N]),H2&$(Nc1:[c,n]:[c,n]:[c,n]:[c,n]:[c,n]:1):2]>>[c:1][N:2] +[C;$(C([#6])[#6;!$([#6]Br)]):4](=[OD1])[CH;$(C([#6])[#6]):5]Br.[#7;H2:3][C;$(C(=N)(N)[c,#7]):2]=[#7;H1;D1:1]>>[C:4]1=[CH0:5][NH:3][C:2]=[N:1]1 +[c;$(c1[c;$(c[C,S,N](=[OD1])[*;R0;!OH1])]cccc1):1][C;$(C(=O)[O;H1])].[c;$(c1aaccc1):2][Cl,Br,I]>>[c:1][c:2] +[c;!$(c1ccccc1);$(c1[n,c]c[n,c]c[n,c]1):1][Cl,F].[N;$(NC);!$(N=*);!$([N-]);!$(N#*);!$([ND3]);!$([ND4]);!$(N[c,O]);!$(N[C,S]=[S,O,N]):2]>>[c:1][N:2] +[c;$(c1c(N(~O)~O)cccc1):1][Cl,F].[N;$(NC);!$(N=*);!$([N-]);!$(N#*);!$([ND3]);!$([ND4]);!$(N[c,O]);!$(N[C,S]=[S,O,N]):2]>>[c:1][N:2] +[c;$(c1ccc(N(~O)~O)cc1):1][Cl,F].[N;$(NC);!$(N=*);!$([N-]);!$(N#*);!$([ND3]);!$([ND4]);!$(N[c,O]);!$(N[C,S]=[S,O,N]):2]>>[c:1][N:2] +[N;$(N-[#6]):3]=[C;$(C=O):1].[N;$(N[#6]);!$(N=*);!$([N-]);!$(N#*);!$([ND3]);!$([ND4]);!$(N[O,N]);!$(N[C,S]=[S,O,N]):2]>>[N:3]-[C:1]-[N+0:2] +[N;$(N-[#6]):3]=[C;$(C=S):1].[N;$(N[#6]);!$(N=*);!$([N-]);!$(N#*);!$([ND3]);!$([ND4]);!$(N[O,N]);!$(N[C,S]=[S,O,N]):2]>>[N:3]-[C:1]-[N+0:2] +[$(C([CH2,CH3])),CH:10](=[O:11])-[NH+0:9]-[C$(C(N)(C)(C)(C)),C$([CH](N)(C)(C)),C$([CH2](N)(C)):8]-[C$(C(c)(C)(C)(C)),C$([CH](c)(C)(C)),C$([CH2](c)(C)):7]-[c:6]1[cH:1][c:2][c:3][c:4][c:5]1>>[C:10]-1=[N+0:9]-[C:8]-[C:7]-[c:6]2[c:5][c:4][c:3][c:2][c:1]-12 +[$(C([CH2,CH3])),CH:10](=[O:11])-[NH+0:9]-[C$([CH](N)(C)(C)),C$([CH2](N)(C)):8]-[C$([C](c)(C)(C)),C$([CH](c)(C)):7]([O$(OC),OH])-[c:6]1[cH:1][c:2][c:3][c:4][c:5]1>>[c:10]-1[n:9][c:8][c:7][c:6]2[c:5][c:4][c:3][c:2][c:1]-12 +[NH3+,NH2]-[C$(C(N)(C)(C)(C)),C$([CH](N)(C)(C)),C$([CH2](N)(C)):8]-[C$(C(c)(C)(C)(C)),C$([CH](c)(C)(C)),C$([CH2](c)(C)):7]-[c:6]1[c:1][c:2][nH:3][cH:5]1.[CH:10](-[CX4:12])=[O:11]>>[c,C:12]-[CH:10]-1-[N]-[C:8]-[C:7]-[c:6]2[c:1][c:2][nH:3][c:5]-12 +[NH2,NH3+1:8]-[c:5]1[cH:4][c:3][c:2][c:1][c:6]1.[Br:18][C$([CH2](C)(Br)),C$([CH](C)(C)(Br)):17]-[C:15](=[O:16])-[c:10]1[c:11][c:12][c:13][c:14][c:9]1>>[c:13]1[c:12][c:11][c:10]([c:9][c:14]1)-[c:15]1[c:17][c:4]2[c:3][c:2][c:1][c:6][c:5]2[nH+0:8]1 +[Cl:1][CH2:2]-[C$([CH](C)),C$(C(C)(C)):3]=[O:4].[OH:12]-[c:11]1[c:6][c:7][c:8][c:9][c:10]1-[CH:13]=[O:14]>>[C:3](=[O:4])-[c:2]1[c:13][c:10]2[c:9][c:8][c:7][c:6][c:11]2[o:12]1 +[NH2,NH3+]-[C$([CX4](N)([c,C])([c,C])([c,C])),C$([CH](N)([c,C])([c,C])),C$([CH2](N)([c,C])),C$([CH3](N)):2].[NH2:12]-[c:7]1[c:6][c:5][c:4][c:3][c:8]1-[C:9](-[OH,O-:11])=[O:10]>>[C:2]-[n+0]-1[c:13][n:12][c:7]2[c:6][c:5][c:4][c:3][c:8]2[c:9]-1=[O:10] +[N$([NH2]([CX4])),N$([NH3+1]([CX4])):1].[O:5]-[C$([CH]([CX4])(C)(O)),C$([CH2]([CX4])(O)):3][C$(C([CX4])(=O)([CX4])),C$([CH]([CX4])(=O)):4]=[O:6]>[O:15]=[C:9]-1-[CH2:10]-[CH2:11]-[CH2:12]-[CH2:13]-[CH2:14]-1>[c:4]1[c:3][n+0:1][c:10]2-[C:11]-[C:12]-[C:13]-[C:14]-[c:9]12 +[C$(C(=O)([CX4])([CX4])),C$([CH](=O)([CX4])):2](=[O:6])-[C$([CH]([CX4])),C$([CH2]):3]-[C$(C(=O)([CX4])([CX4])),C$([CH](=O)([CX4])):4]=[O:7].[NH2:8]-[C:9](=[O:10])-[CH2:11][C:12]#[N:13]>>[OH:10]-[c:9]1[n:8][c:4][c:3][c:2][c:11]1[C:12]#[N:13] +[C$(C(#C)([CX4])):2]#[C$(C(#C)([CX4])):1].[N$(N(~N)([CX4])):5]~[N]~[N]>>[c:2]1[c:1][n:5][n][n]1 +[C$(C(=C)([CX4])):2]=[C$(C(=C)([CX4])):1].[N$(N(~N)([CX4])):5]~[N]~[N]>>[C:2]1[C:1][N:5][N]=[N]1 +[C$(C(=C)([CX4,OX2,NX3])([CX4,OX2,NX3])),C$([CH](=C)([CX4,OX2,NX3])),C$([CH2](=C)):1]=[C$(C(=C)([CX4,OX2,NX3])([CX4,OX2,NX3])),C$([CH](=C)([CX4,OX2,NX3])),C$([CH2](=C)):2].[C$(C(=C)([CX4,OX2,NX3])([CX4,OX2,NX3])),C$([CH](=C)([CX4,OX2,NX3])),C$([CH2](=C)):3]=[C$([C](=C)(C)([CX4,OX2,NX3])),C$([CH](=C)(C)):4]-[C$([C](=C)(C)([CX4,OX2,NX3])),C$([CH](=C)(C)):5]=[C$(C(=C)([CX4,OX2,NX3])([CX4,OX2,NX3])),C$([CH](=C)([CX4,OX2,NX3])),C$([CH2](=C)):6]>>[C:1]1[C:2][C:3][C:4]=[C:5][C:6]1 +[C$(C(#C)([CX4,OX2,NX3])),C$([CH](#C)):1]#[C$(C(#C)([CX4,OX2,NX3])),C$([CH](#C)):2].[C$(C(=C)([CX4,OX2,NX3])([CX4,OX2,NX3])),C$([CH](=C)([CX4,OX2,NX3])),C$([CH2](=C)):3]=[C$([C](=C)(C)([CX4,OX2,NX3])),C$([CH](=C)(C)):4]-[C$([C](=C)(C)([CX4,OX2,NX3])),C$([CH](=C)(C)):5]=[C$(C(=C)([CX4,OX2,NX3])([CX4,OX2,NX3])),C$([CH](=C)([CX4,OX2,NX3])),C$([CH2](=C)):6]>>[C:1]1=[C:2][C:3][C:4]=[C:5][C:6]1 +[NH2,NH3+:3]-[N$([NH](N)([CX4])):2].[C$([CH](C)(C)([CX4])),C$([CH2](C)(C)):6](-[C$(C(=O)(C)([CX4])),C$([CH](=O)(C)):5]=[O:9])-[C$(C(=O)(C)([CX4])),C$([CH](=O)(C)):7]=[O:10]>>[c:7]1[n:3][n:2][c:5][c:6]1 +[C$(C(=O)(C)([CX4])),C$(C[H](=O)(C)):1](=[O:2])-[$([CH](C)(C)([CX4])),$([CH2](C)(C)):3]-[$([CH](C)(C)([CX4])),$([CH2](C)(C)):4]-[C$(C(=O)(C)([CX4])),C$(C[H](=O)(C)):5]=[O:6].[N$([NH2,NH3+1]([CX4])):7]>>[c:5]1[c:4][c:3][c:1][n+0:7]1 +[CH:7](=[O:8])-[c:1]1[c:2][c:3][c:4][c:5][c:6]1.[O:24]=[C:23](-[C:22](=[O:25])-[c:15]1[c:10][c:11][c:12][c:13][c:14]1)-[c:20]1[c:21][c:16][c:17][c:18][c:19]1>[NH4].[O-]C(=O)C>[nH:27]-1[c:7]([n:26][c:23]([c:22]-1[c:15]1[c:10][c:11][c:12][c:13][c:14]1)-[c:20]1[c:21][c:16][c:17][c:18][c:19]1)-[c:1]1[c:2][c:3][c:4][c:5][c:6]1 +[OH:7]-[c:6]1[cH:1][c:2][c:3][c:4][c:5]1.[O$(O(C)([CX4])):12]-[C:11](=[O:15])-[C$([CH](C)(C)([CX4])),C$([CH2](C)(C)):10]-[C:8]=[O:16]>>[C:8]-1=[C:10]-[C:11](=[O:15])-[O]-[c:6]2[c:5][c:4][c:3][c:2][c:1]-12 +[O$(O(C)([CX4])):8][C:7](=[O:9])[CH:6][C:5][C:4][C:3][C:2]([O$(O(C)([CX4])):10])=[O:1]>>[O:8][C:7](=[O:9])[C:6]1[C:5][C:4][C:3][C:2]1=[O:1] +[O$(O(C)([CX4])):8][C:7](=[O:9])[CH:6][C:5][C:11][C:4][C:3][C:2]([O$(O(C)([CX4])):10])=[O:1]>>[O:8][C:7](=[O:9])[C:6]1[C:5][C:11][C:4][C:3][C:2]1=[O:1] +[Cl:9][C:7](=[O:8])-[c:3]1[c:2][c:1][c:6][c:5][c:4]1.[C$([CH2](C)([CX4])),C$([CH3](C)):18]-[C:16](=[O:17])-[c:14]1[c:13][c:12][c:11][c:10][c:15]1-[OH:19]>>[O:17]=[C:16]-1-[C:18]=[C:7](-[O:8]-[c:15]2[c:10][c:11][c:12][c:13][c:14]-12)-[c:3]1[c:2][c:1][c:6][c:5][c:4]1 +[C$(C(C)(=O)([CX4,OX2&H0])),C$(C(C)(#N)),N$([N+1](C)(=O)([O-1])):1][C$([CH]([C,N])([C,N])([CX4])),C$([CH2]([C,N])([C,N])):2][C$(C(C)(=O)([CX4,OX2&H0])),C$(C(C)(#N)),N$([N+1](C)(=O)([O-1])):3].[C$(C(C)(#N)),C$(C(C)([CX4,OX2&H0])([CX4,OX2&H0])([OX2&H0])),C$([CH](C)([CX4,OX2&H0])([OX2&H0])),C$([CH2](C)([OX2&H0])),C$(C(C)(=O)([OX2&H0])):6][CH:5]=[C$(C(=C)([CX4])([CX4])),C$([CH](=C)([CX4])),C$([CH2](=C)):4]>>[C:6][C:5][C:4][C:2]([C:1])[C:3] +[C$([C](O)([CX4])([CX4])([CX4])),C$([CH](O)([CX4])([CX4])),C$([CH2](O)([CX4])):4]-[O:3]-[C$(C(=O)([CX4])),C$([CH](=O)):2]=[O:5].[C$([CH](C)([CX4])([CX4])),C$([CH2](C)([CX4])),C$([CH3](C)):7]-[C$(C(=O)([CX4])),C$([CH](=O)):8]=[O:9]>>[C:7](-[C:2]=[O:5])-[C:8]=[O:9] +[Cl,OH,O-:3][C$(C(=O)([CX4,c])),C$([CH](=O)):2]=[O:4].[O$([OH]([CX4,c])),O$([OH]([CX4,c])([CX4,c])),S$([SH]([CX4,c])),S$([SH]([CX4,c])([CX4,c])):6]>>[*:6]-[C:2]=[O:4] +[C$(C(=O)([CX4,c])([CX4,c])),C$([CH](=O)([CX4,c])):1]=[O:2].[N$([NH2,NH3+1]([CX4,c])),N$([NH]([CX4,c])([CX4,c])):3]>>[N+0:3][C:1] +[Br:1][c$(c(Br)),n$(n(Br)),o$(o(Br)),C$([CH](Br)(=C)):2].[C$(C(B)([CX4])([CX4])([CX4])),C$([CH](B)([CX4])([CX4])),C$([CH2](B)([CX4])),C$([CH2](B)),C$(C(B)(=C)),c$(c(B)),o$(o(B)),n$(n(B)):3][B$(B([C,c,n,o])([OH,$(OC)])([OH,$(OC)])),B$([B-1]([C,c,n,o])(N)([OH,$(OC)])([OH,$(OC)])):4]>>[C,c,n,o:2][C,c,n,o:3] +[Br,I:1][C$(C([Br,I])([CX4])([CX4])([CX4])),C$([CH]([Br,I])([CX4])([CX4])),C$([CH2]([Br,I])([CX4])),C$([CH3]([Br,I])),C$([C]([Br,I])(=C)([CX4])),C$([CH]([Br,I])(=C)),C$(C([Br,I])(#C)),c$(c([Br,I])):2].[Br,I:3][C$(C([Br,I])([CX4])([CX4])([CX4])),C$([CH]([Br,I])([CX4])([CX4])),C$([CH2]([Br,I])([CX4])),C$([CH3]([Br,I])),C$([C]([Br,I])(=C)([CX4])),C$([CH]([Br,I])(=C)),C$(C([Br,I])(#C)),c$(c([Br,I])):4]>>[C,c:2][C,c:4] +[OH,O-]-[C$(C(=O)(O)([CX4,c])):2]=[O:3].[OH:8]-[C$([CH](O)([CX4,c])([CX4,c])),C$([CH2](O)([CX4,c])),C$([CH3](O)):6]>>[C:6][O]-[C:2]=[O:3] +[C$([CH](=C)([CX4])),C$([CH2](=C)):2]=[C$(C(=C)([CX4])([CX4])),C$([CH](=C)([CX4])),C$([CH2](=C)):3].[Br,I:7][C$([CX4]([Br,I])),c$([c]([Br,I])):4]>>[C,c:4][C:2]=[C:3] +[Cl,OH,O-:3][C$(C(=O)([CX4,c])),C$([CH](=O)):2]=[O:4].[N$([NH2,NH3+1]([CX4,c])),N$([NH]([CX4,c])([CX4,c])):6]>>[N+0:6]-[C:2]=[O:4] +[C$(C(=C)([CX4])([CX4])),C$([CH](=C)([CX4])),C$([CH2](=C)):1]=[C$(C(=C)([CX4])([CX4])),C$([CH](=C)([CX4])),C$([CH2](=C)):2].[SH:4]-[CX4:5][Br,Cl,I]>>[C:1]-[C:2]-[S:4][C:5] +[C$([C](=O)([CX4])),C$([CH](=O)):2](=[O:1])[OH,Cl,O-:6].[SH:4]-[CX4:5][Br,Cl,I]>>[CH2:2]-[S:4][C:5] +[I:1][C$(C(I)([CX4,c])([CX4,c])([CX4,c])),C$([CH](I)([CX4,c])([CX4,c])),C$([CH2](I)([CX4,c])),C$([CH3](I)):2].[C$(C(=O)([Cl,OH,O-])([CX4,c])),C$([CH]([Cl,OH,O-])(=O)):3](=[O:6])[Cl,OH,O-:5]>>[C:2]-[C:3]=[O:6] +[Cl:5][S$(S(=O)(=O)(Cl)([CX4])):2](=[O:3])=[O:4].[NH2+0,NH3+:6]-[C$(C(N)([CX4,c])([CX4,c])([CX4,c])),C$([CH](N)([CX4,c])([CX4,c])),C$([CH2](N)([CX4,c])),C$([CH3](N)),c$(c(N)):7]>>[C,c:7]-[NH+0:6][S:2](=[O:4])=[O:3] +[*:1][C:2]#[CH:3].[Br,I:4][C$(C([CX4,c])([CX4,c])([CX4,c])),C$([CH]([CX4,c])([CX4,c])),C$([CH2]([CX4,c])),C$([CH3]),c$(c):5]>>[C,c:5][C:3]#[C:2][*:1] +[C$(C(C)([CX4])([CX4])([CX4])),C$([CH](C)([CX4])([CX4])),C$([CH2](C)([CX4])),C$([CH3](C)):1][C:2]#[CH:3].[Br,I:4][C$(C(=O)([Br,I])([CX4])),C$([CH](=O)([Br,I])):5]=[O:6]>>[C:1][C:2]#[C:3][C:5]=[O:6] +[OH,O-:4]-[C$(C(=O)([OH,O-])([CX4])),C$([CH](=O)([OH,O-])):2]=[O:3]>>[Cl:5][C:2]=[O:3] +[OH:2]-[$([CX4]),c:1]>>[Br:3][C,c:1] +[OH:2]-[$([CX4]),c:1]>>[Cl:3][C,c:1] +[OH,O-:3][S$(S([CX4])):2](=[O:4])=[O:5]>>[Cl:6][S:2](=[O:5])=[O:4] +[OH+0,O-:5]-[C:3](=[O:4])-[C$([CH]([CX4])),C$([CH2]):2]>>[OH+0,O-:5]-[C:3](=[O:4])-[C:2]([Br:6]) +[OH+0,O-:5]-[C:3](=[O:4])-[C$([CH]([CX4])),C$([CH2]):2]>>[OH+0,O-:5]-[C:3](=[O:4])-[C:2]([Cl:6]) +[Cl,I,Br:7][c:1]1[c:2][c:3][c:4][c:5][c:6]1>>[N:9]#[C:8][c:1]1[c:2][c:3][c:4][c:5][c:6]1 +[OH,NH2,NH3+:3]-[CH2:2]-[C$(C([CX4,c])([CX4,c])([CX4,c])),C$([CH]([CX4,c])([CX4,c])),C$([CH2]([CX4,c])),C$([CH3]),c$(c):1]>>[C,c:1][C:2]#[N:4] diff --git a/data/rxn_set_hb.txt b/data/rxn_set_hb.txt deleted file mode 100644 index fa917eff..00000000 --- a/data/rxn_set_hb.txt +++ /dev/null @@ -1,91 +0,0 @@ -|[cH1:1]1:[c:2](-[CH2:7]-[CH2:8]-[NH2:9]):[c:3]:[c:4]:[c:5]:[c:6]:1.[#6:11]-[CH1;R0:10]=[OD1]>>[c:1]12:[c:2](-[CH2:7]-[CH2:8]-[NH1:9]-[C:10]-2(-[#6:11])):[c:3]:[c:4]:[c:5]:[c:6]:1 -|[c;r6:1](-[NH1;$(N-[#6]):2]):[c;r6:3](-[NH2:4]).[#6:6]-[C;R0:5](=[OD1])-[#8;H1,$(O-[CH3])]>>[c:3]2:[c:1]:[n:2]:[c:5](-[#6:6]):[n:4]2 -|[c;r6:1](-[NH1;$(N-[#6]):2]):[c;r6:3](-[NH2:4]).[#6:6]-[CH1;R0:5](=[OD1])>>[c:3]2:[c:1]:[n:2]:[c:5](-[#6:6]):[n:4]2 -|[c;r6:1](-[SH1:2]):[c;r6:3](-[NH2:4]).[#6:6]-[CH1;R0:5](=[OD1])>>[c:3]2:[c:1]:[s:2]:[c:5](-[#6:6]):[n:4]2 -|[c:1](-[OH1;$(Oc1ccccc1):2]):[c;r6:3](-[NH2:4]).[c:6]-[CH1;R0:5](=[OD1])>>[c:3]2:[c:1]:[o:2]:[c:5](-[c:6]):[n:4]2 -|[c;r6:1](-[OH1:2]):[c;r6:3](-[NH2:4]).[#6:6]-[C;R0:5](=[OD1])-[OH1]>>[c:3]2:[c:1]:[o:2]:[c:5](-[#6:6]):[n:4]2 -|[#6:6]-[C;R0:1](=[OD1])-[CH1;R0:5](-[#6:7])-[*;#17,#35,#53].[NH2:2]-[C:3]=[SD1:4]>>[c:1]2(-[#6:6]):[n:2]:[c:3]:[s:4][c:5]([#6:7]):2 -|[c:1](-[C;$(C-c1ccccc1):2](=[OD1:3])-[OH1]):[c:4](-[NH2:5]).[N;!H0;!$(N-N);!$(N-C=N);!$(N(-C=O)-C=O):6]-[C;H1,$(C-[#6]):7]=[OD1]>>[c:4]2:[c:1]-[C:2](=[O:3])-[N:6]-[C:7]=[N:5]-2 -|[CH0;$(C-[#6]):1]#[NH0:2]>>[C:1]1=[N:2]-N-N=N-1 -|[CH0;$(C-[#6]):1]#[NH0:2].[C;A;!$(C=O):3]-[*;#17,#35,#53]>>[C:1]1=[N:2]-N(-[C:3])-N=N-1 -|[CH0;$(C-[#6]):1]#[NH0:2].[C;A;!$(C=O):3]-[*;#17,#35,#53]>>[C:1]1=[N:2]-N=N-N-1(-[C:3]) -|[CH0;$(C-[#6]):1]#[CH1:2].[C;H1,H2;A;!$(C=O):3]-[*;#17,#35,#53,OH1]>>[C:1]1=[C:2]-N(-[C:3])-N=N-1 -|[CH0;$(C-[#6]):1]#[CH1:2].[C;H1,H2;A;!$(C=O):3]-[*;#17,#35,#53,OH1]>>[C:1]1=[C:2]-N=NN(-[C:3])-1 -|[CH0;$(C-[#6]):1]#[CH0;$(C-[#6]):2].[C;H1,H2;A;!$(C=O):3]-[*;#17,#35,#53,OH1]>>[C:1]1=[C:2]-N=NN(-[C:3])-1 -|[CH0;$(C-[#6]):1]#[NH0:2].[NH2:3]-[NH1:4]-[CH0;$(C-[#6]);R0:5]=[OD1]>>[N:2]1-[C:1]=[N:3]-[N:4]-[C:5]=1 -|[CH0;$(C-[#6]):1]#[NH0:2].[CH0;$(C-[#6]);R0:5](=[OD1])-[#8;H1,$(O-[CH3]),$(O-[CH2]-[CH3])]>>[N:2]1-[C:1]=N-N-[C:5]=1 -|[c:1](-[C;$(C-c1ccccc1):2](=[OD1:3])-[CH3:4]):[c:5](-[OH1:6]).[C;$(C1-[CH2]-[CH2]-[N,C]-[CH2]-[CH2]-1):7](=[OD1])>>[O:6]1-[c:5]:[c:1]-[C:2](=[OD1:3])-[C:4]-[C:7]-1 -|[c;r6:1](-[C;$(C=O):6]-[OH1]):[c;r6:2]-[C;H1,$(C-C):3]=[OD1].[NH2:4]-[NH1;$(N-[#6]);!$(NC=[O,S,N]):5]>>[c:1]1:[c:2]-[C:3]=[N:4]-[N:5]-[C:6]-1 -|[C;$(C-c1ccccc1):1](=[OD1])-[C;D3;$(C-c1ccccc1):2]~[O;D1,H1].[CH1;$(C-c):3]=[OD1]>>[C:1]1-N=[C:3]-[NH1]-[C:2]=1 -|[NH1;$(N-c1ccccc1):1](-[NH2])-[c:5]:[cH1:4].[C;$(C([#6])[#6]):2](=[OD1])-[CH2;$(C([#6])[#6]);!$(C(C=O)C=O):3]>>[C:5]1-[N:1]-[C:2]=[C:3]-[C:4]:1 -|[NH2;$(N-c1ccccc1):1]-[c:2]:[c:3]-[CH1:4]=[OD1].[C;$(C([#6])[#6]):6](=[OD1])-[CH2;$(C([#6])[#6]);!$(C(C=O)C=O):5]>>[N:1]1-[c:2]:[c:3]-[C:4]=[C:5]-[C:6]:1 -|[*;Br,I;$(*c1ccccc1)]-[c:1]:[c:2]-[OH1:3].[CH1:5]#[C;$(C-[#6]):4]>>[c:1]1:[c:2]-[O:3]-[C:4]=[C:5]-1 -|[*;Br,I;$(*c1ccccc1)]-[c:1]:[c:2]-[SD2:3]-[CH3].[CH1:5]#[C;$(C-[#6]):4]>>[c:1]1:[c:2]-[S:3]-[C:4]=[C:5]-1 -|[*;Br,I;$(*c1ccccc1)]-[c:1]:[c:2]-[NH2:3].[CH1:5]#[C;$(C-[#6]):4]>>[c:1]1:[c:2]-[N:3]-[C:4]=[C:5]-1 -|[#6:6][C:5]#[#7;D1:4].[#6:1][C:2](=[OD1:3])[OH1]>>[#6:6][c:5]1[n:4][o:3][c:2]([#6:1])n1 -|[#6;$([#6]~[#6]);!$([#6]=O):2][#8;H1:3].[Cl,Br,I][#6;H2;$([#6]~[#6]):4]>>[CH2:4][O:3][#6:2] -|[#6;H0;D3;$([#6](~[#6])~[#6]):1]B(O)O.[#6;H0;D3;$([#6](~[#6])~[#6]):2][Cl,Br,I]>>[#6:2][#6:1] -|[c;H1:3]1:[c:4]:[c:5]:[c;H1:6]:[c:7]2:[nH:8]:[c:9]:[c;H1:1]:[c:2]:1:2.O=[C:10]1[#6;H2:11][#6;H2:12][N:13][#6;H2:14][#6;H2:15]1>>[#6;H2:12]3[#6;H1:11]=[C:10]([c:1]1:[c:9]:[n:8]:[c:7]2:[c:6]:[c:5]:[c:4]:[c:3]:[c:2]:1:2)[#6;H2:15][#6;H2:14][N:13]3 -|[C;H1&$(C([#6])[#6]),H2&$(C[#6]):1][OH1].[NH1;$(N(C=O)C=O):2]>>[C:1][N:2] -|[C;H1&$(C([#6])[#6]),H2&$(C[#6]):1][OH1].[OH1;$(Oc1ccccc1):2]>>[C:1][O:2] -|[C;H1&$(C([#6])[#6]),H2&$(C[#6]):1][OH1].[NH1;$(N([#6])S(=O)=O):2]>>[C:1][N:2] -|[C;H1&$(C([#6])[#6]),H2&$(C[#6]):1][OH1].[#7H1:2]1~[#7:3]~[#7:4]~[#7:5]~[#6:6]~1>>[C:1][#7:2]1:[#7:3]:[#7:4]:[#7:5]:[#6:6]:1 -|[C;H1&$(C([#6])[#6]),H2&$(C[#6]):1][OH1].[#7H1:2]1~[#7:3]~[#7:4]~[#7:5]~[#6:6]~1>>[#7H0:2]1:[#7:3]:[#7H0:4]([C:1]):[#7:5]:[#6:6]:1 -|[C;H1&$(C([#6])[#6]),H2&$(C[#6]):1][OH1].[#7:2]1~[#7:3]~[#7H1:4]~[#7:5]~[#6:6]~1>>[C:1][#7H0:2]1:[#7:3]:[#7H0:4]:[#7:5]:[#6:6]:1 -|[C;H1&$(C([#6])[#6]),H2&$(C[#6]):1][OH1].[#7:2]1~[#7:3]~[#7H1:4]~[#7:5]~[#6:6]~1>>[#7:2]1:[#7:3]:[#7:4]([C:1]):[#7:5]:[#6:6]:1 -|[#6;$(C=C-[#6]),$(c:c):1][Br,I].[Cl,Br,I][c:2]>>[c:2][#6:1] -|[#6:1][C:2]#[#7;D1].[Cl,Br,I][#6;$([#6]~[#6]);!$([#6]([Cl,Br,I])[Cl,Br,I]);!$([#6]=O):3]>>[#6:1][C:2](=O)[#6:3] -|[#6:1][C;H1,$([C]([#6])[#6]):2]=[OD1:3].[Cl,Br,I][#6;$([#6]~[#6]);!$([#6]([Cl,Br,I])[Cl,Br,I]);!$([#6]=O):4]>>[C:1][#6:2]([OH1:3])[#6:4] -|[S;$(S(=O)(=O)[C,N]):1][Cl].[N;$(NC);!$(N=*);!$([N-]);!$(N#*);!$([ND3]);!$([ND4]);!$(N[c,O]);!$(N[C,S]=[S,O,N]):2]>>[S:1][N+0:2] -|[c:1]B(O)O.[nH1;+0;r5;!$(n[#6]=[O,S,N]);!$(n~n~n);!$(n~n~c~n);!$(n~c~n~n):2]>>[c:1][n:2] -|[#6:3]-[C;H1,$([CH0](-[#6])[#6]);!$(CC=O):1]=[OD1].[Cl,Br,I][C;H2;$(C-[#6]);!$(CC[I,Br]);!$(CCO[CH3]):2]>>[C:3][C:1]=[C:2] -|[Cl,Br,I][c;$(c1:[c,n]:[c,n]:[c,n]:[c,n]:[c,n]:1):1].[N;$(NC)&!$(N=*)&!$([N-])&!$(N#*)&!$([ND3])&!$([ND4])&!$(N[c,O])&!$(N[C,S]=[S,O,N]),H2&$(Nc1:[c,n]:[c,n]:[c,n]:[c,n]:[c,n]:1):2]>>[c:1][N:2] -|[C;$(C([#6])[#6;!$([#6]Br)]):4](=[OD1])[CH;$(C([#6])[#6]):5]Br.[#7;H2:3][C;$(C(=N)(N)[c,#7]):2]=[#7;H1;D1:1]>>[C:4]1=[CH0:5][NH:3][C:2]=[N:1]1 -|[c;$(c1[c;$(c[C,S,N](=[OD1])[*;R0;!OH1])]cccc1):1][C;$(C(=O)[O;H1])].[c;$(c1aaccc1):2][Cl,Br,I]>>[c:1][c:2] -|[c;!$(c1ccccc1);$(c1[n,c]c[n,c]c[n,c]1):1][Cl,F].[N;$(NC);!$(N=*);!$([N-]);!$(N#*);!$([ND3]);!$([ND4]);!$(N[c,O]);!$(N[C,S]=[S,O,N]):2]>>[c:1][N:2] -|[c;$(c1c(N(~O)~O)cccc1):1][Cl,F].[N;$(NC);!$(N=*);!$([N-]);!$(N#*);!$([ND3]);!$([ND4]);!$(N[c,O]);!$(N[C,S]=[S,O,N]):2]>>[c:1][N:2] -|[c;$(c1ccc(N(~O)~O)cc1):1][Cl,F].[N;$(NC);!$(N=*);!$([N-]);!$(N#*);!$([ND3]);!$([ND4]);!$(N[c,O]);!$(N[C,S]=[S,O,N]):2]>>[c:1][N:2] -|[N;$(N-[#6]):3]=[C;$(C=O):1].[N;$(N[#6]);!$(N=*);!$([N-]);!$(N#*);!$([ND3]);!$([ND4]);!$(N[O,N]);!$(N[C,S]=[S,O,N]):2]>>[N:3]-[C:1]-[N+0:2] -|[N;$(N-[#6]):3]=[C;$(C=S):1].[N;$(N[#6]);!$(N=*);!$([N-]);!$(N#*);!$([ND3]);!$([ND4]);!$(N[O,N]);!$(N[C,S]=[S,O,N]):2]>>[N:3]-[C:1]-[N+0:2] -|[$(C([CH2,CH3])),CH:10](=[O:11])-[NH+0:9]-[C$(C(N)(C)(C)(C)),C$([CH](N)(C)(C)),C$([CH2](N)(C)):8]-[C$(C(c)(C)(C)(C)),C$([CH](c)(C)(C)),C$([CH2](c)(C)):7]-[c:6]1[cH:1][c:2][c:3][c:4][c:5]1>>[C:10]-1=[N+0:9]-[C:8]-[C:7]-[c:6]2[c:5][c:4][c:3][c:2][c:1]-12 -|[$(C([CH2,CH3])),CH:10](=[O:11])-[NH+0:9]-[C$([CH](N)(C)(C)),C$([CH2](N)(C)):8]-[C$([C](c)(C)(C)),C$([CH](c)(C)):7]([O$(OC),OH])-[c:6]1[cH:1][c:2][c:3][c:4][c:5]1>>[c:10]-1[n:9][c:8][c:7][c:6]2[c:5][c:4][c:3][c:2][c:1]-12 -|[NH3+,NH2]-[C$(C(N)(C)(C)(C)),C$([CH](N)(C)(C)),C$([CH2](N)(C)):8]-[C$(C(c)(C)(C)(C)),C$([CH](c)(C)(C)),C$([CH2](c)(C)):7]-[c:6]1[c:1][c:2][nH:3][cH:5]1.[CH:10](-[CX4:12])=[O:11]>>[c,C:12]-[CH:10]-1-[N]-[C:8]-[C:7]-[c:6]2[c:1][c:2][nH:3][c:5]-12 -|[NH2,NH3+1:8]-[c:5]1[cH:4][c:3][c:2][c:1][c:6]1.[Br:18][C$([CH2](C)(Br)),C$([CH](C)(C)(Br)):17]-[C:15](=[O:16])-[c:10]1[c:11][c:12][c:13][c:14][c:9]1>>[c:13]1[c:12][c:11][c:10]([c:9][c:14]1)-[c:15]1[c:17][c:4]2[c:3][c:2][c:1][c:6][c:5]2[nH+0:8]1 -|[Cl:1][CH2:2]-[C$([CH](C)),C$(C(C)(C)):3]=[O:4].[OH:12]-[c:11]1[c:6][c:7][c:8][c:9][c:10]1-[CH:13]=[O:14]>>[C:3](=[O:4])-[c:2]1[c:13][c:10]2[c:9][c:8][c:7][c:6][c:11]2[o:12]1 -|[NH2,NH3+]-[C$([CX4](N)([c,C])([c,C])([c,C])),C$([CH](N)([c,C])([c,C])),C$([CH2](N)([c,C])),C$([CH3](N)):2].[NH2:12]-[c:7]1[c:6][c:5][c:4][c:3][c:8]1-[C:9](-[OH,O-:11])=[O:10]>>[C:2]-[n+0]-1[c:13][n:12][c:7]2[c:6][c:5][c:4][c:3][c:8]2[c:9]-1=[O:10] -|[N$([NH2]([CX4])),N$([NH3+1]([CX4])):1].[O:5]-[C$([CH]([CX4])(C)(O)),C$([CH2]([CX4])(O)):3][C$(C([CX4])(=O)([CX4])),C$([CH]([CX4])(=O)):4]=[O:6]>[O:15]=[C:9]-1-[CH2:10]-[CH2:11]-[CH2:12]-[CH2:13]-[CH2:14]-1>[c:4]1[c:3][n+0:1][c:10]2-[C:11]-[C:12]-[C:13]-[C:14]-[c:9]12 -|[C$(C(=O)([CX4])([CX4])),C$([CH](=O)([CX4])):2](=[O:6])-[C$([CH]([CX4])),C$([CH2]):3]-[C$(C(=O)([CX4])([CX4])),C$([CH](=O)([CX4])):4]=[O:7].[NH2:8]-[C:9](=[O:10])-[CH2:11][C:12]#[N:13]>>[OH:10]-[c:9]1[n:8][c:4][c:3][c:2][c:11]1[C:12]#[N:13] -|[C$(C(#C)([CX4])):2]#[C$(C(#C)([CX4])):1].[N$(N(~N)([CX4])):5]~[N]~[N]>>[c:2]1[c:1][n:5][n][n]1 -|[C$(C(=C)([CX4])):2]=[C$(C(=C)([CX4])):1].[N$(N(~N)([CX4])):5]~[N]~[N]>>[C:2]1[C:1][N:5][N]=[N]1 -|[C$(C(=C)([CX4,OX2,NX3])([CX4,OX2,NX3])),C$([CH](=C)([CX4,OX2,NX3])),C$([CH2](=C)):1]=[C$(C(=C)([CX4,OX2,NX3])([CX4,OX2,NX3])),C$([CH](=C)([CX4,OX2,NX3])),C$([CH2](=C)):2].[C$(C(=C)([CX4,OX2,NX3])([CX4,OX2,NX3])),C$([CH](=C)([CX4,OX2,NX3])),C$([CH2](=C)):3]=[C$([C](=C)(C)([CX4,OX2,NX3])),C$([CH](=C)(C)):4]-[C$([C](=C)(C)([CX4,OX2,NX3])),C$([CH](=C)(C)):5]=[C$(C(=C)([CX4,OX2,NX3])([CX4,OX2,NX3])),C$([CH](=C)([CX4,OX2,NX3])),C$([CH2](=C)):6]>>[C:1]1[C:2][C:3][C:4]=[C:5][C:6]1 -|[C$(C(#C)([CX4,OX2,NX3])),C$([CH](#C)):1]#[C$(C(#C)([CX4,OX2,NX3])),C$([CH](#C)):2].[C$(C(=C)([CX4,OX2,NX3])([CX4,OX2,NX3])),C$([CH](=C)([CX4,OX2,NX3])),C$([CH2](=C)):3]=[C$([C](=C)(C)([CX4,OX2,NX3])),C$([CH](=C)(C)):4]-[C$([C](=C)(C)([CX4,OX2,NX3])),C$([CH](=C)(C)):5]=[C$(C(=C)([CX4,OX2,NX3])([CX4,OX2,NX3])),C$([CH](=C)([CX4,OX2,NX3])),C$([CH2](=C)):6]>>[C:1]1=[C:2][C:3][C:4]=[C:5][C:6]1 -|[NH2,NH3+:3]-[N$([NH](N)([CX4])):2].[C$([CH](C)(C)([CX4])),C$([CH2](C)(C)):6](-[C$(C(=O)(C)([CX4])),C$([CH](=O)(C)):5]=[O:9])-[C$(C(=O)(C)([CX4])),C$([CH](=O)(C)):7]=[O:10]>>[c:7]1[n:3][n:2][c:5][c:6]1 -|[C$(C(=O)(C)([CX4])),C$(C[H](=O)(C)):1](=[O:2])-[$([CH](C)(C)([CX4])),$([CH2](C)(C)):3]-[$([CH](C)(C)([CX4])),$([CH2](C)(C)):4]-[C$(C(=O)(C)([CX4])),C$(C[H](=O)(C)):5]=[O:6].[N$([NH2,NH3+1]([CX4])):7]>>[c:5]1[c:4][c:3][c:1][n+0:7]1 -|[CH:7](=[O:8])-[c:1]1[c:2][c:3][c:4][c:5][c:6]1.[O:24]=[C:23](-[C:22](=[O:25])-[c:15]1[c:10][c:11][c:12][c:13][c:14]1)-[c:20]1[c:21][c:16][c:17][c:18][c:19]1>[NH4].[O-]C(=O)C>[nH:27]-1[c:7]([n:26][c:23]([c:22]-1[c:15]1[c:10][c:11][c:12][c:13][c:14]1)-[c:20]1[c:21][c:16][c:17][c:18][c:19]1)-[c:1]1[c:2][c:3][c:4][c:5][c:6]1 -|[OH:7]-[c:6]1[cH:1][c:2][c:3][c:4][c:5]1.[O$(O(C)([CX4])):12]-[C:11](=[O:15])-[C$([CH](C)(C)([CX4])),C$([CH2](C)(C)):10]-[C:8]=[O:16]>>[C:8]-1=[C:10]-[C:11](=[O:15])-[O]-[c:6]2[c:5][c:4][c:3][c:2][c:1]-12 -|[O$(O(C)([CX4])):8][C:7](=[O:9])[CH:6][C:5][C:4][C:3][C:2]([O$(O(C)([CX4])):10])=[O:1]>>[O:8][C:7](=[O:9])[C:6]1[C:5][C:4][C:3][C:2]1=[O:1] -|[O$(O(C)([CX4])):8][C:7](=[O:9])[CH:6][C:5][C:11][C:4][C:3][C:2]([O$(O(C)([CX4])):10])=[O:1]>>[O:8][C:7](=[O:9])[C:6]1[C:5][C:11][C:4][C:3][C:2]1=[O:1] -|[Cl:9][C:7](=[O:8])-[c:3]1[c:2][c:1][c:6][c:5][c:4]1.[C$([CH2](C)([CX4])),C$([CH3](C)):18]-[C:16](=[O:17])-[c:14]1[c:13][c:12][c:11][c:10][c:15]1-[OH:19]>>[O:17]=[C:16]-1-[C:18]=[C:7](-[O:8]-[c:15]2[c:10][c:11][c:12][c:13][c:14]-12)-[c:3]1[c:2][c:1][c:6][c:5][c:4]1 -|[C$(C(C)(=O)([CX4,OX2&H0])),C$(C(C)(#N)),N$([N+1](C)(=O)([O-1])):1][C$([CH]([C,N])([C,N])([CX4])),C$([CH2]([C,N])([C,N])):2][C$(C(C)(=O)([CX4,OX2&H0])),C$(C(C)(#N)),N$([N+1](C)(=O)([O-1])):3].[C$(C(C)(#N)),C$(C(C)([CX4,OX2&H0])([CX4,OX2&H0])([OX2&H0])),C$([CH](C)([CX4,OX2&H0])([OX2&H0])),C$([CH2](C)([OX2&H0])),C$(C(C)(=O)([OX2&H0])):6][CH:5]=[C$(C(=C)([CX4])([CX4])),C$([CH](=C)([CX4])),C$([CH2](=C)):4]>>[C:6][C:5][C:4][C:2]([C:1])[C:3] -|[C$([C](O)([CX4])([CX4])([CX4])),C$([CH](O)([CX4])([CX4])),C$([CH2](O)([CX4])):4]-[O:3]-[C$(C(=O)([CX4])),C$([CH](=O)):2]=[O:5].[C$([CH](C)([CX4])([CX4])),C$([CH2](C)([CX4])),C$([CH3](C)):7]-[C$(C(=O)([CX4])),C$([CH](=O)):8]=[O:9]>>[C:7](-[C:2]=[O:5])-[C:8]=[O:9] -|[Cl,OH,O-:3][C$(C(=O)([CX4,c])),C$([CH](=O)):2]=[O:4].[O$([OH]([CX4,c])),O$([OH]([CX4,c])([CX4,c])),S$([SH]([CX4,c])),S$([SH]([CX4,c])([CX4,c])):6]>>[*:6]-[C:2]=[O:4] -|[C$(C(=O)([CX4,c])([CX4,c])),C$([CH](=O)([CX4,c])):1]=[O:2].[N$([NH2,NH3+1]([CX4,c])),N$([NH]([CX4,c])([CX4,c])):3]>>[N+0:3][C:1] -|[Br:1][c$(c(Br)),n$(n(Br)),o$(o(Br)),C$([CH](Br)(=C)):2].[C$(C(B)([CX4])([CX4])([CX4])),C$([CH](B)([CX4])([CX4])),C$([CH2](B)([CX4])),C$([CH2](B)),C$(C(B)(=C)),c$(c(B)),o$(o(B)),n$(n(B)):3][B$(B([C,c,n,o])([OH,$(OC)])([OH,$(OC)])),B$([B-1]([C,c,n,o])(N)([OH,$(OC)])([OH,$(OC)])):4]>>[C,c,n,o:2][C,c,n,o:3] -|[Br,I:1][C$(C([Br,I])([CX4])([CX4])([CX4])),C$([CH]([Br,I])([CX4])([CX4])),C$([CH2]([Br,I])([CX4])),C$([CH3]([Br,I])),C$([C]([Br,I])(=C)([CX4])),C$([CH]([Br,I])(=C)),C$(C([Br,I])(#C)),c$(c([Br,I])):2].[Br,I:3][C$(C([Br,I])([CX4])([CX4])([CX4])),C$([CH]([Br,I])([CX4])([CX4])),C$([CH2]([Br,I])([CX4])),C$([CH3]([Br,I])),C$([C]([Br,I])(=C)([CX4])),C$([CH]([Br,I])(=C)),C$(C([Br,I])(#C)),c$(c([Br,I])):4]>>[C,c:2][C,c:4] -|[OH,O-]-[C$(C(=O)(O)([CX4,c])):2]=[O:3].[OH:8]-[C$([CH](O)([CX4,c])([CX4,c])),C$([CH2](O)([CX4,c])),C$([CH3](O)):6]>>[C:6][O]-[C:2]=[O:3] -|[C$([CH](=C)([CX4])),C$([CH2](=C)):2]=[C$(C(=C)([CX4])([CX4])),C$([CH](=C)([CX4])),C$([CH2](=C)):3].[Br,I:7][C$([CX4]([Br,I])),c$([c]([Br,I])):4]>>[C,c:4][C:2]=[C:3] -|[Cl,OH,O-:3][C$(C(=O)([CX4,c])),C$([CH](=O)):2]=[O:4].[N$([NH2,NH3+1]([CX4,c])),N$([NH]([CX4,c])([CX4,c])):6]>>[N+0:6]-[C:2]=[O:4] -|[C$(C(=C)([CX4])([CX4])),C$([CH](=C)([CX4])),C$([CH2](=C)):1]=[C$(C(=C)([CX4])([CX4])),C$([CH](=C)([CX4])),C$([CH2](=C)):2].[SH:4]-[CX4:5][Br,Cl,I]>>[C:1]-[C:2]-[S:4][C:5] -|[C$([C](=O)([CX4])),C$([CH](=O)):2](=[O:1])[OH,Cl,O-:6].[SH:4]-[CX4:5][Br,Cl,I]>>[CH2:2]-[S:4][C:5] -|[I:1][C$(C(I)([CX4,c])([CX4,c])([CX4,c])),C$([CH](I)([CX4,c])([CX4,c])),C$([CH2](I)([CX4,c])),C$([CH3](I)):2].[C$(C(=O)([Cl,OH,O-])([CX4,c])),C$([CH]([Cl,OH,O-])(=O)):3](=[O:6])[Cl,OH,O-:5]>>[C:2]-[C:3]=[O:6] -|[Cl:5][S$(S(=O)(=O)(Cl)([CX4])):2](=[O:3])=[O:4].[NH2+0,NH3+:6]-[C$(C(N)([CX4,c])([CX4,c])([CX4,c])),C$([CH](N)([CX4,c])([CX4,c])),C$([CH2](N)([CX4,c])),C$([CH3](N)),c$(c(N)):7]>>[C,c:7]-[NH+0:6][S:2](=[O:4])=[O:3] -|[*:1][C:2]#[CH:3].[Br,I:4][C$(C([CX4,c])([CX4,c])([CX4,c])),C$([CH]([CX4,c])([CX4,c])),C$([CH2]([CX4,c])),C$([CH3]),c$(c):5]>>[C,c:5][C:3]#[C:2][*:1] -|[C$(C(C)([CX4])([CX4])([CX4])),C$([CH](C)([CX4])([CX4])),C$([CH2](C)([CX4])),C$([CH3](C)):1][C:2]#[CH:3].[Br,I:4][C$(C(=O)([Br,I])([CX4])),C$([CH](=O)([Br,I])):5]=[O:6]>>[C:1][C:2]#[C:3][C:5]=[O:6] -|[OH,O-:4]-[C$(C(=O)([OH,O-])([CX4])),C$([CH](=O)([OH,O-])):2]=[O:3]>>[Cl:5][C:2]=[O:3] -|[OH:2]-[$([CX4]),c:1]>>[Br:3][C,c:1] -|[OH:2]-[$([CX4]),c:1]>>[Cl:3][C,c:1] -|[OH,O-:3][S$(S([CX4])):2](=[O:4])=[O:5]>>[Cl:6][S:2](=[O:5])=[O:4] -|[OH+0,O-:5]-[C:3](=[O:4])-[C$([CH]([CX4])),C$([CH2]):2]>>[OH+0,O-:5]-[C:3](=[O:4])-[C:2]([Br:6]) -|[OH+0,O-:5]-[C:3](=[O:4])-[C$([CH]([CX4])),C$([CH2]):2]>>[OH+0,O-:5]-[C:3](=[O:4])-[C:2]([Cl:6]) -|[Cl,I,Br:7][c:1]1[c:2][c:3][c:4][c:5][c:6]1>>[N:9]#[C:8][c:1]1[c:2][c:3][c:4][c:5][c:6]1 -|[OH,NH2,NH3+:3]-[CH2:2]-[C$(C([CX4,c])([CX4,c])([CX4,c])),C$([CH]([CX4,c])([CX4,c])),C$([CH2]([CX4,c])),C$([CH3]),c$(c):1]>>[C,c:1][C:2]#[N:4] From 1dca7581380dad7d83c699b663ec4857f8c86785 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 24 Aug 2022 11:32:50 -0400 Subject: [PATCH 051/302] update paths in config --- src/syn_net/config.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/syn_net/config.py b/src/syn_net/config.py index 2273b0af..7f563286 100644 --- a/src/syn_net/config.py +++ b/src/syn_net/config.py @@ -1,18 +1,24 @@ """Central place for all configuration, paths, and parameter.""" -DATA_DIR = "database" -ASSETS_DIR = "database/assets" +DATA_DIR = "data" +ASSETS_DIR = "data/assets" # BUILDING_BLOCKS_RAW_DIR = f"{ASSETS_DIR}/building-blocks" REACTION_TEMPLATE_DIR = f"{ASSETS_DIR}/reaction-templates" # Pre-processed data -DATA_PREPROCESS_DIR = "database/pre-process" -DATA_EMBEDDINGS_DIR = "database/pre-process/embeddings" +DATA_PREPROCESS_DIR = "data/pre-process" +DATA_EMBEDDINGS_DIR = "data/pre-process/embeddings" # Prepared data -DATA_PREPARED_DIR = "database/prepared" +DATA_PREPARED_DIR = "data/prepared" # Prepared data -DATA_FEATURIZED_DIR = "database/featurized" \ No newline at end of file +DATA_FEATURIZED_DIR = "data/featurized" + +# Results +DATA_RESULT_DIR = "results" + +# Checkpoints (& pre-trained weights) +CHECKPOINTS_DIR = "checkpoints" # \ No newline at end of file From 6958c5909a45023b26758142e56922c94828399f Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 24 Aug 2022 11:43:25 -0400 Subject: [PATCH 052/302] add sample targets (smiles) for inference --- data/assets/molecules/sample-targets.txt | 10 ++++++++++ 1 file changed, 10 insertions(+) create mode 100644 data/assets/molecules/sample-targets.txt diff --git a/data/assets/molecules/sample-targets.txt b/data/assets/molecules/sample-targets.txt new file mode 100644 index 00000000..4d4aa219 --- /dev/null +++ b/data/assets/molecules/sample-targets.txt @@ -0,0 +1,10 @@ +COc1cc(Cn2c(C)c(Cc3ccccc3)c3c2CCCC3)ccc1OCC(=O)N(C)C +CCC1CCCC(Nc2cc(C(F)(F)F)c(Cl)cc2SC)CC1 +Clc1cc(Cl)c(C2=NC(c3cccc4c(Br)cccc34)=NN2)nn1 +COc1ccc(S(=O)(=O)c2ccc(-c3nc(-c4cc(B(O)O)ccc4O)no3)cn2)cc1 +CNS(=O)(=O)c1ccc(-c2cc3c4c(ccc3[nH]2)CCCN4C(N)=O)cc1 +CC(NC(=O)C1Cn2c(O)nnc2CN1)c1cc(F)ccc1N1CCC(n2nnn(-c3ccc(Br)cc3)c2=S)CC1 +COc1cc(-c2nc(-c3ccccc3)c(-c3ccccc3)s2)ccn1 +CCCn1c(C)nnc1CC(C)(O)C(=C(C)C)c1nccnc1S(=O)(=O)F +CN(c1ccccc1)c1ccc(-c2nc3ncccc3s2)cn1 +COc1cc(-c2nc(-c3ccc(F)cc3)c(-c3ccc(F)cc3)n2c2cc(Cl)ccc2Cl)ccc1Oc1ccc(S(=O)(=O)N2CCCCC2)cc1[N+](=O)[O-] From f6f2afe39c1f3228631a7655191ab06f790a0a45 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 24 Aug 2022 11:45:23 -0400 Subject: [PATCH 053/302] avoid absolute paths --- scripts/predict_mp.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/scripts/predict_mp.py b/scripts/predict_mp.py index 17e3284c..649c78a5 100644 --- a/scripts/predict_mp.py +++ b/scripts/predict_mp.py @@ -6,7 +6,9 @@ import pandas as pd import scripts._mp_predict as predict from syn_net.utils.data_utils import SyntheticTreeSet +from pathlib import Path +from syn_net.config import DATA_RESULT_DIR if __name__ == '__main__': @@ -47,12 +49,15 @@ print(f"Average similarity {args.data}: {np.mean(np.array(similaritys))}") print('Saving ......') - save_path = '../results/' + args.rxn_template + '_' + args.featurize + '/' + out_dir = Path(DATA_RESULT_DIR) / f"{args.rxn_template}_{args.featurize}" + out_dir.mkdir(exist_ok=1,parent=1) df = pd.DataFrame({'query SMILES': smis_query, 'decode SMILES': smis_decoded, 'similarity': similaritys}) - df.to_csv(save_path + 'decode_result_' + args.data + '.csv.gz', compression='gzip', index=False) + file = out_dir / f'decode_result_{args.data}.csv.gz' + df.to_csv(file, compression='gzip', index=False) synthetic_tree_set = SyntheticTreeSet(sts=trees) - synthetic_tree_set.save(save_path + 'decoded_st_' + args.data + '.json.gz') + file = out_dir / f'decoded_st_{args.data}.json.gz' + synthetic_tree_set.save(file) print('Finish!') From 4744b5c0f3bd1b813ec90deb18a68e01cb86fa4f Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Thu, 25 Aug 2022 11:07:30 -0400 Subject: [PATCH 054/302] wip refactor: no abs path & use dicts instead of ifs --- src/syn_net/models/act.py | 86 +++++++++---------- src/syn_net/models/rt1.py | 82 +++++++++---------- src/syn_net/models/rt2.py | 115 ++++++++++---------------- src/syn_net/models/rxn.py | 168 ++++++++++++++++---------------------- 4 files changed, 187 insertions(+), 264 deletions(-) diff --git a/src/syn_net/models/act.py b/src/syn_net/models/act.py index b2cd6c9f..2a899a7f 100644 --- a/src/syn_net/models/act.py +++ b/src/syn_net/models/act.py @@ -7,8 +7,8 @@ from pytorch_lightning import loggers as pl_loggers from syn_net.models.mlp import MLP, load_array from scipy import sparse - - +from syn_net.config import DATA_FEATURIZED_DIR +from pathlib import Path if __name__ == '__main__': import argparse @@ -21,9 +21,9 @@ help="Radius for Morgan fingerprint.") parser.add_argument("--nbits", type=int, default=4096, help="Number of Bits for Morgan fingerprint.") - parser.add_argument("--out_dim", type=int, default=300, + parser.add_argument("--out_dim", type=int, default=256, help="Output dimension.") - parser.add_argument("--ncpu", type=int, default=8, + parser.add_argument("--ncpu", type=int, default=16, help="Number of cpus") parser.add_argument("--batch_size", type=int, default=64, help="Batch size") @@ -31,62 +31,54 @@ help="Maximum number of epoches.") args = parser.parse_args() - if args.out_dim == 300: - validation_option = 'nn_accuracy_gin' - elif args.out_dim == 4096: - validation_option = 'nn_accuracy_fp_4096' - elif args.out_dim == 256: - validation_option = 'nn_accuracy_fp_256' - elif args.out_dim == 200: - validation_option = 'nn_accuracy_rdkit2d' - else: - raise ValueError + # Helper to select validation func based on output dim + VALIDATION_OPTS = { + 300: "nn_accuracy_gin", + 4096: "nn_accuracy_fp_4096", + 256: "nn_accuracy_fp_256", + 200: "nn_accuracy_rdkit2d", + } + validation_option = VALIDATION_OPTS[args.out_dim] - main_dir = f'/pool001/whgao/data/synth_net/{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_{validation_option[12:]}/' + id = f'{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_{validation_option[12:]}/' + main_dir = Path(DATA_FEATURIZED_DIR) / id batch_size = args.batch_size ncpu = args.ncpu - X = sparse.load_npz(main_dir + 'X_act_train.npz') - y = sparse.load_npz(main_dir + 'y_act_train.npz') + X = sparse.load_npz(main_dir / 'X_act_train.npz') + y = sparse.load_npz(main_dir / 'y_act_train.npz') X = torch.Tensor(X.A) y = torch.LongTensor(y.A.reshape(-1, )) train_data_iter = load_array((X, y), batch_size, ncpu=ncpu, is_train=True) - X = sparse.load_npz(main_dir + 'X_act_valid.npz') - y = sparse.load_npz(main_dir + 'y_act_valid.npz') + X = sparse.load_npz(main_dir / 'X_act_valid.npz') + y = sparse.load_npz(main_dir / 'y_act_valid.npz') X = torch.Tensor(X.A) y = torch.LongTensor(y.A.reshape(-1, )) valid_data_iter = load_array((X, y), batch_size, ncpu=ncpu, is_train=False) pl.seed_everything(0) - if args.featurize == 'fp': - mlp = MLP(input_dim=int(3 * args.nbits), - output_dim=4, - hidden_dim=1000, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='classification', - loss='cross_entropy', - valid_loss='accuracy', - optimizer='adam', - learning_rate=1e-4, - val_freq=10, - ncpu=ncpu) - elif args.featurize == 'gin': - mlp = MLP(input_dim=int(2 * args.nbits + args.out_dim), - output_dim=4, - hidden_dim=1000, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='classification', - loss='cross_entropy', - valid_loss='accuracy', - optimizer='adam', - learning_rate=1e-4, - val_freq=10, - ncpu=ncpu) + INPUT_DIMS = { + "fp": int(3 * args.nbits), + "gin" : int(2 * args.nbits + args.out_dim) + } # somewhat constant... + + input_dims = INPUT_DIMS[args.featurize] + + mlp = MLP(input_dim=input_dims, + output_dim=4, + hidden_dim=1000, + num_layers=5, + dropout=0.5, + num_dropout_layers=1, + task='classification', + loss='cross_entropy', + valid_loss='accuracy', + optimizer='adam', + learning_rate=1e-4, + val_freq=10, + ncpu=ncpu) + tb_logger = pl_loggers.TensorBoardLogger(f'act_{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_logs/') trainer = pl.Trainer(gpus=[0], max_epochs=args.epoch, progress_bar_refresh_rate=20, logger=tb_logger) diff --git a/src/syn_net/models/rt1.py b/src/syn_net/models/rt1.py index 1fe8026c..d622ce74 100644 --- a/src/syn_net/models/rt1.py +++ b/src/syn_net/models/rt1.py @@ -8,6 +8,8 @@ from pytorch_lightning import loggers as pl_loggers from syn_net.models.mlp import MLP, load_array from scipy import sparse +from syn_net.config import DATA_FEATURIZED_DIR +from pathlib import Path if __name__ == '__main__': @@ -24,7 +26,7 @@ help="Number of Bits for Morgan fingerprint.") parser.add_argument("--out_dim", type=int, default=256, help="Output dimension.") - parser.add_argument("--ncpu", type=int, default=8, + parser.add_argument("--ncpu", type=int, default=16, help="Number of cpus") parser.add_argument("--batch_size", type=int, default=64, help="Batch size") @@ -32,63 +34,55 @@ help="Maximum number of epoches.") args = parser.parse_args() - if args.out_dim == 300: - validation_option = 'nn_accuracy_gin' - elif args.out_dim == 4096: - validation_option = 'nn_accuracy_fp_4096' - elif args.out_dim == 256: - validation_option = 'nn_accuracy_fp_256' - elif args.out_dim == 200: - validation_option = 'nn_accuracy_rdkit2d' - else: - raise ValueError + # Helper to select validation func based on output dim + VALIDATION_OPTS = { + 300: "nn_accuracy_gin", + 4096: "nn_accuracy_fp_4096", + 256: "nn_accuracy_fp_256", + 200: "nn_accuracy_rdkit2d", + } + validation_option = VALIDATION_OPTS[args.out_dim] - main_dir = f'/pool001/whgao/data/synth_net/{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_{validation_option[12:]}/' + id = f'{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_{validation_option[12:]}/' + main_dir = Path(DATA_FEATURIZED_DIR) / id batch_size = args.batch_size ncpu = args.ncpu - X = sparse.load_npz(main_dir + 'X_rt1_train.npz') - y = sparse.load_npz(main_dir + 'y_rt1_train.npz') + X = sparse.load_npz(main_dir / 'X_rt1_train.npz') + y = sparse.load_npz(main_dir / 'y_rt1_train.npz') X = torch.Tensor(X.A) y = torch.Tensor(y.A) train_data_iter = load_array((X, y), batch_size, ncpu=ncpu, is_train=True) - X = sparse.load_npz(main_dir + 'X_rt1_valid.npz') - y = sparse.load_npz(main_dir + 'y_rt1_valid.npz') + X = sparse.load_npz(main_dir / 'X_rt1_valid.npz') + y = sparse.load_npz(main_dir / 'y_rt1_valid.npz') X = torch.Tensor(X.A) y = torch.Tensor(y.A) _idx = np.random.choice(list(range(X.shape[0])), size=int(X.shape[0]/10), replace=False) valid_data_iter = load_array((X[_idx], y[_idx]), batch_size, ncpu=ncpu, is_train=False) pl.seed_everything(0) - if args.featurize == 'fp': - mlp = MLP(input_dim=int(3 * args.nbits), - output_dim=args.out_dim, - hidden_dim=1200, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='regression', - loss='mse', - valid_loss=validation_option, - optimizer='adam', - learning_rate=1e-4, - val_freq=10, - ncpu=ncpu) - elif args.featurize == 'gin': - mlp = MLP(input_dim=int(2 * args.nbits + args.out_dim), - output_dim=args.out_dim, - hidden_dim=1200, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='regression', - loss='mse', - valid_loss=validation_option, - optimizer='adam', - learning_rate=1e-4, - val_freq=10, - ncpu=ncpu) + INPUT_DIMS = { + "fp": int(3 * args.nbits), + "gin" : int(2 * args.nbits + args.out_dim) + } # somewhat constant... + + input_dims = INPUT_DIMS[args.featurize] + + mlp = MLP(input_dim=input_dims, + output_dim=args.out_dim, + hidden_dim=1200, + num_layers=5, + dropout=0.5, + num_dropout_layers=1, + task='regression', + loss='mse', + valid_loss=validation_option, + optimizer='adam', + learning_rate=1e-4, + val_freq=10, + ncpu=ncpu) + tb_logger = pl_loggers.TensorBoardLogger( f'rt1_{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_{validation_option[12:]}_logs/' ) diff --git a/src/syn_net/models/rt2.py b/src/syn_net/models/rt2.py index 40ca5237..19991e82 100644 --- a/src/syn_net/models/rt2.py +++ b/src/syn_net/models/rt2.py @@ -8,7 +8,8 @@ from pytorch_lightning import loggers as pl_loggers from syn_net.models.mlp import MLP, load_array from scipy import sparse - +from syn_net.config import DATA_FEATURIZED_DIR +from pathlib import Path if __name__ == '__main__': @@ -32,94 +33,60 @@ help="Maximum number of epoches.") args = parser.parse_args() - if args.out_dim == 300: - validation_option = 'nn_accuracy_gin' - elif args.out_dim == 4096: - validation_option = 'nn_accuracy_fp_4096' - elif args.out_dim == 256: - validation_option = 'nn_accuracy_fp_256' - elif args.out_dim == 200: - validation_option = 'nn_accuracy_rdkit2d' - else: - raise ValueError + # Helper to select validation func based on output dim + VALIDATION_OPTS = { + 300: "nn_accuracy_gin", + 4096: "nn_accuracy_fp_4096", + 256: "nn_accuracy_fp_256", + 200: "nn_accuracy_rdkit2d", + } + validation_option = VALIDATION_OPTS[args.out_dim] - main_dir = f'/pool001/whgao/data/synth_net/{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_{validation_option[12:]}/' + id = f'{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_{validation_option[12:]}/' + main_dir = Path(DATA_FEATURIZED_DIR) / id batch_size = args.batch_size ncpu = args.ncpu - X = sparse.load_npz(main_dir + 'X_rt2_train.npz') - y = sparse.load_npz(main_dir + 'y_rt2_train.npz') + X = sparse.load_npz(main_dir / 'X_rt2_train.npz') + y = sparse.load_npz(main_dir / 'y_rt2_train.npz') X = torch.Tensor(X.A) y = torch.Tensor(y.A) train_data_iter = load_array((X, y), batch_size, ncpu=ncpu, is_train=True) - X = sparse.load_npz(main_dir + 'X_rt2_valid.npz') - y = sparse.load_npz(main_dir + 'y_rt2_valid.npz') + X = sparse.load_npz(main_dir / 'X_rt2_valid.npz') + y = sparse.load_npz(main_dir / 'y_rt2_valid.npz') X = torch.Tensor(X.A) y = torch.Tensor(y.A) _idx = np.random.choice(list(range(X.shape[0])), size=int(X.shape[0]/10), replace=False) valid_data_iter = load_array((X[_idx], y[_idx]), batch_size, ncpu=ncpu, is_train=False) pl.seed_everything(0) - if args.featurize == 'fp': - if args.rxn_template == 'hb': - mlp = MLP(input_dim=int(4 * args.nbits + 91), - output_dim=args.out_dim, - hidden_dim=3000, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='regression', - loss='mse', - valid_loss=validation_option, - optimizer='adam', - learning_rate=1e-4, - val_freq=10, - ncpu=ncpu) - elif args.rxn_template == 'pis': - mlp = MLP(input_dim=int(4 * args.nbits + 4700), - output_dim=args.out_dim, - hidden_dim=3000, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='regression', - loss='mse', - valid_loss=validation_option, - optimizer='adam', - learning_rate=1e-4, - val_freq=10, - ncpu=ncpu) - elif args.featurize == 'gin': - if args.rxn_template == 'hb': - mlp = MLP(input_dim=int(3 * args.nbits + args.out_dim + 91), - output_dim=args.out_dim, - hidden_dim=3000, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='regression', - loss='mse', - valid_loss=validation_option, - optimizer='adam', - learning_rate=1e-4, - val_freq=10, - ncpu=ncpu) - elif args.rxn_template == 'pis': - mlp = MLP(input_dim=int(3 * args.nbits + args.out_dim + 4700), - output_dim=args.out_dim, - hidden_dim=3000, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='regression', - loss='mse', - valid_loss=validation_option, - optimizer='adam', - learning_rate=1e-4, - val_freq=10, - ncpu=ncpu) + INPUT_DIMS = { + "fp": { + "hb": int(4 * args.nbits + 91), + "gin": int(4 * args.nbits + 4700), + }, + "gin" : { + "hb": int(3 * args.nbits + args.out_dim + 91), + "gin": int(3 * args.nbits + args.out_dim + 4700), + } + } # somewhat constant... + input_dims = INPUT_DIMS[args.featurize][args.rxn_template] + + mlp = MLP(input_dim=input_dims, + output_dim=args.out_dim, + hidden_dim=3000, + num_layers=5, + dropout=0.5, + num_dropout_layers=1, + task='regression', + loss='mse', + valid_loss=validation_option, + optimizer='adam', + learning_rate=1e-4, + val_freq=10, + ncpu=ncpu) tb_logger = pl_loggers.TensorBoardLogger( f'rt2_{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_{validation_option[12:]}_logs/' diff --git a/src/syn_net/models/rxn.py b/src/syn_net/models/rxn.py index 69f7ce87..b659810e 100644 --- a/src/syn_net/models/rxn.py +++ b/src/syn_net/models/rxn.py @@ -7,7 +7,8 @@ from pytorch_lightning import loggers as pl_loggers from syn_net.models.mlp import MLP, load_array from scipy import sparse - +from syn_net.config import DATA_FEATURIZED_DIR, CHECKPOINTS_DIR +from pathlib import Path if __name__ == '__main__': @@ -35,119 +36,88 @@ help="Version") args = parser.parse_args() - if args.out_dim == 300: - validation_option = 'nn_accuracy_gin' - elif args.out_dim == 4096: - validation_option = 'nn_accuracy_fp_4096' - elif args.out_dim == 256: - validation_option = 'nn_accuracy_fp_256' - elif args.out_dim == 200: - validation_option = 'nn_accuracy_rdkit2d' - else: - raise ValueError - - main_dir = f'/pool001/whgao/data/synth_net/{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_{validation_option[12:]}/' + # Helper to select validation func based on output dim + VALIDATION_OPTS = { + 300: "nn_accuracy_gin", + 4096: "nn_accuracy_fp_4096", + 256: "nn_accuracy_fp_256", + 200: "nn_accuracy_rdkit2d", + } + validation_option = VALIDATION_OPTS[args.out_dim] + + id = f'{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_{validation_option[12:]}/' + main_dir = Path(DATA_FEATURIZED_DIR) / id batch_size = args.batch_size ncpu = args.ncpu - X = sparse.load_npz(main_dir + 'X_rxn_train.npz') - y = sparse.load_npz(main_dir + 'y_rxn_train.npz') + X = sparse.load_npz(main_dir / 'X_rxn_train.npz') + y = sparse.load_npz(main_dir / 'y_rxn_train.npz') X = torch.Tensor(X.A) y = torch.LongTensor(y.A.reshape(-1, )) train_data_iter = load_array((X, y), batch_size, ncpu=ncpu, is_train=True) - X = sparse.load_npz(main_dir + 'X_rxn_valid.npz') - y = sparse.load_npz(main_dir + 'y_rxn_valid.npz') + X = sparse.load_npz(main_dir / 'X_rxn_valid.npz') + y = sparse.load_npz(main_dir / 'y_rxn_valid.npz') X = torch.Tensor(X.A) y = torch.LongTensor(y.A.reshape(-1, )) valid_data_iter = load_array((X, y), batch_size, ncpu=ncpu, is_train=False) pl.seed_everything(0) - param_path = f'/pool001/rociomer/data/pre-trained-models/{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_v{args.version}/' + param_path = Path(CHECKPOINTS_DIR) / f"{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_v{args.version}/" path_to_rxn = f'{param_path}rxn.ckpt' - if not args.restart: - if args.featurize == 'fp': - if args.rxn_template == 'hb': - mlp = MLP(input_dim=int(4 * args.nbits), - output_dim=91, - hidden_dim=3000, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='classification', - loss='cross_entropy', - valid_loss='accuracy', - optimizer='adam', - learning_rate=1e-4, - val_freq=10, - ncpu=ncpu) - elif args.rxn_template == 'pis': - mlp = MLP(input_dim=int(4 * args.nbits), - output_dim=4700, - hidden_dim=4500, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='classification', - loss='cross_entropy', - valid_loss='accuracy', - optimizer='adam', - learning_rate=1e-4, - val_freq=10, - ncpu=ncpu) - elif args.featurize == 'gin': - if args.rxn_template == 'hb': - mlp = MLP(input_dim=int(3 * args.nbits + args.out_dim), - output_dim=91, - hidden_dim=3000, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='classification', - loss='cross_entropy', - valid_loss='accuracy', - optimizer='adam', - learning_rate=1e-4, - val_freq=10, - ncpu=ncpu) - elif args.rxn_template == 'pis': - mlp = MLP(input_dim=int(3 * args.nbits + args.out_dim), - output_dim=4700, - hidden_dim=3000, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='classification', - loss='cross_entropy', - valid_loss='accuracy', - optimizer='adam', - learning_rate=1e-4, - val_freq=10, - ncpu=ncpu) - else: - if args.rxn_template == 'hb': - mlp = MLP.load_from_checkpoint( - path_to_rxn, - input_dim=int(4 * args.nbits), - output_dim=91, - hidden_dim=3000, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='classification', - loss='cross_entropy', - valid_loss='accuracy', - optimizer='adam', - learning_rate=1e-4, - ncpu=ncpu - ) - elif args.rxn_template == 'pis': - mlp = MLP.load_from_checkpoint( + INPUT_DIMS = { + "fp": { + "hb": int(4 * args.nbits), + "gin": int(4 * args.nbits), + }, + "gin" : { + "hb": int(3 * args.nbits + args.out_dim), + "gin": int(3 * args.nbits + args.out_dim), + } + } # somewhat constant... + input_dim = INPUT_DIMS[args.featurize][args.rxn_template] + + HIDDEN_DIMS = { + "fp": { + "hb": 3000, + "gin": 4500, + }, + "gin" : { + "hb": 3000, + "gin": 3000, + } + } + hidden_dim = HIDDEN_DIMS[args.featurize][args.rxn_template] + + OUTPUT_DIMS = { + "hb": 91, + "gin": 4700, + } + output_dim = OUTPUT_DIMS[args.rxn_template] + + + if not args.restart: + mlp = MLP(input_dim=input_dim, + output_dim=output_dim, + hidden_dim=hidden_dim, + num_layers=5, + dropout=0.5, + num_dropout_layers=1, + task='classification', + loss='cross_entropy', + valid_loss='accuracy', + optimizer='adam', + learning_rate=1e-4, + val_freq=10, + ncpu=ncpu, + ) + else: # load from checkpt -> only for fp, not gin + mlp = MLP.load_from_checkpoint( path_to_rxn, - input_dim=int(4 * args.nbits), - output_dim=4700, - hidden_dim=4500, + input_dim=input_dim, + output_dim=output_dim, + hidden_dim=hidden_dim, num_layers=5, dropout=0.5, num_dropout_layers=1, From a62673c0dc03c1a6288992ece19c661cdc2416cd Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Fri, 26 Aug 2022 10:52:21 -0400 Subject: [PATCH 055/302] refactor: move cli args to single function --- src/syn_net/models/act.py | 22 +++------------------- src/syn_net/models/common.py | 27 +++++++++++++++++++++++++++ src/syn_net/models/mlp.py | 16 +++++++++------- src/syn_net/models/rt1.py | 22 +++------------------- src/syn_net/models/rt2.py | 23 +++-------------------- src/syn_net/models/rxn.py | 26 ++------------------------ 6 files changed, 47 insertions(+), 89 deletions(-) create mode 100644 src/syn_net/models/common.py diff --git a/src/syn_net/models/act.py b/src/syn_net/models/act.py index 2a899a7f..ee40fa61 100644 --- a/src/syn_net/models/act.py +++ b/src/syn_net/models/act.py @@ -9,27 +9,11 @@ from scipy import sparse from syn_net.config import DATA_FEATURIZED_DIR from pathlib import Path +from syn_net.models.common import get_args + if __name__ == '__main__': - import argparse - parser = argparse.ArgumentParser() - parser.add_argument("-f", "--featurize", type=str, default='fp', - help="Choose from ['fp', 'gin']") - parser.add_argument("-r", "--rxn_template", type=str, default='hb', - help="Choose from ['hb', 'pis']") - parser.add_argument("--radius", type=int, default=2, - help="Radius for Morgan fingerprint.") - parser.add_argument("--nbits", type=int, default=4096, - help="Number of Bits for Morgan fingerprint.") - parser.add_argument("--out_dim", type=int, default=256, - help="Output dimension.") - parser.add_argument("--ncpu", type=int, default=16, - help="Number of cpus") - parser.add_argument("--batch_size", type=int, default=64, - help="Batch size") - parser.add_argument("--epoch", type=int, default=2000, - help="Maximum number of epoches.") - args = parser.parse_args() + args = get_args() # Helper to select validation func based on output dim VALIDATION_OPTS = { diff --git a/src/syn_net/models/common.py b/src/syn_net/models/common.py new file mode 100644 index 00000000..affb28ca --- /dev/null +++ b/src/syn_net/models/common.py @@ -0,0 +1,27 @@ +"""Common methods and params shared by all models. +""" + +def get_args(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("-f", "--featurize", type=str, default='fp', + help="Choose from ['fp', 'gin']") + parser.add_argument("-r", "--rxn_template", type=str, default='hb', + help="Choose from ['hb', 'pis']") + parser.add_argument("--radius", type=int, default=2, + help="Radius for Morgan fingerprint.") + parser.add_argument("--nbits", type=int, default=4096, + help="Number of Bits for Morgan fingerprint.") + parser.add_argument("--out_dim", type=int, default=256, + help="Output dimension.") + parser.add_argument("--ncpu", type=int, default=16, + help="Number of cpus") + parser.add_argument("--batch_size", type=int, default=64, + help="Batch size") + parser.add_argument("--epoch", type=int, default=2000, + help="Maximum number of epoches.") + parser.add_argument("--restart", type=bool, default=False, + help="Indicates whether to restart training.") + parser.add_argument("-v", "--version", type=int, default=1, + help="Version") + return parser.parse_args() diff --git a/src/syn_net/models/mlp.py b/src/syn_net/models/mlp.py index 2ad4d52e..048cc55c 100644 --- a/src/syn_net/models/mlp.py +++ b/src/syn_net/models/mlp.py @@ -1,16 +1,18 @@ """ Multi-layer perceptron (MLP) class. """ +import logging import time + +import numpy as np +import pytorch_lightning as pl import torch -from torch import nn import torch.nn.functional as F -import pytorch_lightning as pl from pytorch_lightning import loggers as pl_loggers from sklearn.neighbors import BallTree -import numpy as np - +from torch import nn +logger = logging.getLogger(__name__) class MLP(pl.LightningModule): def __init__(self, input_dim=3072, @@ -73,7 +75,7 @@ def training_step(self, batch, batch_idx): return loss def _load_building_blocks_kdtree(self, out_feat: str) -> np.ndarray: - """Helper function to load the pre-computed building block embeddings + """Helper function to load the pre-computed building block embeddings as a BallTree. TODO: Remove hard-coded paths. @@ -96,8 +98,8 @@ def _load_building_blocks_kdtree(self, out_feat: str) -> np.ndarray: emb = np.load("tests/data/building_blocks_emb.npy") kdtree = BallTree(emb,metric="euclidean") else: - raise ValueError - return kdtree + raise ValueError + return kdtree def validation_step(self, batch, batch_idx): if self.trainer.current_epoch % self.val_freq == 0: diff --git a/src/syn_net/models/rt1.py b/src/syn_net/models/rt1.py index d622ce74..c1de7e6a 100644 --- a/src/syn_net/models/rt1.py +++ b/src/syn_net/models/rt1.py @@ -9,30 +9,14 @@ from syn_net.models.mlp import MLP, load_array from scipy import sparse from syn_net.config import DATA_FEATURIZED_DIR +from syn_net.models.common import get_args + from pathlib import Path if __name__ == '__main__': - import argparse - parser = argparse.ArgumentParser() - parser.add_argument("-f", "--featurize", type=str, default='fp', - help="Choose from ['fp', 'gin']") - parser.add_argument("-r", "--rxn_template", type=str, default='hb', - help="Choose from ['hb', 'pis']") - parser.add_argument("--radius", type=int, default=2, - help="Radius for Morgan fingerprint.") - parser.add_argument("--nbits", type=int, default=4096, - help="Number of Bits for Morgan fingerprint.") - parser.add_argument("--out_dim", type=int, default=256, - help="Output dimension.") - parser.add_argument("--ncpu", type=int, default=16, - help="Number of cpus") - parser.add_argument("--batch_size", type=int, default=64, - help="Batch size") - parser.add_argument("--epoch", type=int, default=2000, - help="Maximum number of epoches.") - args = parser.parse_args() + args = get_args() # Helper to select validation func based on output dim VALIDATION_OPTS = { diff --git a/src/syn_net/models/rt2.py b/src/syn_net/models/rt2.py index 19991e82..eef33d16 100644 --- a/src/syn_net/models/rt2.py +++ b/src/syn_net/models/rt2.py @@ -9,30 +9,13 @@ from syn_net.models.mlp import MLP, load_array from scipy import sparse from syn_net.config import DATA_FEATURIZED_DIR +from syn_net.models.common import get_args + from pathlib import Path if __name__ == '__main__': - import argparse - parser = argparse.ArgumentParser() - parser.add_argument("-f", "--featurize", type=str, default='fp', - help="Choose from ['fp', 'gin']") - parser.add_argument("-r", "--rxn_template", type=str, default='hb', - help="Choose from ['hb', 'pis']") - parser.add_argument("--radius", type=int, default=2, - help="Radius for Morgan fingerprint.") - parser.add_argument("--nbits", type=int, default=4096, - help="Number of Bits for Morgan fingerprint.") - parser.add_argument("--out_dim", type=int, default=256, - help="Output dimension.") - parser.add_argument("--ncpu", type=int, default=8, - help="Number of cpus") - parser.add_argument("--batch_size", type=int, default=64, - help="Batch size") - parser.add_argument("--epoch", type=int, default=2000, - help="Maximum number of epoches.") - args = parser.parse_args() - + args = get_args() # Helper to select validation func based on output dim VALIDATION_OPTS = { 300: "nn_accuracy_gin", diff --git a/src/syn_net/models/rxn.py b/src/syn_net/models/rxn.py index b659810e..242b0621 100644 --- a/src/syn_net/models/rxn.py +++ b/src/syn_net/models/rxn.py @@ -8,34 +8,12 @@ from syn_net.models.mlp import MLP, load_array from scipy import sparse from syn_net.config import DATA_FEATURIZED_DIR, CHECKPOINTS_DIR +from syn_net.models.common import get_args from pathlib import Path if __name__ == '__main__': - import argparse - parser = argparse.ArgumentParser() - parser.add_argument("-f", "--featurize", type=str, default='fp', - help="Choose from ['fp', 'gin']") - parser.add_argument("-r", "--rxn_template", type=str, default='hb', - help="Choose from ['hb', 'pis']") - parser.add_argument("--radius", type=int, default=2, - help="Radius for Morgan fingerprint.") - parser.add_argument("--nbits", type=int, default=4096, - help="Number of Bits for Morgan fingerprint.") - parser.add_argument("--out_dim", type=int, default=300, - help="Output dimension.") - parser.add_argument("--ncpu", type=int, default=8, - help="Number of cpus") - parser.add_argument("--batch_size", type=int, default=64, - help="Batch size") - parser.add_argument("--epoch", type=int, default=2000, - help="Maximum number of epochs.") - parser.add_argument("--restart", type=bool, default=False, - help="Indicates whether to restart training.") - parser.add_argument("-v", "--version", type=int, default=1, - help="Version") - args = parser.parse_args() - + args = get_args() # Helper to select validation func based on output dim VALIDATION_OPTS = { 300: "nn_accuracy_gin", From 770cc60f3da3f988b4d616c6275faa381ff7a999 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Fri, 26 Aug 2022 10:52:49 -0400 Subject: [PATCH 056/302] isort --- src/syn_net/models/act.py | 8 +++++--- src/syn_net/models/prepare_data.py | 8 +++++--- src/syn_net/models/rt1.py | 10 +++++----- src/syn_net/models/rt2.py | 9 +++++---- src/syn_net/models/rxn.py | 10 ++++++---- 5 files changed, 26 insertions(+), 19 deletions(-) diff --git a/src/syn_net/models/act.py b/src/syn_net/models/act.py index ee40fa61..975963aa 100644 --- a/src/syn_net/models/act.py +++ b/src/syn_net/models/act.py @@ -2,14 +2,16 @@ Action network. """ import time -import torch +from pathlib import Path + import pytorch_lightning as pl +import torch from pytorch_lightning import loggers as pl_loggers -from syn_net.models.mlp import MLP, load_array from scipy import sparse + from syn_net.config import DATA_FEATURIZED_DIR -from pathlib import Path from syn_net.models.common import get_args +from syn_net.models.mlp import MLP, load_array if __name__ == '__main__': diff --git a/src/syn_net/models/prepare_data.py b/src/syn_net/models/prepare_data.py index 0e0644a6..cd5a1cf5 100644 --- a/src/syn_net/models/prepare_data.py +++ b/src/syn_net/models/prepare_data.py @@ -3,10 +3,12 @@ and steps for the reaction data and re-writing it as separate one-hot encoded Action, Reactant 1, Reactant 2, and Reaction files. """ -from syn_net.utils.prep_utils import prep_data -from syn_net.config import DATA_FEATURIZED_DIR -from pathlib import Path import logging +from pathlib import Path + +from syn_net.config import DATA_FEATURIZED_DIR +from syn_net.utils.prep_utils import prep_data + logger = logging.getLogger(__file__) if __name__ == '__main__': diff --git a/src/syn_net/models/rt1.py b/src/syn_net/models/rt1.py index c1de7e6a..5c736e8d 100644 --- a/src/syn_net/models/rt1.py +++ b/src/syn_net/models/rt1.py @@ -2,17 +2,17 @@ Reactant1 network (for predicting 1st reactant). """ import time +from pathlib import Path + import numpy as np -import torch import pytorch_lightning as pl +import torch from pytorch_lightning import loggers as pl_loggers -from syn_net.models.mlp import MLP, load_array from scipy import sparse + from syn_net.config import DATA_FEATURIZED_DIR from syn_net.models.common import get_args - -from pathlib import Path - +from syn_net.models.mlp import MLP, load_array if __name__ == '__main__': diff --git a/src/syn_net/models/rt2.py b/src/syn_net/models/rt2.py index eef33d16..6aaab176 100644 --- a/src/syn_net/models/rt2.py +++ b/src/syn_net/models/rt2.py @@ -2,16 +2,17 @@ Reactant2 network (for predicting 2nd reactant). """ import time +from pathlib import Path + import numpy as np -import torch import pytorch_lightning as pl +import torch from pytorch_lightning import loggers as pl_loggers -from syn_net.models.mlp import MLP, load_array from scipy import sparse + from syn_net.config import DATA_FEATURIZED_DIR from syn_net.models.common import get_args - -from pathlib import Path +from syn_net.models.mlp import MLP, load_array if __name__ == '__main__': diff --git a/src/syn_net/models/rxn.py b/src/syn_net/models/rxn.py index 242b0621..5a5c64c2 100644 --- a/src/syn_net/models/rxn.py +++ b/src/syn_net/models/rxn.py @@ -2,14 +2,16 @@ Reaction network. """ import time -import torch +from pathlib import Path + import pytorch_lightning as pl +import torch from pytorch_lightning import loggers as pl_loggers -from syn_net.models.mlp import MLP, load_array from scipy import sparse -from syn_net.config import DATA_FEATURIZED_DIR, CHECKPOINTS_DIR + +from syn_net.config import CHECKPOINTS_DIR, DATA_FEATURIZED_DIR from syn_net.models.common import get_args -from pathlib import Path +from syn_net.models.mlp import MLP, load_array if __name__ == '__main__': From 34ec985635c4cd7aa324d24f9dca1d8b176bacc0 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Fri, 26 Aug 2022 16:09:47 -0400 Subject: [PATCH 057/302] avoid warning: speficy dimension for softmax --- src/syn_net/models/act.py | 10 ++-------- src/syn_net/models/common.py | 10 ++++++++++ src/syn_net/models/mlp.py | 2 +- src/syn_net/models/rt1.py | 9 +-------- src/syn_net/models/rt2.py | 9 ++------- src/syn_net/models/rxn.py | 10 ++-------- 6 files changed, 18 insertions(+), 32 deletions(-) diff --git a/src/syn_net/models/act.py b/src/syn_net/models/act.py index 975963aa..74996ea0 100644 --- a/src/syn_net/models/act.py +++ b/src/syn_net/models/act.py @@ -10,20 +10,14 @@ from scipy import sparse from syn_net.config import DATA_FEATURIZED_DIR -from syn_net.models.common import get_args +from syn_net.models.common import VALIDATION_OPTS, get_args from syn_net.models.mlp import MLP, load_array if __name__ == '__main__': args = get_args() - # Helper to select validation func based on output dim - VALIDATION_OPTS = { - 300: "nn_accuracy_gin", - 4096: "nn_accuracy_fp_4096", - 256: "nn_accuracy_fp_256", - 200: "nn_accuracy_rdkit2d", - } + validation_option = VALIDATION_OPTS[args.out_dim] id = f'{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_{validation_option[12:]}/' diff --git a/src/syn_net/models/common.py b/src/syn_net/models/common.py index affb28ca..14f5c6e4 100644 --- a/src/syn_net/models/common.py +++ b/src/syn_net/models/common.py @@ -1,6 +1,14 @@ """Common methods and params shared by all models. """ +# Helper to select validation func based on output dim +VALIDATION_OPTS = { + 300: "nn_accuracy_gin", + 4096: "nn_accuracy_fp_4096", + 256: "nn_accuracy_fp_256", + 200: "nn_accuracy_rdkit2d", +} + def get_args(): import argparse parser = argparse.ArgumentParser() @@ -25,3 +33,5 @@ def get_args(): parser.add_argument("-v", "--version", type=int, default=1, help="Version") return parser.parse_args() + + diff --git a/src/syn_net/models/mlp.py b/src/syn_net/models/mlp.py index 048cc55c..1cc537ee 100644 --- a/src/syn_net/models/mlp.py +++ b/src/syn_net/models/mlp.py @@ -51,7 +51,7 @@ def __init__(self, input_dim=3072, modules.append(nn.Linear(hidden_dim, output_dim)) if task == 'classification': - modules.append(nn.Softmax()) + modules.append(nn.Softmax(dim=1)) self.layers = nn.Sequential(*modules) diff --git a/src/syn_net/models/rt1.py b/src/syn_net/models/rt1.py index 5c736e8d..a9cf274d 100644 --- a/src/syn_net/models/rt1.py +++ b/src/syn_net/models/rt1.py @@ -11,20 +11,13 @@ from scipy import sparse from syn_net.config import DATA_FEATURIZED_DIR -from syn_net.models.common import get_args +from syn_net.models.common import VALIDATION_OPTS, get_args from syn_net.models.mlp import MLP, load_array if __name__ == '__main__': args = get_args() - # Helper to select validation func based on output dim - VALIDATION_OPTS = { - 300: "nn_accuracy_gin", - 4096: "nn_accuracy_fp_4096", - 256: "nn_accuracy_fp_256", - 200: "nn_accuracy_rdkit2d", - } validation_option = VALIDATION_OPTS[args.out_dim] id = f'{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_{validation_option[12:]}/' diff --git a/src/syn_net/models/rt2.py b/src/syn_net/models/rt2.py index 6aaab176..d2a9c21e 100644 --- a/src/syn_net/models/rt2.py +++ b/src/syn_net/models/rt2.py @@ -11,19 +11,14 @@ from scipy import sparse from syn_net.config import DATA_FEATURIZED_DIR -from syn_net.models.common import get_args +from syn_net.models.common import VALIDATION_OPTS, get_args from syn_net.models.mlp import MLP, load_array if __name__ == '__main__': args = get_args() # Helper to select validation func based on output dim - VALIDATION_OPTS = { - 300: "nn_accuracy_gin", - 4096: "nn_accuracy_fp_4096", - 256: "nn_accuracy_fp_256", - 200: "nn_accuracy_rdkit2d", - } + validation_option = VALIDATION_OPTS[args.out_dim] id = f'{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_{validation_option[12:]}/' diff --git a/src/syn_net/models/rxn.py b/src/syn_net/models/rxn.py index 5a5c64c2..737d960b 100644 --- a/src/syn_net/models/rxn.py +++ b/src/syn_net/models/rxn.py @@ -10,19 +10,13 @@ from scipy import sparse from syn_net.config import CHECKPOINTS_DIR, DATA_FEATURIZED_DIR -from syn_net.models.common import get_args +from syn_net.models.common import VALIDATION_OPTS, get_args from syn_net.models.mlp import MLP, load_array if __name__ == '__main__': args = get_args() - # Helper to select validation func based on output dim - VALIDATION_OPTS = { - 300: "nn_accuracy_gin", - 4096: "nn_accuracy_fp_4096", - 256: "nn_accuracy_fp_256", - 200: "nn_accuracy_rdkit2d", - } + validation_option = VALIDATION_OPTS[args.out_dim] id = f'{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_{validation_option[12:]}/' From 14dbabd365307c6c72335ad4342dfad0ba297241 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Fri, 26 Aug 2022 16:49:37 -0400 Subject: [PATCH 058/302] decrease verbosity of unittest --- tests/test_Training.py | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/tests/test_Training.py b/tests/test_Training.py index 2db0b90b..200b8468 100644 --- a/tests/test_Training.py +++ b/tests/test_Training.py @@ -4,7 +4,7 @@ from pathlib import Path import unittest import shutil - +from multiprocessing import cpu_count import pytorch_lightning as pl from pytorch_lightning import loggers as pl_loggers from scipy import sparse @@ -18,7 +18,7 @@ REACTION_TEMPLATES_FILE = f"{TEST_DIR}/assets/rxn_set_hb_test.txt" class TestReactionTemplateFile(unittest.TestCase): - + def test_number_of_reaction_templates(self): """ Count number of lines in file, i.e. the number of reaction templates.""" with open(REACTION_TEMPLATES_FILE,"r") as f: @@ -32,6 +32,11 @@ class TestTraining(unittest.TestCase): reaction network, (4) reactant 2 network. """ + def setUp(self) -> None: + import warnings + warnings.filterwarnings("ignore", ".*does not have many workers.*") + warnings.filterwarnings("ignore", ".*GPU available but not used.*") + def test_action_network(self): """ Tests the Action Network. @@ -41,11 +46,11 @@ def test_action_network(self): nbits = 4096 batch_size = 10 epochs = 2 - ncpu = 2 + ncpu = min(2,cpu_count()) validation_option = "accuracy" ref_dir = f"{TEST_DIR}/data/ref/" - X = sparse.load_npz(ref_dir + "X_act_train.npz") + X = sparse.load_npz(ref_dir + "X_act_train.npz") assert X.shape==(4,3*nbits) # (4,12288) y = sparse.load_npz(ref_dir + "y_act_train.npz") assert y.shape==(4,1) # (4,1) @@ -81,7 +86,7 @@ def test_action_network(self): f"act_{embedding}_{radius}_{nbits}_logs/" ) trainer = pl.Trainer( - max_epochs=epochs, progress_bar_refresh_rate=20, logger=tb_logger + max_epochs=epochs, logger=tb_logger, weights_summary=None, ) trainer.fit(mlp, train_data_iter, valid_data_iter) @@ -101,11 +106,11 @@ def test_reactant1_network(self): out_dim = 300 # Note: out_dim 300 = gin embedding batch_size = 10 epochs = 2 - ncpu = 2 + ncpu = min(2,cpu_count()) validation_option = "nn_accuracy_gin_unittest" ref_dir = f"{TEST_DIR}/data/ref/" - # load the reaction data + # load the reaction data X = sparse.load_npz(ref_dir + "X_rt1_train.npz") assert X.shape==(2,3*nbits) # (4,12288) X = torch.Tensor(X.A) @@ -138,7 +143,7 @@ def test_reactant1_network(self): f"rt1_{embedding}_{radius}_{nbits}_logs/" ) trainer = pl.Trainer( - max_epochs=epochs, progress_bar_refresh_rate=20, logger=tb_logger + max_epochs=epochs, logger=tb_logger, weights_summary=None, ) trainer.fit(mlp, train_data_iter, valid_data_iter) @@ -157,7 +162,7 @@ def test_reaction_network(self): nbits = 4096 batch_size = 10 epochs = 2 - ncpu = 2 + ncpu = min(2,cpu_count()) n_templates = 3 # num templates in `REACTION_TEMPLATES_FILE` validation_option = "accuracy" ref_dir = f"{TEST_DIR}/data/ref/" @@ -198,7 +203,7 @@ def test_reaction_network(self): f"rxn_{embedding}_{radius}_{nbits}_logs/" ) trainer = pl.Trainer( - max_epochs=epochs, progress_bar_refresh_rate=20, logger=tb_logger + max_epochs=epochs, logger=tb_logger, weights_summary=None, ) trainer.fit(mlp, train_data_iter, valid_data_iter) @@ -218,7 +223,7 @@ def test_reactant2_network(self): out_dim = 300 # Note: out_dim 300 = gin embedding batch_size = 10 epochs = 2 - ncpu = 2 + ncpu = min(2,cpu_count()) n_templates = 3 # num templates in 'data/rxn_set_hb_test.txt' validation_option = "nn_accuracy_gin_unittest" ref_dir = f"{TEST_DIR}/data/ref/" @@ -255,7 +260,7 @@ def test_reactant2_network(self): f"rt2_{embedding}_{radius}_{nbits}_logs/" ) trainer = pl.Trainer( - max_epochs=epochs, progress_bar_refresh_rate=20, logger=tb_logger + max_epochs=epochs, logger=tb_logger, weights_summary=None, ) trainer.fit(mlp, train_data_iter, valid_data_iter) From 9d08956a5ba9f6fe74bca4680e046adfd62296cc Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Fri, 26 Aug 2022 17:21:17 -0400 Subject: [PATCH 059/302] log hyperparams to file --- src/syn_net/models/mlp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/syn_net/models/mlp.py b/src/syn_net/models/mlp.py index 1cc537ee..f3dd8a84 100644 --- a/src/syn_net/models/mlp.py +++ b/src/syn_net/models/mlp.py @@ -29,7 +29,7 @@ def __init__(self, input_dim=3072, val_freq=10, ncpu=16): super().__init__() - + self.save_hyperparameters() self.loss = loss self.valid_loss = valid_loss self.optimizer = optimizer From d5e9bfdf89e7af931af1e9e027f7339bdbf3f6b0 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Fri, 26 Aug 2022 17:21:38 -0400 Subject: [PATCH 060/302] correct name of train loss metric --- src/syn_net/models/mlp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/syn_net/models/mlp.py b/src/syn_net/models/mlp.py index f3dd8a84..96ebc039 100644 --- a/src/syn_net/models/mlp.py +++ b/src/syn_net/models/mlp.py @@ -71,7 +71,7 @@ def training_step(self, batch, batch_idx): loss = F.huber_loss(y_hat, y) else: raise ValueError('Not specified loss function') - self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) + self.log(f'train_{self.loss}', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) return loss def _load_building_blocks_kdtree(self, out_feat: str) -> np.ndarray: From 2e2d1abfcd373cc0fa4c0b1862fd3df83e24eb9f Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Fri, 26 Aug 2022 17:25:29 -0400 Subject: [PATCH 061/302] remove hard coded path for embedding --- src/syn_net/models/mlp.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/syn_net/models/mlp.py b/src/syn_net/models/mlp.py index 96ebc039..f55ea387 100644 --- a/src/syn_net/models/mlp.py +++ b/src/syn_net/models/mlp.py @@ -78,19 +78,20 @@ def _load_building_blocks_kdtree(self, out_feat: str) -> np.ndarray: """Helper function to load the pre-computed building block embeddings as a BallTree. - TODO: Remove hard-coded paths. """ + from syn_net.config import DATA_EMBEDDINGS_DIR + from pathlib import Path if out_feat == 'gin': - bb_emb_gin = np.load('/pool001/whgao/data/synth_net/st_hb/enamine_us_emb_gin.npy') + bb_emb_gin = np.load(Path(DATA_EMBEDDINGS_DIR) / f'enamine_us_emb_{out_feat}.npy') kdtree = BallTree(bb_emb_gin, metric='euclidean') elif out_feat == 'fp_4096': - bb_emb_fp_4096 = np.load('/pool001/whgao/data/synth_net/st_hb/enamine_us_emb_fp_4096.npy') + bb_emb_fp_4096 = np.load(Path(DATA_EMBEDDINGS_DIR) / f'enamine_us_emb_{out_feat}.npy') kdtree = BallTree(bb_emb_fp_4096, metric='euclidean') elif out_feat == 'fp_256': - bb_emb_fp_256 = np.load('/pool001/whgao/data/synth_net/st_hb/enamine_us_emb_fp_256.npy') + bb_emb_fp_256 = np.load(Path(DATA_EMBEDDINGS_DIR) / f'enamine_us_emb_{out_feat}.npy') kdtree = BallTree(bb_emb_fp_256, metric=cosine_distance) elif out_feat == 'rdkit2d': - bb_emb_rdkit2d = np.load('/pool001/whgao/data/synth_net/st_hb/enamine_us_emb_rdkit2d.npy') + bb_emb_rdkit2d = np.load(Path(DATA_EMBEDDINGS_DIR) / f'enamine_us_emb_{out_feat}.npy') kdtree = BallTree(bb_emb_rdkit2d, metric='euclidean') elif out_feat == "gin_unittest": # The embeddings are pre-computed based on the building blocks From 7926a4cff59a9dba431bb57a9ddc11e74538b346 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 29 Aug 2022 11:03:39 -0400 Subject: [PATCH 062/302] add code for visualisation with mermaid.js --- src/syn_net/visualize/drawers.py | 48 ++++++++++++ src/syn_net/visualize/visualizer.py | 113 ++++++++++++++++++++++++++++ src/syn_net/visualize/writers.py | 104 +++++++++++++++++++++++++ 3 files changed, 265 insertions(+) create mode 100644 src/syn_net/visualize/drawers.py create mode 100644 src/syn_net/visualize/visualizer.py create mode 100644 src/syn_net/visualize/writers.py diff --git a/src/syn_net/visualize/drawers.py b/src/syn_net/visualize/drawers.py new file mode 100644 index 00000000..c8e04acb --- /dev/null +++ b/src/syn_net/visualize/drawers.py @@ -0,0 +1,48 @@ +import uuid +from pathlib import Path +from typing import Union + +import rdkit.Chem as Chem +from rdkit.Chem import Draw + + +class MolDrawer: + def __init__(self): + self.lookup: dict = None + self.path: Union[None, str] = None + + def _hash(self, smiles: list[str]) -> dict[str, str]: + """Hashing for amateurs. + Goal: Get a short, valid, and hopefully unique filename for each molecule.""" + self.lookup = {smile: str(uuid.uuid4())[:8] for smile in smiles} + return self + + def get_path(self) -> str: + return self.path + + def get_molecule_filesnames(self): + return self.lookup + + def plot(self, smiles: Union[list[str], str], path: str = "./"): + """Plot smiles as 2d molecules and save to `path`.""" + self._hash(smiles) + self.path = path + + for k, v in self.lookup.items(): + fname = str((Path(path) / f"{v}.svg").resolve()) + mol = Chem.MolFromSmiles(k) + # Plot + drawer = Draw.rdMolDraw2D.MolDraw2DSVG(300, 150) + opts = drawer.drawOptions() + drawer.DrawMolecule(mol) + drawer.FinishDrawing() + p = drawer.GetDrawingText() + + with open(fname, "w") as f: + f.write(p) + + return self + + +if __name__ == "__main__": + pass diff --git a/src/syn_net/visualize/visualizer.py b/src/syn_net/visualize/visualizer.py new file mode 100644 index 00000000..e81fa656 --- /dev/null +++ b/src/syn_net/visualize/visualizer.py @@ -0,0 +1,113 @@ +from typing import Union + +from syn_net.utils.data_utils import NodeChemical, NodeRxn, SyntheticTree +from syn_net.visualize.writers import subgraph + + +class SynTreeVisualizer: + actions_taken: dict[int, str] + CHEMICALS: dict[str, NodeChemical] + + ACTIONS = { + 0: "Add", + 1: "Expand", + 2: "Merge", + 3: "End", + } + + def __init__(self, syntree: SyntheticTree): + self.syntree = syntree + self.actions_taken = { + depth: self.ACTIONS[action] for depth, action in enumerate(syntree.actions) + } + self.CHEMICALS = {node.smiles: node for node in syntree.chemicals} + + # Placeholder for images for molecues. + self.path: Union[None, str] = None + self.molecule_filesnames: Union[None, dict[str, str]] = None + return None + + def with_drawings(self, drawer): + """Plot images of the molecules in the nodes.""" + self.path = drawer.get_path() + self.molecule_filesnames = drawer.get_molecule_filesnames() + + return self + + def _define_chemicals( + self, + chemicals: dict[str, NodeChemical] = None, + ) -> list[str]: + chemicals = self.CHEMICALS if chemicals is None else chemicals + + if self.path is None or self.molecule_filesnames is None: + raise NotImplementedError("Must provide drawer via `_with_drawings()` before plotting.") + + out: list[str] = [] + + for node in chemicals.values(): + name = f'"node.smiles"' + name = f'' + classdef = self._map_node_type_to_classdef(node) + info = f"n{node.index}[{name}]:::{classdef}" + out += [info] + return out + + def _map_node_type_to_classdef(self, node: NodeChemical) -> str: + """Map a node to pre-defined mermaid class for styling.""" + if node.is_leaf: + classdef = "buildingblock" + elif node.is_root: + classdef = "final" + else: + classdef = "intermediate" + return classdef + + def _write_reaction_connectivity( + self, reactants: list[NodeChemical], product: NodeChemical + ) -> list[str]: + """Write the connectivity of the graph. + Unimolecular reactions have one edge, bimolecular two. + + Examples: + n1 --> n3 + n2 --> n3 + """ + NODE_PREFIX = "n" + r1, r2 = reactants + out = [f"{NODE_PREFIX}{r1.index} --> {NODE_PREFIX}{product.index}"] + if r2 is not None: + out += [f"{NODE_PREFIX}{r2.index} --> {NODE_PREFIX}{product.index}"] + return out + + def write(self) -> list[str]: + """Write.""" + rxns: list[NodeRxn] = self.syntree.reactions + text = [] + + # Add node definitions + text.extend(self._define_chemicals(self.CHEMICALS)) + + # Add paragraphs (<=> actions taken) + for i, action in self.actions_taken.items(): + if action == "End": + continue + rxn = rxns[i] + product: str = rxn.parent + reactant1: str = rxn.child[0] + reactant2: str = rxn.child[1] if rxn.rtype == 2 else None + + @subgraph(f'"{i:>2d} : {action}"') + def __printer(): + return self._write_reaction_connectivity( + [self.CHEMICALS.get(reactant1), self.CHEMICALS.get(reactant2)], + self.CHEMICALS.get(product), + ) + + out = __printer() + text.extend(out) + return text + + +if __name__ == "__main__": + pass diff --git a/src/syn_net/visualize/writers.py b/src/syn_net/visualize/writers.py new file mode 100644 index 00000000..260cec59 --- /dev/null +++ b/src/syn_net/visualize/writers.py @@ -0,0 +1,104 @@ +from functools import wraps +from typing import Callable + + +class PrefixWriter: + def __init__(self, file: str = None): + self.prefix = self._default_prefix() if file is None else self._load(file) + + def _default_prefix(self): + md = [ + "# Synthetic Tree Visualisation", + "", + "Legend", + "- :green_square: Building Block", + "- :orange_square: Intermediate", + "- :blue_square: Final Molecule", + "- :red_square: Target Molecule", + "", + ] + start = ["```mermaid"] + theming = [ + "%%{init: {", + " 'theme': 'base',", + " 'themeVariables': {", + " 'backgroud': '#ffffff',", + " 'primaryColor': '#ffffff',", + " 'clusterBkg': '#ffffff',", + " 'clusterBorder': '#000000',", + " 'edgeLabelBackground':'#dbe1e1',", + " 'fontSize': '20px'", + " }", + " }", + "}%%", + ] + diagram_id = ["graph BT"] + style = [ + "classDef buildingblock stroke:#00d26a,stroke-width:2px", + "classDef intermediate stroke:#ff6723,stroke-width:2px", + "classDef final stroke:#0074ba,stroke-width:2px", + "classDef target stroke:#f8312f,stroke-width:2px", + ] + return md + start + theming + diagram_id + style + + def _load(self, file): + with open(file, "rt") as f: + out = [l.removesuffix("\n") for l in f] + return out + + def write(self) -> list[str]: + return self.prefix + + +class PostfixWriter: + def write(self) -> list[str]: + return ["```"] + + +class SynTreeWriter: + def __init__(self, prefixer=None, postfixer=None): + self.prefixer = prefixer + self.postfixer = postfixer + self._text: list[str] = None + + def write(self, out) -> list[str]: + out = self.prefixer.write() + out + self.postfixer.write() + self._text = out + return self + + def to_file(self, file: str, text: list[str] = None): + if text is None: + text = self._text + + with open(file, "wt") as f: + f.writelines((l.rstrip() + "\n" for l in text)) + + @property + def text(self) -> list[str]: + return self.text + + +def subgraph(argument: str = "") -> Callable: + """Decorator that writes a named mermaid subparagraph. + + Example output: + ``` + subparagraph argument + + end + ``` + """ + + def _subgraph(func) -> Callable: + @wraps(func) + def wrapper(*args, **kwargs) -> list[str]: + out = f"subgraph {argument}" + inner = func(*args, **kwargs) + # add a tab to inner + TAB_CHAR = " " * 4 + inner = [f"{TAB_CHAR}{line}" for line in inner] + return [out] + inner + ["end"] + + return wrapper + + return _subgraph From 4ef8f07d9ad4ce7932f7d7a9ff304700aebf6d84 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 31 Aug 2022 09:34:42 -0400 Subject: [PATCH 063/302] adds `--debug` flag to default args for models --- src/syn_net/models/common.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/syn_net/models/common.py b/src/syn_net/models/common.py index 14f5c6e4..391955b4 100644 --- a/src/syn_net/models/common.py +++ b/src/syn_net/models/common.py @@ -32,6 +32,12 @@ def get_args(): help="Indicates whether to restart training.") parser.add_argument("-v", "--version", type=int, default=1, help="Version") + parser.add_argument("--debug", default=False, action="store_true") return parser.parse_args() +if __name__=="__main__": + import json + args = get_args() + print("Default Arguments are:") + print(json.dumps(args.__dict__,indent=2)) From 7c60b9282e4e848eea3c0185097a7a34e8dbab24 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 31 Aug 2022 10:47:26 -0400 Subject: [PATCH 064/302] avoid re-casting to ndarray every iteration --- src/syn_net/utils/prep_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/syn_net/utils/prep_utils.py b/src/syn_net/utils/prep_utils.py index 04a2b0aa..6d22cf42 100644 --- a/src/syn_net/utils/prep_utils.py +++ b/src/syn_net/utils/prep_utils.py @@ -172,8 +172,8 @@ def synthetic_tree_generator( # Initialization tree = SyntheticTree() mol_recent = None + building_blocks = np.asarray(building_blocks) - # Start iteration try: for i in range(max_step): # Encode current state @@ -191,7 +191,7 @@ def synthetic_tree_generator( break elif action == 0: # Add - mol1 = np.random.choice(building_blocks) # TODO: convert to nparray to avoid costly conversion upon each function call + mol1 = np.random.choice(building_blocks) else: # Expand or Merge mol1 = mol_recent From adbe9575e3e8f2de6095fa96a70bdfcf3de99fb4 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 31 Aug 2022 10:48:06 -0400 Subject: [PATCH 065/302] add comments --- src/syn_net/utils/prep_utils.py | 61 +++++++++++++++------------------ 1 file changed, 28 insertions(+), 33 deletions(-) diff --git a/src/syn_net/utils/prep_utils.py b/src/syn_net/utils/prep_utils.py index 6d22cf42..10fa6cda 100644 --- a/src/syn_net/utils/prep_utils.py +++ b/src/syn_net/utils/prep_utils.py @@ -9,7 +9,7 @@ from sklearn.preprocessing import OneHotEncoder from syn_net.utils.data_utils import Reaction, SyntheticTree from syn_net.utils.predict_utils import (can_react, get_action_mask, - get_reaction_mask, mol_fp, + get_reaction_mask, mol_fp, get_mol_embedding) from pathlib import Path from rdkit import Chem @@ -43,10 +43,10 @@ def _fetch_gin_pretrained_model(model_name: str): return model -def organize(st, d_mol=300, target_embedding='fp', radius=2, nBits=4096, +def organize(st, d_mol=300, target_embedding='fp', radius=2, nBits=4096, output_embedding='gin'): """ - Organizes the states and steps from the input synthetic tree into sparse + Organizes the states and steps from the input synthetic tree into sparse matrices. Args: @@ -113,26 +113,26 @@ def organize(st, d_mol=300, target_embedding='fp', radius=2, nBits=4096, if output_embedding == 'gin': step = ([action] + get_mol_embedding(mol1, model=model).tolist() - + [r.rxn_id] - + get_mol_embedding(mol2, model=model).tolist() + + [r.rxn_id] + + get_mol_embedding(mol2, model=model).tolist() + mol_fp(mol1, radius, nBits).tolist()) elif output_embedding == 'fp_4096': - step = ([action] - + mol_fp(mol1, 2, 4096).tolist() - + [r.rxn_id] - + mol_fp(mol2, 2, 4096).tolist() + step = ([action] + + mol_fp(mol1, 2, 4096).tolist() + + [r.rxn_id] + + mol_fp(mol2, 2, 4096).tolist() + mol_fp(mol1, radius, nBits).tolist()) elif output_embedding == 'fp_256': - step = ([action] + step = ([action] + mol_fp(mol1, 2, 256).tolist() + [r.rxn_id] + mol_fp(mol2, 2, 256).tolist() + mol_fp(mol1, radius, nBits).tolist()) elif output_embedding == 'rdkit2d': - step = ([action] - + rdkit2d_embedding(mol1).tolist() - + [r.rxn_id] - + rdkit2d_embedding(mol2).tolist() + step = ([action] + + rdkit2d_embedding(mol1).tolist() + + [r.rxn_id] + + rdkit2d_embedding(mol2).tolist() + mol_fp(mol1, radius, nBits).tolist()) if action == 2: @@ -177,7 +177,7 @@ def synthetic_tree_generator( try: for i in range(max_step): # Encode current state - state = tree.get_state() # a set + state = tree.get_state() # Predict action type, masked selection # Action: (Add: 0, Expand: 1, Merge: 2, End: 3) @@ -186,14 +186,11 @@ def synthetic_tree_generator( action = np.argmax(action_proba * action_mask) # Select first molecule - if action == 3: - # End + if action == 3: # End break - elif action == 0: - # Add + elif action == 0: # Add mol1 = np.random.choice(building_blocks) - else: - # Expand or Merge + else: # Expand or Merge mol1 = mol_recent # Select reaction @@ -216,14 +213,12 @@ def synthetic_tree_generator( rxn_id = np.argmax(reaction_proba * rxn_mask) rxn = reaction_templates[rxn_id] + # Select second molecule if rxn.num_reactant == 2: - # Select second molecule - if action == 2: - # Merge + if action == 2: # Merge temp = set(state) - set([mol1]) mol2 = temp.pop() - else: - # Add or Expand + else: # Add or Expand mol2 = np.random.choice(available[rxn_id]) else: mol2 = None @@ -267,7 +262,7 @@ def prep_data(main_dir, num_rxn, out_dim, datasets=None): print(f'Reading {dataset} data ...') states_list = [] steps_list = [] - + states_list.append(sparse.load_npz(main_dir / f'states_{dataset}.npz')) steps_list.append(sparse.load_npz(main_dir / f'steps_{dataset}.npz')) @@ -283,7 +278,7 @@ def prep_data(main_dir, num_rxn, out_dim, datasets=None): states = sparse.csc_matrix(states.A[(steps[:, 0].A != 3).reshape(-1, )]) steps = sparse.csc_matrix(steps.A[(steps[:, 0].A != 3).reshape(-1, )]) print(f' saved data for "Action"') - + # extract Reaction data X = sparse.hstack([states, steps[:, (2 * out_dim + 2):]]) y = steps[:, out_dim + 1] @@ -300,8 +295,8 @@ def prep_data(main_dir, num_rxn, out_dim, datasets=None): # extract Reactant 2 data X = sparse.hstack( - [states, - steps[:, (2 * out_dim + 2):], + [states, + steps[:, (2 * out_dim + 2):], sparse.csc_matrix(enc.transform(steps[:, out_dim+1].A.reshape((-1, 1))).toarray())] ) y = steps[:, (out_dim+2): (2 * out_dim + 2)] @@ -329,11 +324,11 @@ def __init__(self) -> None: def from_sdf(self, file: Union[str, Path]): """Extract chemicals as SMILES from `*.sdf` file. - - See also: + + See also: https://www.rdkit.org/docs/GettingStartedInPython.html#reading-sets-of-molecules """ - file = str(Path(file).resolve()) + file = str(Path(file).resolve()) suppl = Chem.SDMolSupplier(file) self.smiles = (Chem.MolToSmiles(mol, canonical=True, isomericSmiles=False) for mol in suppl) logger.info(f"Read data from {file}") From ad3cc0db58dcd763e0c84e12ec6bd05e1b2356c7 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 31 Aug 2022 11:33:09 -0400 Subject: [PATCH 066/302] refactor computation of embeddings --- scripts/compute_embedding.py | 60 ---------------------- scripts/compute_embedding_mp.py | 53 ++++++++++---------- src/syn_net/MolEmbedder.py | 88 +++++++++++++++++++++++++++++++++ 3 files changed, 114 insertions(+), 87 deletions(-) delete mode 100644 scripts/compute_embedding.py create mode 100644 src/syn_net/MolEmbedder.py diff --git a/scripts/compute_embedding.py b/scripts/compute_embedding.py deleted file mode 100644 index 3def3ffb..00000000 --- a/scripts/compute_embedding.py +++ /dev/null @@ -1,60 +0,0 @@ -""" -This file contains functions for generating molecular embeddings from SMILES using GIN. -""" -import pandas as pd -import numpy as np -from tqdm import tqdm -from syn_net.utils.predict_utils import mol_embedding, fp_embedding, rdkit2d_embedding - - -def get_mol_embedding_func(feature): - """ - Returns the molecular embedding function. - - Args: - feature (str): Indicates the type of featurization to use (GIN or Morgan - fingerprint), and the size. - - Returns: - Callable: The embedding function. - """ - if feature == 'gin': - embedding_func = lambda smi: mol_embedding(smi, device='cpu') - elif feature == 'fp_4096': - embedding_func = lambda smi: fp_embedding(smi, _nBits=4096) - elif feature == 'fp_2048': - embedding_func = lambda smi: fp_embedding(smi, _nBits=2048) - elif feature == 'fp_1024': - embedding_func = lambda smi: fp_embedding(smi, _nBits=1024) - elif feature == 'fp_512': - embedding_func = lambda smi: fp_embedding(smi, _nBits=512) - elif feature == 'fp_256': - embedding_func = lambda smi: fp_embedding(smi, _nBits=256) - elif feature == 'rdkit2d': - embedding_func = rdkit2d_embedding - return embedding_func - -if __name__ == '__main__': - - import argparse - parser = argparse.ArgumentParser() - parser.add_argument("--feature", type=str, default="gin", - help="Objective function to optimize") - parser.add_argument("--ncpu", type=int, default=16, - help="Number of cpus") - args = parser.parse_args() - - path = '/pool001/whgao/data/synth_net/st_hb/' - ## path = './tests/data/' ## for debugging - data = pd.read_csv(path + 'enamine_us_matched.csv.gz', compression='gzip')['SMILES'].tolist() - ## data = pd.read_csv(path + 'building_blocks_matched.csv.gz', compression='gzip')['SMILES'].tolist() ## for debugging - print('Total data: ', len(data)) - - embeddings = [] - for smi in tqdm(data): - embeddings.append(mol_embedding(smi)) - - embedding = np.array(embeddings) - np.save(path + 'enamine_us_emb_' + args.feature + '.npy', embeddings) - - print('Finish!') diff --git a/scripts/compute_embedding_mp.py b/scripts/compute_embedding_mp.py index 59a35eca..5e7d33ee 100644 --- a/scripts/compute_embedding_mp.py +++ b/scripts/compute_embedding_mp.py @@ -1,16 +1,15 @@ """ Computes the molecular embeddings of the purchasable building blocks. -The embeddings are also referred to as "output embedding". +The embeddings are also referred to as "output embedding". In the embedding space, a kNN-search will identify the 1st or 2nd reactant. """ import logging -import multiprocessing as mp from pathlib import Path -import numpy as np import pandas as pd +from syn_net.MolEmbedder import MolEmbedder from syn_net.config import DATA_EMBEDDINGS_DIR, DATA_PREPROCESS_DIR from syn_net.utils.predict_utils import fp_256, fp_512, fp_1024, fp_2048, fp_4096, mol_embedding, rdkit2d_embedding @@ -30,37 +29,37 @@ def _load_building_blocks(file: Path) -> list[str]: return pd.read_csv(file)["SMILES"].to_list() -def _save_embedding(file: str, embeddings: list[list[float]]): - embeddings = np.array(embeddings) - - np.save(file, embeddings) - logger.info(f"Successfully saved to {file}.") - -if __name__ == "__main__": - +def get_args(): import argparse - parser = argparse.ArgumentParser() + parser.add_argument("--building-blocks-file", type=str, help="Input file with SMILES strings (First row `SMILES`, then one per line).") + parser.add_argument("--output-file", type=str, help="Output file for the computed embeddings.") parser.add_argument("--feature", type=str, default="fp_256", choices=FUNCTIONS.keys(), help="Objective function to optimize") - parser.add_argument("--ncpu", type=int, default=64, help="Number of cpus") - parser.add_argument("-rxn", "--rxn-template", type=str, default="hb", choices=["hb", "pis"], help="Choose from ['hb', 'pis']") - parser.add_argument("--input", type=str, help="Input file with SMILES strings (One per line).") - args = parser.parse_args() + parser.add_argument("--ncpu", type=int, default=32, help="Number of cpus") + # Command line args to be deprecated, only support input/output file in future. + parser.add_argument("--rxn-template", type=str, default="hb", choices=["hb", "pis"], help="Choose from ['hb', 'pis']") + parser.add_argument("--building-blocks-id", type=str, default="enamine_us-2021-smiles") + return parser.parse_args() + +if __name__ == "__main__": - reaction_template_id = args.rxn_template - building_blocks_id = "enamine_us-2021-smiles" + args = get_args() # Load building blocks - file = Path(DATA_PREPROCESS_DIR) / f"{reaction_template_id}-{building_blocks_id}-matched.csv.gz" - data = _load_building_blocks(file) + if (file := args.building_blocks_file) is None: + # Try to construct filename + file = Path(DATA_PREPROCESS_DIR) / f"{args.rxn_template}-{args.building_blocks_id}-matched.csv.gz" + bblocks = _load_building_blocks(file) logger.info(f"Successfully read {file}.") - logger.info(f"Total number of building blocks: {len(data)}.") + logger.info(f"Total number of building blocks: {len(bblocks)}.") + # Compute embeddings func = FUNCTIONS[args.feature] - with mp.Pool(processes=args.ncpu) as pool: - embeddings = pool.map(func, data) + molembedder = MolEmbedder(processes=args.ncpu).compute_embeddings(func,bblocks) + + # Save? + if (outfile := args.output_file) is None: + # Try to construct filename + outfile = Path(DATA_EMBEDDINGS_DIR) / f"{args.rxn_template}-{args.building_blocks_id}-{args.feature}.npy" + molembedder.save_precomputed(outfile) - path = Path(DATA_EMBEDDINGS_DIR) - path.mkdir(exist_ok=1, parents=1) - outfile = path / f"{reaction_template_id}-{building_blocks_id}-embeddings.npy" - _save_embedding(file,embeddings) diff --git a/src/syn_net/MolEmbedder.py b/src/syn_net/MolEmbedder.py new file mode 100644 index 00000000..84496a30 --- /dev/null +++ b/src/syn_net/MolEmbedder.py @@ -0,0 +1,88 @@ +import logging +from pathlib import Path +from typing import Callable, Union + +import numpy as np +from sklearn.neighbors import BallTree + +logger = logging.getLogger(__name__) + + +class MolEmbedder: + def __init__(self, processes: int = 1) -> None: + self.processes = processes + self.func: Callable + self.building_blocks: Union[list[str], np.ndarray] + self.embeddings: np.ndarray + self.kdtree: BallTree + self.kdtree_metric: str + + def get_embeddings(self) -> np.ndarray: + """Returns `self.embeddings` as 2d-array.""" + return np.atleast_2d(self.embeddings) + + def _compute_mp(self, data): + from pathos import multiprocessing as mp + + with mp.Pool(processes=self.processes) as pool: + embeddings = pool.map(self.func, data) + return embeddings + + def compute_embeddings(self, func: Callable, building_blocks: list[str]): + logger.info(f"Will compute embedding with {self.processes} processes.") + logger.info(f"Embedding function: {func.__name__}") + self.func = func + if self.processes == 1: + embeddings = list(map(self.func, building_blocks)) + else: + embeddings = self._compute_mp(building_blocks) + logger.info(f"Computed embeddings.") + self.embeddings = embeddings + return self + + def _save_npy(self, file: str): + if self.embeddings is None: + raise ValueError("Must have computed embeddings to save.") + + embeddings = np.asarray(self.embeddings) # assume at least 2d + np.save(file, embeddings) + logger.info(f"Successfully saved to {file}.") + return self + + def save_precomputed(self, file: str): + """Saves pre-computed molecule embeddings to `*.npy`""" + file = Path(file) + file.parent.mkdir(parents=True, exist_ok=True) + if file.suffixes == [".npy"]: + self._save_npy(file) + else: + raise NotImplementedError(f"File have 'npy' extension, not {file.suffixes}") + return self + + def _load_npy(self, file: Path): + return np.load(file) + + def load_precomputed(self, file: str): + """Loads a pre-computed molecule embeddings from `*.npy`""" + file = Path(file) + if file.suffixes == [".npy"]: + self.embeddings = self._load_npy(file) + self.kdtree = None + else: + raise NotImplementedError + return self + + def init_balltree(self, metric: Union[Callable, str]): + """Initializes a `BallTree`. + + Note: + Can take a couple of minutes.""" + if self.embeddings is None: + raise ValueError("Neeed emebddings to compute kdtree.") + X = self.embeddings + self.kdtree_metric = metric.__name__ + self.kdtree = BallTree(X, metric=metric) + + return self + + From c07681d4c8981b64752168b2d8bba717fa9e2337 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 31 Aug 2022 11:53:26 -0400 Subject: [PATCH 067/302] make `SyntheticTreeSet` iterarable --- src/syn_net/utils/data_utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/syn_net/utils/data_utils.py b/src/syn_net/utils/data_utils.py index 8467eb55..84bc3ec1 100644 --- a/src/syn_net/utils/data_utils.py +++ b/src/syn_net/utils/data_utils.py @@ -768,6 +768,10 @@ def __init__(self, sts=None): def __len__(self): return len(self.sts) + def __getitem__(self,index): + if self.sts is None: raise IndexError("No Synthetic Trees.") + return self.sts[index] + def load(self, json_file): """ A function that loads a JSON-formatted synthetic tree file. @@ -799,9 +803,6 @@ def save(self, json_file): with gzip.open(json_file, 'w') as f: f.write(json.dumps(st_list).encode('utf-8')) - def __len__(self): - return len(self.sts) - def _print(self, x=3): # For debugging for i, r in enumerate(self.sts): From 1bba0aa213502b0f81103dd2e02a9f1658ec0044 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 31 Aug 2022 17:04:50 -0400 Subject: [PATCH 068/302] refactor: inject `MolEmbedder` into `MLP` --- src/syn_net/models/mlp.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/src/syn_net/models/mlp.py b/src/syn_net/models/mlp.py index f55ea387..662b0387 100644 --- a/src/syn_net/models/mlp.py +++ b/src/syn_net/models/mlp.py @@ -12,6 +12,8 @@ from sklearn.neighbors import BallTree from torch import nn +from syn_net.MolEmbedder import MolEmbedder + logger = logging.getLogger(__name__) class MLP(pl.LightningModule): @@ -27,15 +29,18 @@ def __init__(self, input_dim=3072, optimizer='adam', learning_rate=1e-4, val_freq=10, - ncpu=16): + ncpu=16, + molembedder: MolEmbedder = None, + ): super().__init__() - self.save_hyperparameters() + self.save_hyperparameters(ignore="molembedder") self.loss = loss self.valid_loss = valid_loss self.optimizer = optimizer self.learning_rate = learning_rate self.ncpu = ncpu self.val_freq = val_freq + self.molembedder = molembedder modules = [] modules.append(nn.Linear(input_dim, hidden_dim)) @@ -112,12 +117,17 @@ def validation_step(self, batch, batch_idx): y_hat = torch.argmax(y_hat, axis=1) loss = 1 - (sum(y_hat == y) / len(y)) elif self.valid_loss[:11] == 'nn_accuracy': + # NOTE: Very slow! + # Performing the knn-search can easily take a couple of minutes, + # even for small datasets. out_feat = self.valid_loss[12:] + if self.molembedder is None: # legacy kdtree = self._load_building_blocks_kdtree(out_feat) - y = nn_search_list(y.detach().cpu().numpy(), out_feat=out_feat, kdtree=kdtree) - y_hat = nn_search_list(y_hat.detach().cpu().numpy(), out_feat=out_feat, kdtree=kdtree) + else: + kdtree = self.molembedder.kdtree + y = nn_search_list(y.detach().cpu().numpy(), None, kdtree) + y_hat = nn_search_list(y_hat.detach().cpu().numpy(), None, kdtree) loss = 1 - (sum(y_hat == y) / len(y)) - # import ipdb; ipdb.set_trace(context=11) elif self.valid_loss == 'mse': loss = F.mse_loss(y_hat, y) elif self.valid_loss == 'l1': From 6afcf491482bd6071836500d3d358038cfaf2b1b Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 31 Aug 2022 17:05:48 -0400 Subject: [PATCH 069/302] vectorize `mlp.nn_search_list` --- src/syn_net/models/mlp.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/syn_net/models/mlp.py b/src/syn_net/models/mlp.py index 662b0387..54c6b3e6 100644 --- a/src/syn_net/models/mlp.py +++ b/src/syn_net/models/mlp.py @@ -154,12 +154,11 @@ def load_array(data_arrays, batch_size, is_train=True, ncpu=-1): def cosine_distance(v1, v2, eps=1e-15): return 1 - np.dot(v1, v2) / (np.linalg.norm(v1, ord=2) * np.linalg.norm(v2, ord=2) + eps) -def nn_search(_e, _tree, _k=1): - dist, ind = _tree.query(_e, k=_k) - return ind[0][0] def nn_search_list(y, out_feat, kdtree): - return np.array([nn_search(emb.reshape(1, -1), _tree=kdtree) for emb in y]) + y = np.atleast_2d(y) # (n_samples, n_features) + ind = kdtree.query(y,k=1,return_distance=False) # (n_samples, 1) + return ind if __name__ == '__main__': From 5370c42da79ca5c39917fbed9be6527b15fce0cd Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 31 Aug 2022 17:06:29 -0400 Subject: [PATCH 070/302] - revert name chage for loss - comments --- src/syn_net/models/mlp.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/syn_net/models/mlp.py b/src/syn_net/models/mlp.py index 54c6b3e6..20ad82f1 100644 --- a/src/syn_net/models/mlp.py +++ b/src/syn_net/models/mlp.py @@ -75,8 +75,8 @@ def training_step(self, batch, batch_idx): elif self.loss == 'huber': loss = F.huber_loss(y_hat, y) else: - raise ValueError('Not specified loss function') - self.log(f'train_{self.loss}', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) + raise ValueError('Not specified loss function: % s' % self.loss) + self.log(f'train_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) return loss def _load_building_blocks_kdtree(self, out_feat: str) -> np.ndarray: @@ -115,19 +115,22 @@ def validation_step(self, batch, batch_idx): loss = F.cross_entropy(y_hat, y) elif self.valid_loss == 'accuracy': y_hat = torch.argmax(y_hat, axis=1) - loss = 1 - (sum(y_hat == y) / len(y)) + accuracy = (y_hat==y).sum()/len(y) + loss = 1 - accuracy elif self.valid_loss[:11] == 'nn_accuracy': # NOTE: Very slow! # Performing the knn-search can easily take a couple of minutes, # even for small datasets. out_feat = self.valid_loss[12:] if self.molembedder is None: # legacy - kdtree = self._load_building_blocks_kdtree(out_feat) + kdtree = self._load_building_blocks_kdtree(out_feat) else: kdtree = self.molembedder.kdtree y = nn_search_list(y.detach().cpu().numpy(), None, kdtree) y_hat = nn_search_list(y_hat.detach().cpu().numpy(), None, kdtree) loss = 1 - (sum(y_hat == y) / len(y)) + accuracy = (y_hat==y).sum()/len(y) + loss = 1 - accuracy elif self.valid_loss == 'mse': loss = F.mse_loss(y_hat, y) elif self.valid_loss == 'l1': @@ -154,6 +157,9 @@ def load_array(data_arrays, batch_size, is_train=True, ncpu=-1): def cosine_distance(v1, v2, eps=1e-15): return 1 - np.dot(v1, v2) / (np.linalg.norm(v1, ord=2) * np.linalg.norm(v2, ord=2) + eps) +def nn_search(_e, _tree, _k=1): + dist, ind = _tree.query(_e, k=_k) + return ind[0][0] def nn_search_list(y, out_feat, kdtree): y = np.atleast_2d(y) # (n_samples, n_features) From d5f8131cd2888d261d52552dd5a5985ff214a481 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 31 Aug 2022 17:10:40 -0400 Subject: [PATCH 071/302] delete unused method --- src/syn_net/models/mlp.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/syn_net/models/mlp.py b/src/syn_net/models/mlp.py index 20ad82f1..06f68f2a 100644 --- a/src/syn_net/models/mlp.py +++ b/src/syn_net/models/mlp.py @@ -157,9 +157,6 @@ def load_array(data_arrays, batch_size, is_train=True, ncpu=-1): def cosine_distance(v1, v2, eps=1e-15): return 1 - np.dot(v1, v2) / (np.linalg.norm(v1, ord=2) * np.linalg.norm(v2, ord=2) + eps) -def nn_search(_e, _tree, _k=1): - dist, ind = _tree.query(_e, k=_k) - return ind[0][0] def nn_search_list(y, out_feat, kdtree): y = np.atleast_2d(y) # (n_samples, n_features) From f615430372b330422655888ae7ff05fe866b9026 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 31 Aug 2022 17:21:53 -0400 Subject: [PATCH 072/302] add EarlyStopping, refine ModelCheckpoint, match code --- src/syn_net/models/act.py | 37 +++++++++++++++++++++------ src/syn_net/models/rt1.py | 53 ++++++++++++++++++++++++++++++-------- src/syn_net/models/rt2.py | 54 +++++++++++++++++++++++++++++++-------- src/syn_net/models/rxn.py | 42 +++++++++++++++++++++++------- 4 files changed, 148 insertions(+), 38 deletions(-) diff --git a/src/syn_net/models/act.py b/src/syn_net/models/act.py index 74996ea0..a2f76bb1 100644 --- a/src/syn_net/models/act.py +++ b/src/syn_net/models/act.py @@ -1,23 +1,27 @@ """ Action network. """ -import time +import logging from pathlib import Path import pytorch_lightning as pl import torch from pytorch_lightning import loggers as pl_loggers +from pytorch_lightning.callbacks.early_stopping import EarlyStopping +from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from scipy import sparse from syn_net.config import DATA_FEATURIZED_DIR from syn_net.models.common import VALIDATION_OPTS, get_args from syn_net.models.mlp import MLP, load_array +logger = logging.getLogger(__name__) +MODEL_ID = Path(__file__).stem + if __name__ == '__main__': args = get_args() - validation_option = VALIDATION_OPTS[args.out_dim] id = f'{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_{validation_option[12:]}/' @@ -59,11 +63,28 @@ val_freq=10, ncpu=ncpu) + # Set up Trainer + save_dir = Path("results/logs/" + f"{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}" + f"/{MODEL_ID}") + save_dir.mkdir(exist_ok=True,parents=True) - tb_logger = pl_loggers.TensorBoardLogger(f'act_{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_logs/') - trainer = pl.Trainer(gpus=[0], max_epochs=args.epoch, progress_bar_refresh_rate=20, logger=tb_logger) - t = time.time() - trainer.fit(mlp, train_data_iter, valid_data_iter) - print(time.time() - t, 's') + tb_logger = pl_loggers.TensorBoardLogger(save_dir,name="") - print('Finish!') + checkpoint_callback = ModelCheckpoint( + monitor="val_loss", + dirpath= tb_logger.log_dir, + filename="ckpts.{epoch}-{val_loss:.2f}", + save_weights_only=False, + ) + earlystop_callback = EarlyStopping(monitor="val_loss", patience=10) + + max_epochs = args.epoch if not args.debug else 2 + # Create trainer + trainer = pl.Trainer(gpus=[0], + max_epochs=max_epochs, + progress_bar_refresh_rate = int(len(train_data_iter)*0.05), + callbacks=[checkpoint_callback], + logger=[tb_logger]) + + logger.info(f"Start training") + trainer.fit(mlp, train_data_iter, valid_data_iter) + logger.info(f"Training completed.") diff --git a/src/syn_net/models/rt1.py b/src/syn_net/models/rt1.py index a9cf274d..7913a14a 100644 --- a/src/syn_net/models/rt1.py +++ b/src/syn_net/models/rt1.py @@ -1,25 +1,39 @@ """ Reactant1 network (for predicting 1st reactant). """ -import time +import logging from pathlib import Path import numpy as np import pytorch_lightning as pl import torch from pytorch_lightning import loggers as pl_loggers +from pytorch_lightning.callbacks.early_stopping import EarlyStopping +from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from scipy import sparse -from syn_net.config import DATA_FEATURIZED_DIR +from syn_net.config import DATA_EMBEDDINGS_DIR, DATA_FEATURIZED_DIR from syn_net.models.common import VALIDATION_OPTS, get_args -from syn_net.models.mlp import MLP, load_array +from syn_net.models.mlp import MLP, cosine_distance, load_array +from syn_net.MolEmbedder import MolEmbedder + +logger = logging.getLogger(__name__) +MODEL_ID = Path(__file__).stem + if __name__ == '__main__': args = get_args() + logger.info(f"Start.") validation_option = VALIDATION_OPTS[args.out_dim] + knn_embedding_id = validation_option[12:] + file = Path(DATA_EMBEDDINGS_DIR) / f"enamine_us_emb_{knn_embedding_id}.npy" + logger.info(f"Try to load precomputed MolEmbedder from {file}.") + molembedder = MolEmbedder().load_precomputed(file).init_balltree(metric=cosine_distance) + logger.info(f"Loaded MolEmbedder from {file}.") + id = f'{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_{validation_option[12:]}/' main_dir = Path(DATA_FEATURIZED_DIR) / id batch_size = args.batch_size @@ -35,8 +49,10 @@ y = sparse.load_npz(main_dir / 'y_rt1_valid.npz') X = torch.Tensor(X.A) y = torch.Tensor(y.A) + # Select random 10% of the valid data because "nn_accuracy" is very(!) slow _idx = np.random.choice(list(range(X.shape[0])), size=int(X.shape[0]/10), replace=False) valid_data_iter = load_array((X[_idx], y[_idx]), batch_size, ncpu=ncpu, is_train=False) + logger.info(f"Set up dataloaders.") pl.seed_everything(0) INPUT_DIMS = { @@ -54,18 +70,35 @@ num_dropout_layers=1, task='regression', loss='mse', - valid_loss=validation_option, + valid_loss="mse", optimizer='adam', learning_rate=1e-4, val_freq=10, + molembedder=molembedder, ncpu=ncpu) - tb_logger = pl_loggers.TensorBoardLogger( - f'rt1_{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_{validation_option[12:]}_logs/' + # Set up Trainer + save_dir = Path("results/logs/" + f"{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}" + f"/{MODEL_ID}") + save_dir.mkdir(exist_ok=True,parents=True) + + tb_logger = pl_loggers.TensorBoardLogger(save_dir,name="") + + checkpoint_callback = ModelCheckpoint( + monitor="val_loss", + dirpath= tb_logger.log_dir, + filename="ckpts.{epoch}-{val_loss:.2f}", + save_weights_only=False, ) + earlystop_callback = EarlyStopping(monitor="val_loss", patience=10) + + max_epochs = args.epoch if not args.debug else 2 + # Create trainer + trainer = pl.Trainer(gpus=[0], + max_epochs=max_epochs, + progress_bar_refresh_rate = int(len(train_data_iter)*0.05), + callbacks=[checkpoint_callback], + logger=[tb_logger]) - trainer = pl.Trainer(gpus=[0], max_epochs=args.epoch, progress_bar_refresh_rate=20, logger=tb_logger) - t = time.time() + logger.info(f"Start training") trainer.fit(mlp, train_data_iter, valid_data_iter) - print(time.time() - t, 's') - print('Finish!') + logger.info(f"Training completed.") diff --git a/src/syn_net/models/rt2.py b/src/syn_net/models/rt2.py index d2a9c21e..d09b2183 100644 --- a/src/syn_net/models/rt2.py +++ b/src/syn_net/models/rt2.py @@ -1,32 +1,46 @@ """ Reactant2 network (for predicting 2nd reactant). """ -import time +import logging from pathlib import Path import numpy as np import pytorch_lightning as pl import torch from pytorch_lightning import loggers as pl_loggers +from pytorch_lightning.callbacks.early_stopping import EarlyStopping +from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from scipy import sparse -from syn_net.config import DATA_FEATURIZED_DIR +from syn_net.config import DATA_EMBEDDINGS_DIR, DATA_FEATURIZED_DIR from syn_net.models.common import VALIDATION_OPTS, get_args -from syn_net.models.mlp import MLP, load_array +from syn_net.models.mlp import MLP, cosine_distance, load_array +from syn_net.MolEmbedder import MolEmbedder + + +logger = logging.getLogger(__name__) +MODEL_ID = Path(__file__).stem if __name__ == '__main__': args = get_args() - # Helper to select validation func based on output dim + args.debug = True + # Helper to select validation func based on output dim validation_option = VALIDATION_OPTS[args.out_dim] + knn_embedding_id = validation_option[12:] + file = Path(DATA_EMBEDDINGS_DIR) / f"enamine_us_emb_{knn_embedding_id}.npy" + logger.info(f"Try to load precomputed MolEmbedder from {file}.") + molembedder = MolEmbedder().load_precomputed(file).init_balltree(metric=cosine_distance) + id = f'{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_{validation_option[12:]}/' main_dir = Path(DATA_FEATURIZED_DIR) / id batch_size = args.batch_size ncpu = args.ncpu + X = sparse.load_npz(main_dir / 'X_rt2_train.npz') y = sparse.load_npz(main_dir / 'y_rt2_train.npz') X = torch.Tensor(X.A) @@ -37,6 +51,7 @@ y = sparse.load_npz(main_dir / 'y_rt2_valid.npz') X = torch.Tensor(X.A) y = torch.Tensor(y.A) + # Select random 10% of the valid data because "nn_accuracy" is very(!) slow _idx = np.random.choice(list(range(X.shape[0])), size=int(X.shape[0]/10), replace=False) valid_data_iter = load_array((X[_idx], y[_idx]), batch_size, ncpu=ncpu, is_train=False) @@ -65,15 +80,32 @@ optimizer='adam', learning_rate=1e-4, val_freq=10, + molembedder=molembedder, ncpu=ncpu) - tb_logger = pl_loggers.TensorBoardLogger( - f'rt2_{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_{validation_option[12:]}_logs/' + # Set up Trainer + save_dir = Path("results/logs/" + f"{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}" + f"/{MODEL_ID}") + save_dir.mkdir(exist_ok=True,parents=True) + + tb_logger = pl_loggers.TensorBoardLogger(save_dir,name="") + + checkpoint_callback = ModelCheckpoint( + monitor="val_loss", + dirpath= tb_logger.log_dir, + filename="ckpts.{epoch}-{val_loss:.2f}", + save_weights_only=False, ) + earlystop_callback = EarlyStopping(monitor="val_loss", patience=10) - trainer = pl.Trainer(gpus=[0], max_epochs=args.epoch, progress_bar_refresh_rate=20, logger=tb_logger) - t = time.time() - trainer.fit(mlp, train_data_iter, valid_data_iter) - print(time.time() - t, 's') + max_epochs = args.epoch if not args.debug else 2 + # Create trainer + trainer = pl.Trainer(gpus=[0], + max_epochs=max_epochs, + progress_bar_refresh_rate = int(len(train_data_iter)*0.05), + callbacks=[checkpoint_callback], + logger=[tb_logger], + fast_dev_run=True) - print('Finish!') + logger.info(f"Start training") + trainer.fit(mlp, train_data_iter, valid_data_iter) + logger.info(f"Training completed.") diff --git a/src/syn_net/models/rxn.py b/src/syn_net/models/rxn.py index 737d960b..5add5bd2 100644 --- a/src/syn_net/models/rxn.py +++ b/src/syn_net/models/rxn.py @@ -1,21 +1,27 @@ """ Reaction network. """ -import time +import logging from pathlib import Path import pytorch_lightning as pl import torch from pytorch_lightning import loggers as pl_loggers +from pytorch_lightning.callbacks.early_stopping import EarlyStopping +from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from scipy import sparse from syn_net.config import CHECKPOINTS_DIR, DATA_FEATURIZED_DIR from syn_net.models.common import VALIDATION_OPTS, get_args from syn_net.models.mlp import MLP, load_array +logger = logging.getLogger(__name__) +MODEL_ID = Path(__file__).stem + if __name__ == '__main__': args = get_args() + logger.info(f"Start.") validation_option = VALIDATION_OPTS[args.out_dim] @@ -35,6 +41,7 @@ X = torch.Tensor(X.A) y = torch.LongTensor(y.A.reshape(-1, )) valid_data_iter = load_array((X, y), batch_size, ncpu=ncpu, is_train=False) + logger.info(f"Set up dataloaders.") pl.seed_everything(0) param_path = Path(CHECKPOINTS_DIR) / f"{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_v{args.version}/" @@ -103,12 +110,29 @@ ncpu=ncpu ) - tb_logger = pl_loggers.TensorBoardLogger(f'rxn_{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_logs/') - trainer = pl.Trainer(gpus=[0], max_epochs=args.epoch, progress_bar_refresh_rate=20, logger=tb_logger) - t = time.time() - + # Set up Trainer + # Set up Trainer + save_dir = Path("results/logs/" + f"{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}" + f"/{MODEL_ID}") + save_dir.mkdir(exist_ok=True,parents=True) + + tb_logger = pl_loggers.TensorBoardLogger(save_dir,name="") + + checkpoint_callback = ModelCheckpoint( + monitor="val_loss", + dirpath= tb_logger.log_dir, + filename="ckpts.{epoch}-{val_loss:.2f}", + save_weights_only=False, + ) + earlystop_callback = EarlyStopping(monitor="val_loss", patience=10) + + max_epochs = args.epoch if not args.debug else 2 + # Create trainer + trainer = pl.Trainer(gpus=[0], + max_epochs=max_epochs, + progress_bar_refresh_rate = int(len(train_data_iter)*0.05), + callbacks=[checkpoint_callback,earlystop_callback], + logger=[tb_logger]) + + logger.info(f"Start training") trainer.fit(mlp, train_data_iter, valid_data_iter) - - print(time.time() - t, 's') - - print('Finish!') + logger.info(f"Training completed.") From 79c3bd88f6ed5f6423d35233322da8edce3e7541 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 31 Aug 2022 17:36:57 -0400 Subject: [PATCH 073/302] refactor: continue matching code in MLPs, use common "dataloader" --- src/syn_net/models/act.py | 44 ++++++++++++++----------- src/syn_net/models/common.py | 30 +++++++++++++++++ src/syn_net/models/rt1.py | 63 +++++++++++++++++++----------------- src/syn_net/models/rt2.py | 59 ++++++++++++++++++--------------- src/syn_net/models/rxn.py | 41 ++++++++++++----------- 5 files changed, 146 insertions(+), 91 deletions(-) diff --git a/src/syn_net/models/act.py b/src/syn_net/models/act.py index a2f76bb1..b7bc00ce 100644 --- a/src/syn_net/models/act.py +++ b/src/syn_net/models/act.py @@ -12,7 +12,7 @@ from scipy import sparse from syn_net.config import DATA_FEATURIZED_DIR -from syn_net.models.common import VALIDATION_OPTS, get_args +from syn_net.models.common import VALIDATION_OPTS, get_args, xy_to_dataloader from syn_net.models.mlp import MLP, load_array logger = logging.getLogger(__name__) @@ -24,22 +24,30 @@ validation_option = VALIDATION_OPTS[args.out_dim] + # Get ID for the data to know what we're working with and find right files. id = f'{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_{validation_option[12:]}/' - main_dir = Path(DATA_FEATURIZED_DIR) / id - batch_size = args.batch_size - ncpu = args.ncpu - - X = sparse.load_npz(main_dir / 'X_act_train.npz') - y = sparse.load_npz(main_dir / 'y_act_train.npz') - X = torch.Tensor(X.A) - y = torch.LongTensor(y.A.reshape(-1, )) - train_data_iter = load_array((X, y), batch_size, ncpu=ncpu, is_train=True) - - X = sparse.load_npz(main_dir / 'X_act_valid.npz') - y = sparse.load_npz(main_dir / 'y_act_valid.npz') - X = torch.Tensor(X.A) - y = torch.LongTensor(y.A.reshape(-1, )) - valid_data_iter = load_array((X, y), batch_size, ncpu=ncpu, is_train=False) + + dataset = "train" + train_dataloader = xy_to_dataloader( + X_file = Path(DATA_FEATURIZED_DIR) / f"{id}/X_{MODEL_ID}_{dataset}.npz", + y_file = Path(DATA_FEATURIZED_DIR) / f"{id}/y_{MODEL_ID}_{dataset}.npz", + n=None if not args.debug else 1000, + batch_size = args.batch_size, + num_workers=args.ncpu, + shuffle = True if dataset == "train" else False, + ) + + dataset = "valid" + valid_dataloader = xy_to_dataloader( + X_file = Path(DATA_FEATURIZED_DIR) / f"{id}/X_{MODEL_ID}_{dataset}.npz", + y_file = Path(DATA_FEATURIZED_DIR) / f"{id}/y_{MODEL_ID}_{dataset}.npz", + n=None if not args.debug else 1000, + batch_size = args.batch_size, + num_workers=args.ncpu, + shuffle = True if dataset == "train" else False, + ) + logger.info(f"Set up dataloaders.") + pl.seed_everything(0) INPUT_DIMS = { @@ -81,10 +89,10 @@ # Create trainer trainer = pl.Trainer(gpus=[0], max_epochs=max_epochs, - progress_bar_refresh_rate = int(len(train_data_iter)*0.05), + progress_bar_refresh_rate = int(len(train_dataloader)*0.05), callbacks=[checkpoint_callback], logger=[tb_logger]) logger.info(f"Start training") - trainer.fit(mlp, train_data_iter, valid_data_iter) + trainer.fit(mlp, train_dataloader, valid_dataloader) logger.info(f"Training completed.") diff --git a/src/syn_net/models/common.py b/src/syn_net/models/common.py index 391955b4..03462be8 100644 --- a/src/syn_net/models/common.py +++ b/src/syn_net/models/common.py @@ -2,6 +2,10 @@ """ # Helper to select validation func based on output dim +from typing import Union +from scipy import sparse +import torch + VALIDATION_OPTS = { 300: "nn_accuracy_gin", 4096: "nn_accuracy_fp_4096", @@ -35,6 +39,32 @@ def get_args(): parser.add_argument("--debug", default=False, action="store_true") return parser.parse_args() +def xy_to_dataloader(X_file: str = None,y_file: str = None, n: Union[int,float] = 1.0, **kwargs): + """Loads featurized X,y `*.npz`-data into a `DataLoader`""" + X = sparse.load_npz(X_file) + y = sparse.load_npz(y_file) + # Filer? + if isinstance(n,int): + n = min(n,min(X.shape[0],y.shape[0])) # ensure n does not exceed size of dataset + X = X[:n] + y = y[:n] + elif isinstance(n,float) and n < 1.0: + xn = X.shape[0]*n + yn = X.shape[0]*n + X = X[:xn] + y = y[:yn] + else: + pass # + dataset = torch.utils.data.TensorDataset( + torch.Tensor(X.A), + torch.Tensor(y.A.reshape(-1,)), + ) + return torch.utils.data.DataLoader(dataset,**kwargs) + + + + + if __name__=="__main__": import json args = get_args() diff --git a/src/syn_net/models/rt1.py b/src/syn_net/models/rt1.py index 7913a14a..1d9a2cf9 100644 --- a/src/syn_net/models/rt1.py +++ b/src/syn_net/models/rt1.py @@ -4,56 +4,61 @@ import logging from pathlib import Path -import numpy as np import pytorch_lightning as pl -import torch from pytorch_lightning import loggers as pl_loggers from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint -from scipy import sparse from syn_net.config import DATA_EMBEDDINGS_DIR, DATA_FEATURIZED_DIR -from syn_net.models.common import VALIDATION_OPTS, get_args -from syn_net.models.mlp import MLP, cosine_distance, load_array +from syn_net.models.common import VALIDATION_OPTS, get_args, xy_to_dataloader +from syn_net.models.mlp import MLP, cosine_distance from syn_net.MolEmbedder import MolEmbedder logger = logging.getLogger(__name__) MODEL_ID = Path(__file__).stem +def _fetch_molembedder(): + knn_embedding_id = validation_option[12:] + file = Path(DATA_EMBEDDINGS_DIR) / f"enamine_us_emb_{knn_embedding_id}.npy" + logger.info(f"Try to load precomputed MolEmbedder from {file}.") + molembedder = MolEmbedder().load_precomputed(file).init_balltree(metric=cosine_distance) + logger.info(f"Loaded MolEmbedder from {file}.") + return molembedder + if __name__ == '__main__': args = get_args() - logger.info(f"Start.") validation_option = VALIDATION_OPTS[args.out_dim] - knn_embedding_id = validation_option[12:] - file = Path(DATA_EMBEDDINGS_DIR) / f"enamine_us_emb_{knn_embedding_id}.npy" - logger.info(f"Try to load precomputed MolEmbedder from {file}.") - molembedder = MolEmbedder().load_precomputed(file).init_balltree(metric=cosine_distance) - logger.info(f"Loaded MolEmbedder from {file}.") - + # Get ID for the data to know what we're working with and find right files. id = f'{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_{validation_option[12:]}/' - main_dir = Path(DATA_FEATURIZED_DIR) / id - batch_size = args.batch_size - ncpu = args.ncpu - - X = sparse.load_npz(main_dir / 'X_rt1_train.npz') - y = sparse.load_npz(main_dir / 'y_rt1_train.npz') - X = torch.Tensor(X.A) - y = torch.Tensor(y.A) - train_data_iter = load_array((X, y), batch_size, ncpu=ncpu, is_train=True) - - X = sparse.load_npz(main_dir / 'X_rt1_valid.npz') - y = sparse.load_npz(main_dir / 'y_rt1_valid.npz') - X = torch.Tensor(X.A) - y = torch.Tensor(y.A) - # Select random 10% of the valid data because "nn_accuracy" is very(!) slow - _idx = np.random.choice(list(range(X.shape[0])), size=int(X.shape[0]/10), replace=False) - valid_data_iter = load_array((X[_idx], y[_idx]), batch_size, ncpu=ncpu, is_train=False) + + dataset = "train" + train_dataloader = xy_to_dataloader( + X_file = Path(DATA_FEATURIZED_DIR) / f"{id}/X_{MODEL_ID}_{dataset}.npz", + y_file = Path(DATA_FEATURIZED_DIR) / f"{id}/y_{MODEL_ID}_{dataset}.npz", + n=None if not args.debug else 1000, + batch_size = args.batch_size, + num_workers=args.ncpu, + shuffle = True if dataset == "train" else False, + ) + + dataset = "valid" + valid_dataloader = xy_to_dataloader( + X_file = Path(DATA_FEATURIZED_DIR) / f"{id}/X_{MODEL_ID}_{dataset}.npz", + y_file = Path(DATA_FEATURIZED_DIR) / f"{id}/y_{MODEL_ID}_{dataset}.npz", + n=None if not args.debug else 1000, + batch_size = args.batch_size, + num_workers=args.ncpu, + shuffle = True if dataset == "train" else False, + ) logger.info(f"Set up dataloaders.") + # Fetch Molembedder and init BallTree + molembedder = _fetch_molembedder() + pl.seed_everything(0) INPUT_DIMS = { "fp": int(3 * args.nbits), diff --git a/src/syn_net/models/rt2.py b/src/syn_net/models/rt2.py index d09b2183..ddcd4804 100644 --- a/src/syn_net/models/rt2.py +++ b/src/syn_net/models/rt2.py @@ -13,47 +13,54 @@ from scipy import sparse from syn_net.config import DATA_EMBEDDINGS_DIR, DATA_FEATURIZED_DIR -from syn_net.models.common import VALIDATION_OPTS, get_args -from syn_net.models.mlp import MLP, cosine_distance, load_array +from syn_net.models.common import VALIDATION_OPTS, get_args, xy_to_dataloader +from syn_net.models.mlp import MLP, cosine_distance from syn_net.MolEmbedder import MolEmbedder - logger = logging.getLogger(__name__) MODEL_ID = Path(__file__).stem +def _fetch_molembedder(): + knn_embedding_id = validation_option[12:] + file = Path(DATA_EMBEDDINGS_DIR) / f"enamine_us_emb_{knn_embedding_id}.npy" + logger.info(f"Try to load precomputed MolEmbedder from {file}.") + molembedder = MolEmbedder().load_precomputed(file).init_balltree(metric=cosine_distance) + logger.info(f"Loaded MolEmbedder from {file}.") + return molembedder + + if __name__ == '__main__': args = get_args() - args.debug = True - # Helper to select validation func based on output dim validation_option = VALIDATION_OPTS[args.out_dim] - knn_embedding_id = validation_option[12:] - file = Path(DATA_EMBEDDINGS_DIR) / f"enamine_us_emb_{knn_embedding_id}.npy" - logger.info(f"Try to load precomputed MolEmbedder from {file}.") - molembedder = MolEmbedder().load_precomputed(file).init_balltree(metric=cosine_distance) - + # Get ID for the data to know what we're working with and find right files. id = f'{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_{validation_option[12:]}/' - main_dir = Path(DATA_FEATURIZED_DIR) / id - batch_size = args.batch_size - ncpu = args.ncpu - + dataset = "train" + train_dataloader = xy_to_dataloader( + X_file = Path(DATA_FEATURIZED_DIR) / f"{id}/X_{MODEL_ID}_{dataset}.npz", + y_file = Path(DATA_FEATURIZED_DIR) / f"{id}/y_{MODEL_ID}_{dataset}.npz", + n=None if not args.debug else 1000, + batch_size = args.batch_size, + num_workers=args.ncpu, + shuffle = True if dataset == "train" else False, + ) - X = sparse.load_npz(main_dir / 'X_rt2_train.npz') - y = sparse.load_npz(main_dir / 'y_rt2_train.npz') - X = torch.Tensor(X.A) - y = torch.Tensor(y.A) - train_data_iter = load_array((X, y), batch_size, ncpu=ncpu, is_train=True) + dataset = "valid" + valid_dataloader = xy_to_dataloader( + X_file = Path(DATA_FEATURIZED_DIR) / f"{id}/X_{MODEL_ID}_{dataset}.npz", + y_file = Path(DATA_FEATURIZED_DIR) / f"{id}/y_{MODEL_ID}_{dataset}.npz", + n=None if not args.debug else 1000, + batch_size = args.batch_size, + num_workers=args.ncpu, + shuffle = True if dataset == "train" else False, + ) + logger.info(f"Set up dataloaders.") - X = sparse.load_npz(main_dir / 'X_rt2_valid.npz') - y = sparse.load_npz(main_dir / 'y_rt2_valid.npz') - X = torch.Tensor(X.A) - y = torch.Tensor(y.A) - # Select random 10% of the valid data because "nn_accuracy" is very(!) slow - _idx = np.random.choice(list(range(X.shape[0])), size=int(X.shape[0]/10), replace=False) - valid_data_iter = load_array((X[_idx], y[_idx]), batch_size, ncpu=ncpu, is_train=False) + # Fetch Molembedder and init BallTree + molembedder = _fetch_molembedder() pl.seed_everything(0) INPUT_DIMS = { diff --git a/src/syn_net/models/rxn.py b/src/syn_net/models/rxn.py index 5add5bd2..cf6f3453 100644 --- a/src/syn_net/models/rxn.py +++ b/src/syn_net/models/rxn.py @@ -12,8 +12,8 @@ from scipy import sparse from syn_net.config import CHECKPOINTS_DIR, DATA_FEATURIZED_DIR -from syn_net.models.common import VALIDATION_OPTS, get_args -from syn_net.models.mlp import MLP, load_array +from syn_net.models.common import VALIDATION_OPTS, get_args, xy_to_dataloader +from syn_net.models.mlp import MLP logger = logging.getLogger(__name__) MODEL_ID = Path(__file__).stem @@ -21,26 +21,31 @@ if __name__ == '__main__': args = get_args() - logger.info(f"Start.") validation_option = VALIDATION_OPTS[args.out_dim] + # Get ID for the data to know what we're working with and find right files. id = f'{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_{validation_option[12:]}/' - main_dir = Path(DATA_FEATURIZED_DIR) / id - batch_size = args.batch_size - ncpu = args.ncpu - - X = sparse.load_npz(main_dir / 'X_rxn_train.npz') - y = sparse.load_npz(main_dir / 'y_rxn_train.npz') - X = torch.Tensor(X.A) - y = torch.LongTensor(y.A.reshape(-1, )) - train_data_iter = load_array((X, y), batch_size, ncpu=ncpu, is_train=True) - - X = sparse.load_npz(main_dir / 'X_rxn_valid.npz') - y = sparse.load_npz(main_dir / 'y_rxn_valid.npz') - X = torch.Tensor(X.A) - y = torch.LongTensor(y.A.reshape(-1, )) - valid_data_iter = load_array((X, y), batch_size, ncpu=ncpu, is_train=False) + + dataset = "train" + train_dataloader = xy_to_dataloader( + X_file = Path(DATA_FEATURIZED_DIR) / f"{id}/X_{MODEL_ID}_{dataset}.npz", + y_file = Path(DATA_FEATURIZED_DIR) / f"{id}/y_{MODEL_ID}_{dataset}.npz", + n=None if not args.debug else 1000, + batch_size = args.batch_size, + num_workers=args.ncpu, + shuffle = True if dataset == "train" else False, + ) + + dataset = "valid" + valid_dataloader = xy_to_dataloader( + X_file = Path(DATA_FEATURIZED_DIR) / f"{id}/X_{MODEL_ID}_{dataset}.npz", + y_file = Path(DATA_FEATURIZED_DIR) / f"{id}/y_{MODEL_ID}_{dataset}.npz", + n=None if not args.debug else 1000, + batch_size = args.batch_size, + num_workers=args.ncpu, + shuffle = True if dataset == "train" else False, + ) logger.info(f"Set up dataloaders.") pl.seed_everything(0) From d9cb5edf1c68a1d29bba4787f7326a2ded8ba4c5 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 31 Aug 2022 17:38:55 -0400 Subject: [PATCH 074/302] black + isort --- src/syn_net/models/act.py | 79 +++++++++-------- src/syn_net/models/common.py | 79 ++++++++--------- src/syn_net/models/mlp.py | 133 ++++++++++++++++------------- src/syn_net/models/prepare_data.py | 44 +++++++--- src/syn_net/models/rt1.py | 83 ++++++++++-------- src/syn_net/models/rt2.py | 87 ++++++++++--------- src/syn_net/models/rxn.py | 127 ++++++++++++++------------- 7 files changed, 355 insertions(+), 277 deletions(-) diff --git a/src/syn_net/models/act.py b/src/syn_net/models/act.py index b7bc00ce..546d771d 100644 --- a/src/syn_net/models/act.py +++ b/src/syn_net/models/act.py @@ -18,68 +18,75 @@ logger = logging.getLogger(__name__) MODEL_ID = Path(__file__).stem -if __name__ == '__main__': +if __name__ == "__main__": args = get_args() validation_option = VALIDATION_OPTS[args.out_dim] # Get ID for the data to know what we're working with and find right files. - id = f'{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_{validation_option[12:]}/' + id = ( + f"{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_{validation_option[12:]}/" + ) dataset = "train" train_dataloader = xy_to_dataloader( - X_file = Path(DATA_FEATURIZED_DIR) / f"{id}/X_{MODEL_ID}_{dataset}.npz", - y_file = Path(DATA_FEATURIZED_DIR) / f"{id}/y_{MODEL_ID}_{dataset}.npz", + X_file=Path(DATA_FEATURIZED_DIR) / f"{id}/X_{MODEL_ID}_{dataset}.npz", + y_file=Path(DATA_FEATURIZED_DIR) / f"{id}/y_{MODEL_ID}_{dataset}.npz", n=None if not args.debug else 1000, - batch_size = args.batch_size, + batch_size=args.batch_size, num_workers=args.ncpu, - shuffle = True if dataset == "train" else False, + shuffle=True if dataset == "train" else False, ) dataset = "valid" valid_dataloader = xy_to_dataloader( - X_file = Path(DATA_FEATURIZED_DIR) / f"{id}/X_{MODEL_ID}_{dataset}.npz", - y_file = Path(DATA_FEATURIZED_DIR) / f"{id}/y_{MODEL_ID}_{dataset}.npz", + X_file=Path(DATA_FEATURIZED_DIR) / f"{id}/X_{MODEL_ID}_{dataset}.npz", + y_file=Path(DATA_FEATURIZED_DIR) / f"{id}/y_{MODEL_ID}_{dataset}.npz", n=None if not args.debug else 1000, - batch_size = args.batch_size, + batch_size=args.batch_size, num_workers=args.ncpu, - shuffle = True if dataset == "train" else False, + shuffle=True if dataset == "train" else False, ) logger.info(f"Set up dataloaders.") - pl.seed_everything(0) INPUT_DIMS = { "fp": int(3 * args.nbits), - "gin" : int(2 * args.nbits + args.out_dim) - } # somewhat constant... + "gin": int(2 * args.nbits + args.out_dim), + } # somewhat constant... input_dims = INPUT_DIMS[args.featurize] - mlp = MLP(input_dim=input_dims, - output_dim=4, - hidden_dim=1000, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='classification', - loss='cross_entropy', - valid_loss='accuracy', - optimizer='adam', - learning_rate=1e-4, - val_freq=10, - ncpu=ncpu) + mlp = MLP( + input_dim=input_dims, + output_dim=4, + hidden_dim=1000, + num_layers=5, + dropout=0.5, + num_dropout_layers=1, + task="classification", + loss="cross_entropy", + valid_loss="accuracy", + optimizer="adam", + learning_rate=1e-4, + val_freq=10, + ncpu=ncpu, + ) # Set up Trainer - save_dir = Path("results/logs/" + f"{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}" + f"/{MODEL_ID}") - save_dir.mkdir(exist_ok=True,parents=True) + save_dir = Path( + "results/logs/" + + f"{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}" + + f"/{MODEL_ID}" + ) + save_dir.mkdir(exist_ok=True, parents=True) - tb_logger = pl_loggers.TensorBoardLogger(save_dir,name="") + tb_logger = pl_loggers.TensorBoardLogger(save_dir, name="") checkpoint_callback = ModelCheckpoint( monitor="val_loss", - dirpath= tb_logger.log_dir, + dirpath=tb_logger.log_dir, filename="ckpts.{epoch}-{val_loss:.2f}", save_weights_only=False, ) @@ -87,11 +94,13 @@ max_epochs = args.epoch if not args.debug else 2 # Create trainer - trainer = pl.Trainer(gpus=[0], - max_epochs=max_epochs, - progress_bar_refresh_rate = int(len(train_dataloader)*0.05), - callbacks=[checkpoint_callback], - logger=[tb_logger]) + trainer = pl.Trainer( + gpus=[0], + max_epochs=max_epochs, + progress_bar_refresh_rate=int(len(train_dataloader) * 0.05), + callbacks=[checkpoint_callback], + logger=[tb_logger], + ) logger.info(f"Start training") trainer.fit(mlp, train_dataloader, valid_dataloader) diff --git a/src/syn_net/models/common.py b/src/syn_net/models/common.py index 03462be8..5ffffcc8 100644 --- a/src/syn_net/models/common.py +++ b/src/syn_net/models/common.py @@ -3,8 +3,9 @@ # Helper to select validation func based on output dim from typing import Union -from scipy import sparse + import torch +from scipy import sparse VALIDATION_OPTS = { 300: "nn_accuracy_gin", @@ -13,61 +14,63 @@ 200: "nn_accuracy_rdkit2d", } + def get_args(): import argparse + parser = argparse.ArgumentParser() - parser.add_argument("-f", "--featurize", type=str, default='fp', - help="Choose from ['fp', 'gin']") - parser.add_argument("-r", "--rxn_template", type=str, default='hb', - help="Choose from ['hb', 'pis']") - parser.add_argument("--radius", type=int, default=2, - help="Radius for Morgan fingerprint.") - parser.add_argument("--nbits", type=int, default=4096, - help="Number of Bits for Morgan fingerprint.") - parser.add_argument("--out_dim", type=int, default=256, - help="Output dimension.") - parser.add_argument("--ncpu", type=int, default=16, - help="Number of cpus") - parser.add_argument("--batch_size", type=int, default=64, - help="Batch size") - parser.add_argument("--epoch", type=int, default=2000, - help="Maximum number of epoches.") - parser.add_argument("--restart", type=bool, default=False, - help="Indicates whether to restart training.") - parser.add_argument("-v", "--version", type=int, default=1, - help="Version") + parser.add_argument( + "-f", "--featurize", type=str, default="fp", help="Choose from ['fp', 'gin']" + ) + parser.add_argument( + "-r", "--rxn_template", type=str, default="hb", help="Choose from ['hb', 'pis']" + ) + parser.add_argument("--radius", type=int, default=2, help="Radius for Morgan fingerprint.") + parser.add_argument( + "--nbits", type=int, default=4096, help="Number of Bits for Morgan fingerprint." + ) + parser.add_argument("--out_dim", type=int, default=256, help="Output dimension.") + parser.add_argument("--ncpu", type=int, default=16, help="Number of cpus") + parser.add_argument("--batch_size", type=int, default=64, help="Batch size") + parser.add_argument("--epoch", type=int, default=2000, help="Maximum number of epoches.") + parser.add_argument( + "--restart", type=bool, default=False, help="Indicates whether to restart training." + ) + parser.add_argument("-v", "--version", type=int, default=1, help="Version") parser.add_argument("--debug", default=False, action="store_true") return parser.parse_args() -def xy_to_dataloader(X_file: str = None,y_file: str = None, n: Union[int,float] = 1.0, **kwargs): + +def xy_to_dataloader(X_file: str = None, y_file: str = None, n: Union[int, float] = 1.0, **kwargs): """Loads featurized X,y `*.npz`-data into a `DataLoader`""" X = sparse.load_npz(X_file) y = sparse.load_npz(y_file) # Filer? - if isinstance(n,int): - n = min(n,min(X.shape[0],y.shape[0])) # ensure n does not exceed size of dataset + if isinstance(n, int): + n = min(n, min(X.shape[0], y.shape[0])) # ensure n does not exceed size of dataset X = X[:n] y = y[:n] - elif isinstance(n,float) and n < 1.0: - xn = X.shape[0]*n - yn = X.shape[0]*n + elif isinstance(n, float) and n < 1.0: + xn = X.shape[0] * n + yn = X.shape[0] * n X = X[:xn] y = y[:yn] else: - pass # - dataset = torch.utils.data.TensorDataset( + pass # + dataset = torch.utils.data.TensorDataset( torch.Tensor(X.A), - torch.Tensor(y.A.reshape(-1,)), - ) - return torch.utils.data.DataLoader(dataset,**kwargs) + torch.Tensor( + y.A.reshape( + -1, + ) + ), + ) + return torch.utils.data.DataLoader(dataset, **kwargs) - - - -if __name__=="__main__": +if __name__ == "__main__": import json + args = get_args() print("Default Arguments are:") - print(json.dumps(args.__dict__,indent=2)) - + print(json.dumps(args.__dict__, indent=2)) diff --git a/src/syn_net/models/mlp.py b/src/syn_net/models/mlp.py index 06f68f2a..89f48e31 100644 --- a/src/syn_net/models/mlp.py +++ b/src/syn_net/models/mlp.py @@ -15,23 +15,26 @@ from syn_net.MolEmbedder import MolEmbedder logger = logging.getLogger(__name__) -class MLP(pl.LightningModule): - def __init__(self, input_dim=3072, - output_dim=4, - hidden_dim=1000, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='classification', - loss='cross_entropy', - valid_loss='accuracy', - optimizer='adam', - learning_rate=1e-4, - val_freq=10, - ncpu=16, - molembedder: MolEmbedder = None, - ): + +class MLP(pl.LightningModule): + def __init__( + self, + input_dim=3072, + output_dim=4, + hidden_dim=1000, + num_layers=5, + dropout=0.5, + num_dropout_layers=1, + task="classification", + loss="cross_entropy", + valid_loss="accuracy", + optimizer="adam", + learning_rate=1e-4, + val_freq=10, + ncpu=16, + molembedder: MolEmbedder = None, + ): super().__init__() self.save_hyperparameters(ignore="molembedder") self.loss = loss @@ -47,7 +50,7 @@ def __init__(self, input_dim=3072, modules.append(nn.BatchNorm1d(hidden_dim)) modules.append(nn.ReLU()) - for i in range(num_layers-2): + for i in range(num_layers - 2): modules.append(nn.Linear(hidden_dim, hidden_dim)) modules.append(nn.BatchNorm1d(hidden_dim)) modules.append(nn.ReLU()) @@ -55,7 +58,7 @@ def __init__(self, input_dim=3072, modules.append(nn.Dropout(dropout)) modules.append(nn.Linear(hidden_dim, output_dim)) - if task == 'classification': + if task == "classification": modules.append(nn.Softmax(dim=1)) self.layers = nn.Sequential(*modules) @@ -66,17 +69,17 @@ def forward(self, x): def training_step(self, batch, batch_idx): x, y = batch y_hat = self.layers(x) - if self.loss == 'cross_entropy': + if self.loss == "cross_entropy": loss = F.cross_entropy(y_hat, y) - elif self.loss == 'mse': + elif self.loss == "mse": loss = F.mse_loss(y_hat, y) - elif self.loss == 'l1': + elif self.loss == "l1": loss = F.l1_loss(y_hat, y) - elif self.loss == 'huber': + elif self.loss == "huber": loss = F.huber_loss(y_hat, y) else: - raise ValueError('Not specified loss function: % s' % self.loss) - self.log(f'train_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) + raise ValueError("Not specified loss function: % s" % self.loss) + self.log(f"train_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) return loss def _load_building_blocks_kdtree(self, out_feat: str) -> np.ndarray: @@ -84,25 +87,27 @@ def _load_building_blocks_kdtree(self, out_feat: str) -> np.ndarray: as a BallTree. """ - from syn_net.config import DATA_EMBEDDINGS_DIR from pathlib import Path - if out_feat == 'gin': - bb_emb_gin = np.load(Path(DATA_EMBEDDINGS_DIR) / f'enamine_us_emb_{out_feat}.npy') - kdtree = BallTree(bb_emb_gin, metric='euclidean') - elif out_feat == 'fp_4096': - bb_emb_fp_4096 = np.load(Path(DATA_EMBEDDINGS_DIR) / f'enamine_us_emb_{out_feat}.npy') - kdtree = BallTree(bb_emb_fp_4096, metric='euclidean') - elif out_feat == 'fp_256': - bb_emb_fp_256 = np.load(Path(DATA_EMBEDDINGS_DIR) / f'enamine_us_emb_{out_feat}.npy') + + from syn_net.config import DATA_EMBEDDINGS_DIR + + if out_feat == "gin": + bb_emb_gin = np.load(Path(DATA_EMBEDDINGS_DIR) / f"enamine_us_emb_{out_feat}.npy") + kdtree = BallTree(bb_emb_gin, metric="euclidean") + elif out_feat == "fp_4096": + bb_emb_fp_4096 = np.load(Path(DATA_EMBEDDINGS_DIR) / f"enamine_us_emb_{out_feat}.npy") + kdtree = BallTree(bb_emb_fp_4096, metric="euclidean") + elif out_feat == "fp_256": + bb_emb_fp_256 = np.load(Path(DATA_EMBEDDINGS_DIR) / f"enamine_us_emb_{out_feat}.npy") kdtree = BallTree(bb_emb_fp_256, metric=cosine_distance) - elif out_feat == 'rdkit2d': - bb_emb_rdkit2d = np.load(Path(DATA_EMBEDDINGS_DIR) / f'enamine_us_emb_{out_feat}.npy') - kdtree = BallTree(bb_emb_rdkit2d, metric='euclidean') + elif out_feat == "rdkit2d": + bb_emb_rdkit2d = np.load(Path(DATA_EMBEDDINGS_DIR) / f"enamine_us_emb_{out_feat}.npy") + kdtree = BallTree(bb_emb_rdkit2d, metric="euclidean") elif out_feat == "gin_unittest": # The embeddings are pre-computed based on the building blocks # under 'tests/assets/building_blocks_matched.csv.gz'. emb = np.load("tests/data/building_blocks_emb.npy") - kdtree = BallTree(emb,metric="euclidean") + kdtree = BallTree(emb, metric="euclidean") else: raise ValueError return kdtree @@ -111,66 +116,78 @@ def validation_step(self, batch, batch_idx): if self.trainer.current_epoch % self.val_freq == 0: x, y = batch y_hat = self.layers(x) - if self.valid_loss == 'cross_entropy': + if self.valid_loss == "cross_entropy": loss = F.cross_entropy(y_hat, y) - elif self.valid_loss == 'accuracy': + elif self.valid_loss == "accuracy": y_hat = torch.argmax(y_hat, axis=1) - accuracy = (y_hat==y).sum()/len(y) + accuracy = (y_hat == y).sum() / len(y) loss = 1 - accuracy - elif self.valid_loss[:11] == 'nn_accuracy': + elif self.valid_loss[:11] == "nn_accuracy": # NOTE: Very slow! # Performing the knn-search can easily take a couple of minutes, # even for small datasets. out_feat = self.valid_loss[12:] - if self.molembedder is None: # legacy + if self.molembedder is None: # legacy kdtree = self._load_building_blocks_kdtree(out_feat) else: kdtree = self.molembedder.kdtree - y = nn_search_list(y.detach().cpu().numpy(), None, kdtree) + y = nn_search_list(y.detach().cpu().numpy(), None, kdtree) y_hat = nn_search_list(y_hat.detach().cpu().numpy(), None, kdtree) loss = 1 - (sum(y_hat == y) / len(y)) - accuracy = (y_hat==y).sum()/len(y) + accuracy = (y_hat == y).sum() / len(y) loss = 1 - accuracy - elif self.valid_loss == 'mse': + elif self.valid_loss == "mse": loss = F.mse_loss(y_hat, y) - elif self.valid_loss == 'l1': + elif self.valid_loss == "l1": loss = F.l1_loss(y_hat, y) - elif self.valid_loss == 'huber': + elif self.valid_loss == "huber": loss = F.huber_loss(y_hat, y) else: - raise ValueError('Not specified validation loss function') - self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) + raise ValueError("Not specified validation loss function") + self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) else: pass def configure_optimizers(self): - if self.optimizer == 'adam': + if self.optimizer == "adam": optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) - elif self.optimizer == 'sgd': + elif self.optimizer == "sgd": optimizer = torch.optim.SGD(self.parameters(), lr=self.learning_rate) return optimizer + def load_array(data_arrays, batch_size, is_train=True, ncpu=-1): dataset = torch.utils.data.TensorDataset(*data_arrays) return torch.utils.data.DataLoader(dataset, batch_size, shuffle=is_train, num_workers=ncpu) + def cosine_distance(v1, v2, eps=1e-15): return 1 - np.dot(v1, v2) / (np.linalg.norm(v1, ord=2) * np.linalg.norm(v2, ord=2) + eps) def nn_search_list(y, out_feat, kdtree): - y = np.atleast_2d(y) # (n_samples, n_features) - ind = kdtree.query(y,k=1,return_distance=False) # (n_samples, 1) + y = np.atleast_2d(y) # (n_samples, n_features) + ind = kdtree.query(y, k=1, return_distance=False) # (n_samples, 1) return ind -if __name__ == '__main__': +if __name__ == "__main__": states_list = [] steps_list = [] for i in range(1): - states_list.append(np.load('/home/rociomer/data/synth_net/pis_fp/states_' + str(i) + '_valid.npz', allow_pickle=True)) - steps_list.append(np.load('/home/rociomer/data/synth_net/pis_fp/steps_' + str(i) + '_valid.npz', allow_pickle=True)) + states_list.append( + np.load( + "/home/rociomer/data/synth_net/pis_fp/states_" + str(i) + "_valid.npz", + allow_pickle=True, + ) + ) + steps_list.append( + np.load( + "/home/rociomer/data/synth_net/pis_fp/steps_" + str(i) + "_valid.npz", + allow_pickle=True, + ) + ) states = np.concatenate(states_list, axis=0) steps = np.concatenate(steps_list, axis=0) @@ -186,9 +203,9 @@ def nn_search_list(y, out_feat, kdtree): pl.seed_everything(0) mlp = MLP() - tb_logger = pl_loggers.TensorBoardLogger('temp_logs/') + tb_logger = pl_loggers.TensorBoardLogger("temp_logs/") trainer = pl.Trainer(gpus=[0], max_epochs=30, progress_bar_refresh_rate=20, logger=tb_logger) t = time.time() trainer.fit(mlp, train_data_iter, train_data_iter) - print(time.time() - t, 's') + print(time.time() - t, "s") diff --git a/src/syn_net/models/prepare_data.py b/src/syn_net/models/prepare_data.py index cd5a1cf5..d34cdb57 100644 --- a/src/syn_net/models/prepare_data.py +++ b/src/syn_net/models/prepare_data.py @@ -11,30 +11,46 @@ logger = logging.getLogger(__file__) -if __name__ == '__main__': +if __name__ == "__main__": import argparse + parser = argparse.ArgumentParser() - parser.add_argument("-e", "--targetembedding", type=str, default='fp', - help="Choose from ['fp', 'gin']") - parser.add_argument("-o", "--outputembedding", type=str, default='fp_256', - help="Choose from ['fp_4096', 'fp_256', 'gin', 'rdkit2d']") - parser.add_argument("-r", "--radius", type=int, default=2, - help="Radius for Morgan Fingerprint") - parser.add_argument("-b", "--nbits", type=int, default=4096, - help="Number of Bits for Morgan Fingerprint") - parser.add_argument("-rxn", "--rxn_template", type=str, default='hb', choices=["hb","pis"], - help="Choose from ['hb', 'pis']") + parser.add_argument( + "-e", "--targetembedding", type=str, default="fp", help="Choose from ['fp', 'gin']" + ) + parser.add_argument( + "-o", + "--outputembedding", + type=str, + default="fp_256", + help="Choose from ['fp_4096', 'fp_256', 'gin', 'rdkit2d']", + ) + parser.add_argument("-r", "--radius", type=int, default=2, help="Radius for Morgan Fingerprint") + parser.add_argument( + "-b", "--nbits", type=int, default=4096, help="Number of Bits for Morgan Fingerprint" + ) + parser.add_argument( + "-rxn", + "--rxn_template", + type=str, + default="hb", + choices=["hb", "pis"], + help="Choose from ['hb', 'pis']", + ) args = parser.parse_args() reaction_template_id = args.rxn_template embedding = args.targetembedding output_emb = args.outputembedding - main_dir = Path(DATA_FEATURIZED_DIR) / f'{reaction_template_id}_{embedding}_{args.radius}_{args.nbits}_{args.outputembedding}/' # must match with dir in `st2steps.py` - if reaction_template_id == 'hb': + main_dir = ( + Path(DATA_FEATURIZED_DIR) + / f"{reaction_template_id}_{embedding}_{args.radius}_{args.nbits}_{args.outputembedding}/" + ) # must match with dir in `st2steps.py` + if reaction_template_id == "hb": num_rxn = 91 - elif reaction_template_id == 'pis': + elif reaction_template_id == "pis": num_rxn = 4700 # Get dimension of output embedding diff --git a/src/syn_net/models/rt1.py b/src/syn_net/models/rt1.py index 1d9a2cf9..e853bf78 100644 --- a/src/syn_net/models/rt1.py +++ b/src/syn_net/models/rt1.py @@ -17,42 +17,45 @@ logger = logging.getLogger(__name__) MODEL_ID = Path(__file__).stem + def _fetch_molembedder(): knn_embedding_id = validation_option[12:] file = Path(DATA_EMBEDDINGS_DIR) / f"enamine_us_emb_{knn_embedding_id}.npy" logger.info(f"Try to load precomputed MolEmbedder from {file}.") - molembedder = MolEmbedder().load_precomputed(file).init_balltree(metric=cosine_distance) + molembedder = MolEmbedder().load_precomputed(file).init_balltree(metric=cosine_distance) logger.info(f"Loaded MolEmbedder from {file}.") return molembedder -if __name__ == '__main__': +if __name__ == "__main__": args = get_args() validation_option = VALIDATION_OPTS[args.out_dim] # Get ID for the data to know what we're working with and find right files. - id = f'{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_{validation_option[12:]}/' + id = ( + f"{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_{validation_option[12:]}/" + ) dataset = "train" train_dataloader = xy_to_dataloader( - X_file = Path(DATA_FEATURIZED_DIR) / f"{id}/X_{MODEL_ID}_{dataset}.npz", - y_file = Path(DATA_FEATURIZED_DIR) / f"{id}/y_{MODEL_ID}_{dataset}.npz", + X_file=Path(DATA_FEATURIZED_DIR) / f"{id}/X_{MODEL_ID}_{dataset}.npz", + y_file=Path(DATA_FEATURIZED_DIR) / f"{id}/y_{MODEL_ID}_{dataset}.npz", n=None if not args.debug else 1000, - batch_size = args.batch_size, + batch_size=args.batch_size, num_workers=args.ncpu, - shuffle = True if dataset == "train" else False, + shuffle=True if dataset == "train" else False, ) dataset = "valid" valid_dataloader = xy_to_dataloader( - X_file = Path(DATA_FEATURIZED_DIR) / f"{id}/X_{MODEL_ID}_{dataset}.npz", - y_file = Path(DATA_FEATURIZED_DIR) / f"{id}/y_{MODEL_ID}_{dataset}.npz", + X_file=Path(DATA_FEATURIZED_DIR) / f"{id}/X_{MODEL_ID}_{dataset}.npz", + y_file=Path(DATA_FEATURIZED_DIR) / f"{id}/y_{MODEL_ID}_{dataset}.npz", n=None if not args.debug else 1000, - batch_size = args.batch_size, + batch_size=args.batch_size, num_workers=args.ncpu, - shuffle = True if dataset == "train" else False, + shuffle=True if dataset == "train" else False, ) logger.info(f"Set up dataloaders.") @@ -62,35 +65,41 @@ def _fetch_molembedder(): pl.seed_everything(0) INPUT_DIMS = { "fp": int(3 * args.nbits), - "gin" : int(2 * args.nbits + args.out_dim) - } # somewhat constant... + "gin": int(2 * args.nbits + args.out_dim), + } # somewhat constant... input_dims = INPUT_DIMS[args.featurize] - mlp = MLP(input_dim=input_dims, - output_dim=args.out_dim, - hidden_dim=1200, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='regression', - loss='mse', - valid_loss="mse", - optimizer='adam', - learning_rate=1e-4, - val_freq=10, - molembedder=molembedder, - ncpu=ncpu) + mlp = MLP( + input_dim=input_dims, + output_dim=args.out_dim, + hidden_dim=1200, + num_layers=5, + dropout=0.5, + num_dropout_layers=1, + task="regression", + loss="mse", + valid_loss="mse", + optimizer="adam", + learning_rate=1e-4, + val_freq=10, + molembedder=molembedder, + ncpu=ncpu, + ) # Set up Trainer - save_dir = Path("results/logs/" + f"{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}" + f"/{MODEL_ID}") - save_dir.mkdir(exist_ok=True,parents=True) + save_dir = Path( + "results/logs/" + + f"{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}" + + f"/{MODEL_ID}" + ) + save_dir.mkdir(exist_ok=True, parents=True) - tb_logger = pl_loggers.TensorBoardLogger(save_dir,name="") + tb_logger = pl_loggers.TensorBoardLogger(save_dir, name="") checkpoint_callback = ModelCheckpoint( monitor="val_loss", - dirpath= tb_logger.log_dir, + dirpath=tb_logger.log_dir, filename="ckpts.{epoch}-{val_loss:.2f}", save_weights_only=False, ) @@ -98,11 +107,13 @@ def _fetch_molembedder(): max_epochs = args.epoch if not args.debug else 2 # Create trainer - trainer = pl.Trainer(gpus=[0], - max_epochs=max_epochs, - progress_bar_refresh_rate = int(len(train_data_iter)*0.05), - callbacks=[checkpoint_callback], - logger=[tb_logger]) + trainer = pl.Trainer( + gpus=[0], + max_epochs=max_epochs, + progress_bar_refresh_rate=int(len(train_data_iter) * 0.05), + callbacks=[checkpoint_callback], + logger=[tb_logger], + ) logger.info(f"Start training") trainer.fit(mlp, train_data_iter, valid_data_iter) diff --git a/src/syn_net/models/rt2.py b/src/syn_net/models/rt2.py index ddcd4804..6662975d 100644 --- a/src/syn_net/models/rt2.py +++ b/src/syn_net/models/rt2.py @@ -20,42 +20,45 @@ logger = logging.getLogger(__name__) MODEL_ID = Path(__file__).stem + def _fetch_molembedder(): knn_embedding_id = validation_option[12:] file = Path(DATA_EMBEDDINGS_DIR) / f"enamine_us_emb_{knn_embedding_id}.npy" logger.info(f"Try to load precomputed MolEmbedder from {file}.") - molembedder = MolEmbedder().load_precomputed(file).init_balltree(metric=cosine_distance) + molembedder = MolEmbedder().load_precomputed(file).init_balltree(metric=cosine_distance) logger.info(f"Loaded MolEmbedder from {file}.") return molembedder -if __name__ == '__main__': +if __name__ == "__main__": args = get_args() validation_option = VALIDATION_OPTS[args.out_dim] # Get ID for the data to know what we're working with and find right files. - id = f'{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_{validation_option[12:]}/' + id = ( + f"{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_{validation_option[12:]}/" + ) dataset = "train" train_dataloader = xy_to_dataloader( - X_file = Path(DATA_FEATURIZED_DIR) / f"{id}/X_{MODEL_ID}_{dataset}.npz", - y_file = Path(DATA_FEATURIZED_DIR) / f"{id}/y_{MODEL_ID}_{dataset}.npz", + X_file=Path(DATA_FEATURIZED_DIR) / f"{id}/X_{MODEL_ID}_{dataset}.npz", + y_file=Path(DATA_FEATURIZED_DIR) / f"{id}/y_{MODEL_ID}_{dataset}.npz", n=None if not args.debug else 1000, - batch_size = args.batch_size, + batch_size=args.batch_size, num_workers=args.ncpu, - shuffle = True if dataset == "train" else False, + shuffle=True if dataset == "train" else False, ) dataset = "valid" valid_dataloader = xy_to_dataloader( - X_file = Path(DATA_FEATURIZED_DIR) / f"{id}/X_{MODEL_ID}_{dataset}.npz", - y_file = Path(DATA_FEATURIZED_DIR) / f"{id}/y_{MODEL_ID}_{dataset}.npz", + X_file=Path(DATA_FEATURIZED_DIR) / f"{id}/X_{MODEL_ID}_{dataset}.npz", + y_file=Path(DATA_FEATURIZED_DIR) / f"{id}/y_{MODEL_ID}_{dataset}.npz", n=None if not args.debug else 1000, - batch_size = args.batch_size, + batch_size=args.batch_size, num_workers=args.ncpu, - shuffle = True if dataset == "train" else False, + shuffle=True if dataset == "train" else False, ) logger.info(f"Set up dataloaders.") @@ -68,37 +71,43 @@ def _fetch_molembedder(): "hb": int(4 * args.nbits + 91), "gin": int(4 * args.nbits + 4700), }, - "gin" : { + "gin": { "hb": int(3 * args.nbits + args.out_dim + 91), "gin": int(3 * args.nbits + args.out_dim + 4700), - } - } # somewhat constant... + }, + } # somewhat constant... input_dims = INPUT_DIMS[args.featurize][args.rxn_template] - mlp = MLP(input_dim=input_dims, - output_dim=args.out_dim, - hidden_dim=3000, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='regression', - loss='mse', - valid_loss=validation_option, - optimizer='adam', - learning_rate=1e-4, - val_freq=10, - molembedder=molembedder, - ncpu=ncpu) + mlp = MLP( + input_dim=input_dims, + output_dim=args.out_dim, + hidden_dim=3000, + num_layers=5, + dropout=0.5, + num_dropout_layers=1, + task="regression", + loss="mse", + valid_loss=validation_option, + optimizer="adam", + learning_rate=1e-4, + val_freq=10, + molembedder=molembedder, + ncpu=ncpu, + ) # Set up Trainer - save_dir = Path("results/logs/" + f"{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}" + f"/{MODEL_ID}") - save_dir.mkdir(exist_ok=True,parents=True) + save_dir = Path( + "results/logs/" + + f"{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}" + + f"/{MODEL_ID}" + ) + save_dir.mkdir(exist_ok=True, parents=True) - tb_logger = pl_loggers.TensorBoardLogger(save_dir,name="") + tb_logger = pl_loggers.TensorBoardLogger(save_dir, name="") checkpoint_callback = ModelCheckpoint( monitor="val_loss", - dirpath= tb_logger.log_dir, + dirpath=tb_logger.log_dir, filename="ckpts.{epoch}-{val_loss:.2f}", save_weights_only=False, ) @@ -106,12 +115,14 @@ def _fetch_molembedder(): max_epochs = args.epoch if not args.debug else 2 # Create trainer - trainer = pl.Trainer(gpus=[0], - max_epochs=max_epochs, - progress_bar_refresh_rate = int(len(train_data_iter)*0.05), - callbacks=[checkpoint_callback], - logger=[tb_logger], - fast_dev_run=True) + trainer = pl.Trainer( + gpus=[0], + max_epochs=max_epochs, + progress_bar_refresh_rate=int(len(train_data_iter) * 0.05), + callbacks=[checkpoint_callback], + logger=[tb_logger], + fast_dev_run=True, + ) logger.info(f"Start training") trainer.fit(mlp, train_data_iter, valid_data_iter) diff --git a/src/syn_net/models/rxn.py b/src/syn_net/models/rxn.py index cf6f3453..35b2a4a4 100644 --- a/src/syn_net/models/rxn.py +++ b/src/syn_net/models/rxn.py @@ -18,50 +18,55 @@ logger = logging.getLogger(__name__) MODEL_ID = Path(__file__).stem -if __name__ == '__main__': +if __name__ == "__main__": args = get_args() validation_option = VALIDATION_OPTS[args.out_dim] # Get ID for the data to know what we're working with and find right files. - id = f'{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_{validation_option[12:]}/' + id = ( + f"{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_{validation_option[12:]}/" + ) dataset = "train" train_dataloader = xy_to_dataloader( - X_file = Path(DATA_FEATURIZED_DIR) / f"{id}/X_{MODEL_ID}_{dataset}.npz", - y_file = Path(DATA_FEATURIZED_DIR) / f"{id}/y_{MODEL_ID}_{dataset}.npz", + X_file=Path(DATA_FEATURIZED_DIR) / f"{id}/X_{MODEL_ID}_{dataset}.npz", + y_file=Path(DATA_FEATURIZED_DIR) / f"{id}/y_{MODEL_ID}_{dataset}.npz", n=None if not args.debug else 1000, - batch_size = args.batch_size, + batch_size=args.batch_size, num_workers=args.ncpu, - shuffle = True if dataset == "train" else False, + shuffle=True if dataset == "train" else False, ) dataset = "valid" valid_dataloader = xy_to_dataloader( - X_file = Path(DATA_FEATURIZED_DIR) / f"{id}/X_{MODEL_ID}_{dataset}.npz", - y_file = Path(DATA_FEATURIZED_DIR) / f"{id}/y_{MODEL_ID}_{dataset}.npz", + X_file=Path(DATA_FEATURIZED_DIR) / f"{id}/X_{MODEL_ID}_{dataset}.npz", + y_file=Path(DATA_FEATURIZED_DIR) / f"{id}/y_{MODEL_ID}_{dataset}.npz", n=None if not args.debug else 1000, - batch_size = args.batch_size, + batch_size=args.batch_size, num_workers=args.ncpu, - shuffle = True if dataset == "train" else False, + shuffle=True if dataset == "train" else False, ) logger.info(f"Set up dataloaders.") pl.seed_everything(0) - param_path = Path(CHECKPOINTS_DIR) / f"{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_v{args.version}/" - path_to_rxn = f'{param_path}rxn.ckpt' + param_path = ( + Path(CHECKPOINTS_DIR) + / f"{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_v{args.version}/" + ) + path_to_rxn = f"{param_path}rxn.ckpt" INPUT_DIMS = { "fp": { "hb": int(4 * args.nbits), "gin": int(4 * args.nbits), }, - "gin" : { + "gin": { "hb": int(3 * args.nbits + args.out_dim), "gin": int(3 * args.nbits + args.out_dim), - } - } # somewhat constant... + }, + } # somewhat constant... input_dim = INPUT_DIMS[args.featurize][args.rxn_template] HIDDEN_DIMS = { @@ -69,62 +74,66 @@ "hb": 3000, "gin": 4500, }, - "gin" : { + "gin": { "hb": 3000, "gin": 3000, - } + }, } hidden_dim = HIDDEN_DIMS[args.featurize][args.rxn_template] OUTPUT_DIMS = { - "hb": 91, - "gin": 4700, + "hb": 91, + "gin": 4700, } output_dim = OUTPUT_DIMS[args.rxn_template] - if not args.restart: - mlp = MLP(input_dim=input_dim, - output_dim=output_dim, - hidden_dim=hidden_dim, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='classification', - loss='cross_entropy', - valid_loss='accuracy', - optimizer='adam', - learning_rate=1e-4, - val_freq=10, - ncpu=ncpu, - ) - else: # load from checkpt -> only for fp, not gin + mlp = MLP( + input_dim=input_dim, + output_dim=output_dim, + hidden_dim=hidden_dim, + num_layers=5, + dropout=0.5, + num_dropout_layers=1, + task="classification", + loss="cross_entropy", + valid_loss="accuracy", + optimizer="adam", + learning_rate=1e-4, + val_freq=10, + ncpu=ncpu, + ) + else: # load from checkpt -> only for fp, not gin mlp = MLP.load_from_checkpoint( - path_to_rxn, - input_dim=input_dim, - output_dim=output_dim, - hidden_dim=hidden_dim, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='classification', - loss='cross_entropy', - valid_loss='accuracy', - optimizer='adam', - learning_rate=1e-4, - ncpu=ncpu - ) + path_to_rxn, + input_dim=input_dim, + output_dim=output_dim, + hidden_dim=hidden_dim, + num_layers=5, + dropout=0.5, + num_dropout_layers=1, + task="classification", + loss="cross_entropy", + valid_loss="accuracy", + optimizer="adam", + learning_rate=1e-4, + ncpu=ncpu, + ) # Set up Trainer # Set up Trainer - save_dir = Path("results/logs/" + f"{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}" + f"/{MODEL_ID}") - save_dir.mkdir(exist_ok=True,parents=True) + save_dir = Path( + "results/logs/" + + f"{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}" + + f"/{MODEL_ID}" + ) + save_dir.mkdir(exist_ok=True, parents=True) - tb_logger = pl_loggers.TensorBoardLogger(save_dir,name="") + tb_logger = pl_loggers.TensorBoardLogger(save_dir, name="") checkpoint_callback = ModelCheckpoint( monitor="val_loss", - dirpath= tb_logger.log_dir, + dirpath=tb_logger.log_dir, filename="ckpts.{epoch}-{val_loss:.2f}", save_weights_only=False, ) @@ -132,11 +141,13 @@ max_epochs = args.epoch if not args.debug else 2 # Create trainer - trainer = pl.Trainer(gpus=[0], - max_epochs=max_epochs, - progress_bar_refresh_rate = int(len(train_data_iter)*0.05), - callbacks=[checkpoint_callback,earlystop_callback], - logger=[tb_logger]) + trainer = pl.Trainer( + gpus=[0], + max_epochs=max_epochs, + progress_bar_refresh_rate=int(len(train_data_iter) * 0.05), + callbacks=[checkpoint_callback, earlystop_callback], + logger=[tb_logger], + ) logger.info(f"Start training") trainer.fit(mlp, train_data_iter, valid_data_iter) From 2c8815a68f8a03e04e6e6e07716c369f079145d3 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 31 Aug 2022 17:40:36 -0400 Subject: [PATCH 075/302] fix not re-namend vars --- src/syn_net/models/act.py | 2 +- src/syn_net/models/rt1.py | 6 +++--- src/syn_net/models/rt2.py | 7 +++---- src/syn_net/models/rxn.py | 8 ++++---- 4 files changed, 11 insertions(+), 12 deletions(-) diff --git a/src/syn_net/models/act.py b/src/syn_net/models/act.py index 546d771d..23af682d 100644 --- a/src/syn_net/models/act.py +++ b/src/syn_net/models/act.py @@ -71,7 +71,7 @@ optimizer="adam", learning_rate=1e-4, val_freq=10, - ncpu=ncpu, + ncpu=args.ncpu, ) # Set up Trainer diff --git a/src/syn_net/models/rt1.py b/src/syn_net/models/rt1.py index e853bf78..ead051dd 100644 --- a/src/syn_net/models/rt1.py +++ b/src/syn_net/models/rt1.py @@ -84,7 +84,7 @@ def _fetch_molembedder(): learning_rate=1e-4, val_freq=10, molembedder=molembedder, - ncpu=ncpu, + ncpu=args.ncpu, ) # Set up Trainer @@ -110,11 +110,11 @@ def _fetch_molembedder(): trainer = pl.Trainer( gpus=[0], max_epochs=max_epochs, - progress_bar_refresh_rate=int(len(train_data_iter) * 0.05), + progress_bar_refresh_rate=int(len(train_dataloader) * 0.05), callbacks=[checkpoint_callback], logger=[tb_logger], ) logger.info(f"Start training") - trainer.fit(mlp, train_data_iter, valid_data_iter) + trainer.fit(mlp, train_dataloader, valid_dataloader) logger.info(f"Training completed.") diff --git a/src/syn_net/models/rt2.py b/src/syn_net/models/rt2.py index 6662975d..ca7f90d8 100644 --- a/src/syn_net/models/rt2.py +++ b/src/syn_net/models/rt2.py @@ -92,7 +92,7 @@ def _fetch_molembedder(): learning_rate=1e-4, val_freq=10, molembedder=molembedder, - ncpu=ncpu, + ncpu=args.ncpu, ) # Set up Trainer @@ -118,12 +118,11 @@ def _fetch_molembedder(): trainer = pl.Trainer( gpus=[0], max_epochs=max_epochs, - progress_bar_refresh_rate=int(len(train_data_iter) * 0.05), + progress_bar_refresh_rate=int(len(train_dataloader) * 0.05), callbacks=[checkpoint_callback], logger=[tb_logger], - fast_dev_run=True, ) logger.info(f"Start training") - trainer.fit(mlp, train_data_iter, valid_data_iter) + trainer.fit(mlp, train_dataloader, valid_dataloader) logger.info(f"Training completed.") diff --git a/src/syn_net/models/rxn.py b/src/syn_net/models/rxn.py index 35b2a4a4..2d54d6e4 100644 --- a/src/syn_net/models/rxn.py +++ b/src/syn_net/models/rxn.py @@ -101,7 +101,7 @@ optimizer="adam", learning_rate=1e-4, val_freq=10, - ncpu=ncpu, + ncpu=args.ncpu, ) else: # load from checkpt -> only for fp, not gin mlp = MLP.load_from_checkpoint( @@ -144,11 +144,11 @@ trainer = pl.Trainer( gpus=[0], max_epochs=max_epochs, - progress_bar_refresh_rate=int(len(train_data_iter) * 0.05), - callbacks=[checkpoint_callback, earlystop_callback], + progress_bar_refresh_rate=int(len(train_dataloader) * 0.05), + callbacks=[checkpoint_callback], logger=[tb_logger], ) logger.info(f"Start training") - trainer.fit(mlp, train_data_iter, valid_data_iter) + trainer.fit(mlp, train_dataloader, valid_dataloader) logger.info(f"Training completed.") From dacee4919c2e9c4524ec7851738f233d12998afe Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 31 Aug 2022 17:41:34 -0400 Subject: [PATCH 076/302] delete unused imports --- src/syn_net/models/act.py | 2 -- src/syn_net/models/rt2.py | 3 --- src/syn_net/models/rxn.py | 2 -- 3 files changed, 7 deletions(-) diff --git a/src/syn_net/models/act.py b/src/syn_net/models/act.py index 23af682d..8c99d31c 100644 --- a/src/syn_net/models/act.py +++ b/src/syn_net/models/act.py @@ -5,11 +5,9 @@ from pathlib import Path import pytorch_lightning as pl -import torch from pytorch_lightning import loggers as pl_loggers from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint -from scipy import sparse from syn_net.config import DATA_FEATURIZED_DIR from syn_net.models.common import VALIDATION_OPTS, get_args, xy_to_dataloader diff --git a/src/syn_net/models/rt2.py b/src/syn_net/models/rt2.py index ca7f90d8..c7a57e97 100644 --- a/src/syn_net/models/rt2.py +++ b/src/syn_net/models/rt2.py @@ -4,13 +4,10 @@ import logging from pathlib import Path -import numpy as np import pytorch_lightning as pl -import torch from pytorch_lightning import loggers as pl_loggers from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint -from scipy import sparse from syn_net.config import DATA_EMBEDDINGS_DIR, DATA_FEATURIZED_DIR from syn_net.models.common import VALIDATION_OPTS, get_args, xy_to_dataloader diff --git a/src/syn_net/models/rxn.py b/src/syn_net/models/rxn.py index 2d54d6e4..b6cae03e 100644 --- a/src/syn_net/models/rxn.py +++ b/src/syn_net/models/rxn.py @@ -5,11 +5,9 @@ from pathlib import Path import pytorch_lightning as pl -import torch from pytorch_lightning import loggers as pl_loggers from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint -from scipy import sparse from syn_net.config import CHECKPOINTS_DIR, DATA_FEATURIZED_DIR from syn_net.models.common import VALIDATION_OPTS, get_args, xy_to_dataloader From 8483494941ca81769129792a7e0b289681719557 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 31 Aug 2022 20:15:36 -0400 Subject: [PATCH 077/302] fix copy-paste errors --- src/syn_net/models/act.py | 4 +++- src/syn_net/models/common.py | 13 ++++++------- src/syn_net/models/mlp.py | 4 ++-- src/syn_net/models/rt1.py | 2 +- src/syn_net/models/rt2.py | 4 ++-- src/syn_net/models/rxn.py | 5 ++++- 6 files changed, 18 insertions(+), 14 deletions(-) diff --git a/src/syn_net/models/act.py b/src/syn_net/models/act.py index 8c99d31c..6236f287 100644 --- a/src/syn_net/models/act.py +++ b/src/syn_net/models/act.py @@ -32,6 +32,7 @@ X_file=Path(DATA_FEATURIZED_DIR) / f"{id}/X_{MODEL_ID}_{dataset}.npz", y_file=Path(DATA_FEATURIZED_DIR) / f"{id}/y_{MODEL_ID}_{dataset}.npz", n=None if not args.debug else 1000, + task = "classification", batch_size=args.batch_size, num_workers=args.ncpu, shuffle=True if dataset == "train" else False, @@ -42,6 +43,7 @@ X_file=Path(DATA_FEATURIZED_DIR) / f"{id}/X_{MODEL_ID}_{dataset}.npz", y_file=Path(DATA_FEATURIZED_DIR) / f"{id}/y_{MODEL_ID}_{dataset}.npz", n=None if not args.debug else 1000, + task = "classification", batch_size=args.batch_size, num_workers=args.ncpu, shuffle=True if dataset == "train" else False, @@ -90,7 +92,7 @@ ) earlystop_callback = EarlyStopping(monitor="val_loss", patience=10) - max_epochs = args.epoch if not args.debug else 2 + max_epochs = args.epoch if not args.debug else 20 # Create trainer trainer = pl.Trainer( gpus=[0], diff --git a/src/syn_net/models/common.py b/src/syn_net/models/common.py index 5ffffcc8..66199eb6 100644 --- a/src/syn_net/models/common.py +++ b/src/syn_net/models/common.py @@ -5,6 +5,7 @@ from typing import Union import torch +import numpy as np from scipy import sparse VALIDATION_OPTS = { @@ -41,7 +42,7 @@ def get_args(): return parser.parse_args() -def xy_to_dataloader(X_file: str = None, y_file: str = None, n: Union[int, float] = 1.0, **kwargs): +def xy_to_dataloader(X_file: str, y_file: str, task: str = "regression", n: Union[int, float] = 1.0, **kwargs): """Loads featurized X,y `*.npz`-data into a `DataLoader`""" X = sparse.load_npz(X_file) y = sparse.load_npz(y_file) @@ -57,13 +58,11 @@ def xy_to_dataloader(X_file: str = None, y_file: str = None, n: Union[int, float y = y[:yn] else: pass # + X = np.atleast_2d(np.asarray(X.todense())) + y = np.atleast_2d(np.asarray(y.todense())) if task == "regression" else np.asarray(y.todense()).squeeze() dataset = torch.utils.data.TensorDataset( - torch.Tensor(X.A), - torch.Tensor( - y.A.reshape( - -1, - ) - ), + torch.Tensor(X), + torch.Tensor(y), ) return torch.utils.data.DataLoader(dataset, **kwargs) diff --git a/src/syn_net/models/mlp.py b/src/syn_net/models/mlp.py index 89f48e31..7aa4e605 100644 --- a/src/syn_net/models/mlp.py +++ b/src/syn_net/models/mlp.py @@ -70,7 +70,7 @@ def training_step(self, batch, batch_idx): x, y = batch y_hat = self.layers(x) if self.loss == "cross_entropy": - loss = F.cross_entropy(y_hat, y) + loss = F.cross_entropy(y_hat, y.long()) elif self.loss == "mse": loss = F.mse_loss(y_hat, y) elif self.loss == "l1": @@ -117,7 +117,7 @@ def validation_step(self, batch, batch_idx): x, y = batch y_hat = self.layers(x) if self.valid_loss == "cross_entropy": - loss = F.cross_entropy(y_hat, y) + loss = F.cross_entropy(y_hat, y.long()) elif self.valid_loss == "accuracy": y_hat = torch.argmax(y_hat, axis=1) accuracy = (y_hat == y).sum() / len(y) diff --git a/src/syn_net/models/rt1.py b/src/syn_net/models/rt1.py index ead051dd..f512abd2 100644 --- a/src/syn_net/models/rt1.py +++ b/src/syn_net/models/rt1.py @@ -20,7 +20,7 @@ def _fetch_molembedder(): knn_embedding_id = validation_option[12:] - file = Path(DATA_EMBEDDINGS_DIR) / f"enamine_us_emb_{knn_embedding_id}.npy" + file = Path(DATA_EMBEDDINGS_DIR) / f"hb-enamine_us-2021-smiles-{knn_embedding_id}.npy" logger.info(f"Try to load precomputed MolEmbedder from {file}.") molembedder = MolEmbedder().load_precomputed(file).init_balltree(metric=cosine_distance) logger.info(f"Loaded MolEmbedder from {file}.") diff --git a/src/syn_net/models/rt2.py b/src/syn_net/models/rt2.py index c7a57e97..b9df4a7a 100644 --- a/src/syn_net/models/rt2.py +++ b/src/syn_net/models/rt2.py @@ -20,7 +20,7 @@ def _fetch_molembedder(): knn_embedding_id = validation_option[12:] - file = Path(DATA_EMBEDDINGS_DIR) / f"enamine_us_emb_{knn_embedding_id}.npy" + file = Path(DATA_EMBEDDINGS_DIR) / f"hb-enamine_us-2021-smiles-{knn_embedding_id}.npy" logger.info(f"Try to load precomputed MolEmbedder from {file}.") molembedder = MolEmbedder().load_precomputed(file).init_balltree(metric=cosine_distance) logger.info(f"Loaded MolEmbedder from {file}.") @@ -84,7 +84,7 @@ def _fetch_molembedder(): num_dropout_layers=1, task="regression", loss="mse", - valid_loss=validation_option, + valid_loss="mse", optimizer="adam", learning_rate=1e-4, val_freq=10, diff --git a/src/syn_net/models/rxn.py b/src/syn_net/models/rxn.py index b6cae03e..c319370b 100644 --- a/src/syn_net/models/rxn.py +++ b/src/syn_net/models/rxn.py @@ -32,6 +32,7 @@ X_file=Path(DATA_FEATURIZED_DIR) / f"{id}/X_{MODEL_ID}_{dataset}.npz", y_file=Path(DATA_FEATURIZED_DIR) / f"{id}/y_{MODEL_ID}_{dataset}.npz", n=None if not args.debug else 1000, + task="classification", batch_size=args.batch_size, num_workers=args.ncpu, shuffle=True if dataset == "train" else False, @@ -42,6 +43,7 @@ X_file=Path(DATA_FEATURIZED_DIR) / f"{id}/X_{MODEL_ID}_{dataset}.npz", y_file=Path(DATA_FEATURIZED_DIR) / f"{id}/y_{MODEL_ID}_{dataset}.npz", n=None if not args.debug else 1000, + task="classification", batch_size=args.batch_size, num_workers=args.ncpu, shuffle=True if dataset == "train" else False, @@ -102,6 +104,7 @@ ncpu=args.ncpu, ) else: # load from checkpt -> only for fp, not gin + # TODO: Use `ckpt_path`, c.f. https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html#pytorch_lightning.trainer.trainer.Trainer.fit mlp = MLP.load_from_checkpoint( path_to_rxn, input_dim=input_dim, @@ -115,7 +118,7 @@ valid_loss="accuracy", optimizer="adam", learning_rate=1e-4, - ncpu=ncpu, + ncpu=args.ncpu, ) # Set up Trainer From 126943c770c5a33521df1b5fa4deeb789291f9a5 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Tue, 6 Sep 2022 12:15:53 -0400 Subject: [PATCH 078/302] complete refactor to inject `MolEmbedder` in `MLP` --- src/syn_net/MolEmbedder.py | 2 +- src/syn_net/models/mlp.py | 78 ++------------------------------------ tests/test_Training.py | 8 ++++ 3 files changed, 13 insertions(+), 75 deletions(-) diff --git a/src/syn_net/MolEmbedder.py b/src/syn_net/MolEmbedder.py index 84496a30..bab38c34 100644 --- a/src/syn_net/MolEmbedder.py +++ b/src/syn_net/MolEmbedder.py @@ -80,7 +80,7 @@ def init_balltree(self, metric: Union[Callable, str]): if self.embeddings is None: raise ValueError("Neeed emebddings to compute kdtree.") X = self.embeddings - self.kdtree_metric = metric.__name__ + self.kdtree_metric = metric.__name__ if not isinstance(metric,str) else metric self.kdtree = BallTree(X, metric=metric) return self diff --git a/src/syn_net/models/mlp.py b/src/syn_net/models/mlp.py index 7aa4e605..ddf351b6 100644 --- a/src/syn_net/models/mlp.py +++ b/src/syn_net/models/mlp.py @@ -82,36 +82,6 @@ def training_step(self, batch, batch_idx): self.log(f"train_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) return loss - def _load_building_blocks_kdtree(self, out_feat: str) -> np.ndarray: - """Helper function to load the pre-computed building block embeddings - as a BallTree. - - """ - from pathlib import Path - - from syn_net.config import DATA_EMBEDDINGS_DIR - - if out_feat == "gin": - bb_emb_gin = np.load(Path(DATA_EMBEDDINGS_DIR) / f"enamine_us_emb_{out_feat}.npy") - kdtree = BallTree(bb_emb_gin, metric="euclidean") - elif out_feat == "fp_4096": - bb_emb_fp_4096 = np.load(Path(DATA_EMBEDDINGS_DIR) / f"enamine_us_emb_{out_feat}.npy") - kdtree = BallTree(bb_emb_fp_4096, metric="euclidean") - elif out_feat == "fp_256": - bb_emb_fp_256 = np.load(Path(DATA_EMBEDDINGS_DIR) / f"enamine_us_emb_{out_feat}.npy") - kdtree = BallTree(bb_emb_fp_256, metric=cosine_distance) - elif out_feat == "rdkit2d": - bb_emb_rdkit2d = np.load(Path(DATA_EMBEDDINGS_DIR) / f"enamine_us_emb_{out_feat}.npy") - kdtree = BallTree(bb_emb_rdkit2d, metric="euclidean") - elif out_feat == "gin_unittest": - # The embeddings are pre-computed based on the building blocks - # under 'tests/assets/building_blocks_matched.csv.gz'. - emb = np.load("tests/data/building_blocks_emb.npy") - kdtree = BallTree(emb, metric="euclidean") - else: - raise ValueError - return kdtree - def validation_step(self, batch, batch_idx): if self.trainer.current_epoch % self.val_freq == 0: x, y = batch @@ -126,14 +96,10 @@ def validation_step(self, batch, batch_idx): # NOTE: Very slow! # Performing the knn-search can easily take a couple of minutes, # even for small datasets. - out_feat = self.valid_loss[12:] - if self.molembedder is None: # legacy - kdtree = self._load_building_blocks_kdtree(out_feat) - else: - kdtree = self.molembedder.kdtree + kdtree = self.molembedder.kdtree y = nn_search_list(y.detach().cpu().numpy(), None, kdtree) y_hat = nn_search_list(y_hat.detach().cpu().numpy(), None, kdtree) - loss = 1 - (sum(y_hat == y) / len(y)) + accuracy = (y_hat == y).sum() / len(y) loss = 1 - accuracy elif self.valid_loss == "mse": @@ -143,7 +109,7 @@ def validation_step(self, batch, batch_idx): elif self.valid_loss == "huber": loss = F.huber_loss(y_hat, y) else: - raise ValueError("Not specified validation loss function") + raise ValueError("Not specified validation loss function for '%s'" % self.valid_loss) self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) else: pass @@ -172,40 +138,4 @@ def nn_search_list(y, out_feat, kdtree): if __name__ == "__main__": - - states_list = [] - steps_list = [] - for i in range(1): - states_list.append( - np.load( - "/home/rociomer/data/synth_net/pis_fp/states_" + str(i) + "_valid.npz", - allow_pickle=True, - ) - ) - steps_list.append( - np.load( - "/home/rociomer/data/synth_net/pis_fp/steps_" + str(i) + "_valid.npz", - allow_pickle=True, - ) - ) - - states = np.concatenate(states_list, axis=0) - steps = np.concatenate(steps_list, axis=0) - - X = states - y = steps[:, 0] - - X_train = torch.Tensor(X) - y_train = torch.LongTensor(y) - - batch_size = 64 - train_data_iter = load_array((X_train, y_train), batch_size, is_train=True) - - pl.seed_everything(0) - mlp = MLP() - tb_logger = pl_loggers.TensorBoardLogger("temp_logs/") - - trainer = pl.Trainer(gpus=[0], max_epochs=30, progress_bar_refresh_rate=20, logger=tb_logger) - t = time.time() - trainer.fit(mlp, train_data_iter, train_data_iter) - print(time.time() - t, "s") + pass diff --git a/tests/test_Training.py b/tests/test_Training.py index 200b8468..3765f5a6 100644 --- a/tests/test_Training.py +++ b/tests/test_Training.py @@ -11,12 +11,18 @@ import torch from syn_net.models.mlp import MLP, load_array +from syn_net.MolEmbedder import MolEmbedder TEST_DIR = Path(__file__).parent REACTION_TEMPLATES_FILE = f"{TEST_DIR}/assets/rxn_set_hb_test.txt" +def _fetch_molembedder(): + file = "tests/data/building_blocks_emb.npy" + molembedder = MolEmbedder().load_precomputed(file).init_balltree(metric="euclidean") + return molembedder + class TestReactionTemplateFile(unittest.TestCase): def test_number_of_reaction_templates(self): @@ -136,6 +142,7 @@ def test_reactant1_network(self): optimizer="adam", learning_rate=1e-4, val_freq=10, + molembedder=_fetch_molembedder(), ncpu=ncpu, ) @@ -253,6 +260,7 @@ def test_reactant2_network(self): optimizer="adam", learning_rate=1e-4, val_freq=10, + molembedder=_fetch_molembedder(), ncpu=ncpu, ) From 715bbafc7a0dc417f236d64a94148d6007c9b75d Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Tue, 6 Sep 2022 13:44:53 -0400 Subject: [PATCH 079/302] adds cli arg `building-blocks-file` --- src/syn_net/data_generation/process_rxn_mp.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/syn_net/data_generation/process_rxn_mp.py b/src/syn_net/data_generation/process_rxn_mp.py index 0c190f8e..38368903 100644 --- a/src/syn_net/data_generation/process_rxn_mp.py +++ b/src/syn_net/data_generation/process_rxn_mp.py @@ -3,7 +3,7 @@ reactants from a list of purchasable building blocks. Usage: - python process__rxnmp.py + python process_rxn.py """ import multiprocessing as mp from functools import partial @@ -33,7 +33,15 @@ def _match_building_blocks_to_rxn(building_blocks: list[str], _rxn: Reaction): from syn_net.config import (BUILDING_BLOCKS_RAW_DIR, DATA_PREPROCESS_DIR, REACTION_TEMPLATE_DIR) +def get_args(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--building-blocks-file", type=str, help="Input file with SMILES strings (First row `SMILES`, then one per line).") + return parser.parse_args() + if __name__ == "__main__": + + args = get_args() reaction_template_id = "hb" # "pis" or "hb" building_blocks_id = "enamine_us-2021-smiles" From fa189c2c3460d484536bb89169229b236d817f3b Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Tue, 6 Sep 2022 13:45:24 -0400 Subject: [PATCH 080/302] adds cli arg `fast-dev-run` --- src/syn_net/models/common.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/syn_net/models/common.py b/src/syn_net/models/common.py index 66199eb6..fb96164f 100644 --- a/src/syn_net/models/common.py +++ b/src/syn_net/models/common.py @@ -39,6 +39,7 @@ def get_args(): ) parser.add_argument("-v", "--version", type=int, default=1, help="Version") parser.add_argument("--debug", default=False, action="store_true") + parser.add_argument("--fast-dev-run", default=False, action="store_true") return parser.parse_args() @@ -48,7 +49,7 @@ def xy_to_dataloader(X_file: str, y_file: str, task: str = "regression", n: Unio y = sparse.load_npz(y_file) # Filer? if isinstance(n, int): - n = min(n, min(X.shape[0], y.shape[0])) # ensure n does not exceed size of dataset + n = min(n, X.shape[0]) # ensure n does not exceed size of dataset X = X[:n] y = y[:n] elif isinstance(n, float) and n < 1.0: From 82fe9b61fdd02b6f87502f685762d28a5685c528 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Tue, 6 Sep 2022 13:45:43 -0400 Subject: [PATCH 081/302] add comments --- src/syn_net/models/mlp.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/syn_net/models/mlp.py b/src/syn_net/models/mlp.py index ddf351b6..45e05b8f 100644 --- a/src/syn_net/models/mlp.py +++ b/src/syn_net/models/mlp.py @@ -8,8 +8,6 @@ import pytorch_lightning as pl import torch import torch.nn.functional as F -from pytorch_lightning import loggers as pl_loggers -from sklearn.neighbors import BallTree from torch import nn from syn_net.MolEmbedder import MolEmbedder @@ -64,9 +62,11 @@ def __init__( self.layers = nn.Sequential(*modules) def forward(self, x): + """Forward step for inference only.""" return self.layers(x) def training_step(self, batch, batch_idx): + """The complete training loop.""" x, y = batch y_hat = self.layers(x) if self.loss == "cross_entropy": @@ -83,6 +83,7 @@ def training_step(self, batch, batch_idx): return loss def validation_step(self, batch, batch_idx): + """The complete validation loop.""" if self.trainer.current_epoch % self.val_freq == 0: x, y = batch y_hat = self.layers(x) @@ -115,6 +116,7 @@ def validation_step(self, batch, batch_idx): pass def configure_optimizers(self): + """Define Optimerzers and LR schedulers.""" if self.optimizer == "adam": optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) elif self.optimizer == "sgd": From d3dafc91be48e78227af0f81a1bb95f0c8eb8d8d Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Tue, 6 Sep 2022 19:00:54 -0400 Subject: [PATCH 082/302] refactor: split up code from `predict_utils` --- src/syn_net/encoders/distances.py | 62 ++ src/syn_net/encoders/fingerprints.py | 59 ++ src/syn_net/encoders/gins.py | 147 +++++ src/syn_net/encoders/utils.py | 17 + src/syn_net/models/chkpt_loader.py | 292 +++++++++ src/syn_net/utils/predict_utils.py | 884 ++++++--------------------- src/syn_net/utils/prep_utils.py | 14 +- tests/test_DataPreparation.py | 12 +- tests/test_Predict.py | 3 +- 9 files changed, 790 insertions(+), 700 deletions(-) create mode 100644 src/syn_net/encoders/distances.py create mode 100644 src/syn_net/encoders/fingerprints.py create mode 100644 src/syn_net/encoders/gins.py create mode 100644 src/syn_net/encoders/utils.py create mode 100644 src/syn_net/models/chkpt_loader.py diff --git a/src/syn_net/encoders/distances.py b/src/syn_net/encoders/distances.py new file mode 100644 index 00000000..7abbeac0 --- /dev/null +++ b/src/syn_net/encoders/distances.py @@ -0,0 +1,62 @@ +import numpy as np +from syn_net.encoders.fingerprints import mol_fp + +def cosine_distance(v1, v2, eps=1e-15): + """Computes the cosine similarity between two vectors. + + Args: + v1 (np.ndarray): First vector. + v2 (np.ndarray): Second vector. + eps (float, optional): Small value, for numerical stability. Defaults + to 1e-15. + + Returns: + float: The cosine similarity. + """ + return (1 - np.dot(v1, v2) + / (np.linalg.norm(v1, ord=2) * np.linalg.norm(v2, ord=2) + eps)) + +def ce_distance(y, y_pred, eps=1e-15): + """Computes the cross-entropy between two vectors. + + Args: + y (np.ndarray): First vector. + y_pred (np.ndarray): Second vector. + eps (float, optional): Small value, for numerical stability. Defaults + to 1e-15. + + Returns: + float: The cross-entropy. + """ + y_pred = np.clip(y_pred, eps, 1 - eps) + return - np.sum((y * np.log(y_pred) + (1 - y) * np.log(1 - y_pred))) + + +def _tanimoto_similarity(fp1: np.ndarray, fp2: np.ndarray): + """ + Returns the Tanimoto similarity between two molecular fingerprints. + + Args: + fp1 (np.ndarray): Molecular fingerprint 1. + fp2 (np.ndarray): Molecular fingerprint 2. + + Returns: + np.float: Tanimoto similarity. + """ + return np.sum(fp1 * fp2) / (np.sum(fp1) + np.sum(fp2) - np.sum(fp1 * fp2)) + +def tanimoto_similarity(target_fp: np.ndarray, smis: list[str]): + """ + Returns the Tanimoto similarities between a target fingerprint and molecules + in an input list of SMILES. + + Args: + target_fp (np.ndarray): Contains the reference (target) fingerprint. + smis (list of str): Contains SMILES to compute similarity to. + + Returns: + list of np.ndarray: Contains Tanimoto similarities. + """ + fps = [mol_fp(smi, 2, 4096) for smi in smis] + return [_tanimoto_similarity(target_fp, fp) for fp in fps] + diff --git a/src/syn_net/encoders/fingerprints.py b/src/syn_net/encoders/fingerprints.py new file mode 100644 index 00000000..15872fd3 --- /dev/null +++ b/src/syn_net/encoders/fingerprints.py @@ -0,0 +1,59 @@ +import numpy as np +from rdkit import Chem, DataStructs + +## Morgan fingerprints +def mol_fp(smi, _radius=2, _nBits=4096): + """ + Computes the Morgan fingerprint for the input SMILES. + + Args: + smi (str): SMILES for molecule to compute fingerprint for. + _radius (int, optional): Fingerprint radius to use. Defaults to 2. + _nBits (int, optional): Length of fingerprint. Defaults to 1024. + + Returns: + features (np.ndarray): For valid SMILES, this is the fingerprint. + Otherwise, if the input SMILES is bad, this will be a zero vector. + """ + if smi is None: + return np.zeros(_nBits) + else: + mol = Chem.MolFromSmiles(smi) + features_vec = Chem.AllChem.GetMorganFingerprintAsBitVect(mol, _radius, _nBits) + return np.array(features_vec) + +def fp_embedding(smi, _radius=2, _nBits=4096): + """ + General function for building variable-size & -radius Morgan fingerprints. + + Args: + smi (str): The SMILES to encode. + _radius (int, optional): Morgan fingerprint radius. Defaults to 2. + _nBits (int, optional): Morgan fingerprint length. Defaults to 4096. + + Returns: + np.ndarray: A Morgan fingerprint generated using the specified parameters. + """ + if smi is None: + return np.zeros(_nBits).reshape((-1, )).tolist() + else: + mol = Chem.MolFromSmiles(smi) + features_vec = Chem.AllChem.GetMorganFingerprintAsBitVect(mol, _radius, _nBits) + features = np.zeros((1,)) + DataStructs.ConvertToNumpyArray(features_vec, features) + return features.reshape((-1, )).tolist() + +def fp_4096(smi): + return fp_embedding(smi, _radius=2, _nBits=4096) + +def fp_2048(smi): + return fp_embedding(smi, _radius=2, _nBits=2048) + +def fp_1024(smi): + return fp_embedding(smi, _radius=2, _nBits=1024) + +def fp_512(smi): + return fp_embedding(smi, _radius=2, _nBits=512) + +def fp_256(smi): + return fp_embedding(smi, _radius=2, _nBits=256) \ No newline at end of file diff --git a/src/syn_net/encoders/gins.py b/src/syn_net/encoders/gins.py new file mode 100644 index 00000000..82f19845 --- /dev/null +++ b/src/syn_net/encoders/gins.py @@ -0,0 +1,147 @@ +import functools + +import numpy as np +import torch +import tqdm +from dgl.nn.pytorch.glob import AvgPooling +from dgllife.model import load_pretrained +from dgllife.utils import (PretrainAtomFeaturizer, PretrainBondFeaturizer, + mol_to_bigraph) +from rdkit import Chem + + +@functools.lru_cache(1) +def _fetch_gin_pretrained_model(model_name: str): + """Get a GIN pretrained model to use for creating molecular embeddings""" + device = 'cpu' + model = load_pretrained(model_name).to(device) # used to learn embedding + model.eval() + return model + + +def graph_construction_and_featurization(smiles): + """ + Constructs graphs from SMILES and featurizes them. + + Args: + smiles (list of str): Contains SMILES of molecules to embed. + + Returns: + graphs (list of DGLGraph): List of graphs constructed and featurized. + success (list of bool): Indicators for whether the SMILES string can be + parsed by RDKit. + """ + graphs = [] + success = [] + for smi in tqdm(smiles): + try: + mol = Chem.MolFromSmiles(smi) + if mol is None: + success.append(False) + continue + g = mol_to_bigraph(mol, add_self_loop=True, + node_featurizer=PretrainAtomFeaturizer(), + edge_featurizer=PretrainBondFeaturizer(), + canonical_atom_order=False) + graphs.append(g) + success.append(True) + except: + success.append(False) + + return graphs, success + +def mol_embedding(smi, device='cpu', readout=AvgPooling()): + """ + Constructs a graph embedding using the GIN network for an input SMILES. + + Args: + smi (str): A SMILES string. + device (str): Indicates the device to run on ('cpu' or 'cuda:0'). Default 'cpu'. + + Returns: + np.ndarray: Either a zeros array or the graph embedding. + """ + name = 'gin_supervised_contextpred' + gin_pretrained_model = _fetch_gin_pretrained_model(name) + + # get the embedding + if smi is None: + return np.zeros(300) + else: + mol = Chem.MolFromSmiles(smi) + # convert RDKit.Mol into featurized bi-directed DGLGraph + g = mol_to_bigraph(mol, add_self_loop=True, + node_featurizer=PretrainAtomFeaturizer(), + edge_featurizer=PretrainBondFeaturizer(), + canonical_atom_order=False) + bg = g.to(device) + nfeats = [bg.ndata.pop('atomic_number').to(device), + bg.ndata.pop('chirality_type').to(device)] + efeats = [bg.edata.pop('bond_type').to(device), + bg.edata.pop('bond_direction_type').to(device)] + with torch.no_grad(): + node_repr = gin_pretrained_model(bg, nfeats, efeats) + return readout(bg, node_repr).detach().cpu().numpy().reshape(-1, ).tolist() + + +def get_mol_embedding(smi, model, device='cpu', readout=AvgPooling()): + """ + Computes the molecular graph embedding for the input SMILES. + + Args: + smi (str): SMILES of molecule to embed. + model (dgllife.model, optional): Pre-trained NN model to use for + computing the embedding. + device (str, optional): Indicates the device to run on. Defaults to 'cpu'. + readout (dgl.nn.pytorch.glob, optional): Readout function to use for + computing the graph embedding. Defaults to readout. + + Returns: + torch.Tensor: Learned embedding for the input molecule. + """ + mol = Chem.MolFromSmiles(smi) + g = mol_to_bigraph(mol, add_self_loop=True, + node_featurizer=PretrainAtomFeaturizer(), + edge_featurizer=PretrainBondFeaturizer(), + canonical_atom_order=False) + bg = g.to(device) + nfeats = [bg.ndata.pop('atomic_number').to(device), + bg.ndata.pop('chirality_type').to(device)] + efeats = [bg.edata.pop('bond_type').to(device), + bg.edata.pop('bond_direction_type').to(device)] + with torch.no_grad(): + node_repr = model(bg, nfeats, efeats) + return readout(bg, node_repr).detach().cpu().numpy()[0] + + + +def graph_construction_and_featurization(smiles): + """ + Constructs graphs from SMILES and featurizes them. + + Args: + smiles (list of str): SMILES of molecules, for embedding computation. + + Returns: + graphs (list of DGLGraph): List of graphs constructed and featurized. + success (list of bool): Indicators for whether the SMILES string can be + parsed by RDKit. + """ + graphs = [] + success = [] + for smi in tqdm(smiles): + try: + mol = Chem.MolFromSmiles(smi) + if mol is None: + success.append(False) + continue + g = mol_to_bigraph(mol, add_self_loop=True, + node_featurizer=PretrainAtomFeaturizer(), + edge_featurizer=PretrainBondFeaturizer(), + canonical_atom_order=False) + graphs.append(g) + success.append(True) + except: + success.append(False) + + return graphs, success diff --git a/src/syn_net/encoders/utils.py b/src/syn_net/encoders/utils.py new file mode 100644 index 00000000..1b06828d --- /dev/null +++ b/src/syn_net/encoders/utils.py @@ -0,0 +1,17 @@ +import numpy as np + +def one_hot_encoder(dim, space): + """ + Create a one-hot encoded vector of length=`space`, with a non-zero element + at the index given by `dim`. + + Args: + dim (int): Non-zero bit in one-hot vector. + space (int): Length of one-hot encoded vector. + + Returns: + vec (np.ndarray): One-hot encoded vector. + """ + vec = np.zeros((1, space)) + vec[0, dim] = 1 + return vec \ No newline at end of file diff --git a/src/syn_net/models/chkpt_loader.py b/src/syn_net/models/chkpt_loader.py new file mode 100644 index 00000000..c3f3231b --- /dev/null +++ b/src/syn_net/models/chkpt_loader.py @@ -0,0 +1,292 @@ +from typing import Tuple +from syn_net.models.mlp import MLP +import pytorch_lightning as pl +from typing import List + +def load_modules_from_checkpoint( + path_to_act: str, + path_to_rt1: str, + path_to_rxn: str, + path_to_rt2: str, + featurize: str, + rxn_template: str, + out_dim: int, + nbits: int, + ncpu: int, +) -> List[pl.LightningModule]: + + if rxn_template == "unittest": + + act_net = MLP.load_from_checkpoint( + path_to_act, + input_dim=int(3 * nbits), + output_dim=4, + hidden_dim=100, + num_layers=3, + dropout=0.5, + num_dropout_layers=1, + task="classification", + loss="cross_entropy", + valid_loss="accuracy", + optimizer="adam", + learning_rate=1e-4, + ncpu=ncpu, + ) + + rt1_net = MLP.load_from_checkpoint( + path_to_rt1, + input_dim=int(3 * nbits), + output_dim=out_dim, + hidden_dim=100, + num_layers=3, + dropout=0.5, + num_dropout_layers=1, + task="regression", + loss="mse", + valid_loss="mse", + optimizer="adam", + learning_rate=1e-4, + ncpu=ncpu, + ) + + rxn_net = MLP.load_from_checkpoint( + path_to_rxn, + input_dim=int(4 * nbits), + output_dim=3, + hidden_dim=100, + num_layers=5, + dropout=0.5, + num_dropout_layers=1, + task="classification", + loss="cross_entropy", + valid_loss="accuracy", + optimizer="adam", + learning_rate=1e-4, + ncpu=ncpu, + ) + + rt2_net = MLP.load_from_checkpoint( + path_to_rt2, + input_dim=int(4 * nbits + 3), + output_dim=out_dim, + hidden_dim=100, + num_layers=3, + dropout=0.5, + num_dropout_layers=1, + task="regression", + loss="mse", + valid_loss="mse", + optimizer="adam", + learning_rate=1e-4, + ncpu=ncpu, + ) + elif featurize == "fp": + + act_net = MLP.load_from_checkpoint( + path_to_act, + input_dim=int(3 * nbits), + output_dim=4, + hidden_dim=1000, + num_layers=5, + dropout=0.5, + num_dropout_layers=1, + task="classification", + loss="cross_entropy", + valid_loss="accuracy", + optimizer="adam", + learning_rate=1e-4, + ncpu=ncpu, + ) + + rt1_net = MLP.load_from_checkpoint( + path_to_rt1, + input_dim=int(3 * nbits), + output_dim=int(out_dim), + hidden_dim=1200, + num_layers=5, + dropout=0.5, + num_dropout_layers=1, + task="regression", + loss="mse", + valid_loss="mse", + optimizer="adam", + learning_rate=1e-4, + ncpu=ncpu, + ) + + if rxn_template == "hb": + + rxn_net = MLP.load_from_checkpoint( + path_to_rxn, + input_dim=int(4 * nbits), + output_dim=91, + hidden_dim=3000, + num_layers=5, + dropout=0.5, + num_dropout_layers=1, + task="classification", + loss="cross_entropy", + valid_loss="accuracy", + optimizer="adam", + learning_rate=1e-4, + ncpu=ncpu, + ) + + rt2_net = MLP.load_from_checkpoint( + path_to_rt2, + input_dim=int(4 * nbits + 91), + output_dim=int(out_dim), + hidden_dim=3000, + num_layers=5, + dropout=0.5, + num_dropout_layers=1, + task="regression", + loss="mse", + valid_loss="mse", + optimizer="adam", + learning_rate=1e-4, + ncpu=ncpu, + ) + + elif rxn_template == "pis": + + rxn_net = MLP.load_from_checkpoint( + path_to_rxn, + input_dim=int(4 * nbits), + output_dim=4700, + hidden_dim=4500, + num_layers=5, + dropout=0.5, + num_dropout_layers=1, + task="classification", + loss="cross_entropy", + valid_loss="accuracy", + optimizer="adam", + learning_rate=1e-4, + ncpu=ncpu, + ) + + rt2_net = MLP.load_from_checkpoint( + path_to_rt2, + input_dim=int(4 * nbits + 4700), + output_dim=out_dim, + hidden_dim=3000, + num_layers=5, + dropout=0.5, + num_dropout_layers=1, + task="regression", + loss="mse", + valid_loss="mse", + optimizer="adam", + learning_rate=1e-4, + ncpu=ncpu, + ) + + elif featurize == "gin": + + act_net = MLP.load_from_checkpoint( + path_to_act, + input_dim=int(2 * nbits + out_dim), + output_dim=4, + hidden_dim=1000, + num_layers=5, + dropout=0.5, + num_dropout_layers=1, + task="classification", + loss="cross_entropy", + valid_loss="accuracy", + optimizer="adam", + learning_rate=1e-4, + ncpu=ncpu, + ) + + rt1_net = MLP.load_from_checkpoint( + path_to_rt1, + input_dim=int(2 * nbits + out_dim), + output_dim=out_dim, + hidden_dim=1200, + num_layers=5, + dropout=0.5, + num_dropout_layers=1, + task="regression", + loss="mse", + valid_loss="mse", + optimizer="adam", + learning_rate=1e-4, + ncpu=ncpu, + ) + + if rxn_template == "hb": + + rxn_net = MLP.load_from_checkpoint( + path_to_rxn, + input_dim=int(3 * nbits + out_dim), + output_dim=91, + hidden_dim=3000, + num_layers=5, + dropout=0.5, + num_dropout_layers=1, + task="classification", + loss="cross_entropy", + valid_loss="accuracy", + optimizer="adam", + learning_rate=1e-4, + ncpu=ncpu, + ) + + rt2_net = MLP.load_from_checkpoint( + path_to_rt2, + input_dim=int(3 * nbits + out_dim + 91), + output_dim=out_dim, + hidden_dim=3000, + num_layers=5, + dropout=0.5, + num_dropout_layers=1, + task="regression", + loss="mse", + valid_loss="mse", + optimizer="adam", + learning_rate=1e-4, + ncpu=ncpu, + ) + + elif rxn_template == "pis": + + rxn_net = MLP.load_from_checkpoint( + path_to_rxn, + input_dim=int(3 * nbits + out_dim), + output_dim=4700, + hidden_dim=3000, + num_layers=5, + dropout=0.5, + num_dropout_layers=1, + task="classification", + loss="cross_entropy", + valid_loss="accuracy", + optimizer="adam", + learning_rate=1e-4, + ncpu=ncpu, + ) + + rt2_net = MLP.load_from_checkpoint( + path_to_rt2, + input_dim=int(3 * nbits + out_dim + 4700), + output_dim=out_dim, + hidden_dim=3000, + num_layers=5, + dropout=0.5, + num_dropout_layers=1, + task="regression", + loss="mse", + valid_loss="mse", + optimizer="adam", + learning_rate=1e-4, + ncpu=ncpu, + ) + + act_net.eval() + rt1_net.eval() + rxn_net.eval() + rt2_net.eval() + + return act_net, rt1_net, rxn_net, rt2_net diff --git a/src/syn_net/utils/predict_utils.py b/src/syn_net/utils/predict_utils.py index a6f27bd2..9ef65c1a 100644 --- a/src/syn_net/utils/predict_utils.py +++ b/src/syn_net/utils/predict_utils.py @@ -2,38 +2,25 @@ This file contains various utils for creating molecular embeddings and for decoding synthetic trees. """ +from typing import Callable, Tuple + import numpy as np +import pytorch_lightning as pl import rdkit -from tqdm import tqdm import torch from rdkit import Chem -from rdkit import DataStructs -from rdkit.Chem import AllChem from sklearn.neighbors import BallTree -from dgl.nn.pytorch.glob import AvgPooling -from dgl.nn.pytorch.glob import AvgPooling -from dgllife.model import load_pretrained -from dgllife.utils import mol_to_bigraph, PretrainAtomFeaturizer, PretrainBondFeaturizer +from syn_net.encoders.distances import cosine_distance, tanimoto_similarity +from syn_net.encoders.fingerprints import mol_fp +from syn_net.encoders.utils import one_hot_encoder +from syn_net.utils.data_utils import Reaction, SyntheticTree from tdc.chem_utils import MolConvert -from syn_net.models.mlp import MLP -from syn_net.utils.data_utils import SyntheticTree -import functools # create a random seed for NumPy np.random.seed(6) - -@functools.lru_cache(1) -def _fetch_gin_pretrained_model(model_name: str): - """Get a GIN pretrained model to use for creating molecular embeddings""" - device = 'cpu' - model = load_pretrained(model_name).to(device) # used to learn embedding - model.eval() - return model - - # general functions -def can_react(state, rxns): +def can_react(state, rxns: list[Reaction]) -> Tuple[int, list[bool]]: """ Determines if two molecules can react using any of the input reactions. @@ -52,7 +39,8 @@ def can_react(state, rxns): reaction_mask = [int(rxn.run_reaction([mol1, mol2]) is not None) for rxn in rxns] return sum(reaction_mask), reaction_mask -def get_action_mask(state, rxns): + +def get_action_mask(state: list, rxns: list[Reaction]) -> np.ndarray: """ Determines which actions can apply to a given state in the synthetic tree and returns a mask for which actions can apply. @@ -71,19 +59,21 @@ def get_action_mask(state, rxns): """ # Action: (Add: 0, Expand: 1, Merge: 2, End: 3) if len(state) == 0: - return np.array([1, 0, 0, 0]) + mask = [1, 0, 0, 0] elif len(state) == 1: - return np.array([1, 1, 0, 1]) + mask = [1, 1, 0, 1] elif len(state) == 2: can_react_, _ = can_react(state, rxns) if can_react_: - return np.array([0, 1, 1, 0]) + mask = [0, 1, 1, 0] else: - return np.array([0, 1, 0, 0]) + mask = [0, 1, 0, 0] else: - raise ValueError('Problem with state.') + raise ValueError("Problem with state.") + return np.asarray(mask, dtype=bool) + -def get_reaction_mask(smi, rxns): +def get_reaction_mask(smi: str, rxns: list[Reaction]): """ Determines which reaction templates can apply to the input molecule. @@ -109,6 +99,7 @@ def get_reaction_mask(smi, rxns): if sum(reaction_mask) == 0: return None, None + available_list = [] mol = rdkit.Chem.MolFromSmiles(smi) for i, rxn in enumerate(rxns): @@ -119,7 +110,7 @@ def get_reaction_mask(smi, rxns): elif rxn.is_reactant_second(mol): available_list.append(rxn.available_reactants[0]) else: - raise ValueError('Check the reactants') + raise ValueError("Check the reactants") if len(available_list[-1]) == 0: reaction_mask[i] = 0 @@ -129,170 +120,8 @@ def get_reaction_mask(smi, rxns): return reaction_mask, available_list -def graph_construction_and_featurization(smiles): - """ - Constructs graphs from SMILES and featurizes them. - - Args: - smiles (list of str): Contains SMILES of molecules to embed. - - Returns: - graphs (list of DGLGraph): List of graphs constructed and featurized. - success (list of bool): Indicators for whether the SMILES string can be - parsed by RDKit. - """ - graphs = [] - success = [] - for smi in tqdm(smiles): - try: - mol = Chem.MolFromSmiles(smi) - if mol is None: - success.append(False) - continue - g = mol_to_bigraph(mol, add_self_loop=True, - node_featurizer=PretrainAtomFeaturizer(), - edge_featurizer=PretrainBondFeaturizer(), - canonical_atom_order=False) - graphs.append(g) - success.append(True) - except: - success.append(False) - - return graphs, success - -def one_hot_encoder(dim, space): - """ - Create a one-hot encoded vector of length=`space`, with a non-zero element - at the index given by `dim`. - - Args: - dim (int): Non-zero bit in one-hot vector. - space (int): Length of one-hot encoded vector. - - Returns: - vec (np.ndarray): One-hot encoded vector. - """ - vec = np.zeros((1, space)) - vec[0, dim] = 1 - return vec - -def mol_embedding(smi, device='cpu', readout=AvgPooling()): - """ - Constructs a graph embedding using the GIN network for an input SMILES. - - Args: - smi (str): A SMILES string. - device (str): Indicates the device to run on ('cpu' or 'cuda:0'). Default 'cpu'. - - Returns: - np.ndarray: Either a zeros array or the graph embedding. - """ - name = 'gin_supervised_contextpred' - gin_pretrained_model = _fetch_gin_pretrained_model(name) - - # get the embedding - if smi is None: - return np.zeros(300) - else: - mol = Chem.MolFromSmiles(smi) - # convert RDKit.Mol into featurized bi-directed DGLGraph - g = mol_to_bigraph(mol, add_self_loop=True, - node_featurizer=PretrainAtomFeaturizer(), - edge_featurizer=PretrainBondFeaturizer(), - canonical_atom_order=False) - bg = g.to(device) - nfeats = [bg.ndata.pop('atomic_number').to(device), - bg.ndata.pop('chirality_type').to(device)] - efeats = [bg.edata.pop('bond_type').to(device), - bg.edata.pop('bond_direction_type').to(device)] - with torch.no_grad(): - node_repr = gin_pretrained_model(bg, nfeats, efeats) - return readout(bg, node_repr).detach().cpu().numpy().reshape(-1, ).tolist() - - -def get_mol_embedding(smi, model, device='cpu', readout=AvgPooling()): - """ - Computes the molecular graph embedding for the input SMILES. - - Args: - smi (str): SMILES of molecule to embed. - model (dgllife.model, optional): Pre-trained NN model to use for - computing the embedding. - device (str, optional): Indicates the device to run on. Defaults to 'cpu'. - readout (dgl.nn.pytorch.glob, optional): Readout function to use for - computing the graph embedding. Defaults to readout. - - Returns: - torch.Tensor: Learned embedding for the input molecule. - """ - mol = Chem.MolFromSmiles(smi) - g = mol_to_bigraph(mol, add_self_loop=True, - node_featurizer=PretrainAtomFeaturizer(), - edge_featurizer=PretrainBondFeaturizer(), - canonical_atom_order=False) - bg = g.to(device) - nfeats = [bg.ndata.pop('atomic_number').to(device), - bg.ndata.pop('chirality_type').to(device)] - efeats = [bg.edata.pop('bond_type').to(device), - bg.edata.pop('bond_direction_type').to(device)] - with torch.no_grad(): - node_repr = model(bg, nfeats, efeats) - return readout(bg, node_repr).detach().cpu().numpy()[0] - -def mol_fp(smi, _radius=2, _nBits=4096): - """ - Computes the Morgan fingerprint for the input SMILES. - - Args: - smi (str): SMILES for molecule to compute fingerprint for. - _radius (int, optional): Fingerprint radius to use. Defaults to 2. - _nBits (int, optional): Length of fingerprint. Defaults to 1024. - - Returns: - features (np.ndarray): For valid SMILES, this is the fingerprint. - Otherwise, if the input SMILES is bad, this will be a zero vector. - """ - if smi is None: - return np.zeros(_nBits) - else: - mol = Chem.MolFromSmiles(smi) - features_vec = AllChem.GetMorganFingerprintAsBitVect(mol, _radius, _nBits) - return np.array(features_vec) - -def cosine_distance(v1, v2, eps=1e-15): - """ - Computes the cosine similarity between two vectors. - - Args: - v1 (np.ndarray): First vector. - v2 (np.ndarray): Second vector. - eps (float, optional): Small value, for numerical stability. Defaults - to 1e-15. - - Returns: - float: The cosine similarity. - """ - return (1 - np.dot(v1, v2) - / (np.linalg.norm(v1, ord=2) * np.linalg.norm(v2, ord=2) + eps)) - -def ce_distance(y, y_pred, eps=1e-15): - """ - Computes the cross-entropy between two vectors. - - Args: - y (np.ndarray): First vector. - y_pred (np.ndarray): Second vector. - eps (float, optional): Small value, for numerical stability. Defaults - to 1e-15. - - Returns: - float: The cross-entropy. - """ - y_pred = np.clip(y_pred, eps, 1 - eps) - return - np.sum((y * np.log(y_pred) + (1 - y) * np.log(1 - y_pred))) - -def nn_search(_e, _tree, _k=1): +def nn_search(_e: np.ndarray, _tree: BallTree, _k: int = 1) -> Tuple[float, float]: """ Conducts a nearest neighbor search to find the molecule from the tree most simimilar to the input embedding. @@ -310,80 +139,57 @@ def nn_search(_e, _tree, _k=1): dist, ind = _tree.query(_e, k=_k) return dist[0][0], ind[0][0] -def graph_construction_and_featurization(smiles): - """ - Constructs graphs from SMILES and featurizes them. - Args: - smiles (list of str): SMILES of molecules, for embedding computation. +def nn_search_rt1(_e: np.ndarray, _tree: BallTree, _k: int = 1) -> Tuple[np.ndarray, np.ndarray]: + dist, ind = _tree.query(_e, k=_k) + return dist[0], ind[0] - Returns: - graphs (list of DGLGraph): List of graphs constructed and featurized. - success (list of bool): Indicators for whether the SMILES string can be - parsed by RDKit. - """ - graphs = [] - success = [] - for smi in tqdm(smiles): - try: - mol = Chem.MolFromSmiles(smi) - if mol is None: - success.append(False) - continue - g = mol_to_bigraph(mol, add_self_loop=True, - node_featurizer=PretrainAtomFeaturizer(), - edge_featurizer=PretrainBondFeaturizer(), - canonical_atom_order=False) - graphs.append(g) - success.append(True) - except: - success.append(False) - - return graphs, success - -def set_embedding(z_target, state, nbits, _mol_embedding=get_mol_embedding): + +def set_embedding(z_target: np.ndarray, state: list[str], nbits: int, _mol_embedding: Callable): """ Computes embeddings for all molecules in the input space. + Embedding = [z_mol1, z_mol2, z_target] Args: z_target (np.ndarray): Embedding for the target molecule. - state (list): Contains molecules in the current state, if not the - initial state. + state (list): Contains molecules in the current state, if not the initial state. nbits (int): Length of fingerprint. - _mol_embedding (Callable, optional): Function to use for computing the - embeddings of the first and second molecules in the state. Defaults - to `get_mol_embedding`. + _mol_embedding (Callable): Function to use for computing the + embeddings of the first and second molecules in the state. Returns: np.ndarray: Embedding consisting of the concatenation of the target molecule with the current molecules (if available) in the input state. """ if len(state) == 0: - return np.concatenate([np.zeros((1, 2 * nbits)), z_target], axis=1) + embedding = np.concatenate([np.zeros((1, 2 * nbits)), z_target], axis=1) else: e1 = np.expand_dims(_mol_embedding(state[0]), axis=0) if len(state) == 1: e2 = np.zeros((1, nbits)) else: e2 = _mol_embedding(state[1]) - return np.concatenate([e1, e2, z_target], axis=1) - -def synthetic_tree_decoder(z_target, - building_blocks, - bb_dict, - reaction_templates, - mol_embedder, - action_net, - reactant1_net, - rxn_net, - reactant2_net, - bb_emb, - rxn_template, - n_bits, - max_step=15): - """ - Computes the synthetic tree given an input molecule embedding, using the - Action, Reaction, Reactant1, and Reactant2 networks and a greedy search + embedding = np.concatenate([e1, e2, z_target], axis=1) + return embedding + +def synthetic_tree_decoder( + z_target: np.ndarray, + building_blocks: list[str], + bb_dict: dict[str, int], + reaction_templates: list[Reaction], + mol_embedder, + action_net: pl.LightningModule, + reactant1_net: pl.LightningModule, + rxn_net: pl.LightningModule, + reactant2_net: pl.LightningModule, + bb_emb: np.ndarray, + rxn_template: str, + n_bits: int, + max_step: int = 15, +) -> Tuple[SyntheticTree, int]: + """ + Computes a synthetic tree given an input molecule embedding. + Uses the Action, Reaction, Reactant1, and Reactant2 networks and a greedy search. Args: z_target (np.ndarray): Embedding for the target molecule @@ -408,15 +214,15 @@ def synthetic_tree_decoder(z_target, terminated). """ # Initialization - tree = SyntheticTree() - kdtree = BallTree(bb_emb, metric=cosine_distance) + tree = SyntheticTree() mol_recent = None + kdtree = BallTree(bb_emb, metric=cosine_distance) # TODO: cache this or use class + z_target = np.atleast_2d(z_target) # Start iteration - # try: for i in range(max_step): # Encode current state - state = tree.get_state() # a set + state = tree.get_state() # a list z_state = set_embedding(z_target, state, nbits=n_bits, _mol_embedding=mol_fp) # Predict action type, masked selection @@ -426,38 +232,39 @@ def synthetic_tree_decoder(z_target, action_mask = get_action_mask(tree.get_state(), reaction_templates) act = np.argmax(action_proba * action_mask) + # Continue growing tree? + if act == 3: # End + break + z_mol1 = reactant1_net(torch.Tensor(z_state)) z_mol1 = z_mol1.detach().numpy() # Select first molecule - if act == 3: - # End - break - elif act == 0: + if act == 0: # Add dist, ind = nn_search(z_mol1, _tree=kdtree) mol1 = building_blocks[ind] - else: + elif act == 1 or act == 2: # Expand or Merge mol1 = mol_recent + else: + raise ValueError(f"Unexpected action {act}.") z_mol1 = mol_fp(mol1) + z_mol1 = np.atleast_2d(z_mol1) # (1,4096) # Select reaction - try: - reaction_proba = rxn_net(torch.Tensor(np.concatenate([z_state, z_mol1], axis=1))) - except: - z_mol1 = np.expand_dims(z_mol1, axis=0) - reaction_proba = rxn_net(torch.Tensor(np.concatenate([z_state, z_mol1], axis=1))) - reaction_proba = reaction_proba.squeeze().detach().numpy() + 1e-10 - - if act != 2: - reaction_mask, available_list = get_reaction_mask(smi=mol1, - rxns=reaction_templates) - else: + z = np.concatenate([z_state, z_mol1], axis=1) + reaction_proba = rxn_net(torch.Tensor(z)) + reaction_proba = reaction_proba.squeeze().detach().numpy() + 1e-10 # (nReactionTemplate) + + if act != 2: # add or expand + reaction_mask, available_list = get_reaction_mask(mol1, reaction_templates) + else: # merge _, reaction_mask = can_react(tree.get_state(), reaction_templates) available_list = [[] for rxn in reaction_templates] + # If we ended up in a state where no reaction is possible, end this iteration. if reaction_mask is None: if len(state) == 1: act = 3 @@ -465,33 +272,32 @@ def synthetic_tree_decoder(z_target, else: break + # Select reaction template rxn_id = np.argmax(reaction_proba * reaction_mask) rxn = reaction_templates[rxn_id] + NUMBER_OF_REACTION_TEMPLATES = { + "hb": 91, + "pis": 4700, + "unittest": 3, + } # TODO: Refactor / use class + + # Select 2nd reactant if rxn.num_reactant == 2: - # Select second molecule - if act == 2: - # Merge + if act == 2: # Merge temp = set(state) - set([mol1]) mol2 = temp.pop() - else: - # Add or Expand - if rxn_template == 'hb': - num_rxns = 91 - elif rxn_template == 'pis': - num_rxns = 4700 - else: - num_rxns = 3 # unit testing uses only 3 reaction templates - reactant2_net_input = torch.Tensor( - np.concatenate([z_state, z_mol1, one_hot_encoder(rxn_id, num_rxns)], - axis=1) - ) - z_mol2 = reactant2_net(reactant2_net_input) + else: # Add or Expand + x_rxn = one_hot_encoder(rxn_id, NUMBER_OF_REACTION_TEMPLATES[rxn_template]) + x_rct2 = np.concatenate([z_state, z_mol1, x_rxn], axis=1) + z_mol2 = reactant2_net(torch.Tensor(x_rct2)) z_mol2 = z_mol2.detach().numpy() - available = available_list[rxn_id] - available = [bb_dict[available[i]] for i in range(len(available))] + available = available_list[rxn_id] # list[str], list of reactants for this rxn + available = [bb_dict[available[i]] for i in range(len(available))] # list[int] temp_emb = bb_emb[available] - available_tree = BallTree(temp_emb, metric=cosine_distance) + available_tree = BallTree( + temp_emb, metric=cosine_distance + ) # TODO: evaluate if distance matrix is faster/feasible as this BallTree is discarded immediately. dist, ind = nn_search(z_mol2, _tree=available_tree) mol2 = building_blocks[available[ind]] else: @@ -517,300 +323,23 @@ def synthetic_tree_decoder(z_target, return tree, act -def load_modules_from_checkpoint(path_to_act, path_to_rt1, path_to_rxn, path_to_rt2, featurize, rxn_template, out_dim, nbits, ncpu): - - if rxn_template == 'unittest': - - act_net = MLP.load_from_checkpoint(path_to_act, - input_dim=int(3 * nbits), - output_dim=4, - hidden_dim=100, - num_layers=3, - dropout=0.5, - num_dropout_layers=1, - task='classification', - loss='cross_entropy', - valid_loss='accuracy', - optimizer='adam', - learning_rate=1e-4, - ncpu=ncpu) - - rt1_net = MLP.load_from_checkpoint(path_to_rt1, - input_dim=int(3 * nbits), - output_dim=out_dim, - hidden_dim=100, - num_layers=3, - dropout=0.5, - num_dropout_layers=1, - task='regression', - loss='mse', - valid_loss='mse', - optimizer='adam', - learning_rate=1e-4, - ncpu=ncpu) - - rxn_net = MLP.load_from_checkpoint(path_to_rxn, - input_dim=int(4 * nbits), - output_dim=3, - hidden_dim=100, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='classification', - loss='cross_entropy', - valid_loss='accuracy', - optimizer='adam', - learning_rate=1e-4, - ncpu=ncpu) - - rt2_net = MLP.load_from_checkpoint(path_to_rt2, - input_dim=int(4 * nbits + 3), - output_dim=out_dim, - hidden_dim=100, - num_layers=3, - dropout=0.5, - num_dropout_layers=1, - task='regression', - loss='mse', - valid_loss='mse', - optimizer='adam', - learning_rate=1e-4, - ncpu=ncpu) - elif featurize == 'fp': - - act_net = MLP.load_from_checkpoint(path_to_act, - input_dim=int(3 * nbits), - output_dim=4, - hidden_dim=1000, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='classification', - loss='cross_entropy', - valid_loss='accuracy', - optimizer='adam', - learning_rate=1e-4, - ncpu=ncpu) - - rt1_net = MLP.load_from_checkpoint(path_to_rt1, - input_dim=int(3 * nbits), - output_dim=int(out_dim), - hidden_dim=1200, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='regression', - loss='mse', - valid_loss='mse', - optimizer='adam', - learning_rate=1e-4, - ncpu=ncpu) - - if rxn_template == 'hb': - - rxn_net = MLP.load_from_checkpoint(path_to_rxn, - input_dim=int(4 * nbits), - output_dim=91, - hidden_dim=3000, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='classification', - loss='cross_entropy', - valid_loss='accuracy', - optimizer='adam', - learning_rate=1e-4, - ncpu=ncpu) - - rt2_net = MLP.load_from_checkpoint(path_to_rt2, - input_dim=int(4 * nbits + 91), - output_dim=int(out_dim), - hidden_dim=3000, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='regression', - loss='mse', - valid_loss='mse', - optimizer='adam', - learning_rate=1e-4, - ncpu=ncpu) - - elif rxn_template == 'pis': - - rxn_net = MLP.load_from_checkpoint(path_to_rxn, - input_dim=int(4 * nbits), - output_dim=4700, - hidden_dim=4500, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='classification', - loss='cross_entropy', - valid_loss='accuracy', - optimizer='adam', - learning_rate=1e-4, - ncpu=ncpu) - - rt2_net = MLP.load_from_checkpoint(path_to_rt2, - input_dim=int(4 * nbits + 4700), - output_dim=out_dim, - hidden_dim=3000, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='regression', - loss='mse', - valid_loss='mse', - optimizer='adam', - learning_rate=1e-4, - ncpu=ncpu) - - elif featurize == 'gin': - - act_net = MLP.load_from_checkpoint(path_to_act, - input_dim=int(2 * nbits + out_dim), - output_dim=4, - hidden_dim=1000, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='classification', - loss='cross_entropy', - valid_loss='accuracy', - optimizer='adam', - learning_rate=1e-4, - ncpu=ncpu) - - rt1_net = MLP.load_from_checkpoint(path_to_rt1, - input_dim=int(2 * nbits + out_dim), - output_dim=out_dim, - hidden_dim=1200, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='regression', - loss='mse', - valid_loss='mse', - optimizer='adam', - learning_rate=1e-4, - ncpu=ncpu) - - if rxn_template == 'hb': - - rxn_net = MLP.load_from_checkpoint(path_to_rxn, - input_dim=int(3 * nbits + out_dim), - output_dim=91, - hidden_dim=3000, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='classification', - loss='cross_entropy', - valid_loss='accuracy', - optimizer='adam', - learning_rate=1e-4, - ncpu=ncpu) - - rt2_net = MLP.load_from_checkpoint(path_to_rt2, - input_dim=int(3 * nbits + out_dim + 91), - output_dim=out_dim, - hidden_dim=3000, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='regression', - loss='mse', - valid_loss='mse', - optimizer='adam', - learning_rate=1e-4, - ncpu=ncpu) - - elif rxn_template == 'pis': - - rxn_net = MLP.load_from_checkpoint(path_to_rxn, - input_dim=int(3 * nbits + out_dim), - output_dim=4700, - hidden_dim=3000, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='classification', - loss='cross_entropy', - valid_loss='accuracy', - optimizer='adam', - learning_rate=1e-4, - ncpu=ncpu) - - rt2_net = MLP.load_from_checkpoint(path_to_rt2, - input_dim=int(3 * nbits + out_dim + 4700), - output_dim=out_dim, - hidden_dim=3000, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='regression', - loss='mse', - valid_loss='mse', - optimizer='adam', - learning_rate=1e-4, - ncpu=ncpu) - - act_net.eval() - rt1_net.eval() - rxn_net.eval() - rt2_net.eval() - - return act_net, rt1_net, rxn_net, rt2_net - -def _tanimoto_similarity(fp1, fp2): - """ - Returns the Tanimoto similarity between two molecular fingerprints. - - Args: - fp1 (np.ndarray): Molecular fingerprint 1. - fp2 (np.ndarray): Molecular fingerprint 2. - - Returns: - np.float: Tanimoto similarity. - """ - return np.sum(fp1 * fp2) / (np.sum(fp1) + np.sum(fp2) - np.sum(fp1 * fp2)) - -def tanimoto_similarity(target_fp, smis): - """ - Returns the Tanimoto similarities between a target fingerprint and molecules - in an input list of SMILES. - - Args: - target_fp (np.ndarray): Contains the reference (target) fingerprint. - smis (list of str): Contains SMILES to compute similarity to. - - Returns: - list of np.ndarray: Contains Tanimoto similarities. - """ - fps = [mol_fp(smi, 2, 4096) for smi in smis] - return [_tanimoto_similarity(target_fp, fp) for fp in fps] - - -# functions used in the *_multireactant.py -def nn_search_rt1(_e, _tree, _k=1): - dist, ind = _tree.query(_e, k=_k) - return dist[0], ind[0] -def synthetic_tree_decoder_rt1(z_target, - building_blocks, - bb_dict, - reaction_templates, - mol_embedder, - action_net, - reactant1_net, - rxn_net, - reactant2_net, - bb_emb, - rxn_template, - n_bits, - max_step=15, - rt1_index=0): +def synthetic_tree_decoder_rt1( + z_target: np.ndarray, + building_blocks: list[str], + bb_dict: dict[str, int], + reaction_templates: list[Reaction], + mol_embedder, + action_net: pl.LightningModule, + reactant1_net: pl.LightningModule, + rxn_net: pl.LightningModule, + reactant2_net: pl.LightningModule, + bb_emb: np.ndarray, + rxn_template: str, + n_bits: int, + max_step: int = 15, + rt1_index=0, +) -> Tuple[SyntheticTree, int]: """ Computes the synthetic tree given an input molecule embedding, using the Action, Reaction, Reactant1, and Reactant2 networks and a greedy search. @@ -835,7 +364,7 @@ def synthetic_tree_decoder_rt1(z_target, synthetic tree rt1_index (int, optional): Index for molecule in the building blocks corresponding to reactant 1. - + Returns: tree (SyntheticTree): The final synthetic tree act (int): The final action (to know if the tree was "properly" @@ -844,59 +373,61 @@ def synthetic_tree_decoder_rt1(z_target, # Initialization tree = SyntheticTree() mol_recent = None - kdtree = BallTree(bb_emb, metric=cosine_distance) # TODO: cache this or use class - z_target = np.atleast_2d(z_target) + kdtree = BallTree(bb_emb, metric=cosine_distance) # TODO: cache this or use class + z_target = np.atleast_2d(z_target) # Start iteration for i in range(max_step): # Encode current state - state = tree.get_state() # a list + state = tree.get_state() # a list z_state = set_embedding(z_target, state, nbits=n_bits, _mol_embedding=mol_fp) # Predict action type, masked selection # Action: (Add: 0, Expand: 1, Merge: 2, End: 3) - action_proba = action_net(torch.Tensor(z_state)) # (1,4) + action_proba = action_net(torch.Tensor(z_state)) # (1,4) action_proba = action_proba.squeeze().detach().numpy() + 1e-10 - action_mask = get_action_mask(tree.get_state(), reaction_templates) - act = np.argmax(action_proba * action_mask) + action_mask = get_action_mask(tree.get_state(), reaction_templates) + act = np.argmax(action_proba * action_mask) - # Continue growing tree? - if act == 3: # End + # Continue growing tree? + if act == 3: # End break z_mol1 = reactant1_net(torch.Tensor(z_state)) - z_mol1 = z_mol1.detach().numpy() # (1,dimension_output_embedding), default: (1,256) - + z_mol1 = z_mol1.detach().numpy() # (1,dimension_output_embedding), default: (1,256) # Select first molecule - if act == 0: # Add + if act == 0: # Add if mol_recent is not None: dist, ind = nn_search(z_mol1) mol1 = building_blocks[ind] - else: # no recent mol - dist, ind = nn_search_rt1(z_mol1, _tree=kdtree, _k=rt1_index+1) # TODO: why is there an option to select the k-th? rt1_index (???) + else: # no recent mol + dist, ind = nn_search_rt1( + z_mol1, _tree=kdtree, _k=rt1_index + 1 + ) # TODO: why is there an option to select the k-th? rt1_index (???) mol1 = building_blocks[ind[rt1_index]] - elif act==1 or act==2: + elif act == 1 or act == 2: # Expand or Merge mol1 = mol_recent - else: + else: raise ValueError(f"Unexpected action {act}.") - z_mol1 = mol_fp(mol1) # (dimension_input_embedding=d), default (4096,) - z_mol1 = np.atleast_2d(z_mol1) # (1,4096) + z_mol1 = mol_fp(mol1) # (dimension_input_embedding=d), default (4096,) + z_mol1 = np.atleast_2d(z_mol1) # (1,4096) # Select reaction - z = np.concatenate([z_state, z_mol1], axis=1) # (1,4d) + z = np.concatenate([z_state, z_mol1], axis=1) # (1,4d) reaction_proba = rxn_net(torch.Tensor(z)) - reaction_proba = reaction_proba.squeeze().detach().numpy() + 1e-10 # (nReactionTemplate) + reaction_proba = reaction_proba.squeeze().detach().numpy() + 1e-10 # (nReactionTemplate) - if act != 2: # add or expand + if act != 2: # add or expand reaction_mask, available_list = get_reaction_mask(mol1, reaction_templates) - else: # merge + else: # merge _, reaction_mask = can_react(tree.get_state(), reaction_templates) - available_list = [[] for rxn in reaction_templates] # TODO: if act=merge, this is not used at all + available_list = [ + [] for rxn in reaction_templates + ] # TODO: if act=merge, this is not used at all - # If we ended up in a state where no reaction is possible, - # end this iteration. + # If we ended up in a state where no reaction is possible, end this iteration. if reaction_mask is None: if len(state) == 1: act = 3 @@ -906,30 +437,32 @@ def synthetic_tree_decoder_rt1(z_target, # Select reaction template rxn_id = np.argmax(reaction_proba * reaction_mask) - rxn = reaction_templates[rxn_id] + rxn = reaction_templates[rxn_id] NUMBER_OF_REACTION_TEMPLATES = { "hb": 91, "pis": 4700, "unittest": 3, - } # TODO: Refactor / use class + } # TODO: Refactor / use class # Select 2nd reactant if rxn.num_reactant == 2: - if act == 2: # Merge + if act == 2: # Merge temp = set(state) - set([mol1]) mol2 = temp.pop() - else: # Add or Expand - x_rxn = one_hot_encoder(rxn_id,NUMBER_OF_REACTION_TEMPLATES[rxn_template]) - x_rct2 = np.concatenate([z_state,z_mol1, x_rxn],axis=1) + else: # Add or Expand + x_rxn = one_hot_encoder(rxn_id, NUMBER_OF_REACTION_TEMPLATES[rxn_template]) + x_rct2 = np.concatenate([z_state, z_mol1, x_rxn], axis=1) z_mol2 = reactant2_net(torch.Tensor(x_rct2)) - z_mol2 = z_mol2.detach().numpy() - available = available_list[rxn_id] # list[str], list of reactants for this rxn - available = [bb_dict[smiles] for smiles in available] # list[int] - temp_emb = bb_emb[available] - available_tree = BallTree(temp_emb, metric=cosine_distance) # TODO: evaluate if distance matrix is faster/feasible as this BallTree is discarded immediately. - dist, ind = nn_search(z_mol2, _tree=available_tree) - mol2 = building_blocks[available[ind]] + z_mol2 = z_mol2.detach().numpy() + available = available_list[rxn_id] # list[str], list of reactants for this rxn + available = [bb_dict[smiles] for smiles in available] # list[int] + temp_emb = bb_emb[available] + available_tree = BallTree( + temp_emb, metric=cosine_distance + ) # TODO: evaluate if distance matrix is faster/feasible as this BallTree is discarded immediately. + dist, ind = nn_search(z_mol2, _tree=available_tree) + mol2 = building_blocks[available[ind]] else: mol2 = None @@ -950,20 +483,23 @@ def synthetic_tree_decoder_rt1(z_target, return tree, act -def synthetic_tree_decoder_multireactant(z_target, - building_blocks, - bb_dict, - reaction_templates, - mol_embedder, - action_net, - reactant1_net, - rxn_net, - reactant2_net, - bb_emb, - rxn_template, - n_bits, - beam_width : int=3, - max_step : int=15): + +def synthetic_tree_decoder_multireactant( + z_target, + building_blocks, + bb_dict, + reaction_templates, + mol_embedder, + action_net, + reactant1_net, + rxn_net, + reactant2_net, + bb_emb, + rxn_template, + n_bits, + beam_width: int = 3, + max_step: int = 15, +): """ Computes the synthetic tree given an input molecule embedding, using the Action, Reaction, Reactant1, and Reactant2 networks and a greedy search. @@ -994,24 +530,27 @@ def synthetic_tree_decoder_multireactant(z_target, acts = [] for i in range(beam_width): - tree, act = synthetic_tree_decoder_rt1(z_target=z_target, - building_blocks=building_blocks, - bb_dict=bb_dict, - reaction_templates=reaction_templates, - mol_embedder=mol_embedder, - action_net=action_net, - reactant1_net=reactant1_net, - rxn_net=rxn_net, - reactant2_net=reactant2_net, - bb_emb=bb_emb, - rxn_template=rxn_template, - n_bits=n_bits, - max_step=max_step, - rt1_index=i) - - - similarities_ = np.array(tanimoto_similarity(z_target, [node.smiles for node in tree.chemicals])) - max_simi_idx = np.where(similarities_ == np.max(similarities_))[0][0] + tree, act = synthetic_tree_decoder_rt1( + z_target=z_target, + building_blocks=building_blocks, + bb_dict=bb_dict, + reaction_templates=reaction_templates, + mol_embedder=mol_embedder, + action_net=action_net, + reactant1_net=reactant1_net, + rxn_net=rxn_net, + reactant2_net=reactant2_net, + bb_emb=bb_emb, + rxn_template=rxn_template, + n_bits=n_bits, + max_step=max_step, + rt1_index=i, + ) + + similarities_ = np.array( + tanimoto_similarity(z_target, [node.smiles for node in tree.chemicals]) + ) + max_simi_idx = np.where(similarities_ == np.max(similarities_))[0][0] similarities.append(np.max(similarities_)) smiles.append(tree.chemicals[max_simi_idx].smiles) @@ -1019,54 +558,19 @@ def synthetic_tree_decoder_multireactant(z_target, acts.append(act) max_simi_idx = np.where(similarities == np.max(similarities))[0][0] - similarity = similarities[max_simi_idx] - tree = trees[max_simi_idx] - smi = smiles[max_simi_idx] - act = acts[max_simi_idx] + similarity = similarities[max_simi_idx] + tree = trees[max_simi_idx] + smi = smiles[max_simi_idx] + act = acts[max_simi_idx] return smi, similarity, tree, act -def fp_embedding(smi, _radius=2, _nBits=4096): - """ - General function for building variable-size & -radius Morgan fingerprints. - - Args: - smi (str): The SMILES to encode. - _radius (int, optional): Morgan fingerprint radius. Defaults to 2. - _nBits (int, optional): Morgan fingerprint length. Defaults to 4096. - - Returns: - np.ndarray: A Morgan fingerprint generated using the specified parameters. - """ - if smi is None: - return np.zeros(_nBits).reshape((-1, )).tolist() - else: - mol = Chem.MolFromSmiles(smi) - features_vec = AllChem.GetMorganFingerprintAsBitVect(mol, _radius, _nBits) - features = np.zeros((1,)) - DataStructs.ConvertToNumpyArray(features_vec, features) - return features.reshape((-1, )).tolist() - -def fp_4096(smi): - return fp_embedding(smi, _radius=2, _nBits=4096) - -def fp_2048(smi): - return fp_embedding(smi, _radius=2, _nBits=2048) - -def fp_1024(smi): - return fp_embedding(smi, _radius=2, _nBits=1024) - -def fp_512(smi): - return fp_embedding(smi, _radius=2, _nBits=512) - -def fp_256(smi): - return fp_embedding(smi, _radius=2, _nBits=256) def rdkit2d_embedding(smi): # define the RDKit 2D descriptors conversion function - rdkit2d = MolConvert(src = 'SMILES', dst = 'RDKit2D') + rdkit2d = MolConvert(src="SMILES", dst="RDKit2D") if smi is None: - return np.zeros(200).reshape((-1, )).tolist() + return np.zeros(200).reshape((-1,)).tolist() else: - return rdkit2d(smi).tolist() \ No newline at end of file + return rdkit2d(smi).tolist() diff --git a/src/syn_net/utils/prep_utils.py b/src/syn_net/utils/prep_utils.py index 10fa6cda..503f3131 100644 --- a/src/syn_net/utils/prep_utils.py +++ b/src/syn_net/utils/prep_utils.py @@ -4,13 +4,14 @@ from typing import Iterator, Union import numpy as np from scipy import sparse -from dgllife.model import load_pretrained from tdc.chem_utils import MolConvert from sklearn.preprocessing import OneHotEncoder from syn_net.utils.data_utils import Reaction, SyntheticTree from syn_net.utils.predict_utils import (can_react, get_action_mask, get_reaction_mask, mol_fp, - get_mol_embedding) + ) +from syn_net.encoders.gins import get_mol_embedding + from pathlib import Path from rdkit import Chem import logging @@ -36,6 +37,7 @@ def rdkit2d_embedding(smi): import functools @functools.lru_cache(maxsize=1) def _fetch_gin_pretrained_model(model_name: str): + from dgllife.model import load_pretrained """Get a GIN pretrained model to use for creating molecular embeddings""" device = 'cpu' model = load_pretrained(model_name).to(device) @@ -68,8 +70,7 @@ def organize(st, d_mol=300, target_embedding='fp', radius=2, nBits=4096, sparse.csc_matrix: Node states pulled from the tree. sparse.csc_matrix: Actions pulled from the tree. """ - # define model to use for molecular embedding - model = _fetch_gin_pretrained_model("gin_supervised_contextpred") + states = [] steps = [] @@ -83,10 +84,15 @@ def organize(st, d_mol=300, target_embedding='fp', radius=2, nBits=4096, d_mol = OUTPUT_EMBEDDINGS_DIMS[output_embedding] + # Do we need a gin embedder? + if output_embedding == "gin" or target_embedding == "gin": + model = _fetch_gin_pretrained_model("gin_supervised_contextpred") + # Compute embedding of target molecule, i.e. the root of the synthetic tree if target_embedding == 'fp': target = mol_fp(st.root.smiles, radius, nBits).tolist() elif target_embedding == 'gin': + # define model to use for molecular embedding target = get_mol_embedding(st.root.smiles, model=model).tolist() else: raise ValueError('Target embedding only supports fp and gin.') diff --git a/tests/test_DataPreparation.py b/tests/test_DataPreparation.py index 95debcc9..6ed4f0e4 100644 --- a/tests/test_DataPreparation.py +++ b/tests/test_DataPreparation.py @@ -12,7 +12,7 @@ from scipy import sparse from tqdm import tqdm -from syn_net.utils.predict_utils import get_mol_embedding +from syn_net.encoders.gins import get_mol_embedding from syn_net.utils.prep_utils import organize, synthetic_tree_generator, prep_data from syn_net.utils.data_utils import SyntheticTreeSet, Reaction, ReactionSet @@ -198,17 +198,17 @@ def test_dataprep(self): def _compare_to_reference(network_type: str): X = sparse.load_npz(f"{main_dir}X_{network_type}_train.npz") y = sparse.load_npz(f"{main_dir}y_{network_type}_train.npz") - + Xref = sparse.load_npz(f"{ref_dir}X_{network_type}_train.npz") yref = sparse.load_npz(f"{ref_dir}y_{network_type}_train.npz") self.assertEqual(X.toarray().all(), Xref.toarray().all(),msg=f"{network_type=}") - self.assertEqual(y.toarray().all(), yref.toarray().all(),msg=f"{network_type=}") + self.assertEqual(y.toarray().all(), yref.toarray().all(),msg=f"{network_type=}") for network in ["act", "rt1", "rxn", "rt2"]: _compare_to_reference(network) - + def test_bb_emb(self): """ @@ -242,3 +242,7 @@ def test_bb_emb(self): embeddings_ref = np.load(f"{ref_dir}building_blocks_emb.npy") self.assertEqual(embeddings.all(), embeddings_ref.all()) + + +if __name__=="__main__": + TestDataPrep() \ No newline at end of file diff --git a/tests/test_Predict.py b/tests/test_Predict.py index ada21154..1aaa564a 100644 --- a/tests/test_Predict.py +++ b/tests/test_Predict.py @@ -10,10 +10,9 @@ from syn_net.utils.predict_utils import ( synthetic_tree_decoder_multireactant, mol_fp, - load_modules_from_checkpoint, ) from syn_net.utils.data_utils import SyntheticTreeSet, ReactionSet - +from syn_net.models.chkpt_loader import load_modules_from_checkpoint TEST_DIR = Path(__file__).parent From e6c51199b21d3a4706c4924e3cafd02787ac9ebc Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 7 Sep 2022 10:40:58 -0400 Subject: [PATCH 083/302] rename --- src/syn_net/{encoders => encoding}/distances.py | 2 +- src/syn_net/{encoders => encoding}/fingerprints.py | 0 src/syn_net/{encoders => encoding}/gins.py | 0 src/syn_net/{encoders => encoding}/utils.py | 0 src/syn_net/utils/predict_utils.py | 12 ++++++------ src/syn_net/utils/prep_utils.py | 4 ++-- 6 files changed, 9 insertions(+), 9 deletions(-) rename src/syn_net/{encoders => encoding}/distances.py (97%) rename src/syn_net/{encoders => encoding}/fingerprints.py (100%) rename src/syn_net/{encoders => encoding}/gins.py (100%) rename src/syn_net/{encoders => encoding}/utils.py (100%) diff --git a/src/syn_net/encoders/distances.py b/src/syn_net/encoding/distances.py similarity index 97% rename from src/syn_net/encoders/distances.py rename to src/syn_net/encoding/distances.py index 7abbeac0..55c6cf36 100644 --- a/src/syn_net/encoders/distances.py +++ b/src/syn_net/encoding/distances.py @@ -1,5 +1,5 @@ import numpy as np -from syn_net.encoders.fingerprints import mol_fp +from syn_net.encoding.fingerprints import mol_fp def cosine_distance(v1, v2, eps=1e-15): """Computes the cosine similarity between two vectors. diff --git a/src/syn_net/encoders/fingerprints.py b/src/syn_net/encoding/fingerprints.py similarity index 100% rename from src/syn_net/encoders/fingerprints.py rename to src/syn_net/encoding/fingerprints.py diff --git a/src/syn_net/encoders/gins.py b/src/syn_net/encoding/gins.py similarity index 100% rename from src/syn_net/encoders/gins.py rename to src/syn_net/encoding/gins.py diff --git a/src/syn_net/encoders/utils.py b/src/syn_net/encoding/utils.py similarity index 100% rename from src/syn_net/encoders/utils.py rename to src/syn_net/encoding/utils.py diff --git a/src/syn_net/utils/predict_utils.py b/src/syn_net/utils/predict_utils.py index 9ef65c1a..55763c6e 100644 --- a/src/syn_net/utils/predict_utils.py +++ b/src/syn_net/utils/predict_utils.py @@ -10,11 +10,10 @@ import torch from rdkit import Chem from sklearn.neighbors import BallTree -from syn_net.encoders.distances import cosine_distance, tanimoto_similarity -from syn_net.encoders.fingerprints import mol_fp -from syn_net.encoders.utils import one_hot_encoder +from syn_net.encoding.distances import cosine_distance, tanimoto_similarity +from syn_net.encoding.fingerprints import mol_fp +from syn_net.encoding.utils import one_hot_encoder from syn_net.utils.data_utils import Reaction, SyntheticTree -from tdc.chem_utils import MolConvert # create a random seed for NumPy np.random.seed(6) @@ -567,10 +566,11 @@ def synthetic_tree_decoder_multireactant( def rdkit2d_embedding(smi): - # define the RDKit 2D descriptors conversion function - rdkit2d = MolConvert(src="SMILES", dst="RDKit2D") + from tdc.chem_utils import MolConvert if smi is None: return np.zeros(200).reshape((-1,)).tolist() else: + # define the RDKit 2D descriptor + rdkit2d = MolConvert(src="SMILES", dst="RDKit2D") return rdkit2d(smi).tolist() diff --git a/src/syn_net/utils/prep_utils.py b/src/syn_net/utils/prep_utils.py index 503f3131..9811dd27 100644 --- a/src/syn_net/utils/prep_utils.py +++ b/src/syn_net/utils/prep_utils.py @@ -4,13 +4,12 @@ from typing import Iterator, Union import numpy as np from scipy import sparse -from tdc.chem_utils import MolConvert from sklearn.preprocessing import OneHotEncoder from syn_net.utils.data_utils import Reaction, SyntheticTree from syn_net.utils.predict_utils import (can_react, get_action_mask, get_reaction_mask, mol_fp, ) -from syn_net.encoders.gins import get_mol_embedding +from syn_net.encoding.gins import get_mol_embedding from pathlib import Path from rdkit import Chem @@ -27,6 +26,7 @@ def rdkit2d_embedding(smi): Returns: np.ndarray: A molecular embedding corresponding to the input molecule. """ + from tdc.chem_utils import MolConvert if smi is None: return np.zeros(200).reshape((-1, )) else: From 06d3452eeba9ca4c0f8dc31da6022350329a7299 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 7 Sep 2022 10:42:13 -0400 Subject: [PATCH 084/302] shrink dependencies + removes dgl --- README.md | 6 +- environment.yml | 219 +++++------------------------------------------- 2 files changed, 23 insertions(+), 202 deletions(-) diff --git a/README.md b/README.md index b0563a21..1d1063a7 100644 --- a/README.md +++ b/README.md @@ -60,7 +60,7 @@ conda env create -f environment.yml Before running any SynNet code, activate the environment and install this package in development mode. This ensures the scripts can find the right files. You can do this by typing: ```bash -source activate synthenv +source activate synnet pip install -e . ``` @@ -121,7 +121,7 @@ In addition to the necessary data, see [Data](#data), we pre-compute an embeddin python scripts/compute_embedding_mp.py \ --feature "fp_256" \ --rxn-template "hb" \ - --ncpu 10 + --ncpu 10 ``` #### Synthesis Planning @@ -162,7 +162,7 @@ python scripts/optimize_ga.py \ -i path/to/population.npy \ --radius 2 --nbits 4096 \ --num_population 128 --num_offspring 512 --num_gen 200 --objective gsk --restart \ - --ncpu 32 + --ncpu 32 ``` Note: the input file indicated by `-i` contains the seed molecules in CSV format for an initial run, and as a pre-saved numpy array of the population for restarting the run. diff --git a/environment.yml b/environment.yml index 7ad42d28..f9093563 100644 --- a/environment.yml +++ b/environment.yml @@ -1,204 +1,25 @@ -name: synthenv +name: synnet channels: - pytorch - - nvidia + # - dglteam # only needed for gin - conda-forge - - defaults dependencies: - - _libgcc_mutex=0.1 - - _openmp_mutex=4.5 - - blas=1.0 - - boost=1.74.0 - - boost-cpp=1.74.0 - - bzip2=1.0.8 - - ca-certificates=2021.5.30 - - cairo=1.16.0 - - certifi=2021.5.30 - - cudatoolkit=11.1.74 - - cycler=0.10.0 - - fontconfig=2.13.1 - - freetype=2.10.4 - - gettext=0.19.8.1 - - greenlet=1.1.0 - - icu=68.1 - - intel-openmp - - jbig=2.1 - - jpeg=9d - - kiwisolver=1.3.1 - - lcms2=2.12 - - ld_impl_linux-64=2.36.1 - - lerc=2.2.1 - - libdeflate=1.7 - - libffi=3.3 - - libgcc-ng=9.3.0 - - libgfortran-ng=9.3.0 - - libgfortran5=9.3.0 - - libglib=2.68.3 - - libgomp=9.3.0 - - libiconv=1.16 - - libopenblas=0.3.15 - - libpng=1.6.37 - - libstdcxx-ng=9.3.0 - - libtiff=4.3.0 - - libuuid=2.32.1 - # - libuv=1.40.0 - - libwebp-base=1.2.0 - - libxcb=1.13 - - libxml2=2.9.12 - - lz4-c=1.9.3 - - matplotlib-base=3.4.2 - - mkl=2021.3.0 - - mkl-service=2.4.0 - # - mkl_fft=1.3.0 - # - mkl_random=1.2.2 - - ncurses=6.2 - - ninja=1.10.2 - - numpy=1.20.3 - # - numpy-base=1.20.3 - - olefile=0.46 - - openjpeg=2.4.0 - - openssl=1.1.1k - - pandas=1.3.0 - - pcre=8.45 - - pillow=8.3.1 - - pip=21.1.3 - - pixman=0.40.0 - - pthread-stubs=0.4 - - pycairo=1.20.1 - - pyparsing=2.4.7 - - python=3.9.6 - - python-dateutil=2.8.2 - - python_abi=3.9 - - pytorch=1.9.0 - - pytz=2021.1 - - rdkit=2021.03.4 - - readline=8.1 - - reportlab=3.5.68 - - setuptools=49.6.0 - - six=1.16.0 - - sqlalchemy=1.4.21 - - sqlite=3.36.0 - - tk=8.6.10 - - torchaudio=0.9.0 - - torchvision=0.2.2 - - tornado=6.1 - - typing_extensions=3.10.0.0 - - tzdata=2021a - - wheel=0.36.2 - - xorg-kbproto=1.0.7 - - xorg-libice=1.0.10 - - xorg-libsm=1.2.3 - - xorg-libx11=1.7.2 - - xorg-libxau=1.0.9 - - xorg-libxdmcp=1.1.3 - - xorg-libxext=1.3.4 - - xorg-libxrender=0.9.10 - - xorg-renderproto=0.11.1 - - xorg-xextproto=7.3.0 - - xorg-xproto=7.0.31 - - xz=5.2.5 - - zlib=1.2.11 - - zstd=1.5.0 + - python=3.9.* + - pytorch::torchvision + - pytorch::pytorch=1.9.* + - pytorch-lightning + - rdkit=2021.03.* + # - dglteam::dgl-cuda11.1 # only needed for gin + - scikit-learn>=1.1.* + - ipykernel=6.15.* + - nb_conda_kernels + - black=22.6.* + - black-jupyter=22.6.* + - isort=5.10.* + - pip - pip: - - absl-py==0.13.0 - - aiohttp==3.7.4.post0 - - anyio==3.3.0 - - argon2-cffi==20.1.0 - - async-generator==1.10 - - async-timeout==3.0.1 - - attrs==21.2.0 - - babel==2.9.1 - - backcall==0.2.0 - - bleach==4.0.0 - - cachetools==4.2.2 - - cffi==1.14.6 - - chardet==4.0.0 - - charset-normalizer==2.0.3 - - cloudpickle==1.6.0 - - debugpy==1.4.1 - - decorator==4.4.2 - - defusedxml==0.7.1 - - dgl-cu110==0.6.1 - - dgllife==0.2.8 - - dill==0.3.4 - - entrypoints==0.3 - - fsspec==2021.7.0 - - future==0.18.2 - - fuzzywuzzy==0.18.0 - - google-auth==1.33.1 - - google-auth-oauthlib==0.4.4 - - grpcio==1.38.1 - - hyperopt==0.2.5 - - idna==3.2 - - ipdb==0.13.9 - - ipykernel==6.1.0 - - ipython==7.25.0 - - ipython-genutils==0.2.0 - - jedi==0.18.0 - - jinja2==3.0.1 - - joblib==1.0.1 - - json5==0.9.6 - - jsonschema==3.2.0 - - jupyter-client==6.1.12 - - jupyter-core==4.7.1 - - jupyter-server==1.10.2 - - jupyterlab==3.1.6 - - jupyterlab-pygments==0.1.2 - - jupyterlab-server==2.7.0 - - markdown==3.3.4 - - markupsafe==2.0.1 - - matplotlib-inline==0.1.2 - - mistune==0.8.4 - - multidict==5.1.0 - - nbclassic==0.3.1 - - nbclient==0.5.3 - - nbconvert==6.1.0 - - nbformat==5.1.3 - - nest-asyncio==1.5.1 - - networkx==2.5.1 - - notebook==6.4.3 - - oauthlib==3.1.1 - - packaging==21.0 - - pandocfilters==1.4.3 - - parso==0.8.2 - - pexpect==4.8.0 - - pickleshare==0.7.5 - - prometheus-client==0.11.0 - - prompt-toolkit==3.0.19 - - protobuf==3.17.3 - - ptyprocess==0.7.0 - - pyasn1==0.4.8 - - pyasn1-modules==0.2.8 - - pycparser==2.20 - - pydeprecate==0.3.0 - - pygments==2.9.0 - - pyrsistent==0.18.0 - - pytdc==0.2.0 - - pytorch-lightning==1.3.8 - - pyyaml==5.4.1 - - pyzmq==22.2.1 - - requests==2.26.0 - - requests-oauthlib==1.3.0 - - requests-unixsocket==0.2.0 - - rsa==4.7.2 - - scikit-learn==0.24.2 - - scipy==1.7.0 - - seaborn==0.11.1 - - send2trash==1.8.0 - - shutup==0.1.1 - - sniffio==1.2.0 - - tensorboard==2.4.1 - - tensorboard-plugin-wit==1.8.0 - - terminado==0.11.0 - - testpath==0.5.0 - - threadpoolctl==2.2.0 - - toml==0.10.2 - - torchmetrics==0.4.1 - - tqdm==4.61.2 - - traitlets==5.0.5 - - urllib3==1.26.6 - - wcwidth==0.2.5 - - webencodings==0.5.1 - - websocket-client==1.2.0 - - werkzeug==2.0.1 - - yarl==1.6.3 + - setuptools==59.5.0 # https://github.com/pytorch/pytorch/issues/69894 +# - dgllife # only needed fro gin, will force scikit-learn < 1.0 + - pathos + - rich + - pyyaml From 2c41d95284a4c885d7d6d20d337a85f7406da3c4 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 7 Sep 2022 10:54:46 -0400 Subject: [PATCH 085/302] delete duplicate code (same as in `prep_utils.py`) --- src/syn_net/utils/predict_utils.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/src/syn_net/utils/predict_utils.py b/src/syn_net/utils/predict_utils.py index 55763c6e..03dd5844 100644 --- a/src/syn_net/utils/predict_utils.py +++ b/src/syn_net/utils/predict_utils.py @@ -563,14 +563,3 @@ def synthetic_tree_decoder_multireactant( act = acts[max_simi_idx] return smi, similarity, tree, act - - -def rdkit2d_embedding(smi): - from tdc.chem_utils import MolConvert - - if smi is None: - return np.zeros(200).reshape((-1,)).tolist() - else: - # define the RDKit 2D descriptor - rdkit2d = MolConvert(src="SMILES", dst="RDKit2D") - return rdkit2d(smi).tolist() From 56890e9a35aa23df4cc1b2b301857a3935489ab5 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 7 Sep 2022 11:07:12 -0400 Subject: [PATCH 086/302] fix imports --- tests/test_DataPreparation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_DataPreparation.py b/tests/test_DataPreparation.py index 6ed4f0e4..27342fc4 100644 --- a/tests/test_DataPreparation.py +++ b/tests/test_DataPreparation.py @@ -12,7 +12,7 @@ from scipy import sparse from tqdm import tqdm -from syn_net.encoders.gins import get_mol_embedding +from syn_net.encoding.gins import get_mol_embedding from syn_net.utils.prep_utils import organize, synthetic_tree_generator, prep_data from syn_net.utils.data_utils import SyntheticTreeSet, Reaction, ReactionSet From b8ec9b8e37db4cb5026a92bc40ebaebecd4d9b1a Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 7 Sep 2022 11:18:38 -0400 Subject: [PATCH 087/302] fix imports --- scripts/compute_embedding_mp.py | 9 ++++++--- scripts/predict_multireactant_mp.py | 18 +++++++++--------- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/scripts/compute_embedding_mp.py b/scripts/compute_embedding_mp.py index 5e7d33ee..487406e1 100644 --- a/scripts/compute_embedding_mp.py +++ b/scripts/compute_embedding_mp.py @@ -11,19 +11,22 @@ from syn_net.MolEmbedder import MolEmbedder from syn_net.config import DATA_EMBEDDINGS_DIR, DATA_PREPROCESS_DIR -from syn_net.utils.predict_utils import fp_256, fp_512, fp_1024, fp_2048, fp_4096, mol_embedding, rdkit2d_embedding +from syn_net.encoding.fingerprints import fp_256, fp_512, fp_1024, fp_2048, fp_4096 +# from syn_net.encoding.gins import mol_embedding +# from syn_net.utils.prep_utils import rdkit2d_embedding + logger = logging.getLogger(__file__) FUNCTIONS = { - "gin": mol_embedding, + # "gin": mol_embedding, "fp_4096": fp_4096, "fp_2048": fp_2048, "fp_1024": fp_1024, "fp_512": fp_512, "fp_256": fp_256, - "rdkit2d": rdkit2d_embedding, + # "rdkit2d": rdkit2d_embedding, } def _load_building_blocks(file: Path) -> list[str]: diff --git a/scripts/predict_multireactant_mp.py b/scripts/predict_multireactant_mp.py index 3e70cc28..ff9ebc53 100644 --- a/scripts/predict_multireactant_mp.py +++ b/scripts/predict_multireactant_mp.py @@ -6,12 +6,12 @@ import numpy as np import pandas as pd - from syn_net.config import (CHECKPOINTS_DIR, DATA_EMBEDDINGS_DIR, DATA_PREPARED_DIR, DATA_PREPROCESS_DIR, DATA_RESULT_DIR) +from syn_net.models.chkpt_loader import load_modules_from_checkpoint from syn_net.utils.data_utils import ReactionSet, SyntheticTreeSet -from syn_net.utils.predict_utils import (load_modules_from_checkpoint, mol_fp, +from syn_net.utils.predict_utils import (mol_fp, synthetic_tree_decoder_multireactant) Path(DATA_RESULT_DIR).mkdir(exist_ok=True) @@ -113,7 +113,7 @@ def func(smiles: str): smi = None similarity = 0 tree = None - + return smi, similarity, tree @@ -148,7 +148,7 @@ def func(smiles: str): featurize = args.featurize radius = args.radius ncpu = args.ncpu - param_dir = f"{rxn_template}_{featurize}_{radius}_{nbits}_{out_dim}" + param_dir = f"{rxn_template}_{featurize}_{radius}_{nbits}_{out_dim}" # Load data ... # ... query molecules (i.e. molecules to decode) @@ -195,13 +195,13 @@ def func(smiles: str): # Save to local dir output_dir = DATA_RESULT_DIR if args.output_dir is None else args.output_dir print('Saving results to {output_dir} ...') - df = pd.DataFrame({'query SMILES' : smiles_queries, - 'decode SMILES': smis_decoded, + df = pd.DataFrame({'query SMILES' : smiles_queries, + 'decode SMILES': smis_decoded, 'similarity' : similarities}) - df.to_csv(f'{output_dir}/decode_result_{args.data}.csv.gz', - compression='gzip', + df.to_csv(f'{output_dir}/decode_result_{args.data}.csv.gz', + compression='gzip', index=False,) - + synthetic_tree_set = SyntheticTreeSet(sts=trees) synthetic_tree_set.save(f'{output_dir}/decoded_st_{args.data}.json.gz') From a26a926e3d371bd6331b188c01eac2e34cc4b620 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 7 Sep 2022 11:18:58 -0400 Subject: [PATCH 088/302] add `PyTDC` as dependency --- environment.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/environment.yml b/environment.yml index f9093563..6858de93 100644 --- a/environment.yml +++ b/environment.yml @@ -10,6 +10,7 @@ dependencies: - pytorch-lightning - rdkit=2021.03.* # - dglteam::dgl-cuda11.1 # only needed for gin + - pytdc - scikit-learn>=1.1.* - ipykernel=6.15.* - nb_conda_kernels From 183b39489487ead170ffaa9278acbd385ca86c91 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 7 Sep 2022 11:20:01 -0400 Subject: [PATCH 089/302] exclude file with `dgl`dependency from package --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 958cf08d..f68f64da 100644 --- a/setup.py +++ b/setup.py @@ -15,6 +15,6 @@ "Operating System :: OS Independent", ], package_dir={"": "src"}, - packages=setuptools.find_packages(where="src"), + packages=setuptools.find_packages(where="src",exclude=["src/syn_net/encoding/gins.py"]), python_requires=">=3.9", ) \ No newline at end of file From 41732e6d65d0161d6985f8425acc2b81a98b76b4 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 7 Sep 2022 14:57:39 -0400 Subject: [PATCH 090/302] add type hints for attributes of `Reaction` --- src/syn_net/utils/data_utils.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/src/syn_net/utils/data_utils.py b/src/syn_net/utils/data_utils.py index 84bc3ec1..ff3107fb 100644 --- a/src/syn_net/utils/data_utils.py +++ b/src/syn_net/utils/data_utils.py @@ -9,12 +9,13 @@ """ import gzip import json +from typing import Any, Tuple, Union + import pandas as pd +from rdkit import Chem +from rdkit.Chem import AllChem, Draw, rdChemReactions from tqdm import tqdm -import rdkit.Chem as Chem -from rdkit.Chem import Draw -from rdkit.Chem import AllChem -from rdkit.Chem import rdChemReactions + # the definition of reaction classes below class Reaction: @@ -27,6 +28,19 @@ class Reaction: smiles: (str): A reaction SMILES string that macthes the SMARTS pattern. reference (str): Reference information for the reaction. """ + smirks: str # SMARTS pattern + rxn: Chem.rdChemReactions.ChemicalReaction + num_reactant: int + num_agent: int + num_product: int + reactant_template: Tuple[str,str] + product_template: str + agent_templat: str + available_reactants: list[Union[str,Chem.Mol]] + rxnname: str + smiles: Any + reference: Any + def __init__(self, template=None, rxnname=None, smiles=None, reference=None): if template is not None: From aa5c24f5745e0bc989a123981a0822337c71173a Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 7 Sep 2022 15:03:34 -0400 Subject: [PATCH 091/302] store `rxn` instead of recomputing it in `Reaction` --- src/syn_net/utils/data_utils.py | 89 +++++++++------------------------ 1 file changed, 23 insertions(+), 66 deletions(-) diff --git a/src/syn_net/utils/data_utils.py b/src/syn_net/utils/data_utils.py index ff3107fb..49c1c2a1 100644 --- a/src/syn_net/utils/data_utils.py +++ b/src/syn_net/utils/data_utils.py @@ -51,13 +51,14 @@ def __init__(self, template=None, rxnname=None, smiles=None, reference=None): self.reference = reference # compute a few additional attributes - rxn = AllChem.ReactionFromSmarts(self.smirks) - rdChemReactions.ChemicalReaction.Initialize(rxn) - self.num_reactant = rxn.GetNumReactantTemplates() - if self.num_reactant == 0 or self.num_reactant > 2: - raise ValueError('This reaction is neither uni- nor bi-molecular.') - self.num_agent = rxn.GetNumAgentTemplates() - self.num_product = rxn.GetNumProductTemplates() + self.rxn = self.__init_reaction(self.smirks) + + # Extract number of ... + self.num_reactant = self.rxn.GetNumReactantTemplates() + if self.num_reactant not in (1,2): + raise ValueError('Reaction is neither uni- nor bi-molecular.') + self.num_agent = self.rxn.GetNumAgentTemplates() + self.num_product = self.rxn.GetNumProductTemplates() if self.num_reactant == 1: self.reactant_template = list((self.smirks.split('>')[0], )) else: @@ -69,6 +70,12 @@ def __init__(self, template=None, rxnname=None, smiles=None, reference=None): else: self.smirks = None + def __init_reaction(self,smirks: str) -> Chem.rdChemReactions.ChemicalReaction: + """Initializes a reaction by converting the SMARTS-pattern to an `rdkit` object.""" + rxn = AllChem.ReactionFromSmarts(smirks) + rdChemReactions.ChemicalReaction.Initialize(rxn) + return rxn + def load(self, smirks, num_reactant, num_agent, num_product, reactant_template, product_template, agent_template, available_reactants, rxnname, smiles, reference): """ @@ -125,59 +132,20 @@ def visualize(self, name='./reaction1_highlight.o.png'): del rxn return name - def is_reactant(self, smi): - """ - A function that checks if a molecule is a reactant of the reaction - defined by the `Reaction` object. - - Args: - smi (str or RDKit.Chem.Mol): The query molecule, as either a SMILES - string or an `RDKit.Chem.Mol` object. - - Returns: - result (bool): Indicates if the molecule is a reactant of the reaction. - """ - rxn = self.get_rxnobj() + def is_reactant(self, smi: Union[str,Chem.Molecule]) -> bool: + """Checks if `smi` is a reactant of this reaction.""" smi = self.get_mol(smi) - result = rxn.IsMoleculeReactant(smi) - del rxn - return result + return self.rxn.IsMoleculeReactant(smi) - def is_agent(self, smi): - """ - A function that checks if a molecule is an agent in the reaction defined - by the `Reaction` object. - - Args: - smi (str or RDKit.Chem.Mol): The query molecule, as either a SMILES - string or an `RDKit.Chem.Mol` object. - - Returns: - result (bool): Indicates if the molecule is an agent in the reaction. - """ - rxn = self.get_rxnobj() + def is_agent(self, smi: Union[str,Chem.Molecule]) -> bool: + """Checks if `smi` is an agent of this reaction.""" smi = self.get_mol(smi) - result = rxn.IsMoleculeAgent(smi) - del rxn - return result + return self.rxn.IsMoleculeAgent(smi) def is_product(self, smi): - """ - A function that checks if a molecule is the product in the reaction defined - by the `Reaction` object. - - Args: - smi (str or RDKit.Chem.Mol): The query molecule, as either a SMILES - string or an `RDKit.Chem.Mol` object. - - Returns: - result (bool): Indicates if the molecule is the product in the reaction. - """ - rxn = self.get_rxnobj() + """Checks if `smi` is a product of this reaction.""" smi = self.get_mol(smi) - result = rxn.IsMoleculeProduct(smi) - del rxn - return result + return self.rxn.IsMoleculeProduct(smi) def is_reactant_first(self, smi): """ @@ -228,17 +196,6 @@ def get_smirks(self): """ return self.smirks - def get_rxnobj(self): - """ - A function that returns the RDKit Reaction object. - - Returns: - rxn (rdChem.Reactions.ChemicalReaction): RDKit reaction object. - """ - rxn = AllChem.ReactionFromSmarts(self.smirks) - rdChemReactions.ChemicalReaction.Initialize(rxn) - return rxn - def get_reactant_template(self, ind=0): """ A function that returns the SMARTS pattern which represents the specified @@ -273,7 +230,7 @@ def run_reaction(self, reactants, keep_main=True): Returns: uniqps (str): SMILES string representing the product. """ - rxn = self.get_rxnobj() + rxn = self.rxn if self.num_reactant == 1: From abe19685cf91dbd93bd7fb91d40ea61e874f24be Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 7 Sep 2022 15:04:26 -0400 Subject: [PATCH 092/302] shorten comments, add type hint --- src/syn_net/utils/data_utils.py | 60 +++++++++++++++------------------ 1 file changed, 27 insertions(+), 33 deletions(-) diff --git a/src/syn_net/utils/data_utils.py b/src/syn_net/utils/data_utils.py index 49c1c2a1..5563055a 100644 --- a/src/syn_net/utils/data_utils.py +++ b/src/syn_net/utils/data_utils.py @@ -59,14 +59,16 @@ def __init__(self, template=None, rxnname=None, smiles=None, reference=None): raise ValueError('Reaction is neither uni- nor bi-molecular.') self.num_agent = self.rxn.GetNumAgentTemplates() self.num_product = self.rxn.GetNumProductTemplates() + + # Extract reactants, agents, products + reactants, agents, products = self.smirks.split(">") + if self.num_reactant == 1: - self.reactant_template = list((self.smirks.split('>')[0], )) + self.reactant_template = list((reactants, )) else: - self.reactant_template = list((self.smirks.split('>')[0].split('.')[0], self.smirks.split('>')[0].split('.')[1])) - self.product_template = self.smirks.split('>')[2] - self.agent_template = self.smirks.split('>')[1] - - del rxn + self.reactant_template = list(reactants.split(".")) + self.product_template = products + self.agent_template = agents else: self.smirks = None @@ -109,8 +111,7 @@ def get_mol(self, smi): elif isinstance(smi, Chem.Mol): return smi else: - raise TypeError('The input should be either a SMILES string or an ' - 'RDKit.Chem.Mol object.') + raise TypeError(f"f{type(smi)} not supported, only `str` or `RDKit.Chem.Mol`") def visualize(self, name='./reaction1_highlight.o.png'): """ @@ -147,7 +148,7 @@ def is_product(self, smi): smi = self.get_mol(smi) return self.rxn.IsMoleculeProduct(smi) - def is_reactant_first(self, smi): + def is_reactant_first(self, smi: Union[str, Chem.Mol]) -> bool: """ A function that checks if a molecule is the first reactant in the reaction defined by the `Reaction` object, where the order of the reactants is @@ -167,7 +168,7 @@ def is_reactant_first(self, smi): else: return False - def is_reactant_second(self, smi): + def is_reactant_second(self, smi: Union[str,Chem.Mol]) -> bool: """ A function that checks if a molecule is the second reactant in the reaction defined by the `Reaction` object, where the order of the reactants is @@ -187,13 +188,8 @@ def is_reactant_second(self, smi): else: return False - def get_smirks(self): - """ - A function that returns the SMARTS pattern which represents the reaction. - - Returns: - self.smirks (str): SMARTS pattern representing the reaction. - """ + def get_smirks(self) -> str: + """Returns the SMARTS pattern which represents the reaction.""" return self.smirks def get_reactant_template(self, ind=0): @@ -290,24 +286,23 @@ def run_reaction(self, reactants, keep_main=True): else: return uniqps - def _filter_reactants(self, smi_list): + def _filter_reactants(self, smiles: list[str]) -> Tuple[list[str],list[str]]: """ Filters reactants which do not match the reaction. Args: - smi_list (list): Contains SMILES to search through for matches. - - Raises: - ValueError: Raised if the `Reaction` object does not describe a uni- - or bi-molecular reaction. + smiles: Possible reactants for this reaction. Returns: - tuple: Contains list(s) of SMILES which match either the first + :lists of SMILES which match either the first reactant, or, if applicable, the second reactant. + + Raises: + ValueError: If `self` is not a uni- or bi-molecular reaction. """ if self.num_reactant == 1: # uni-molecular reaction smi_w_patt = [] - for smi in tqdm(smi_list): + for smi in tqdm(smiles): if self.is_reactant_first(smi): smi_w_patt.append(smi) return (smi_w_patt, ) @@ -315,7 +310,7 @@ def _filter_reactants(self, smi_list): elif self.num_reactant == 2: # bi-molecular reaction smi_w_patt1 = [] smi_w_patt2 = [] - for smi in tqdm(smi_list): + for smi in tqdm(smiles): if self.is_reactant_first(smi): smi_w_patt1.append(smi) if self.is_reactant_second(smi): @@ -324,17 +319,16 @@ def _filter_reactants(self, smi_list): else: raise ValueError('This reaction is neither uni- nor bi-molecular.') - def set_available_reactants(self, building_block_list): + def set_available_reactants(self, building_blocks: list[str]): """ - A function that finds the applicable building blocks from a list of - purchasable building blocks. + Finds applicable reactants from a list of building blocks. + Sets `self.available_reactants`. Args: - building_block_list (list): The list of purchasable building blocks, - where building blocks are represented as SMILES strings. + building_blocks: Building blocks as SMILES strings. """ - self.available_reactants = list(self._filter_reactants(building_block_list)) - return None + self.available_reactants = list(self._filter_reactants(building_blocks)) + return self class ReactionSet: From 276436782a672b0ec35e3f1a7f6e9f667ee23087 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 7 Sep 2022 15:08:25 -0400 Subject: [PATCH 093/302] `Chem.Molecule` -> `Chem.Mol` --- src/syn_net/utils/data_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/syn_net/utils/data_utils.py b/src/syn_net/utils/data_utils.py index 5563055a..a8b71bdb 100644 --- a/src/syn_net/utils/data_utils.py +++ b/src/syn_net/utils/data_utils.py @@ -133,12 +133,12 @@ def visualize(self, name='./reaction1_highlight.o.png'): del rxn return name - def is_reactant(self, smi: Union[str,Chem.Molecule]) -> bool: + def is_reactant(self, smi: Union[str,Chem.Mol]) -> bool: """Checks if `smi` is a reactant of this reaction.""" smi = self.get_mol(smi) return self.rxn.IsMoleculeReactant(smi) - def is_agent(self, smi: Union[str,Chem.Molecule]) -> bool: + def is_agent(self, smi: Union[str,Chem.Mol]) -> bool: """Checks if `smi` is an agent of this reaction.""" smi = self.get_mol(smi) return self.rxn.IsMoleculeAgent(smi) From 3dd4e1cc9d4ba01b046cc01d280da072bdf1eef4 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 7 Sep 2022 18:56:49 -0400 Subject: [PATCH 094/302] refactor: introduces `BuildingBlockFilter`and helper classes --- INSTRUCTIONS.md | 13 ++- scripts/01-filter-building-blocks.py | 66 +++++++++++ src/syn_net/data_generation/filter_unmatch.py | 43 ------- src/syn_net/data_generation/preprocessing.py | 105 ++++++++++++++++++ src/syn_net/data_generation/process_rxn_mp.py | 72 ------------ src/syn_net/utils/data_utils.py | 44 +++++--- 6 files changed, 204 insertions(+), 139 deletions(-) create mode 100644 scripts/01-filter-building-blocks.py delete mode 100644 src/syn_net/data_generation/filter_unmatch.py create mode 100644 src/syn_net/data_generation/preprocessing.py delete mode 100644 src/syn_net/data_generation/process_rxn_mp.py diff --git a/INSTRUCTIONS.md b/INSTRUCTIONS.md index 45d9a4c6..8e22ab84 100644 --- a/INSTRUCTIONS.md +++ b/INSTRUCTIONS.md @@ -28,17 +28,18 @@ Let's start. In other words, filter out all building blocks that do not match any reaction template. There is no need to keep them, as they cannot act as reactant. In a first step, we match all building blocks with each reaction template. - In a second step, we save a set of all matched building blocks. + In a second step, we save all matched building blocks. ```bash # Match - python scripts/01-process_rxn.py - # Filter - python scripts/02-filter-unmatched.py + python scripts/01-filter-buildingblocks.py \ + --building-blocks-file "data/assets/building-blocks/enamine-us-smiles.csv.gz" \ + --rxn-template-file "data/assets/reaction-templates/hb.txt" \ + --output-file "data/pre-process/building-blocks/enamine-us-smiles.csv.gz" ``` > :bulb: All following steps use this matched building blocks <-> reaction template data. As of now, you still have to specify these parameters again for every script to that it can load the right data. - + 2. Generate *synthetic trees* Herein we generate the data used for training the networks. @@ -49,7 +50,7 @@ Let's start. ```bash # Generate synthetic trees python scripts/03-make_dataset_mp.py - # Filter + # Filter python scripts/04-sample_from_original.py ``` diff --git a/scripts/01-filter-building-blocks.py b/scripts/01-filter-building-blocks.py new file mode 100644 index 00000000..ef32f0c9 --- /dev/null +++ b/scripts/01-filter-building-blocks.py @@ -0,0 +1,66 @@ +"""Filter out building blocks that cannot react with any template. +""" +import logging + +from rdkit import RDLogger +from syn_net.data_generation.preprocessing import BuildingBlockFileHandler, BuildingBlockFilter + +RDLogger.DisableLog("rdApp.*") +logger = logging.getLogger(__name__) + + +def get_args(): + import argparse + + parser = argparse.ArgumentParser() + # File I/O + parser.add_argument( + "--building-blocks-file", + type=str, + help="Input file with SMILES strings (First row `SMILES`, then one per line).", + ) + parser.add_argument( + "--rxn-templates-file", + type=str, + help="Input file with reaction templates as SMARTS(No header, one per line).", + ) + parser.add_argument( + "--output-file", + type=str, + help="Output file for the filtered building-blocks file.", + ) + # Processing + parser.add_argument("--ncpu", type=int, default=32, help="Number of cpus") + parser.add_argument("--verbose", default=False, action="store_true") + return parser.parse_args() + + +if __name__ == "__main__": + args = get_args() + logger.info("Start.") + + # Load assets + bblocks = BuildingBlockFileHandler().load(args.building_blocks_file) + with open(args.rxn_templates_file, "rt") as f: + rxn_templates = f.readlines() + + bbf = BuildingBlockFilter( + building_blocks=bblocks, + rxn_templates=rxn_templates, + verbose=args.verbose, + processes=args.ncpu, + ) + # Time intensive task... + bbf.filter() + + # ... and save to disk + bblocks_filtered = bbf.building_blocks_filtered + BuildingBlockFileHandler().save(args.output_file, bblocks_filtered) + + logger.info(f"Total number of building blocks {len(bblocks):d}") + logger.info(f"Matched number of building blocks {len(bblocks_filtered):d}") + logger.info( + f"{len(bblocks_filtered)/len(bblocks):.2%} of building blocks applicable for the reaction template." + ) + + logger.info("Completed.") diff --git a/src/syn_net/data_generation/filter_unmatch.py b/src/syn_net/data_generation/filter_unmatch.py deleted file mode 100644 index c8272589..00000000 --- a/src/syn_net/data_generation/filter_unmatch.py +++ /dev/null @@ -1,43 +0,0 @@ -""" -Filters out purchasable building blocks which don't match a single template. -""" -from syn_net.utils.data_utils import * -import pandas as pd -from tqdm import tqdm -from pathlib import Path -from syn_net.data_generation.process_rxn_mp import _load_building_blocks # TODO: refactor -from syn_net.config import BUILDING_BLOCKS_RAW_DIR, DATA_PREPROCESS_DIR -import logging - -logger = logging.getLogger(__name__) - -if __name__ == '__main__': - reaction_template_id = "hb" # "pis" or "hb" - building_blocks_id = "enamine_us-2021-smiles" - - # Load building blocks - building_blocks_file = Path(BUILDING_BLOCKS_RAW_DIR) / f"{building_blocks_id}.csv.gz" - building_blocks = _load_building_blocks(building_blocks_file) - - - # Load genearted reactions (matched reactions <=> building blocks) - reactions_dir = Path(DATA_PREPROCESS_DIR) - reactions_file = f"reaction-sets_{reaction_template_id}_{building_blocks_id}.json.gz" - r_set = ReactionSet().load(reactions_dir / reactions_file) - - # Identify all used building blocks (via union of sets) - matched_bblocks = set() - for r in tqdm(r_set.rxns): - for reactants in r.available_reactants: - matched_bblocks = matched_bblocks.union(set(reactants)) - - - logger.info(f'Total number of building blocks {len(building_blocks):d}') - logger.info(f'Matched number of building blocks {len(matched_bblocks):d}') - logger.info(f"{len(matched_bblocks)/len(building_blocks):.2%} of building blocks are applicable for the reaction template set '{reaction_template_id}'.") - - # Save to local disk - df = pd.DataFrame({'SMILES': list(matched_bblocks)}) - outfile = f"{reaction_template_id}-{building_blocks_id}-matched.csv.gz" - file = Path(DATA_PREPROCESS_DIR) / outfile - df.to_csv(file, compression='gzip') diff --git a/src/syn_net/data_generation/preprocessing.py b/src/syn_net/data_generation/preprocessing.py new file mode 100644 index 00000000..bcd79512 --- /dev/null +++ b/src/syn_net/data_generation/preprocessing.py @@ -0,0 +1,105 @@ +from tqdm import tqdm + +from syn_net.utils.data_utils import Reaction + + +class BuildingBlockFilter: + """Filter building blocks.""" + + building_blocks: list[str] + building_blocks_filtered: list[str] + rxn_templates: list[str] + rxns: list[Reaction] + rxns_initialised: bool + + def __init__( + self, + *, + building_blocks: list[str], + rxn_templates: list[str], + processes: int = 1, + verbose: bool = False + ) -> None: + self.building_blocks = building_blocks + self.rxn_templates = rxn_templates + + # Init reactions + self.rxns = [Reaction(template=template.strip()) for template in self.rxn_templates] + # Init other stuff + self.processes = processes + self.verbose = verbose + self.rxns_initialised = False + + def _match_mp(self): + from functools import partial + + from pathos import multiprocessing as mp + + def __match(bblocks: list[str], _rxn: Reaction): + return _rxn.set_available_reactants(bblocks) + + func = partial(__match, self.building_blocks) + with mp.Pool(processes=self.processes) as pool: + self.rxns = pool.map(func, self.rxns) + return self + + def _init_rxns_with_reactants(self): + """Initializes a `Reaction` with a list of possible reactants. + + Info: This can take a while for lots of possible reactants.""" + self.rxns = tqdm(self.rxns) if self.verbose else self.rxns + if self.processes == 1: + [rxn.set_available_reactants(self.building_blocks) for rxn in self.rxns] + else: + self._match_mp() + + self.rxns_initialised = True + return self + + def filter(self): + """Filters out building blocks which do not match a reaction template.""" + if not self.rxns_initialised: + self = self._init_rxns_with_reactants() + matched_bblocks = {x for rxn in self.rxns for x in rxn.get_available_reactants} + self.building_blocks_filtered = list(matched_bblocks) + return self + + +from pathlib import Path + + +class BuildingBlockFileHandler: + def _load_csv(self, file: str) -> list[str]: + """Load building blocks as smiles from `*.csv` or `*.csv.gz`.""" + import pandas as pd + + return pd.read_csv(file)["SMILES"].to_list() + + def load(self, file: str) -> list[str]: + """Load building blocks from file.""" + file = Path(file) + if ".csv" in file.suffixes: + return self._load_csv(file) + else: + raise NotImplementedError + + def _save_csv(self, file: Path, building_blocks: list[str]): + """Save building blocks to `*.csv`""" + import pandas as pd + + # remove possible 1 or more extensions, i.e. + # .csv OR .csv.gz --> + file_no_ext = file.parent / file.stem.split(".")[0] + file = (file_no_ext).with_suffix(".csv.gz") + # Save + df = pd.DataFrame({"SMILES": building_blocks}) + df.to_csv(file, compression="gzip") + return None + + def save(self, file: str, building_blocks: list[str]): + """Save building blocks to file.""" + file = Path(file) + if ".csv" in file.suffixes: + self._save_csv(file, building_blocks) + else: + raise NotImplementedError diff --git a/src/syn_net/data_generation/process_rxn_mp.py b/src/syn_net/data_generation/process_rxn_mp.py deleted file mode 100644 index 38368903..00000000 --- a/src/syn_net/data_generation/process_rxn_mp.py +++ /dev/null @@ -1,72 +0,0 @@ -""" -This file processes a set of reaction templates and finds applicable -reactants from a list of purchasable building blocks. - -Usage: - python process_rxn.py -""" -import multiprocessing as mp -from functools import partial -from pathlib import Path -from time import time - -# Silence RDKit loggers (https://github.com/rdkit/rdkit/issues/2683) -from rdkit import RDLogger - -from syn_net.utils.data_utils import Reaction, ReactionSet - -RDLogger.DisableLog("rdApp.*") - - -import pandas as pd - - -def _load_building_blocks(file: Path) -> list[str]: - return pd.read_csv(file)["SMILES"].to_list() - - -def _match_building_blocks_to_rxn(building_blocks: list[str], _rxn: Reaction): - _rxn.set_available_reactants(building_blocks) - return _rxn - - -from syn_net.config import (BUILDING_BLOCKS_RAW_DIR, DATA_PREPROCESS_DIR, - REACTION_TEMPLATE_DIR) - -def get_args(): - import argparse - parser = argparse.ArgumentParser() - parser.add_argument("--building-blocks-file", type=str, help="Input file with SMILES strings (First row `SMILES`, then one per line).") - return parser.parse_args() - -if __name__ == "__main__": - - args = get_args() - reaction_template_id = "hb" # "pis" or "hb" - building_blocks_id = "enamine_us-2021-smiles" - - # Load building blocks - building_blocks_file = Path(BUILDING_BLOCKS_RAW_DIR) / f"{building_blocks_id}.csv.gz" - building_blocks = _load_building_blocks(building_blocks_file) - - # Load reaction templates and parse - path_to__rxntemplates = Path(REACTION_TEMPLATE_DIR) / f"{reaction_template_id}.txt" - _rxntemplates = [] - for line in open(path_to__rxntemplates, "rt"): - template = line.strip() - rxn = Reaction(template) - _rxntemplates.append(rxn) - - # Filter building blocks on each reaction - t = time() - func = partial(_match_building_blocks_to_rxn, building_blocks) - with mp.Pool(processes=64) as pool: - rxns = pool.map(func, _rxntemplates) - print("Time: ", time() - t, "s") - - # Save data to local disk - r = ReactionSet(rxns) - out_dir = Path(DATA_PREPROCESS_DIR) - out_dir.mkdir(exist_ok=True, parents=True) - out_file = out_dir / f"reaction-sets_{reaction_template_id}_{building_blocks_id}.json.gz" - r.save(out_file) diff --git a/src/syn_net/utils/data_utils.py b/src/syn_net/utils/data_utils.py index a8b71bdb..5e5cf684 100644 --- a/src/syn_net/utils/data_utils.py +++ b/src/syn_net/utils/data_utils.py @@ -9,7 +9,7 @@ """ import gzip import json -from typing import Any, Tuple, Union +from typing import Any, Optional, Tuple, Union, Set import pandas as pd from rdkit import Chem @@ -36,7 +36,7 @@ class Reaction: reactant_template: Tuple[str,str] product_template: str agent_templat: str - available_reactants: list[Union[str,Chem.Mol]] + available_reactants: Tuple[list[str],Optional[list[str]]] rxnname: str smiles: Any reference: Any @@ -90,12 +90,12 @@ def load(self, smirks, num_reactant, num_agent, num_product, reactant_template, self.reactant_template = list(reactant_template) self.product_template = product_template self.agent_template = agent_template - self.available_reactants = list(available_reactants) + self.available_reactants = list(available_reactants) # TODO: use Tuple[list,list] here self.rxnname = rxnname self.smiles = smiles self.reference = reference - def get_mol(self, smi): + def get_mol(self, smi: Union[str,Chem.Mol]) -> Chem.Mol: """ A internal function that returns an `RDKit.Chem.Mol` object. @@ -111,7 +111,8 @@ def get_mol(self, smi): elif isinstance(smi, Chem.Mol): return smi else: - raise TypeError(f"f{type(smi)} not supported, only `str` or `RDKit.Chem.Mol`") + raise TypeError(f"{type(smi)} not supported, only `str` or `rdkit.Chem.Mol`") + def visualize(self, name='./reaction1_highlight.o.png'): """ @@ -286,7 +287,7 @@ def run_reaction(self, reactants, keep_main=True): else: return uniqps - def _filter_reactants(self, smiles: list[str]) -> Tuple[list[str],list[str]]: + def _filter_reactants(self, smiles: list[str],verbose: bool=False) -> Tuple[list[str],list[str]]: """ Filters reactants which do not match the reaction. @@ -300,26 +301,28 @@ def _filter_reactants(self, smiles: list[str]) -> Tuple[list[str],list[str]]: Raises: ValueError: If `self` is not a uni- or bi-molecular reaction. """ + smiles = tqdm(smiles) if verbose else smiles + if self.num_reactant == 1: # uni-molecular reaction - smi_w_patt = [] - for smi in tqdm(smiles): + reactants_1 = [] + for smi in smiles: if self.is_reactant_first(smi): - smi_w_patt.append(smi) - return (smi_w_patt, ) + reactants_1.append(smi) + return (reactants_1, ) elif self.num_reactant == 2: # bi-molecular reaction - smi_w_patt1 = [] - smi_w_patt2 = [] - for smi in tqdm(smiles): + reactants_1 = [] + reactants_2 = [] + for smi in smiles: if self.is_reactant_first(smi): - smi_w_patt1.append(smi) + reactants_1.append(smi) if self.is_reactant_second(smi): - smi_w_patt2.append(smi) - return (smi_w_patt1, smi_w_patt2) + reactants_2.append(smi) + return (reactants_1, reactants_2) else: raise ValueError('This reaction is neither uni- nor bi-molecular.') - def set_available_reactants(self, building_blocks: list[str]): + def set_available_reactants(self, building_blocks: list[str],verbose: bool=False): """ Finds applicable reactants from a list of building blocks. Sets `self.available_reactants`. @@ -327,9 +330,14 @@ def set_available_reactants(self, building_blocks: list[str]): Args: building_blocks: Building blocks as SMILES strings. """ - self.available_reactants = list(self._filter_reactants(building_blocks)) + self.available_reactants = self._filter_reactants(building_blocks,verbose=verbose) return self + @property + def get_available_reactants(self) -> Set[str]: + return {x for reactants in self.available_reactants for x in reactants} + + class ReactionSet: """ From f61c8f6938e35ae8364fa62a69087ddba1811fd1 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Thu, 8 Sep 2022 15:43:31 -0400 Subject: [PATCH 095/302] refactor: use OOP for pre-computing embeddings --- INSTRUCTIONS.md | 16 +++++- scripts/02-compute-embeddings.py | 79 ++++++++++++++++++++++++++++ scripts/compute_embedding_mp.py | 68 ------------------------ src/syn_net/MolEmbedder.py | 2 +- src/syn_net/encoding/fingerprints.py | 5 +- 5 files changed, 97 insertions(+), 73 deletions(-) create mode 100644 scripts/02-compute-embeddings.py delete mode 100644 scripts/compute_embedding_mp.py diff --git a/INSTRUCTIONS.md b/INSTRUCTIONS.md index 8e22ab84..67d6b735 100644 --- a/INSTRUCTIONS.md +++ b/INSTRUCTIONS.md @@ -38,9 +38,21 @@ Let's start. --output-file "data/pre-process/building-blocks/enamine-us-smiles.csv.gz" ``` - > :bulb: All following steps use this matched building blocks <-> reaction template data. As of now, you still have to specify these parameters again for every script to that it can load the right data. + > :bulb: All following steps use this matched building blocks <-> reaction template data. You have to specify the correct files for every script to that it can load the right data. It can save some time to store these as environment variables. -2. Generate *synthetic trees* +2. Pre-compute embeddings + + We use the embedding space for the building blocks a lot. + Hence, we pre-compute and store the building blocks. + + ```bash + python scripts/02-compute-embeddings.py \ + --building-blocks-file "data/pre-process/building-blocks/enamine-us-smiles.csv.gz" \ + --rxn-templates-file "data/assets/reaction-templates/hb.txt" + --output-file "data/pre-process/embeddings/hb-enamine-embeddings.npy" \ + ``` + +3. Generate *synthetic trees* Herein we generate the data used for training the networks. The data is generated by randomly selecting building blocks, reaction templates and directives to grow a synthetic tree. diff --git a/scripts/02-compute-embeddings.py b/scripts/02-compute-embeddings.py new file mode 100644 index 00000000..518675f2 --- /dev/null +++ b/scripts/02-compute-embeddings.py @@ -0,0 +1,79 @@ +""" +Computes the molecular embeddings of the purchasable building blocks. + +The embeddings are also referred to as "output embedding". +In the embedding space, a kNN-search will identify the 1st or 2nd reactant. +""" + +import logging + +from syn_net.data_generation.preprocessing import BuildingBlockFileHandler +from syn_net.encoding.fingerprints import fp_256, fp_512, fp_1024, fp_2048, fp_4096 +from syn_net.MolEmbedder import MolEmbedder + +# from syn_net.encoding.gins import mol_embedding +# from syn_net.utils.prep_utils import rdkit2d_embedding + + +logger = logging.getLogger(__file__) + + +FUNCTIONS = { + # "gin": mol_embedding, + "fp_4096": fp_4096, + "fp_2048": fp_2048, + "fp_1024": fp_1024, + "fp_512": fp_512, + "fp_256": fp_256, + # "rdkit2d": rdkit2d_embedding, +} + + +def get_args(): + import argparse + + parser = argparse.ArgumentParser() + # File I/O + parser.add_argument( + "--building-blocks-file", + type=str, + help="Input file with SMILES strings (First row `SMILES`, then one per line).", + ) + parser.add_argument( + "--rxn-templates-file", + type=str, + help="Input file with reaction templates as SMARTS(No header, one per line).", + ) + parser.add_argument( + "--output-file", + type=str, + help="Output file for the computed embeddings file. (*.npy)", + ) + parser.add_argument( + "--featurization-fct", + type=str, + default="fp_256", + choices=FUNCTIONS.keys(), + help="Objective function to optimize", + ) + # Processing + parser.add_argument("--ncpu", type=int, default=32, help="Number of cpus") + parser.add_argument("--verbose", default=False, action="store_true") + return parser.parse_args() + + +if __name__ == "__main__": + + args = get_args() + + # Load building blocks + bblocks = BuildingBlockFileHandler().load(args.building_blocks_file) + logger.info(f"Successfully read {args.building_blocks_file}.") + logger.info(f"Total number of building blocks: {len(bblocks)}.") + + # Compute embeddings + func = FUNCTIONS[args.featurization_fct] + molembedder = MolEmbedder(processes=args.ncpu).compute_embeddings(func, bblocks) + + # Save? + molembedder.save_precomputed(args.output_file) diff --git a/scripts/compute_embedding_mp.py b/scripts/compute_embedding_mp.py deleted file mode 100644 index 487406e1..00000000 --- a/scripts/compute_embedding_mp.py +++ /dev/null @@ -1,68 +0,0 @@ -""" -Computes the molecular embeddings of the purchasable building blocks. - -The embeddings are also referred to as "output embedding". -In the embedding space, a kNN-search will identify the 1st or 2nd reactant. -""" -import logging -from pathlib import Path - -import pandas as pd - -from syn_net.MolEmbedder import MolEmbedder -from syn_net.config import DATA_EMBEDDINGS_DIR, DATA_PREPROCESS_DIR -from syn_net.encoding.fingerprints import fp_256, fp_512, fp_1024, fp_2048, fp_4096 -# from syn_net.encoding.gins import mol_embedding -# from syn_net.utils.prep_utils import rdkit2d_embedding - - -logger = logging.getLogger(__file__) - - -FUNCTIONS = { - # "gin": mol_embedding, - "fp_4096": fp_4096, - "fp_2048": fp_2048, - "fp_1024": fp_1024, - "fp_512": fp_512, - "fp_256": fp_256, - # "rdkit2d": rdkit2d_embedding, -} - -def _load_building_blocks(file: Path) -> list[str]: - return pd.read_csv(file)["SMILES"].to_list() - -def get_args(): - import argparse - parser = argparse.ArgumentParser() - parser.add_argument("--building-blocks-file", type=str, help="Input file with SMILES strings (First row `SMILES`, then one per line).") - parser.add_argument("--output-file", type=str, help="Output file for the computed embeddings.") - parser.add_argument("--feature", type=str, default="fp_256", choices=FUNCTIONS.keys(), help="Objective function to optimize") - parser.add_argument("--ncpu", type=int, default=32, help="Number of cpus") - # Command line args to be deprecated, only support input/output file in future. - parser.add_argument("--rxn-template", type=str, default="hb", choices=["hb", "pis"], help="Choose from ['hb', 'pis']") - parser.add_argument("--building-blocks-id", type=str, default="enamine_us-2021-smiles") - return parser.parse_args() - -if __name__ == "__main__": - - args = get_args() - - # Load building blocks - if (file := args.building_blocks_file) is None: - # Try to construct filename - file = Path(DATA_PREPROCESS_DIR) / f"{args.rxn_template}-{args.building_blocks_id}-matched.csv.gz" - bblocks = _load_building_blocks(file) - logger.info(f"Successfully read {file}.") - logger.info(f"Total number of building blocks: {len(bblocks)}.") - - # Compute embeddings - func = FUNCTIONS[args.feature] - molembedder = MolEmbedder(processes=args.ncpu).compute_embeddings(func,bblocks) - - # Save? - if (outfile := args.output_file) is None: - # Try to construct filename - outfile = Path(DATA_EMBEDDINGS_DIR) / f"{args.rxn_template}-{args.building_blocks_id}-{args.feature}.npy" - molembedder.save_precomputed(outfile) - diff --git a/src/syn_net/MolEmbedder.py b/src/syn_net/MolEmbedder.py index bab38c34..3e4264cc 100644 --- a/src/syn_net/MolEmbedder.py +++ b/src/syn_net/MolEmbedder.py @@ -46,7 +46,7 @@ def _save_npy(self, file: str): embeddings = np.asarray(self.embeddings) # assume at least 2d np.save(file, embeddings) - logger.info(f"Successfully saved to {file}.") + logger.info(f"Successfully saved data (shape={embeddings.shape}) to {file}.") return self def save_precomputed(self, file: str): diff --git a/src/syn_net/encoding/fingerprints.py b/src/syn_net/encoding/fingerprints.py index 15872fd3..4c64e223 100644 --- a/src/syn_net/encoding/fingerprints.py +++ b/src/syn_net/encoding/fingerprints.py @@ -1,5 +1,6 @@ import numpy as np -from rdkit import Chem, DataStructs +from rdkit import Chem +from rdkit.Chem import AllChem, DataStructs ## Morgan fingerprints def mol_fp(smi, _radius=2, _nBits=4096): @@ -38,7 +39,7 @@ def fp_embedding(smi, _radius=2, _nBits=4096): return np.zeros(_nBits).reshape((-1, )).tolist() else: mol = Chem.MolFromSmiles(smi) - features_vec = Chem.AllChem.GetMorganFingerprintAsBitVect(mol, _radius, _nBits) + features_vec = AllChem.GetMorganFingerprintAsBitVect(mol, _radius, _nBits) features = np.zeros((1,)) DataStructs.ConvertToNumpyArray(features_vec, features) return features.reshape((-1, )).tolist() From df78417c9ba7b2982a63dcfc1d20029a0ee7d961 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Thu, 8 Sep 2022 17:03:41 -0400 Subject: [PATCH 096/302] correct import statement --- src/syn_net/data_generation/make_dataset_mp.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/syn_net/data_generation/make_dataset_mp.py b/src/syn_net/data_generation/make_dataset_mp.py index 92cf3571..03f354ac 100644 --- a/src/syn_net/data_generation/make_dataset_mp.py +++ b/src/syn_net/data_generation/make_dataset_mp.py @@ -10,8 +10,8 @@ from pathlib import Path from syn_net.data_generation.make_dataset import synthetic_tree_generator from syn_net.utils.data_utils import ReactionSet, SyntheticTreeSet -from syn_net.data_generation.process_rxn_mp import _load_building_blocks # TODO: refactor from syn_net.config import BUILDING_BLOCKS_RAW_DIR, DATA_PREPROCESS_DIR +from syn_net.data_generation.preprocessing import BuildingBlockFileHandler import logging logger = logging.getLogger(__name__) @@ -25,13 +25,13 @@ def func(_x): if __name__ == '__main__': - reaction_template_id = "hb" # "pis" or "hb" + reaction_template_id = "hb" # "pis" or "hb" building_blocks_id = "enamine_us-2021-smiles" NUM_TREES = 600_000 # Load building blocks building_blocks_file = Path(BUILDING_BLOCKS_RAW_DIR) / f"{building_blocks_id}.csv.gz" - building_blocks = _load_building_blocks(building_blocks_file) + building_blocks = BuildingBlockFileHandler.load(building_blocks_file) # Load genearted reactions (matched reactions <=> building blocks) reactions_dir = Path(DATA_PREPROCESS_DIR) @@ -43,7 +43,7 @@ def func(_x): with mp.Pool(processes=64) as pool: results = pool.map(func, np.arange(NUM_TREES).tolist()) - # Filter out trees that were completed with action="end" + # Filter out trees that were completed with action="end" trees = [r[0] for r in results if r[1] == 3] actions = [r[1] for r in results] From cadf43a136d72b4b227686aecbcba8afadbb77bd Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Fri, 9 Sep 2022 17:25:56 -0400 Subject: [PATCH 097/302] cache smiles -> Mol conversion --- src/syn_net/utils/data_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/syn_net/utils/data_utils.py b/src/syn_net/utils/data_utils.py index 5e5cf684..c0fcd8cf 100644 --- a/src/syn_net/utils/data_utils.py +++ b/src/syn_net/utils/data_utils.py @@ -7,6 +7,7 @@ * `SyntheticTree` * `SyntheticTreeSet` """ +import functools import gzip import json from typing import Any, Optional, Tuple, Union, Set @@ -95,6 +96,7 @@ def load(self, smirks, num_reactant, num_agent, num_product, reactant_template, self.smiles = smiles self.reference = reference + @functools.lru_cache(maxsize=20) def get_mol(self, smi: Union[str,Chem.Mol]) -> Chem.Mol: """ A internal function that returns an `RDKit.Chem.Mol` object. From ed61b2415adbbe7a880146374d1c9cb030785fac Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Fri, 9 Sep 2022 17:30:46 -0400 Subject: [PATCH 098/302] refactor: shorten methods & delete simple getter methods --- src/syn_net/utils/data_utils.py | 70 ++++----------------------------- 1 file changed, 8 insertions(+), 62 deletions(-) diff --git a/src/syn_net/utils/data_utils.py b/src/syn_net/utils/data_utils.py index c0fcd8cf..b9cf671a 100644 --- a/src/syn_net/utils/data_utils.py +++ b/src/syn_net/utils/data_utils.py @@ -152,70 +152,16 @@ def is_product(self, smi): return self.rxn.IsMoleculeProduct(smi) def is_reactant_first(self, smi: Union[str, Chem.Mol]) -> bool: - """ - A function that checks if a molecule is the first reactant in the reaction - defined by the `Reaction` object, where the order of the reactants is - determined by the SMARTS pattern. - - Args: - smi (str or RDKit.Chem.Mol): The query molecule, as either a SMILES - string or an `RDKit.Chem.Mol` object. - - Returns: - result (bool): Indicates if the molecule is the first reactant in - the reaction. - """ - smi = self.get_mol(smi) - if smi.HasSubstructMatch(Chem.MolFromSmarts(self.get_reactant_template(0))): - return True - else: - return False + """Check if `smi` is the first reactant in this reaction """ + mol = self.get_mol(smi) + pattern = Chem.MolFromSmarts(self.reactant_template[0]) + return mol.HasSubstructMatch(pattern) def is_reactant_second(self, smi: Union[str,Chem.Mol]) -> bool: - """ - A function that checks if a molecule is the second reactant in the reaction - defined by the `Reaction` object, where the order of the reactants is - determined by the SMARTS pattern. - - Args: - smi (str or RDKit.Chem.Mol): The query molecule, as either a SMILES - string or an `RDKit.Chem.Mol` object. - - Returns: - result (bool): Indicates if the molecule is the second reactant in - the reaction. - """ - smi = self.get_mol(smi) - if smi.HasSubstructMatch(Chem.MolFromSmarts(self.get_reactant_template(1))): - return True - else: - return False - - def get_smirks(self) -> str: - """Returns the SMARTS pattern which represents the reaction.""" - return self.smirks - - def get_reactant_template(self, ind=0): - """ - A function that returns the SMARTS pattern which represents the specified - reactant. - - Args: - ind (int): The index of the reactant. Defaults to 0. - - Returns: - reactant_template (str): SMARTS pattern representing the reactant. - """ - return self.reactant_template[ind] - - def get_product_template(self): - """ - A function that returns the SMARTS pattern which represents the product. - - Returns: - product_template (str): SMARTS pattern representing the product. - """ - return self.product_template + """Check if `smi` the second reactant in this reaction """ + mol = self.get_mol(smi) + pattern = Chem.MolFromSmarts(self.reactant_template[1]) + return mol.HasSubstructMatch(pattern) def run_reaction(self, reactants, keep_main=True): """ From ab0cc1a9c183a049f32e99133c8830690098ffa9 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Fri, 9 Sep 2022 17:32:08 -0400 Subject: [PATCH 099/302] refactor --- src/syn_net/utils/data_utils.py | 85 +++++++++++++-------------------- 1 file changed, 34 insertions(+), 51 deletions(-) diff --git a/src/syn_net/utils/data_utils.py b/src/syn_net/utils/data_utils.py index b9cf671a..51240c95 100644 --- a/src/syn_net/utils/data_utils.py +++ b/src/syn_net/utils/data_utils.py @@ -8,6 +8,7 @@ * `SyntheticTreeSet` """ import functools +import itertools import gzip import json from typing import Any, Optional, Tuple, Union, Set @@ -163,78 +164,60 @@ def is_reactant_second(self, smi: Union[str,Chem.Mol]) -> bool: pattern = Chem.MolFromSmarts(self.reactant_template[1]) return mol.HasSubstructMatch(pattern) - def run_reaction(self, reactants, keep_main=True): - """ - A function that transform the reactants into the corresponding product. + def run_reaction(self, reactants: Tuple[Union[str,Chem.Mol,None]], keep_main: bool=True) -> Union[str,None]: + """Run this reactions with reactants and return corresponding product. Args: - reactants (list): Contains SMILES strings for the reactants. - keep_main (bool): Indicates whether to return only the main product, - or all possible products. Defaults to True. + reactants (tuple): Contains SMILES strings for the reactants. + keep_main (bool): Return main product only or all possibel products. Defaults to True. Returns: - uniqps (str): SMILES string representing the product. + uniqps: SMILES string representing the product or `None` if not reaction possible """ - rxn = self.rxn + # Input validation. + if not isinstance(reactants, tuple): + raise TypeError(f"Unsupported type '{type(reactants)}' for `reactants`.") + if not len(reactants) in (1,2): + raise ValueError(f"Can only run reactions with 1 or 2 reactants, not {len(reactants)}.") - if self.num_reactant == 1: + rxn = self.rxn # TODO: investigate if this is necessary (if not, delete "delete rxn below") - if isinstance(reactants, (tuple, list)): - if len(reactants) == 1: - r = self.get_mol(reactants[0]) - elif len(reactants) == 2 and reactants[1] is None: - r = self.get_mol(reactants[0]) - else: - return None + # Convert all reactants to `Chem.Mol` + r: Tuple = tuple(self.get_mol(smiles) for smiles in reactants if smiles is not None) - elif isinstance(reactants, (str, Chem.Mol)): - r = self.get_mol(reactants) - else: - raise TypeError('The input of a uni-molecular reaction should ' - 'be a SMILES, an rdkit.Chem.Mol object, or a ' - 'tuple/list of length 1 or 2.') - if not self.is_reactant(r): + if self.num_reactant == 1: + if not self.is_reactant(r[0]): return None - - ps = rxn.RunReactants((r, )) - elif self.num_reactant == 2: - if isinstance(reactants, (tuple, list)) and len(reactants) == 2: - r1 = self.get_mol(reactants[0]) - r2 = self.get_mol(reactants[1]) - else: - raise TypeError('The input of a bi-molecular reaction should ' - 'be a tuple/list of length 2.') - - if self.is_reactant_first(r1) and self.is_reactant_second(r2): + # Match reactant order with reaction template + if self.is_reactant_first(r[0]) and self.is_reactant_second(r[1]): pass - elif self.is_reactant_first(r2) and self.is_reactant_second(r1): - r1, r2 = (r2, r1) - else: + elif self.is_reactant_first(r[1]) and self.is_reactant_second(r[0]): + r = tuple(reversed(r)) + else: # No reaction possible return None - - ps = rxn.RunReactants((r1, r2)) - else: raise ValueError('This reaction is neither uni- nor bi-molecular.') + # Run reaction with rdkit magic + ps = rxn.RunReactants(r) - uniqps = [] - for p in ps: - smi = Chem.MolToSmiles(p[0]) - uniqps.append(smi) + # Filter for unique products (less magic) + # Note: Use chain() to flatten the tuple of tuples + uniqps = list({Chem.MolToSmiles(p) for p in itertools.chain(*ps)}) - uniqps = list(set(uniqps)) - - assert len(uniqps) >= 1 + # Sanity check + if not len(uniqps) >= 1: + raise ValueError("Reaction did not yield any products.") del rxn if keep_main: - return uniqps[0] - else: - return uniqps - + uniqps = uniqps[:1] + # >>> TODO: Always return list[str] (currently depends on "keep_main") + uniqps = uniqps[0] + # <<< ^ delete this line if resolved. + return uniqps def _filter_reactants(self, smiles: list[str],verbose: bool=False) -> Tuple[list[str],list[str]]: """ Filters reactants which do not match the reaction. From d25e0aa2b5c39262319effa933a2ea62a6105024 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Sat, 10 Sep 2022 19:14:41 -0400 Subject: [PATCH 100/302] use list comp --- src/syn_net/utils/data_utils.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/syn_net/utils/data_utils.py b/src/syn_net/utils/data_utils.py index 51240c95..ca12204a 100644 --- a/src/syn_net/utils/data_utils.py +++ b/src/syn_net/utils/data_utils.py @@ -468,10 +468,7 @@ def get_state(self): Returns: state (list): A list contains all root node molecules. """ - state = [] - for mol in self.chemicals: - if mol.is_root: - state.append(mol.smiles) + state = [mol for mol in self.chemicals if mol.is_root] return state[::-1] def update(self, action, rxn_id, mol1, mol2, mol_product): From 7d1a3b1424d1eba219c399183d1e5b355e364549 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Sat, 10 Sep 2022 19:14:57 -0400 Subject: [PATCH 101/302] add type hint --- src/syn_net/utils/data_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/syn_net/utils/data_utils.py b/src/syn_net/utils/data_utils.py index ca12204a..c6c2eeeb 100644 --- a/src/syn_net/utils/data_utils.py +++ b/src/syn_net/utils/data_utils.py @@ -387,10 +387,10 @@ class SyntheticTree: type (uni- or bi-molecular). """ def __init__(self, tree=None): - self.chemicals = [] - self.reactions = [] + self.chemicals: list[NodeChemical] = [] + self.reactions:list [Reaction] = [] self.root = None - self.depth = 0 + self.depth: float= 0 self.actions = [] self.rxn_id2type = None From 2678cab1842f870dde738c9beaeb5f6632d8549e Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Sat, 10 Sep 2022 19:15:23 -0400 Subject: [PATCH 102/302] use list comps --- src/syn_net/utils/data_utils.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/src/syn_net/utils/data_utils.py b/src/syn_net/utils/data_utils.py index c6c2eeeb..5a0c70d2 100644 --- a/src/syn_net/utils/data_utils.py +++ b/src/syn_net/utils/data_utils.py @@ -218,6 +218,7 @@ def run_reaction(self, reactants: Tuple[Union[str,Chem.Mol,None]], keep_main: bo uniqps = uniqps[0] # <<< ^ delete this line if resolved. return uniqps + def _filter_reactants(self, smiles: list[str],verbose: bool=False) -> Tuple[list[str],list[str]]: """ Filters reactants which do not match the reaction. @@ -235,20 +236,13 @@ def _filter_reactants(self, smiles: list[str],verbose: bool=False) -> Tuple[list smiles = tqdm(smiles) if verbose else smiles if self.num_reactant == 1: # uni-molecular reaction - reactants_1 = [] - for smi in smiles: - if self.is_reactant_first(smi): - reactants_1.append(smi) + reactants_1 = [smi for smi in smiles if self.is_reactant_first(smi)] return (reactants_1, ) elif self.num_reactant == 2: # bi-molecular reaction - reactants_1 = [] - reactants_2 = [] - for smi in smiles: - if self.is_reactant_first(smi): - reactants_1.append(smi) - if self.is_reactant_second(smi): - reactants_2.append(smi) + reactants_1 = [smi for smi in smiles if self.is_reactant_first(smi)] + reactants_2 = [smi for smi in smiles if self.is_reactant_second(smi)] + return (reactants_1, reactants_2) else: raise ValueError('This reaction is neither uni- nor bi-molecular.') From feacbecab002c9abe1cb98c1268a9cf2509d70e5 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 12 Sep 2022 15:22:34 -0400 Subject: [PATCH 103/302] adds a `ReactionTemplateFileHandler` cls --- src/syn_net/data_generation/preprocessing.py | 28 ++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/src/syn_net/data_generation/preprocessing.py b/src/syn_net/data_generation/preprocessing.py index bcd79512..d465e873 100644 --- a/src/syn_net/data_generation/preprocessing.py +++ b/src/syn_net/data_generation/preprocessing.py @@ -103,3 +103,31 @@ def save(self, file: str, building_blocks: list[str]): self._save_csv(file, building_blocks) else: raise NotImplementedError + +class ReactionTemplateFileHandler: + + def load(self, file: str) -> list[str]: + """Load reaction templates from file.""" + with open(file, "rt") as f: + rxn_templates = f.readlines() + + if not all([self._validate(t)] for t in rxn_templates): + raise ValueError("Not all reaction templates are valid.") + + return rxn_templates + + def _validate(self, rxn_template: str) -> bool: + """Validate reaction templates. + + Checks if: + - reaction is uni- or bimolecular + - has only a single product + + Note: + - only uses std-lib functions, very basic validation only + """ + reactants, agents, products = rxn_template.split(">") + is_uni_or_bimolecular = len(reactants) == 1 or len(reactants) == 2 + has_single_product = len(products) == 1 + + return is_uni_or_bimolecular and has_single_product From dc258fa37d18bf99d4a2351b0fb8a866ee99e137 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 12 Sep 2022 15:23:24 -0400 Subject: [PATCH 104/302] adds type hints + comments --- src/syn_net/utils/data_utils.py | 31 ++++++++++++------------------- 1 file changed, 12 insertions(+), 19 deletions(-) diff --git a/src/syn_net/utils/data_utils.py b/src/syn_net/utils/data_utils.py index 5a0c70d2..3c1da712 100644 --- a/src/syn_net/utils/data_utils.py +++ b/src/syn_net/utils/data_utils.py @@ -208,6 +208,7 @@ def run_reaction(self, reactants: Tuple[Union[str,Chem.Mol,None]], keep_main: bo # Sanity check if not len(uniqps) >= 1: + # TODO: Raise (custom) exception? raise ValueError("Reaction did not yield any products.") del rxn @@ -454,10 +455,9 @@ def get_node_index(self, smi): return node.index return None - def get_state(self): - """ - Returns the state of the synthetic tree. The most recent root node has 0 - as its index. + def get_state(self) -> list[NodeChemical]: + """Get the state of this synthetic tree. + The most recent root node has 0 as its index. Returns: state (list): A list contains all root node molecules. @@ -465,9 +465,8 @@ def get_state(self): state = [mol for mol in self.chemicals if mol.is_root] return state[::-1] - def update(self, action, rxn_id, mol1, mol2, mol_product): - """ - A function that updates a synthetic tree by adding a reaction step. + def update(self, action: int, rxn_id:int, mol1: str, mol2: str, mol_product:str): + """Update this synthetic tree by adding a reaction step. Args: action (int): Action index, where the indices (0, 1, 2, 3) represent @@ -480,13 +479,11 @@ def update(self, action, rxn_id, mol1, mol2, mol_product): """ self.actions.append(int(action)) - if action == 3: - # End + if action == 3: # End self.root = self.chemicals[-1] self.depth = self.root.depth - elif action == 2: - # Merge with bi-mol rxn + elif action == 2: # Merge (with bi-mol rxn) node_mol1 = self.chemicals[self.get_node_index(mol1)] node_mol2 = self.chemicals[self.get_node_index(mol2)] node_rxn = NodeRxn(rxn_id=rxn_id, @@ -512,8 +509,7 @@ def update(self, action, rxn_id, mol1, mol2, mol_product): self.chemicals.append(node_product) self.reactions.append(node_rxn) - elif action == 1 and mol2 is None: - # Expand with uni-mol rxn + elif action == 1 and mol2 is None: # Expand with uni-mol rxn node_mol1 = self.chemicals[self.get_node_index(mol1)] node_rxn = NodeRxn(rxn_id=rxn_id, rtype=1, @@ -536,8 +532,7 @@ def update(self, action, rxn_id, mol1, mol2, mol_product): self.chemicals.append(node_product) self.reactions.append(node_rxn) - elif action == 1 and mol2 is not None: - # Expand with bi-mol rxn + elif action == 1 and mol2 is not None: # Expand with bi-mol rxn node_mol1 = self.chemicals[self.get_node_index(mol1)] node_mol2 = NodeChemical(smiles=mol2, parent=None, @@ -570,8 +565,7 @@ def update(self, action, rxn_id, mol1, mol2, mol_product): self.chemicals.append(node_product) self.reactions.append(node_rxn) - elif action == 0 and mol2 is None: - # Add with uni-mol rxn + elif action == 0 and mol2 is None: # Add with uni-mol rxn node_mol1 = NodeChemical(smiles=mol1, parent=None, child=None, @@ -600,8 +594,7 @@ def update(self, action, rxn_id, mol1, mol2, mol_product): self.chemicals.append(node_product) self.reactions.append(node_rxn) - elif action == 0 and mol2 is not None: - # Add with bi-mol rxn + elif action == 0 and mol2 is not None: # Add with bi-mol rxn node_mol1 = NodeChemical(smiles=mol1, parent=None, child=None, From 041b9e85da370620a9b81424a5271fa7051a7f0b Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 12 Sep 2022 18:51:44 -0400 Subject: [PATCH 105/302] wip: adds `SynTreeGenerator` --- src/syn_net/data_generation/syntrees.py | 243 ++++++++++++++++++++++++ 1 file changed, 243 insertions(+) create mode 100644 src/syn_net/data_generation/syntrees.py diff --git a/src/syn_net/data_generation/syntrees.py b/src/syn_net/data_generation/syntrees.py new file mode 100644 index 00000000..2330a3a6 --- /dev/null +++ b/src/syn_net/data_generation/syntrees.py @@ -0,0 +1,243 @@ +"""syntrees +""" +from typing import Tuple +from tqdm import tqdm +import numpy as np +from rdkit import Chem + +import logging + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger() + +from syn_net.utils.data_utils import Reaction, SyntheticTree + +class NoReactantAvailableError(Exception): + def __init__(self, message): + # Call the base class constructor with the parameters it needs + super().__init__(message) + +class NoReactionAvailableError(Exception): + def __init__(self, message): + # Call the base class constructor with the parameters it needs + super().__init__(message) + +class NoReactionPossible(Exception): + def __init__(self, message): + # Call the base class constructor with the parameters it needs + super().__init__(message) + + +class SynTreeGenerator: + + building_blocks: list[str] + rxn_templates: list[Reaction] + rxns: dict[int, Reaction] + IDX_RXNS: list + ACTIONS: dict[int, str] = {i: action for i, action in enumerate("add expand merge end".split())} + verbose: bool + logger: logging.Logger + + def __init__( + self, + *, + building_blocks: list[str], + rxn_templates: list[str], + rng=np.random.default_rng(seed=42), + verbose:bool = False, + ) -> None: + self.building_blocks = building_blocks + self.rxn_templates = rxn_templates + self.rxns = [Reaction(template=tmplt) for tmplt in rxn_templates] + self.rng = rng + self.IDX_RXNS = np.arange(len(self.rxns)) + self.processes = 32 + self.verbose = verbose + self.logger = logging.getLogger(__class__.__name__) + + # Time intensive tasks + self._init_rxns_with_reactants() + + def __match_mp(self): + # TODO: refactor / merge with `BuildingBlockFilter` + # TODO: Rename `ReactionSet` -> `ReactionCollection` (same for `SyntheticTreeSet`) + # `Reaction` as "datacls", `*Collection` as cls that encompasses operations on "data"? + # Third class simpyl for file I/O or include somewhere? + from functools import partial + + from pathos import multiprocessing as mp + + def __match(bblocks: list[str], _rxn: Reaction): + return _rxn.set_available_reactants(bblocks) + + func = partial(__match, self.building_blocks) + with mp.Pool(processes=self.processes) as pool: + rxns = pool.map(func, self.rxns) + + self.rxns = rxns + return self + + def _init_rxns_with_reactants(self): + """Initializes a `Reaction` with a list of possible reactants. + + Info: This can take a while for lots of possible reactants.""" + self.rxns = tqdm(self.rxns) if self.verbose else self.rxns + if self.processes == 1: + self.rxns = [rxn.set_available_reactants(self.building_blocks) for rxn in self.rxns] + else: + self.__match_mp() + + self.rxns_initialised = True + return self + + def _sample_molecule(self) -> str: + """Sample a molecule.""" + idx = self.rng.choice(len(self.building_blocks)) + smiles = self.building_blocks[idx] + self.logger.debug(f" Sampled molecule: {smiles}") + return smiles + + def _base_case(self) -> str: + return self._sample_molecule() + + def _find_rxn_candidates(self, smiles: str): + """Find a reaction with `mol` as reactant.""" + mol = Chem.MolFromSmiles(smiles) + rxn_mask = [rxn.is_reactant(mol) for rxn in self.rxns] + if not any(rxn_mask): + raise NoReactionAvailableError(f"No reaction available for: {smiles}.") + return rxn_mask + + def _sample_rxn(self, mask: np.ndarray = None) -> Tuple[Reaction, int]: + """Sample a reaction by index.""" + if mask is None: + irxn_mask = self.IDX_RXNS # + else: + mask = np.asarray(mask) + irxn_mask = self.IDX_RXNS[mask] + idx = self.rng.choice(irxn_mask) + self.logger.debug(f" Sampled reaction with index: {idx} (nreactants: {self.rxns[idx].num_reactant})") + return self.rxns[idx], idx + + def _expand(self, reactant_1: str) -> Tuple[str, str, str, np.int64]: + """Expand a sub-tree from one molecule. + This can result in uni- or bimolecular reaction.""" + + # Identify applicable reactions + rxn_mask = self._find_rxn_candidates(reactant_1) + + # Sample reaction (by index) + rxn, idx_rxn = self._sample_rxn(mask=rxn_mask) + + # Sample 2nd reactant + if rxn.num_reactant == 1: + reactant_2 = None + else: + # Sample a molecule from the available reactants of this reaction + # That is, for a reaction A + B -> C, + # - determine if we have "A" or "B" + # - then sample "B" (or "A") + idx = 1 if rxn.is_reactant_first(reactant_1) else 0 + available_reactants = rxn.available_reactants[idx] + nPossible = len(available_reactants) + if nPossible==0: + raise NoReactantAvailableError("Unable to find two reactants for this bimolecular reaction.") + # TODO: 2 bi-molecular rxn templates have no matching bblock + # TODO: use numpy array to avoid type conversion or stick to sampling idx? + idx = self.rng.choice(nPossible) + reactant_2 = available_reactants[idx] + + + # Run reaction + reactants = (reactant_1, reactant_2) + product = rxn.run_reaction(reactants) + return *reactants, product, idx_rxn + + def _get_action_mask(self, syntree: SyntheticTree): + """Get a mask of possible action for a SyntheticTree""" + # Recall: (Add, Expand, Merge, and End) + canAdd = False + canMerge = False + canExpand = False + canEnd = False + + state = syntree.get_state() + nTrees = len(state) + if nTrees == 0: + canAdd = True + elif nTrees == 1: + canAdd = True + canExpand = True + canEnd = True + elif nTrees == 2: + canExpand = True + canMerge = True # TODO: only if rxn is possible + else: + raise ValueError + + return np.array((canAdd, canExpand, canMerge, canEnd), dtype=bool) + + def generate(self, max_depth: int = 15, retries: int = 3): + """Generate a syntree by random sampling.""" + + # Init + self.logger.debug(f"Starting synthetic tree generation with {max_depth=} ") + syntree = SyntheticTree() + recent_mol = self._sample_molecule() # root of the current tree + + for i in range(max_depth): + self.logger.debug(f"Iteration {i}") + + # State of syntree + state = syntree.get_state() + + # Sample action + p_action = self.rng.random((1, 4)) # (1,4) + action_mask = self._get_action_mask(syntree) # (1,4) + act = np.argmax(p_action * action_mask) # (1,) + action = self.ACTIONS[act] + self.logger.debug(f" Sampled action: {action}") + + if action == "end": + break + elif action == "expand": + for j in range(retries): + self.logger.debug(f" Try {j}") + r1, r2, p, idx_rxn= self._expand(recent_mol) + if p is not None: break + if p is None: + # TODO: move to rxn.run_reaction? + raise NoReactionPossible("No reaction possible.") + + elif action == "add": + mol = self._sample_molecule() + r1, r2, p, idx_rxn = self._expand(mol) + # Expand this subtree: reactant, reaction, reactant2 + + elif action == "merge": + # merge two subtrees: sample reaction, run it. + r1, r2 = [node.smiles for node in state] + # Identify suitable rxn + # TODO: naive implementation + rxn_mask1 = self._find_rxn_candidates(r1) + rxn_mask2 = self._find_rxn_candidates(r2) + rxn_mask = rxn_mask1 and rxn_mask2 + rxn, idx_rxn = self._sample_rxn(mask=rxn_mask) + # Run reaction + p = rxn.run_reaction((r1, r2)) + if p is None: + # TODO: move to rxn.run_reaction? + raise NoReactionPossible("No reaction possible.") + + # Prepare next iteration + self.logger.debug(f" Ran reaction {r1} + {r2} -> {p}") + + recent_mol = p + + # Update tree + assert isinstance(act,(int,np.int64)), type(act) + assert isinstance(r1,str), type(r1) + assert isinstance(r2,(str,type(None))), type(r2) + assert isinstance(p,(str)), type(p) + syntree.update(act, rxn_id=idx_rxn, mol1=r1, mol2=r2, mol_product=p) + return syntree From 0806ad0419a235bc6c9c4e560fd5e00da96ee085 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Tue, 13 Sep 2022 14:39:19 -0400 Subject: [PATCH 106/302] finalize utils to plot `SyntheticTree` --- src/syn_net/visualize/drawers.py | 24 +++++++++----- src/syn_net/visualize/visualizer.py | 51 ++++++++++++++++++++++++----- src/syn_net/visualize/writers.py | 2 +- 3 files changed, 59 insertions(+), 18 deletions(-) diff --git a/src/syn_net/visualize/drawers.py b/src/syn_net/visualize/drawers.py index c8e04acb..8b6c0511 100644 --- a/src/syn_net/visualize/drawers.py +++ b/src/syn_net/visualize/drawers.py @@ -1,15 +1,24 @@ import uuid from pathlib import Path -from typing import Union +from typing import Optional, Union import rdkit.Chem as Chem from rdkit.Chem import Draw class MolDrawer: - def __init__(self): - self.lookup: dict = None - self.path: Union[None, str] = None + """Draws molecules as images.""" + + def __init__(self, path: Optional[str], subfolder: str = "assets"): + + # Init outfolder + if not (path is not None and Path(path).exists()): + raise NotADirectoryError(path) + self.outfolder = Path(path) / subfolder + self.outfolder.mkdir(exist_ok=1) + + # Placeholder + self.lookup: dict[str, str] = None def _hash(self, smiles: list[str]) -> dict[str, str]: """Hashing for amateurs. @@ -23,13 +32,12 @@ def get_path(self) -> str: def get_molecule_filesnames(self): return self.lookup - def plot(self, smiles: Union[list[str], str], path: str = "./"): - """Plot smiles as 2d molecules and save to `path`.""" + def plot(self, smiles: Union[list[str], str]): + """Plot smiles as 2d molecules and save to `self.path/subfolder/*.svg`.""" self._hash(smiles) - self.path = path for k, v in self.lookup.items(): - fname = str((Path(path) / f"{v}.svg").resolve()) + fname = self.outfolder / f"{v}.svg" mol = Chem.MolFromSmiles(k) # Plot drawer = Draw.rdMolDraw2D.MolDraw2DSVG(300, 150) diff --git a/src/syn_net/visualize/visualizer.py b/src/syn_net/visualize/visualizer.py index e81fa656..ff17882a 100644 --- a/src/syn_net/visualize/visualizer.py +++ b/src/syn_net/visualize/visualizer.py @@ -1,12 +1,16 @@ +from pathlib import Path from typing import Union from syn_net.utils.data_utils import NodeChemical, NodeRxn, SyntheticTree +from syn_net.visualize.drawers import MolDrawer from syn_net.visualize.writers import subgraph class SynTreeVisualizer: actions_taken: dict[int, str] CHEMICALS: dict[str, NodeChemical] + outfolder: Union[str, Path] + version: int ACTIONS = { 0: "Add", @@ -15,7 +19,7 @@ class SynTreeVisualizer: 3: "End", } - def __init__(self, syntree: SyntheticTree): + def __init__(self, syntree: SyntheticTree, outfolder: str = "./syntree-viz/st"): self.syntree = syntree self.actions_taken = { depth: self.ACTIONS[action] for depth, action in enumerate(syntree.actions) @@ -23,15 +27,41 @@ def __init__(self, syntree: SyntheticTree): self.CHEMICALS = {node.smiles: node for node in syntree.chemicals} # Placeholder for images for molecues. - self.path: Union[None, str] = None + self.drawer: Union[MolDrawer, None] self.molecule_filesnames: Union[None, dict[str, str]] = None + + # Folders + outfolder = Path(outfolder) + self.version = self._get_next_version(outfolder) + self.path = outfolder.with_name(outfolder.name + f"_{self.version}") return None - def with_drawings(self, drawer): - """Plot images of the molecules in the nodes.""" - self.path = drawer.get_path() - self.molecule_filesnames = drawer.get_molecule_filesnames() + def _get_next_version(self, dir: str) -> int: + root_dir = Path(dir).parent + name = Path(dir).name + + existing_versions = [] + for d in Path(root_dir).glob(f"{name}_*"): + d = str(d.resolve()) + existing_versions.append(int(d.split("_")[1])) + + if len(existing_versions) == 0: + return 0 + + return max(existing_versions) + 1 + + def with_drawings(self, drawer: MolDrawer): + """Init `MolDrawer` to plot molecules in the nodes.""" + self.path.mkdir(parents=True) + self.drawer = drawer(self.path) + return self + def plot(self): + """Plots molecules via `self.drawer.plot()`.""" + if self.drawer is None: + raise ValueError("Must initialize drawer beforehand.") + self.drawer.plot(self.CHEMICALS) + self.molecule_filesnames = self.drawer.get_molecule_filesnames() return self def _define_chemicals( @@ -40,14 +70,14 @@ def _define_chemicals( ) -> list[str]: chemicals = self.CHEMICALS if chemicals is None else chemicals - if self.path is None or self.molecule_filesnames is None: + if self.drawer.outfolder is None or self.molecule_filesnames is None: raise NotImplementedError("Must provide drawer via `_with_drawings()` before plotting.") out: list[str] = [] for node in chemicals.values(): name = f'"node.smiles"' - name = f'' + name = f'' classdef = self._map_node_type_to_classdef(node) info = f"n{node.index}[{name}]:::{classdef}" out += [info] @@ -81,7 +111,10 @@ def _write_reaction_connectivity( return out def write(self) -> list[str]: - """Write.""" + """Write markdown with mermaid block.""" + # 1. Plot images + self.plot() + # 2. Write markdown (with reference to image files.) rxns: list[NodeRxn] = self.syntree.reactions text = [] diff --git a/src/syn_net/visualize/writers.py b/src/syn_net/visualize/writers.py index 260cec59..550ba813 100644 --- a/src/syn_net/visualize/writers.py +++ b/src/syn_net/visualize/writers.py @@ -56,7 +56,7 @@ def write(self) -> list[str]: class SynTreeWriter: - def __init__(self, prefixer=None, postfixer=None): + def __init__(self, prefixer=PrefixWriter(), postfixer=PostfixWriter()): self.prefixer = prefixer self.postfixer = postfixer self._text: list[str] = None From 402a8328e3a5df940362bd68e12da1851ed930e8 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Tue, 13 Sep 2022 16:16:30 -0400 Subject: [PATCH 107/302] refactor logging --- src/syn_net/data_generation/syntrees.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/src/syn_net/data_generation/syntrees.py b/src/syn_net/data_generation/syntrees.py index 2330a3a6..15f59435 100644 --- a/src/syn_net/data_generation/syntrees.py +++ b/src/syn_net/data_generation/syntrees.py @@ -7,8 +7,8 @@ import logging -logging.basicConfig(level=logging.DEBUG) -logger = logging.getLogger() +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) from syn_net.utils.data_utils import Reaction, SyntheticTree @@ -36,7 +36,6 @@ class SynTreeGenerator: IDX_RXNS: list ACTIONS: dict[int, str] = {i: action for i, action in enumerate("add expand merge end".split())} verbose: bool - logger: logging.Logger def __init__( self, @@ -53,7 +52,8 @@ def __init__( self.IDX_RXNS = np.arange(len(self.rxns)) self.processes = 32 self.verbose = verbose - self.logger = logging.getLogger(__class__.__name__) + if verbose: + logger.setLevel(logging.DEBUG) # Time intensive tasks self._init_rxns_with_reactants() @@ -94,7 +94,7 @@ def _sample_molecule(self) -> str: """Sample a molecule.""" idx = self.rng.choice(len(self.building_blocks)) smiles = self.building_blocks[idx] - self.logger.debug(f" Sampled molecule: {smiles}") + logger.debug(f" Sampled molecule: {smiles}") return smiles def _base_case(self) -> str: @@ -116,7 +116,7 @@ def _sample_rxn(self, mask: np.ndarray = None) -> Tuple[Reaction, int]: mask = np.asarray(mask) irxn_mask = self.IDX_RXNS[mask] idx = self.rng.choice(irxn_mask) - self.logger.debug(f" Sampled reaction with index: {idx} (nreactants: {self.rxns[idx].num_reactant})") + logger.debug(f" Sampled reaction with index: {idx} (nreactants: {self.rxns[idx].num_reactant})") return self.rxns[idx], idx def _expand(self, reactant_1: str) -> Tuple[str, str, str, np.int64]: @@ -181,12 +181,12 @@ def generate(self, max_depth: int = 15, retries: int = 3): """Generate a syntree by random sampling.""" # Init - self.logger.debug(f"Starting synthetic tree generation with {max_depth=} ") + logger.debug(f"Starting synthetic tree generation with {max_depth=} ") syntree = SyntheticTree() recent_mol = self._sample_molecule() # root of the current tree for i in range(max_depth): - self.logger.debug(f"Iteration {i}") + logger.debug(f"Iteration {i}") # State of syntree state = syntree.get_state() @@ -196,13 +196,13 @@ def generate(self, max_depth: int = 15, retries: int = 3): action_mask = self._get_action_mask(syntree) # (1,4) act = np.argmax(p_action * action_mask) # (1,) action = self.ACTIONS[act] - self.logger.debug(f" Sampled action: {action}") + logger.debug(f" Sampled action: {action}") if action == "end": break elif action == "expand": for j in range(retries): - self.logger.debug(f" Try {j}") + logger.debug(f" Try {j}") r1, r2, p, idx_rxn= self._expand(recent_mol) if p is not None: break if p is None: @@ -230,7 +230,7 @@ def generate(self, max_depth: int = 15, retries: int = 3): raise NoReactionPossible("No reaction possible.") # Prepare next iteration - self.logger.debug(f" Ran reaction {r1} + {r2} -> {p}") + logger.debug(f" Ran reaction {r1} + {r2} -> {p}") recent_mol = p @@ -240,4 +240,7 @@ def generate(self, max_depth: int = 15, retries: int = 3): assert isinstance(r2,(str,type(None))), type(r2) assert isinstance(p,(str)), type(p) syntree.update(act, rxn_id=idx_rxn, mol1=r1, mol2=r2, mol_product=p) + logger.debug(f"SynTree updated.") + + logger.debug(f"🙌 SynTree completed.") return syntree From f57f80c048cedaa663e9930821cef70e2aab85f6 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Tue, 13 Sep 2022 16:17:31 -0400 Subject: [PATCH 108/302] demo usage of `SynTreeVisualizer` --- src/syn_net/visualize/visualizer.py | 27 +++++- src/syn_net/visualize/writers.py | 1 + tests/assets/syntree-small.json | 139 ++++++++++++++++++++++++++++ 3 files changed, 166 insertions(+), 1 deletion(-) create mode 100644 tests/assets/syntree-small.json diff --git a/src/syn_net/visualize/visualizer.py b/src/syn_net/visualize/visualizer.py index ff17882a..c629c741 100644 --- a/src/syn_net/visualize/visualizer.py +++ b/src/syn_net/visualize/visualizer.py @@ -142,5 +142,30 @@ def __printer(): return text +def main(): + """Demo syntree visualisation""" + # 1. Load syntree + import json + with open("tests/assets/syntree-small.json","rt") as f: + data = json.load(f) + + st = SyntheticTree() + st.read(data) + + from syn_net.visualize.drawers import MolDrawer + from syn_net.visualize.visualizer import SynTreeVisualizer + from syn_net.visualize.writers import SynTreeWriter + + outpath = Path("./0-figures/syntrees/generation/st") + outpath.mkdir(parents=True, exist_ok=True) + + # 2. Plot & Write mermaid markup diagram + stviz = SynTreeVisualizer(syntree=st, outfolder=outpath).with_drawings(drawer=MolDrawer) + mermaid_txt = stviz.write() + # 3. Write everything to a markdown doc + outfile = stviz.path / "syntree.md" + SynTreeWriter().write(mermaid_txt).to_file(outfile) + return None + if __name__ == "__main__": - pass + main() diff --git a/src/syn_net/visualize/writers.py b/src/syn_net/visualize/writers.py index 550ba813..1832ef54 100644 --- a/src/syn_net/visualize/writers.py +++ b/src/syn_net/visualize/writers.py @@ -72,6 +72,7 @@ def to_file(self, file: str, text: list[str] = None): with open(file, "wt") as f: f.writelines((l.rstrip() + "\n" for l in text)) + return None @property def text(self) -> list[str]: diff --git a/tests/assets/syntree-small.json b/tests/assets/syntree-small.json new file mode 100644 index 00000000..8b865180 --- /dev/null +++ b/tests/assets/syntree-small.json @@ -0,0 +1,139 @@ +{ + "reactions": [ + { + "rxn_id": 12, + "rtype": 2, + "parent": "CCOc1ccc(CCNC(=O)CN2N=NC=C2CN2CCC(C(=O)O)CC2)cc1OCC", + "child": [ + "CCOc1ccc(CCNC(=O)CCl)cc1OCC", + "C#CCN1CCC(C(=O)O)CC1.Cl" + ], + "depth": 0.5, + "index": 0 + }, + { + "rxn_id": 47, + "rtype": 2, + "parent": "C=C(C)C(=O)OCCNC(=O)N(c1ccc(C#N)cc1C)C1CC1", + "child": [ + "C=C(C)C(=O)OCCN=C=O", + "Cc1cc(C#N)ccc1NC1CC1" + ], + "depth": 0.5, + "index": 1 + }, + { + "rxn_id": 15, + "rtype": 2, + "parent": "C=C(C)C(=O)OCCNC(=O)N(c1ccc(C2=NNC(C3CCN(Cc4cnnn4CC(=O)NCCc4ccc(OCC)c(OCC)c4)CC3)=N2)cc1C)C1CC1", + "child": [ + "C=C(C)C(=O)OCCNC(=O)N(c1ccc(C#N)cc1C)C1CC1", + "CCOc1ccc(CCNC(=O)CN2N=NC=C2CN2CCC(C(=O)O)CC2)cc1OCC" + ], + "depth": 1.5, + "index": 2 + }, + { + "rxn_id": 49, + "rtype": 1, + "parent": "C=C(C)C(=O)OCCNC(=O)N(c1ccc(-c2n[nH]c(C3CCN(Cc4cnnn4CC4=NCCc5cc(OCC)c(OCC)cc54)CC3)n2)cc1C)C1CC1", + "child": [ + "C=C(C)C(=O)OCCNC(=O)N(c1ccc(C2=NNC(C3CCN(Cc4cnnn4CC(=O)NCCc4ccc(OCC)c(OCC)c4)CC3)=N2)cc1C)C1CC1" + ], + "depth": 2.5, + "index": 3 + } + ], + "chemicals": [ + { + "smiles": "CCOc1ccc(CCNC(=O)CCl)cc1OCC", + "parent": 12, + "child": null, + "is_leaf": true, + "is_root": false, + "depth": 0, + "index": 0 + }, + { + "smiles": "C#CCN1CCC(C(=O)O)CC1.Cl", + "parent": 12, + "child": null, + "is_leaf": true, + "is_root": false, + "depth": 0, + "index": 1 + }, + { + "smiles": "CCOc1ccc(CCNC(=O)CN2N=NC=C2CN2CCC(C(=O)O)CC2)cc1OCC", + "parent": 15, + "child": 12, + "is_leaf": false, + "is_root": false, + "depth": 1, + "index": 2 + }, + { + "smiles": "C=C(C)C(=O)OCCN=C=O", + "parent": 47, + "child": null, + "is_leaf": true, + "is_root": false, + "depth": 0, + "index": 3 + }, + { + "smiles": "Cc1cc(C#N)ccc1NC1CC1", + "parent": 47, + "child": null, + "is_leaf": true, + "is_root": false, + "depth": 0, + "index": 4 + }, + { + "smiles": "C=C(C)C(=O)OCCNC(=O)N(c1ccc(C#N)cc1C)C1CC1", + "parent": 15, + "child": 47, + "is_leaf": false, + "is_root": false, + "depth": 1, + "index": 5 + }, + { + "smiles": "C=C(C)C(=O)OCCNC(=O)N(c1ccc(C2=NNC(C3CCN(Cc4cnnn4CC(=O)NCCc4ccc(OCC)c(OCC)c4)CC3)=N2)cc1C)C1CC1", + "parent": 49, + "child": 15, + "is_leaf": false, + "is_root": false, + "depth": 2.0, + "index": 6 + }, + { + "smiles": "C=C(C)C(=O)OCCNC(=O)N(c1ccc(-c2n[nH]c(C3CCN(Cc4cnnn4CC4=NCCc5cc(OCC)c(OCC)cc54)CC3)n2)cc1C)C1CC1", + "parent": null, + "child": 49, + "is_leaf": false, + "is_root": true, + "depth": 3.0, + "index": 7 + } + ], + "root": { + "smiles": "C=C(C)C(=O)OCCNC(=O)N(c1ccc(-c2n[nH]c(C3CCN(Cc4cnnn4CC4=NCCc5cc(OCC)c(OCC)cc54)CC3)n2)cc1C)C1CC1", + "parent": null, + "child": 49, + "is_leaf": false, + "is_root": true, + "depth": 3.0, + "index": 7 + }, + "depth": 3.0, + "actions": [ + 0, + 0, + 2, + 1, + 3 + ], + "rxn_id2type": null +} \ No newline at end of file From fe89058a6464e0f9e7f060de1d99ba6d1d997fd5 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Tue, 13 Sep 2022 16:17:56 -0400 Subject: [PATCH 109/302] fix typos --- INSTRUCTIONS.md | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/INSTRUCTIONS.md b/INSTRUCTIONS.md index 67d6b735..4007849b 100644 --- a/INSTRUCTIONS.md +++ b/INSTRUCTIONS.md @@ -32,10 +32,11 @@ Let's start. ```bash # Match - python scripts/01-filter-buildingblocks.py \ + python scripts/01-filter-building-blocks.py \ --building-blocks-file "data/assets/building-blocks/enamine-us-smiles.csv.gz" \ - --rxn-template-file "data/assets/reaction-templates/hb.txt" \ - --output-file "data/pre-process/building-blocks/enamine-us-smiles.csv.gz" + --rxn-templates-file "data/assets/reaction-templates/hb.txt" \ + --output-file "data/pre-process/building-blocks/enamine-us-smiles.csv.gz" \ + --verbose ``` > :bulb: All following steps use this matched building blocks <-> reaction template data. You have to specify the correct files for every script to that it can load the right data. It can save some time to store these as environment variables. @@ -49,7 +50,7 @@ Let's start. python scripts/02-compute-embeddings.py \ --building-blocks-file "data/pre-process/building-blocks/enamine-us-smiles.csv.gz" \ --rxn-templates-file "data/assets/reaction-templates/hb.txt" - --output-file "data/pre-process/embeddings/hb-enamine-embeddings.npy" \ + --output-file "data/pre-process/embeddings/hb-enamine-embeddings.npy" ``` 3. Generate *synthetic trees* From f751b592290ebb40c6a196416c72f3a001bfab26 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 14 Sep 2022 10:20:59 -0400 Subject: [PATCH 110/302] includes reactant info in error msg --- src/syn_net/data_generation/syntrees.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/syn_net/data_generation/syntrees.py b/src/syn_net/data_generation/syntrees.py index 15f59435..0293cecd 100644 --- a/src/syn_net/data_generation/syntrees.py +++ b/src/syn_net/data_generation/syntrees.py @@ -207,7 +207,7 @@ def generate(self, max_depth: int = 15, retries: int = 3): if p is not None: break if p is None: # TODO: move to rxn.run_reaction? - raise NoReactionPossible("No reaction possible.") + raise NoReactionPossible(f"Reaction (ID: {idx_rxn}) not possible with: {r1} + {r2}.") elif action == "add": mol = self._sample_molecule() @@ -227,7 +227,7 @@ def generate(self, max_depth: int = 15, retries: int = 3): p = rxn.run_reaction((r1, r2)) if p is None: # TODO: move to rxn.run_reaction? - raise NoReactionPossible("No reaction possible.") + raise NoReactionPossible(f"Reaction (ID: {idx_rxn}) not possible with: {r1} + {r2}.") # Prepare next iteration logger.debug(f" Ran reaction {r1} + {r2} -> {p}") From 12287306881276dc3b4178e8d83bbe14d68ea2ce Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 14 Sep 2022 10:21:37 -0400 Subject: [PATCH 111/302] bug fix: syntree must be updated for "end" action --- src/syn_net/data_generation/syntrees.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/syn_net/data_generation/syntrees.py b/src/syn_net/data_generation/syntrees.py index 0293cecd..c3fb5e5f 100644 --- a/src/syn_net/data_generation/syntrees.py +++ b/src/syn_net/data_generation/syntrees.py @@ -199,7 +199,7 @@ def generate(self, max_depth: int = 15, retries: int = 3): logger.debug(f" Sampled action: {action}") if action == "end": - break + r1, r2, p, idx_rxn = None, None, None, None elif action == "expand": for j in range(retries): logger.debug(f" Try {j}") @@ -235,12 +235,14 @@ def generate(self, max_depth: int = 15, retries: int = 3): recent_mol = p # Update tree - assert isinstance(act,(int,np.int64)), type(act) - assert isinstance(r1,str), type(r1) - assert isinstance(r2,(str,type(None))), type(r2) - assert isinstance(p,(str)), type(p) + # assert isinstance(act,(int,np.int64)), type(act) + # assert isinstance(r1,str), type(r1) + # assert isinstance(r2,(str,type(None))), type(r2) + # assert isinstance(p,(str)), type(p) syntree.update(act, rxn_id=idx_rxn, mol1=r1, mol2=r2, mol_product=p) logger.debug(f"SynTree updated.") + if action == "end": + break logger.debug(f"🙌 SynTree completed.") return syntree From 1d63d900c316c61e71e14aae131241960f2da768 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 14 Sep 2022 10:22:06 -0400 Subject: [PATCH 112/302] wip: syntree generation script --- scripts/03-generate-syntrees.py | 106 ++++++++++++++++++++++++++++++++ 1 file changed, 106 insertions(+) create mode 100644 scripts/03-generate-syntrees.py diff --git a/scripts/03-generate-syntrees.py b/scripts/03-generate-syntrees.py new file mode 100644 index 00000000..7a667432 --- /dev/null +++ b/scripts/03-generate-syntrees.py @@ -0,0 +1,106 @@ +import logging +from collections import Counter +from pathlib import Path +from pathos import multiprocessing as mp +import numpy as np +from rdkit import Chem, RDLogger + +from syn_net.data_generation.preprocessing import (BuildingBlockFileHandler, + ReactionTemplateFileHandler) +from syn_net.data_generation.syntrees import (NoReactantAvailableError, NoReactionAvailableError, + NoReactionPossible, SynTreeGenerator) +from syn_net.utils.data_utils import Reaction, SyntheticTree, SyntheticTreeSet + +logger = logging.getLogger(__name__) +from typing import Tuple, Union + +RDLogger.DisableLog("rdApp.*") + + +def __sanity_checks(): + """Sanity check some methods. Poor mans tests""" + out = stgen._sample_molecule() + assert isinstance(out, str) + assert Chem.MolFromSmiles(out) + + rxn_mask = stgen._find_rxn_candidates(out) + assert isinstance(rxn_mask, list) + assert isinstance(rxn_mask[0], bool) + + rxn, rxn_idx = stgen._sample_rxn() + assert isinstance(rxn, Reaction) + assert isinstance(rxn_idx, np.int64), print(f"{type(rxn_idx)=}") + + out = stgen._base_case() + assert isinstance(out, str) + assert Chem.MolFromSmiles(out) + + st = SyntheticTree() + mask = stgen._get_action_mask(st) + assert isinstance(mask, np.ndarray) + np.testing.assert_array_equal(mask, np.array([True, False, False, False])) + + +def wraps_syntreegenerator_generate() -> Tuple[Union[SyntheticTree, None], Union[Exception, None]]: + try: + st = stgen.generate() + except NoReactantAvailableError as e: + logger.error(e) + return None, e + except NoReactionAvailableError as e: + logger.error(e) + return None, e + except NoReactionPossible as e: + logger.error(e) + return None, e + except TypeError as e: + logger.error(e) + return None, e + except Exception as e: + logger.error(e, exc_info=e, stack_info=False) + return None, e + else: + return st, None + + + + + +if __name__ == "__main__": + logger.info("Start.") + # Load assets + bblocks = BuildingBlockFileHandler().load( + "data/pre-process/building-blocks/enamine-us-smiles.csv.gz" + ) + rxn_templates = ReactionTemplateFileHandler().load("data/assets/reaction-templates/hb.txt") + + # Init SynTree Generator + import pickle + file = "stgen.pickle" + with open(file,"rb") as f: + stgen = pickle.load(f) + # stgen = SynTreeGenerator(building_blocks=bblocks, rxn_templates=rxn_templates, verbose=True) + + # Run some sanity tests + __sanity_checks() + + outcomes: dict[int, any] = dict() + syntrees = [] + for i in range(1_000): + st, e = wraps_syntreegenerator_generate() + outcomes[i] = e.__class__.__name__ if e is not None else "success" + syntrees.append(st) + + logger.info(Counter(outcomes.values())) + + # Store syntrees on disk + syntrees = [st for st in syntrees if st is not None] + syntree_collection = SyntheticTreeSet(syntrees) + import datetime + now = datetime.datetime.now().strftime("%Y%m%d_%H_%M") + file = f"data/{now}-syntrees.json.gz" + + syntree_collection.save(file) + + print("completed at", now) + logger.info(f"Completed.") From 1e22a95ecfc7a33573d08296a544c86dfb4b10eb Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 14 Sep 2022 11:15:59 -0400 Subject: [PATCH 113/302] refactor: explicit calculation of input space embedding --- src/syn_net/utils/predict_utils.py | 35 +++++++++++++++--------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/src/syn_net/utils/predict_utils.py b/src/syn_net/utils/predict_utils.py index 03dd5844..0a821952 100644 --- a/src/syn_net/utils/predict_utils.py +++ b/src/syn_net/utils/predict_utils.py @@ -144,32 +144,34 @@ def nn_search_rt1(_e: np.ndarray, _tree: BallTree, _k: int = 1) -> Tuple[np.ndar return dist[0], ind[0] -def set_embedding(z_target: np.ndarray, state: list[str], nbits: int, _mol_embedding: Callable): +def set_embedding(z_target: np.ndarray, state: list[str], nbits: int, _mol_embedding: Callable) -> np.ndarray: """ Computes embeddings for all molecules in the input space. Embedding = [z_mol1, z_mol2, z_target] Args: - z_target (np.ndarray): Embedding for the target molecule. - state (list): Contains molecules in the current state, if not the initial state. + z_target (np.ndarray): Molecular embedding of the target molecule. + state (list): State of the synthetic tree, i.e. list of root molecules. nbits (int): Length of fingerprint. - _mol_embedding (Callable): Function to use for computing the - embeddings of the first and second molecules in the state. + _mol_embedding (Callable): Computes the embeddings of molecules in the state. Returns: - np.ndarray: Embedding consisting of the concatenation of the target - molecule with the current molecules (if available) in the input state. + embedding (np.ndarray): shape (1,d+2*nbits) """ + z_target = np.atleast_2d(z_target) # (1,d) if len(state) == 0: - embedding = np.concatenate([np.zeros((1, 2 * nbits)), z_target], axis=1) + z_mol1 = np.zeros((1, nbits)) + z_mol2 = np.zeros((1, nbits)) + elif len(state) == 1: + z_mol1 = np.atleast_2d(_mol_embedding(state[0])) + z_mol2 = np.zeros((1, nbits)) + elif len(state) == 2: + z_mol1 = np.atleast_2d(_mol_embedding(state[0])) + z_mol2 = np.atleast_2d(_mol_embedding(state[1])) else: - e1 = np.expand_dims(_mol_embedding(state[0]), axis=0) - if len(state) == 1: - e2 = np.zeros((1, nbits)) - else: - e2 = _mol_embedding(state[1]) - embedding = np.concatenate([e1, e2, z_target], axis=1) - return embedding + raise ValueError + embedding = np.concatenate([z_mol1, z_mol2, z_target], axis=1) + return embedding # (1,d+2*nbits) def synthetic_tree_decoder( z_target: np.ndarray, @@ -216,7 +218,6 @@ def synthetic_tree_decoder( tree = SyntheticTree() mol_recent = None kdtree = BallTree(bb_emb, metric=cosine_distance) # TODO: cache this or use class - z_target = np.atleast_2d(z_target) # Start iteration for i in range(max_step): @@ -373,7 +374,7 @@ def synthetic_tree_decoder_rt1( tree = SyntheticTree() mol_recent = None kdtree = BallTree(bb_emb, metric=cosine_distance) # TODO: cache this or use class - z_target = np.atleast_2d(z_target) + # Start iteration for i in range(max_step): # Encode current state From 12060abe27a36eaf5f0fafc2c1e8a7386bc8fca8 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 14 Sep 2022 14:17:40 -0400 Subject: [PATCH 114/302] fix: use tuple, not list, for reactants --- src/syn_net/utils/predict_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/syn_net/utils/predict_utils.py b/src/syn_net/utils/predict_utils.py index 0a821952..f6db3210 100644 --- a/src/syn_net/utils/predict_utils.py +++ b/src/syn_net/utils/predict_utils.py @@ -304,7 +304,7 @@ def synthetic_tree_decoder( mol2 = None # Run reaction - mol_product = rxn.run_reaction([mol1, mol2]) + mol_product = rxn.run_reaction((mol1, mol2)) if mol_product is None or Chem.MolFromSmiles(mol_product) is None: if len(tree.get_state()) == 1: act = 3 @@ -467,7 +467,7 @@ def synthetic_tree_decoder_rt1( mol2 = None # Run reaction - mol_product = rxn.run_reaction([mol1, mol2]) + mol_product = rxn.run_reaction((mol1, mol2)) if mol_product is None or Chem.MolFromSmiles(mol_product) is None: act = 3 break From 5f6ddd37c0988d8f7c38fed93be0a03f20071066 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 14 Sep 2022 14:22:11 -0400 Subject: [PATCH 115/302] bug fix --- src/syn_net/utils/data_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/syn_net/utils/data_utils.py b/src/syn_net/utils/data_utils.py index 3c1da712..9ec7b0b7 100644 --- a/src/syn_net/utils/data_utils.py +++ b/src/syn_net/utils/data_utils.py @@ -96,6 +96,7 @@ def load(self, smirks, num_reactant, num_agent, num_product, reactant_template, self.rxnname = rxnname self.smiles = smiles self.reference = reference + self.rxn = self.__init_reaction(self.smirks) @functools.lru_cache(maxsize=20) def get_mol(self, smi: Union[str,Chem.Mol]) -> Chem.Mol: @@ -462,7 +463,7 @@ def get_state(self) -> list[NodeChemical]: Returns: state (list): A list contains all root node molecules. """ - state = [mol for mol in self.chemicals if mol.is_root] + state = [node.smiles for node in self.chemicals if node.is_root] return state[::-1] def update(self, action: int, rxn_id:int, mol1: str, mol2: str, mol_product:str): From 41b6e9cabdec913cf54b82bef71fa14554750f47 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 14 Sep 2022 16:24:24 -0400 Subject: [PATCH 116/302] fix: find paths to best model checkpoints --- scripts/predict_multireactant_mp.py | 32 +++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/scripts/predict_multireactant_mp.py b/scripts/predict_multireactant_mp.py index ff9ebc53..7bf6d76b 100644 --- a/scripts/predict_multireactant_mp.py +++ b/scripts/predict_multireactant_mp.py @@ -3,7 +3,7 @@ """ import multiprocessing as mp from pathlib import Path - +from typing import Union import numpy as np import pandas as pd from syn_net.config import (CHECKPOINTS_DIR, DATA_EMBEDDINGS_DIR, @@ -53,13 +53,28 @@ def _fetch_building_blocks(file: str): """Load the building blocks""" return pd.read_csv(file, compression='gzip')['SMILES'].tolist() -def _load_pretrained_model(path_to_checkpoints: str): +def find_best_model_ckpt(path: str) -> Union[Path,None]: # TODO: move to utils.py + """Find checkpoint with lowest val_loss. + + Poor man's regex: + somepath/act/ckpts.epoch=70-val_loss=0.03.ckpt + ^^^^--extract this as float + """ + ckpts = Path(path).rglob("*.ckpt") + best_model_ckpt = None + lowest_loss = 10_000 + for file in (ckpts): + stem = file.stem + val_loss = float(stem.split("val_loss=")[-1]) + if val_loss < lowest_loss: + best_model_ckpt = file + lowest_loss = val_loss + return best_model_ckpt + +def _load_pretrained_model(path_to_checkpoints: list[Path]): """Wrapper to load modules from checkpoint.""" # Define paths to pretrained models. - path_to_act = Path(path_to_checkpoints) / "act.ckpt" - path_to_rt1 = Path(path_to_checkpoints) / "rt1.ckpt" - path_to_rxn = Path(path_to_checkpoints) / "rxn.ckpt" - path_to_rt2 = Path(path_to_checkpoints) / "rt2.ckpt" + path_to_act, path_to_rt1, path_to_rxn, path_to_rt2 = path_to_checkpoints # Load the pre-trained models. act_net, rt1_net, rxn_net, rt2_net = load_modules_from_checkpoint( @@ -165,13 +180,14 @@ def func(smiles: str): file = Path(DATA_PREPROCESS_DIR) / f"reaction-sets_{rxn_template}_{building_blocks_id}.json.gz" rxns = _fetch_reaction_templates(file) - # ... building blocks + # ... building block embedding file = Path(DATA_EMBEDDINGS_DIR) / f"{rxn_template}-{building_blocks_id}-embeddings.npy" bb_emb = _fetch_building_blocks_embeddings(file) # ... models path = Path(CHECKPOINTS_DIR) / f"{param_dir}" - act_net, rt1_net, rxn_net, rt2_net = _load_pretrained_model(path) + paths = [find_best_model_ckpt("results/logs/hb_fp_2_4096/" + mdl) for mdl in "act rt1 rxn rt2".split()] + act_net, rt1_net, rxn_net, rt2_net = _load_pretrained_model(paths) # Decode queries, i.e. the target molecules. From c1ecd7245da0c0345c3b35250c4b3438f8fef602 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 14 Sep 2022 16:26:45 -0400 Subject: [PATCH 117/302] format --- scripts/predict_multireactant_mp.py | 144 ++++++++++++++++------------ 1 file changed, 85 insertions(+), 59 deletions(-) diff --git a/scripts/predict_multireactant_mp.py b/scripts/predict_multireactant_mp.py index 7bf6d76b..d26da233 100644 --- a/scripts/predict_multireactant_mp.py +++ b/scripts/predict_multireactant_mp.py @@ -4,56 +4,68 @@ import multiprocessing as mp from pathlib import Path from typing import Union + import numpy as np import pandas as pd -from syn_net.config import (CHECKPOINTS_DIR, DATA_EMBEDDINGS_DIR, - DATA_PREPARED_DIR, DATA_PREPROCESS_DIR, - DATA_RESULT_DIR) + +from syn_net.config import ( + CHECKPOINTS_DIR, + DATA_EMBEDDINGS_DIR, + DATA_PREPARED_DIR, + DATA_PREPROCESS_DIR, + DATA_RESULT_DIR, +) from syn_net.models.chkpt_loader import load_modules_from_checkpoint from syn_net.utils.data_utils import ReactionSet, SyntheticTreeSet -from syn_net.utils.predict_utils import (mol_fp, - synthetic_tree_decoder_multireactant) +from syn_net.utils.predict_utils import mol_fp, synthetic_tree_decoder_multireactant Path(DATA_RESULT_DIR).mkdir(exist_ok=True) + def _fetch_data_chembl(name: str) -> list[str]: raise NotImplementedError - df = pd.read_csv(f'{DATA_DIR}/chembl_20k.csv') + df = pd.read_csv(f"{DATA_DIR}/chembl_20k.csv") smis_query = df.smiles.to_list() return smis_query + def _fetch_data_from_file(name: str) -> list[str]: - with open(name,"rt") as f: + with open(name, "rt") as f: smis_query = [line.strip() for line in f] return smis_query + def _fetch_data(name: str) -> list[str]: if args.data in ["train", "valid", "test"]: file = Path(DATA_PREPARED_DIR) / f"synthetic-trees-{args.data}.json.gz" - print(f'Reading data from {file}') + print(f"Reading data from {file}") sts = SyntheticTreeSet() sts.load(file) smis_query = [st.root.smiles for st in sts.sts] elif args.data in ["chembl"]: smis_query = _fetch_data_chembl(name) - else: # Hopefully got a filename instead + else: # Hopefully got a filename instead smis_query = _fetch_data_from_file(name) return smis_query + def _fetch_reaction_templates(file: str): # Load reaction templates rxn_set = ReactionSet().load(file) return rxn_set.rxns + def _fetch_building_blocks_embeddings(file: str): """Load the purchasable building block embeddings.""" return np.load(file) + def _fetch_building_blocks(file: str): """Load the building blocks""" - return pd.read_csv(file, compression='gzip')['SMILES'].tolist() + return pd.read_csv(file, compression="gzip")["SMILES"].tolist() -def find_best_model_ckpt(path: str) -> Union[Path,None]: # TODO: move to utils.py + +def find_best_model_ckpt(path: str) -> Union[Path, None]: # TODO: move to utils.py """Find checkpoint with lowest val_loss. Poor man's regex: @@ -63,7 +75,7 @@ def find_best_model_ckpt(path: str) -> Union[Path,None]: # TODO: move to utils.p ckpts = Path(path).rglob("*.ckpt") best_model_ckpt = None lowest_loss = 10_000 - for file in (ckpts): + for file in ckpts: stem = file.stem val_loss = float(stem.split("val_loss=")[-1]) if val_loss < lowest_loss: @@ -71,6 +83,7 @@ def find_best_model_ckpt(path: str) -> Union[Path,None]: # TODO: move to utils.p lowest_loss = val_loss return best_model_ckpt + def _load_pretrained_model(path_to_checkpoints: list[Path]): """Wrapper to load modules from checkpoint.""" # Define paths to pretrained models. @@ -90,6 +103,7 @@ def _load_pretrained_model(path_to_checkpoints: list[Path]): ) return act_net, rt1_net, rxn_net, rt2_net + def func(smiles: str): """ Generates the synthetic tree for the input molecular embedding. @@ -119,62 +133,75 @@ def func(smiles: str): rxn_template=rxn_template, n_bits=nbits, beam_width=3, - max_step=15) + max_step=15, + ) except Exception as e: print(e) action = -1 - if action != 3: # aka tree has not been properly ended + if action != 3: # aka tree has not been properly ended smi = None - similarity = 0 + similarity = 0 tree = None return smi, similarity, tree -if __name__ == '__main__': +if __name__ == "__main__": import argparse + parser = argparse.ArgumentParser() - parser.add_argument("-f", "--featurize", type=str, default='fp', - help="Choose from ['fp', 'gin']") - parser.add_argument("--radius", type=int, default=2, - help="Radius for Morgan Fingerprint") - parser.add_argument("-b", "--nbits", type=int, default=4096, - help="Number of Bits for Morgan Fingerprint") - parser.add_argument("-r", "--rxn_template", type=str, default='hb', - help="Choose from ['hb', 'pis']") - parser.add_argument("--ncpu", type=int, default=1, - help="Number of cpus") - parser.add_argument("-n", "--num", type=int, default=1, - help="Number of molecules to predict.") - parser.add_argument("-d", "--data", type=str, default='test', - help="Choose from ['train', 'valid', 'test', 'chembl'] or provide a file with one SMILES per line.") - parser.add_argument("-o", "--outputembedding", type=str, default='fp_256', - help="Choose from ['fp_4096', 'fp_256', 'gin', 'rdkit2d']") - parser.add_argument("--output-dir", type=str, default=None, - help="Directory to save output.") + parser.add_argument( + "-f", "--featurize", type=str, default="fp", help="Choose from ['fp', 'gin']" + ) + parser.add_argument("--radius", type=int, default=2, help="Radius for Morgan Fingerprint") + parser.add_argument( + "-b", "--nbits", type=int, default=4096, help="Number of Bits for Morgan Fingerprint" + ) + parser.add_argument( + "-r", "--rxn_template", type=str, default="hb", help="Choose from ['hb', 'pis']" + ) + parser.add_argument("--ncpu", type=int, default=1, help="Number of cpus") + parser.add_argument("-n", "--num", type=int, default=1, help="Number of molecules to predict.") + parser.add_argument( + "-d", + "--data", + type=str, + default="test", + help="Choose from ['train', 'valid', 'test', 'chembl'] or provide a file with one SMILES per line.", + ) + parser.add_argument( + "-o", + "--outputembedding", + type=str, + default="fp_256", + help="Choose from ['fp_4096', 'fp_256', 'gin', 'rdkit2d']", + ) + parser.add_argument("--output-dir", type=str, default=None, help="Directory to save output.") args = parser.parse_args() - nbits = args.nbits - out_dim = args.outputembedding.split("_")[-1] # <=> morgan fingerprint with 256 bits + nbits = args.nbits + out_dim = args.outputembedding.split("_")[-1] # <=> morgan fingerprint with 256 bits rxn_template = args.rxn_template building_blocks_id = "enamine_us-2021-smiles" - featurize = args.featurize - radius = args.radius - ncpu = args.ncpu - param_dir = f"{rxn_template}_{featurize}_{radius}_{nbits}_{out_dim}" + featurize = args.featurize + radius = args.radius + ncpu = args.ncpu + param_dir = f"{rxn_template}_{featurize}_{radius}_{nbits}_{out_dim}" # Load data ... # ... query molecules (i.e. molecules to decode) smiles_queries = _fetch_data(args.data) - if args.num > 0: # Select only n queries - smiles_queries = smiles_queries[:args.num] + if args.num > 0: # Select only n queries + smiles_queries = smiles_queries[: args.num] # ... building blocks file = Path(DATA_PREPROCESS_DIR) / f"{rxn_template}-{building_blocks_id}-matched.csv.gz" building_blocks = _fetch_building_blocks(file) - building_blocks_dict = {block: i for i,block in enumerate(building_blocks)} # dict is used as lookup table for 2nd reactant during inference + building_blocks_dict = { + block: i for i, block in enumerate(building_blocks) + } # dict is used as lookup table for 2nd reactant during inference # ... reaction templates file = Path(DATA_PREPROCESS_DIR) / f"reaction-sets_{rxn_template}_{building_blocks_id}.json.gz" @@ -186,23 +213,24 @@ def func(smiles: str): # ... models path = Path(CHECKPOINTS_DIR) / f"{param_dir}" - paths = [find_best_model_ckpt("results/logs/hb_fp_2_4096/" + mdl) for mdl in "act rt1 rxn rt2".split()] + paths = [ + find_best_model_ckpt("results/logs/hb_fp_2_4096/" + model) + for model in "act rt1 rxn rt2".split() + ] act_net, rt1_net, rxn_net, rt2_net = _load_pretrained_model(paths) - # Decode queries, i.e. the target molecules. - print(f'Start to decode {len(smiles_queries)} target molecules.') + print(f"Start to decode {len(smiles_queries)} target molecules.") with mp.Pool(processes=args.ncpu) as pool: results = pool.map(func, smiles_queries) - print('Finished decoding.') - + print("Finished decoding.") # Print some results from the prediction smis_decoded = [r[0] for r in results] similarities = [r[1] for r in results] - trees = [r[2] for r in results] + trees = [r[2] for r in results] - recovery_rate = (np.asfarray(similarities)==1.0).sum()/len(similarities) + recovery_rate = (np.asfarray(similarities) == 1.0).sum() / len(similarities) avg_similarity = np.mean(similarities) print(f"For {args.data}:") print(f" {recovery_rate=}") @@ -210,15 +238,13 @@ def func(smiles: str): # Save to local dir output_dir = DATA_RESULT_DIR if args.output_dir is None else args.output_dir - print('Saving results to {output_dir} ...') - df = pd.DataFrame({'query SMILES' : smiles_queries, - 'decode SMILES': smis_decoded, - 'similarity' : similarities}) - df.to_csv(f'{output_dir}/decode_result_{args.data}.csv.gz', - compression='gzip', - index=False,) + print("Saving results to {output_dir} ...") + df = pd.DataFrame( + {"query SMILES": smiles_queries, "decode SMILES": smis_decoded, "similarity": similarities} + ) + df.to_csv(f"{output_dir}/decode_result_{args.data}.csv.gz", compression="gzip", index=False) synthetic_tree_set = SyntheticTreeSet(sts=trees) - synthetic_tree_set.save(f'{output_dir}/decoded_st_{args.data}.json.gz') + synthetic_tree_set.save(f"{output_dir}/decoded_st_{args.data}.json.gz") - print('Finish!') + print("Finish!") From 40c385696b43aff6f7b768d564b2489091db39c0 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 14 Sep 2022 18:11:03 -0400 Subject: [PATCH 118/302] bugfix (c.f. https://github.com/wenhao-gao/SynNet/issues/15) --- src/syn_net/utils/predict_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/syn_net/utils/predict_utils.py b/src/syn_net/utils/predict_utils.py index f6db3210..67aeae93 100644 --- a/src/syn_net/utils/predict_utils.py +++ b/src/syn_net/utils/predict_utils.py @@ -398,7 +398,7 @@ def synthetic_tree_decoder_rt1( # Select first molecule if act == 0: # Add if mol_recent is not None: - dist, ind = nn_search(z_mol1) + dist, ind = nn_search(z_mol1,_tree=kdtree) mol1 = building_blocks[ind] else: # no recent mol dist, ind = nn_search_rt1( From befb97bd2d36590af666256238ca0fb1616e6f4e Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Thu, 15 Sep 2022 19:10:19 -0400 Subject: [PATCH 119/302] remeoves debug `assert` statements --- src/syn_net/data_generation/syntrees.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/syn_net/data_generation/syntrees.py b/src/syn_net/data_generation/syntrees.py index c3fb5e5f..3e59aaba 100644 --- a/src/syn_net/data_generation/syntrees.py +++ b/src/syn_net/data_generation/syntrees.py @@ -235,10 +235,6 @@ def generate(self, max_depth: int = 15, retries: int = 3): recent_mol = p # Update tree - # assert isinstance(act,(int,np.int64)), type(act) - # assert isinstance(r1,str), type(r1) - # assert isinstance(r2,(str,type(None))), type(r2) - # assert isinstance(p,(str)), type(p) syntree.update(act, rxn_id=idx_rxn, mol1=r1, mol2=r2, mol_product=p) logger.debug(f"SynTree updated.") if action == "end": From 98b0b7c187fb7cbad27aa825d414740a5bedc527 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Thu, 15 Sep 2022 19:13:42 -0400 Subject: [PATCH 120/302] fix: correctly compare two lists of booleans --- src/syn_net/data_generation/syntrees.py | 39 ++++++++++++++++++------- 1 file changed, 29 insertions(+), 10 deletions(-) diff --git a/src/syn_net/data_generation/syntrees.py b/src/syn_net/data_generation/syntrees.py index 3e59aaba..ed2ebda1 100644 --- a/src/syn_net/data_generation/syntrees.py +++ b/src/syn_net/data_generation/syntrees.py @@ -1,11 +1,11 @@ """syntrees """ +import logging from typing import Tuple -from tqdm import tqdm + import numpy as np from rdkit import Chem - -import logging +from tqdm import tqdm logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -22,6 +22,14 @@ def __init__(self, message): # Call the base class constructor with the parameters it needs super().__init__(message) + +class NoBiReactionAvailableError(Exception): + """Reactants do not match any reaction template.""" + + def __init__(self, message): + super().__init__(message) + + class NoReactionPossible(Exception): def __init__(self, message): # Call the base class constructor with the parameters it needs @@ -100,11 +108,11 @@ def _sample_molecule(self) -> str: def _base_case(self) -> str: return self._sample_molecule() - def _find_rxn_candidates(self, smiles: str): + def _find_rxn_candidates(self, smiles: str, raise_exc: bool = True) -> list[bool]: """Find a reaction with `mol` as reactant.""" mol = Chem.MolFromSmiles(smiles) rxn_mask = [rxn.is_reactant(mol) for rxn in self.rxns] - if not any(rxn_mask): + if raise_exc and not any(rxn_mask): # Do not raise exc when checking if two mols can react raise NoReactionAvailableError(f"No reaction available for: {smiles}.") return rxn_mask @@ -177,6 +185,18 @@ def _get_action_mask(self, syntree: SyntheticTree): return np.array((canAdd, canExpand, canMerge, canEnd), dtype=bool) + def _get_rxn_mask(self, reactants: tuple[str, str]) -> list[bool]: + """Get a mask of possible reactions for the two reactants.""" + masks = [self._find_rxn_candidates(r, raise_exc=False) for r in reactants] + # TODO: We do not check if the two reactants are 1st and 2nd reactants in a given reaction. + # It is possible that both are only applicable as 1st reactant, + # and then the reaction is not possible, although the mask returns true. + # Alternative: Run the reaction and check if the product is valid. + mask = [rxn1 and rxn2 for rxn1, rxn2 in zip(*masks)] + if not any(mask): + raise NoBiReactionAvailableError(f"No reaction available for {reactants}.") + return mask + def generate(self, max_depth: int = 15, retries: int = 3): """Generate a syntree by random sampling.""" @@ -216,12 +236,11 @@ def generate(self, max_depth: int = 15, retries: int = 3): elif action == "merge": # merge two subtrees: sample reaction, run it. - r1, r2 = [node.smiles for node in state] + # Identify suitable rxn - # TODO: naive implementation - rxn_mask1 = self._find_rxn_candidates(r1) - rxn_mask2 = self._find_rxn_candidates(r2) - rxn_mask = rxn_mask1 and rxn_mask2 + r1, r2 = syntree.get_state() + rxn_mask = self._get_rxn_mask(tuple((r1, r2))) + # Sample reaction rxn, idx_rxn = self._sample_rxn(mask=rxn_mask) # Run reaction p = rxn.run_reaction((r1, r2)) From 8e0074601045cdb096478839aa57e0ef65b873f2 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Thu, 15 Sep 2022 19:16:34 -0400 Subject: [PATCH 121/302] only allow merging if theres a common match in the reaction masks --- src/syn_net/data_generation/syntrees.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/syn_net/data_generation/syntrees.py b/src/syn_net/data_generation/syntrees.py index ed2ebda1..05cb6bde 100644 --- a/src/syn_net/data_generation/syntrees.py +++ b/src/syn_net/data_generation/syntrees.py @@ -179,7 +179,7 @@ def _get_action_mask(self, syntree: SyntheticTree): canEnd = True elif nTrees == 2: canExpand = True - canMerge = True # TODO: only if rxn is possible + canMerge = any(self._get_action_mask(state)) else: raise ValueError From e497621b008a17b5036832d61079c2ca5b65357b Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Thu, 15 Sep 2022 19:17:08 -0400 Subject: [PATCH 122/302] formatting + error messages --- src/syn_net/data_generation/syntrees.py | 51 ++++++++++++++++--------- 1 file changed, 32 insertions(+), 19 deletions(-) diff --git a/src/syn_net/data_generation/syntrees.py b/src/syn_net/data_generation/syntrees.py index 05cb6bde..4f663e9a 100644 --- a/src/syn_net/data_generation/syntrees.py +++ b/src/syn_net/data_generation/syntrees.py @@ -12,14 +12,18 @@ from syn_net.utils.data_utils import Reaction, SyntheticTree + class NoReactantAvailableError(Exception): + """No second reactant available for the bimolecular reaction.""" + def __init__(self, message): - # Call the base class constructor with the parameters it needs super().__init__(message) + class NoReactionAvailableError(Exception): + """Reactant does not match any reaction template.""" + def __init__(self, message): - # Call the base class constructor with the parameters it needs super().__init__(message) @@ -31,8 +35,9 @@ def __init__(self, message): class NoReactionPossible(Exception): + """`rdkit` can not yield a valid reaction product.""" + def __init__(self, message): - # Call the base class constructor with the parameters it needs super().__init__(message) @@ -40,7 +45,7 @@ class SynTreeGenerator: building_blocks: list[str] rxn_templates: list[Reaction] - rxns: dict[int, Reaction] + rxns: list[Reaction] IDX_RXNS: list ACTIONS: dict[int, str] = {i: action for i, action in enumerate("add expand merge end".split())} verbose: bool @@ -51,7 +56,7 @@ def __init__( building_blocks: list[str], rxn_templates: list[str], rng=np.random.default_rng(seed=42), - verbose:bool = False, + verbose: bool = False, ) -> None: self.building_blocks = building_blocks self.rxn_templates = rxn_templates @@ -109,7 +114,7 @@ def _base_case(self) -> str: return self._sample_molecule() def _find_rxn_candidates(self, smiles: str, raise_exc: bool = True) -> list[bool]: - """Find a reaction with `mol` as reactant.""" + """Tests which reactions have `mol` as reactant.""" mol = Chem.MolFromSmiles(smiles) rxn_mask = [rxn.is_reactant(mol) for rxn in self.rxns] if raise_exc and not any(rxn_mask): # Do not raise exc when checking if two mols can react @@ -119,12 +124,14 @@ def _find_rxn_candidates(self, smiles: str, raise_exc: bool = True) -> list[bool def _sample_rxn(self, mask: np.ndarray = None) -> Tuple[Reaction, int]: """Sample a reaction by index.""" if mask is None: - irxn_mask = self.IDX_RXNS # + irxn_mask = self.IDX_RXNS # All reactions are possible else: - mask = np.asarray(mask) + mask = np.asarray(mask) irxn_mask = self.IDX_RXNS[mask] idx = self.rng.choice(irxn_mask) - logger.debug(f" Sampled reaction with index: {idx} (nreactants: {self.rxns[idx].num_reactant})") + logger.debug( + f"Sampled reaction with index: {idx} (nreactants: {self.rxns[idx].num_reactant})" + ) return self.rxns[idx], idx def _expand(self, reactant_1: str) -> Tuple[str, str, str, np.int64]: @@ -147,15 +154,16 @@ def _expand(self, reactant_1: str) -> Tuple[str, str, str, np.int64]: # - then sample "B" (or "A") idx = 1 if rxn.is_reactant_first(reactant_1) else 0 available_reactants = rxn.available_reactants[idx] - nPossible = len(available_reactants) - if nPossible==0: - raise NoReactantAvailableError("Unable to find two reactants for this bimolecular reaction.") - # TODO: 2 bi-molecular rxn templates have no matching bblock + nPossibleReactants = len(available_reactants) + if nPossibleReactants == 0: + raise NoReactantAvailableError( + f"Unable to find reactant {idx+1} for bimolecular reaction (ID: {idx_rxn}) and reactant {reactant_1}." + ) + # TODO: 2 bi-molecular rxn templates have no matching bblock # TODO: use numpy array to avoid type conversion or stick to sampling idx? - idx = self.rng.choice(nPossible) + idx = self.rng.choice(nPossibleReactants) reactant_2 = available_reactants[idx] - # Run reaction reactants = (reactant_1, reactant_2) product = rxn.run_reaction(reactants) @@ -223,11 +231,14 @@ def generate(self, max_depth: int = 15, retries: int = 3): elif action == "expand": for j in range(retries): logger.debug(f" Try {j}") - r1, r2, p, idx_rxn= self._expand(recent_mol) - if p is not None: break + r1, r2, p, idx_rxn = self._expand(recent_mol) + if p is not None: + break if p is None: # TODO: move to rxn.run_reaction? - raise NoReactionPossible(f"Reaction (ID: {idx_rxn}) not possible with: {r1} + {r2}.") + raise NoReactionPossible( + f"Reaction (ID: {idx_rxn}) not possible with: {r1} + {r2}." + ) elif action == "add": mol = self._sample_molecule() @@ -246,7 +257,9 @@ def generate(self, max_depth: int = 15, retries: int = 3): p = rxn.run_reaction((r1, r2)) if p is None: # TODO: move to rxn.run_reaction? - raise NoReactionPossible(f"Reaction (ID: {idx_rxn}) not possible with: {r1} + {r2}.") + raise NoReactionPossible( + f"Reaction (ID: {idx_rxn}) not possible with: {r1} + {r2}." + ) # Prepare next iteration logger.debug(f" Ran reaction {r1} + {r2} -> {p}") From be8ee92cfcd3550aae1cfbe8b06d23f4665c0772 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Fri, 16 Sep 2022 11:20:52 -0400 Subject: [PATCH 123/302] move number of processes for mp into `config.py` --- scripts/01-filter-building-blocks.py | 4 ++-- scripts/02-compute-embeddings.py | 4 ++-- src/syn_net/MolEmbedder.py | 4 ++-- src/syn_net/config.py | 12 ++++++++---- src/syn_net/data_generation/make_dataset_mp.py | 4 ++-- src/syn_net/data_generation/preprocessing.py | 3 ++- src/syn_net/data_generation/syntrees.py | 6 ++++-- 7 files changed, 22 insertions(+), 15 deletions(-) diff --git a/scripts/01-filter-building-blocks.py b/scripts/01-filter-building-blocks.py index ef32f0c9..b85ee1e6 100644 --- a/scripts/01-filter-building-blocks.py +++ b/scripts/01-filter-building-blocks.py @@ -4,7 +4,7 @@ from rdkit import RDLogger from syn_net.data_generation.preprocessing import BuildingBlockFileHandler, BuildingBlockFilter - +from syn_net.config import MAX_PROCESSES RDLogger.DisableLog("rdApp.*") logger = logging.getLogger(__name__) @@ -30,7 +30,7 @@ def get_args(): help="Output file for the filtered building-blocks file.", ) # Processing - parser.add_argument("--ncpu", type=int, default=32, help="Number of cpus") + parser.add_argument("--ncpu", type=int, default=MAX_PROCESSES, help="Number of cpus") parser.add_argument("--verbose", default=False, action="store_true") return parser.parse_args() diff --git a/scripts/02-compute-embeddings.py b/scripts/02-compute-embeddings.py index 518675f2..53d8c796 100644 --- a/scripts/02-compute-embeddings.py +++ b/scripts/02-compute-embeddings.py @@ -10,7 +10,7 @@ from syn_net.data_generation.preprocessing import BuildingBlockFileHandler from syn_net.encoding.fingerprints import fp_256, fp_512, fp_1024, fp_2048, fp_4096 from syn_net.MolEmbedder import MolEmbedder - +from syn_net.config import MAX_PROCESSES # from syn_net.encoding.gins import mol_embedding # from syn_net.utils.prep_utils import rdkit2d_embedding @@ -57,7 +57,7 @@ def get_args(): help="Objective function to optimize", ) # Processing - parser.add_argument("--ncpu", type=int, default=32, help="Number of cpus") + parser.add_argument("--ncpu", type=int, default=MAX_PROCESSES, help="Number of cpus") parser.add_argument("--verbose", default=False, action="store_true") return parser.parse_args() diff --git a/src/syn_net/MolEmbedder.py b/src/syn_net/MolEmbedder.py index 3e4264cc..cd8159cc 100644 --- a/src/syn_net/MolEmbedder.py +++ b/src/syn_net/MolEmbedder.py @@ -4,12 +4,12 @@ import numpy as np from sklearn.neighbors import BallTree - +from syn_net.config import MAX_PROCESSES logger = logging.getLogger(__name__) class MolEmbedder: - def __init__(self, processes: int = 1) -> None: + def __init__(self, processes: int = MAX_PROCESSES) -> None: self.processes = processes self.func: Callable self.building_blocks: Union[list[str], np.ndarray] diff --git a/src/syn_net/config.py b/src/syn_net/config.py index 7f563286..424af0bf 100644 --- a/src/syn_net/config.py +++ b/src/syn_net/config.py @@ -1,11 +1,15 @@ """Central place for all configuration, paths, and parameter.""" +import multiprocessing +# Multiprocessing +MAX_PROCESSES = min(32,multiprocessing.cpu_count()-1) +# Paths DATA_DIR = "data" ASSETS_DIR = "data/assets" -# -BUILDING_BLOCKS_RAW_DIR = f"{ASSETS_DIR}/building-blocks" -REACTION_TEMPLATE_DIR = f"{ASSETS_DIR}/reaction-templates" +# +BUILDING_BLOCKS_RAW_DIR = f"{ASSETS_DIR}/building-blocks" +REACTION_TEMPLATE_DIR = f"{ASSETS_DIR}/reaction-templates" # Pre-processed data DATA_PREPROCESS_DIR = "data/pre-process" @@ -21,4 +25,4 @@ DATA_RESULT_DIR = "results" # Checkpoints (& pre-trained weights) -CHECKPOINTS_DIR = "checkpoints" # \ No newline at end of file +CHECKPOINTS_DIR = "checkpoints" # \ No newline at end of file diff --git a/src/syn_net/data_generation/make_dataset_mp.py b/src/syn_net/data_generation/make_dataset_mp.py index 03f354ac..3185ab75 100644 --- a/src/syn_net/data_generation/make_dataset_mp.py +++ b/src/syn_net/data_generation/make_dataset_mp.py @@ -10,7 +10,7 @@ from pathlib import Path from syn_net.data_generation.make_dataset import synthetic_tree_generator from syn_net.utils.data_utils import ReactionSet, SyntheticTreeSet -from syn_net.config import BUILDING_BLOCKS_RAW_DIR, DATA_PREPROCESS_DIR +from syn_net.config import BUILDING_BLOCKS_RAW_DIR, DATA_PREPROCESS_DIR, MAX_PROCESSES from syn_net.data_generation.preprocessing import BuildingBlockFileHandler import logging @@ -40,7 +40,7 @@ def func(_x): rxns = r_set.rxns # Generate synthetic trees - with mp.Pool(processes=64) as pool: + with mp.Pool(processes=MAX_PROCESSES) as pool: results = pool.map(func, np.arange(NUM_TREES).tolist()) # Filter out trees that were completed with action="end" diff --git a/src/syn_net/data_generation/preprocessing.py b/src/syn_net/data_generation/preprocessing.py index d465e873..0ac55af4 100644 --- a/src/syn_net/data_generation/preprocessing.py +++ b/src/syn_net/data_generation/preprocessing.py @@ -1,4 +1,5 @@ from tqdm import tqdm +from syn_net.config import MAX_PROCESSES from syn_net.utils.data_utils import Reaction @@ -17,7 +18,7 @@ def __init__( *, building_blocks: list[str], rxn_templates: list[str], - processes: int = 1, + processes: int = MAX_PROCESSES, verbose: bool = False ) -> None: self.building_blocks = building_blocks diff --git a/src/syn_net/data_generation/syntrees.py b/src/syn_net/data_generation/syntrees.py index 4f663e9a..2cb3d7c5 100644 --- a/src/syn_net/data_generation/syntrees.py +++ b/src/syn_net/data_generation/syntrees.py @@ -6,6 +6,7 @@ import numpy as np from rdkit import Chem from tqdm import tqdm +from syn_net.config import MAX_PROCESSES logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -46,7 +47,7 @@ class SynTreeGenerator: building_blocks: list[str] rxn_templates: list[Reaction] rxns: list[Reaction] - IDX_RXNS: list + IDX_RXNS: np.ndarray # (nReactions,) ACTIONS: dict[int, str] = {i: action for i, action in enumerate("add expand merge end".split())} verbose: bool @@ -56,6 +57,7 @@ def __init__( building_blocks: list[str], rxn_templates: list[str], rng=np.random.default_rng(seed=42), + processes: int = MAX_PROCESSES, verbose: bool = False, ) -> None: self.building_blocks = building_blocks @@ -63,7 +65,7 @@ def __init__( self.rxns = [Reaction(template=tmplt) for tmplt in rxn_templates] self.rng = rng self.IDX_RXNS = np.arange(len(self.rxns)) - self.processes = 32 + self.processes = processes self.verbose = verbose if verbose: logger.setLevel(logging.DEBUG) From 2f6b12f8319b42e1732c08a987927cb92f194e28 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Fri, 16 Sep 2022 11:21:45 -0400 Subject: [PATCH 124/302] fix type error --- src/syn_net/utils/predict_utils.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/syn_net/utils/predict_utils.py b/src/syn_net/utils/predict_utils.py index 67aeae93..1054e1ec 100644 --- a/src/syn_net/utils/predict_utils.py +++ b/src/syn_net/utils/predict_utils.py @@ -35,7 +35,7 @@ def can_react(state, rxns: list[Reaction]) -> Tuple[int, list[bool]]: """ mol1 = state.pop() mol2 = state.pop() - reaction_mask = [int(rxn.run_reaction([mol1, mol2]) is not None) for rxn in rxns] + reaction_mask = [int(rxn.run_reaction((mol1, mol2)) is not None) for rxn in rxns] return sum(reaction_mask), reaction_mask @@ -423,9 +423,7 @@ def synthetic_tree_decoder_rt1( reaction_mask, available_list = get_reaction_mask(mol1, reaction_templates) else: # merge _, reaction_mask = can_react(tree.get_state(), reaction_templates) - available_list = [ - [] for rxn in reaction_templates - ] # TODO: if act=merge, this is not used at all + available_list = [[] for rxn in reaction_templates] # TODO: if act=merge, this is not used at all # If we ended up in a state where no reaction is possible, end this iteration. if reaction_mask is None: From bae6b5ad77b5081aeba4530aac4d9341c67ec5ef Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Fri, 16 Sep 2022 11:22:11 -0400 Subject: [PATCH 125/302] adds `CSVLogger` --- src/syn_net/models/act.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/syn_net/models/act.py b/src/syn_net/models/act.py index 6236f287..c43afc89 100644 --- a/src/syn_net/models/act.py +++ b/src/syn_net/models/act.py @@ -83,6 +83,8 @@ save_dir.mkdir(exist_ok=True, parents=True) tb_logger = pl_loggers.TensorBoardLogger(save_dir, name="") + csv_logger = pl_loggers.CSVLogger(save_dir,name="") + logger.info(f"Log dir set to: {tb_logger.log_dir}") checkpoint_callback = ModelCheckpoint( monitor="val_loss", @@ -99,7 +101,8 @@ max_epochs=max_epochs, progress_bar_refresh_rate=int(len(train_dataloader) * 0.05), callbacks=[checkpoint_callback], - logger=[tb_logger], + logger=[tb_logger,csv_logger], + fast_dev_run=args.fast_dev_run, ) logger.info(f"Start training") From 77a72a15947e5a36e5c12daf9bb80e575daf8e39 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Fri, 16 Sep 2022 11:22:24 -0400 Subject: [PATCH 126/302] fix type hint --- src/syn_net/utils/data_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/syn_net/utils/data_utils.py b/src/syn_net/utils/data_utils.py index 9ec7b0b7..247f8979 100644 --- a/src/syn_net/utils/data_utils.py +++ b/src/syn_net/utils/data_utils.py @@ -456,7 +456,7 @@ def get_node_index(self, smi): return node.index return None - def get_state(self) -> list[NodeChemical]: + def get_state(self) -> list[str]: """Get the state of this synthetic tree. The most recent root node has 0 as its index. From 8a88e7bfe963da768413c5d09cb4fc6ea66f28e6 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Fri, 16 Sep 2022 11:24:43 -0400 Subject: [PATCH 127/302] use logging instead of print --- scripts/predict_multireactant_mp.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/scripts/predict_multireactant_mp.py b/scripts/predict_multireactant_mp.py index d26da233..6ef86330 100644 --- a/scripts/predict_multireactant_mp.py +++ b/scripts/predict_multireactant_mp.py @@ -4,7 +4,9 @@ import multiprocessing as mp from pathlib import Path from typing import Union +import logging +logger = logging.getLogger(__name__) import numpy as np import pandas as pd @@ -38,7 +40,7 @@ def _fetch_data_from_file(name: str) -> list[str]: def _fetch_data(name: str) -> list[str]: if args.data in ["train", "valid", "test"]: file = Path(DATA_PREPARED_DIR) / f"synthetic-trees-{args.data}.json.gz" - print(f"Reading data from {file}") + logger.info(f"Reading data from {file}") sts = SyntheticTreeSet() sts.load(file) smis_query = [st.root.smiles for st in sts.sts] @@ -136,7 +138,7 @@ def func(smiles: str): max_step=15, ) except Exception as e: - print(e) + logger.error(e,exc_info=e) action = -1 if action != 3: # aka tree has not been properly ended @@ -191,10 +193,11 @@ def func(smiles: str): param_dir = f"{rxn_template}_{featurize}_{radius}_{nbits}_{out_dim}" # Load data ... + logger.info("Stat loading data...") # ... query molecules (i.e. molecules to decode) smiles_queries = _fetch_data(args.data) if args.num > 0: # Select only n queries - smiles_queries = smiles_queries[: args.num] + smiles_queries = smiles_queries[:args.num] # ... building blocks file = Path(DATA_PREPROCESS_DIR) / f"{rxn_template}-{building_blocks_id}-matched.csv.gz" @@ -210,20 +213,23 @@ def func(smiles: str): # ... building block embedding file = Path(DATA_EMBEDDINGS_DIR) / f"{rxn_template}-{building_blocks_id}-embeddings.npy" bb_emb = _fetch_building_blocks_embeddings(file) + logger.info("...loading data completed.") # ... models + logger.info("Start loading models from checkpoints...") path = Path(CHECKPOINTS_DIR) / f"{param_dir}" paths = [ find_best_model_ckpt("results/logs/hb_fp_2_4096/" + model) for model in "act rt1 rxn rt2".split() ] act_net, rt1_net, rxn_net, rt2_net = _load_pretrained_model(paths) + logger.info("...loading models completed.") # Decode queries, i.e. the target molecules. - print(f"Start to decode {len(smiles_queries)} target molecules.") + logger.info(f"Start to decode {len(smiles_queries)} target molecules.") with mp.Pool(processes=args.ncpu) as pool: results = pool.map(func, smiles_queries) - print("Finished decoding.") + logger.info("Finished decoding.") # Print some results from the prediction smis_decoded = [r[0] for r in results] @@ -232,13 +238,14 @@ def func(smiles: str): recovery_rate = (np.asfarray(similarities) == 1.0).sum() / len(similarities) avg_similarity = np.mean(similarities) - print(f"For {args.data}:") - print(f" {recovery_rate=}") - print(f" {avg_similarity=}") + logger.info(f"For {args.data}:") + logger.info(f" {len(smiles_queries)=}") + logger.info(f" {recovery_rate=}") + logger.info(f" {avg_similarity=}") # Save to local dir output_dir = DATA_RESULT_DIR if args.output_dir is None else args.output_dir - print("Saving results to {output_dir} ...") + logger.info("Saving results to {output_dir} ...") df = pd.DataFrame( {"query SMILES": smiles_queries, "decode SMILES": smis_decoded, "similarity": similarities} ) @@ -247,4 +254,4 @@ def func(smiles: str): synthetic_tree_set = SyntheticTreeSet(sts=trees) synthetic_tree_set.save(f"{output_dir}/decoded_st_{args.data}.json.gz") - print("Finish!") + logger.info("Finish!") From 36f5041dfbf6b560ebd650981d6478266d77583d Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Fri, 16 Sep 2022 11:25:15 -0400 Subject: [PATCH 128/302] refactor argparse to `get_args()` --- scripts/predict_multireactant_mp.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/scripts/predict_multireactant_mp.py b/scripts/predict_multireactant_mp.py index 6ef86330..0bacfc7a 100644 --- a/scripts/predict_multireactant_mp.py +++ b/scripts/predict_multireactant_mp.py @@ -148,9 +148,7 @@ def func(smiles: str): return smi, similarity, tree - -if __name__ == "__main__": - +def get_args(): import argparse parser = argparse.ArgumentParser() @@ -164,7 +162,7 @@ def func(smiles: str): parser.add_argument( "-r", "--rxn_template", type=str, default="hb", help="Choose from ['hb', 'pis']" ) - parser.add_argument("--ncpu", type=int, default=1, help="Number of cpus") + parser.add_argument("--ncpu", type=int, default=32, help="Number of cpus") parser.add_argument("-n", "--num", type=int, default=1, help="Number of molecules to predict.") parser.add_argument( "-d", @@ -181,7 +179,12 @@ def func(smiles: str): help="Choose from ['fp_4096', 'fp_256', 'gin', 'rdkit2d']", ) parser.add_argument("--output-dir", type=str, default=None, help="Directory to save output.") - args = parser.parse_args() + return parser.parse_args() + + +if __name__ == "__main__": + args = get_args() + logger.info(f"Args: {vars(args)}") nbits = args.nbits out_dim = args.outputembedding.split("_")[-1] # <=> morgan fingerprint with 256 bits From e5ef9984f4b4e17fe1935223b2774462884f1eb0 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Fri, 16 Sep 2022 11:26:40 -0400 Subject: [PATCH 129/302] delete unused script `make_dataset.py` --- src/syn_net/data_generation/make_dataset.py | 50 --------------------- 1 file changed, 50 deletions(-) delete mode 100644 src/syn_net/data_generation/make_dataset.py diff --git a/src/syn_net/data_generation/make_dataset.py b/src/syn_net/data_generation/make_dataset.py deleted file mode 100644 index 05dc3a74..00000000 --- a/src/syn_net/data_generation/make_dataset.py +++ /dev/null @@ -1,50 +0,0 @@ -""" -This file generates synthetic tree data in a sequential fashion. -""" -import dill as pickle -import gzip -import pandas as pd -import numpy as np -from tqdm import tqdm -from syn_net.utils.data_utils import SyntheticTreeSet -from syn_net.utils.prep_utils import synthetic_tree_generator - - - -if __name__ == '__main__': - path_reaction_file = '/home/whgao/shared/Data/scGen/reactions_pis.pickle.gz' - path_to_building_blocks = '/home/whgao/shared/Data/scGen/enamine_building_blocks_nochiral_matched.csv.gz' - - np.random.seed(6) - - building_blocks = pd.read_csv(path_to_building_blocks, compression='gzip')['SMILES'].tolist() - with gzip.open(path_reaction_file, 'rb') as f: - rxns = pickle.load(f) - - Trial = 5 - num_finish = 0 - num_error = 0 - num_unfinish = 0 - - trees = [] - for _ in tqdm(range(Trial)): - tree, action = synthetic_tree_generator(building_blocks, rxns, max_step=15) - if action == 3: - trees.append(tree) - num_finish += 1 - elif action == -1: - num_error += 1 - else: - num_unfinish += 1 - - print('Total trial: ', Trial) - print('num of finished trees: ', num_finish) - print('num of unfinished tree: ', num_unfinish) - print('num of error processes: ', num_error) - - synthetic_tree_set = SyntheticTreeSet(sts=trees) - synthetic_tree_set.save('st_data.json.gz') - - # data_file = gzip.open('st_data.pickle.gz', 'wb') - # pickle.dump(trees, data_file) - # data_file.close() From db8d08f93687b91b50a7ca06367bdbc1f4b93b54 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Fri, 16 Sep 2022 12:23:04 -0400 Subject: [PATCH 130/302] fix: call correct fct --- scripts/03-generate-syntrees.py | 9 ++++++--- src/syn_net/data_generation/syntrees.py | 2 +- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/scripts/03-generate-syntrees.py b/scripts/03-generate-syntrees.py index 7a667432..6707ec30 100644 --- a/scripts/03-generate-syntrees.py +++ b/scripts/03-generate-syntrees.py @@ -7,8 +7,8 @@ from syn_net.data_generation.preprocessing import (BuildingBlockFileHandler, ReactionTemplateFileHandler) -from syn_net.data_generation.syntrees import (NoReactantAvailableError, NoReactionAvailableError, - NoReactionPossible, SynTreeGenerator) +from syn_net.data_generation.syntrees import (NoReactantAvailableError, NoReactionAvailableError, NoBiReactionAvailableError, + NoReactionPossibleError, SynTreeGenerator) from syn_net.utils.data_utils import Reaction, SyntheticTree, SyntheticTreeSet logger = logging.getLogger(__name__) @@ -50,7 +50,10 @@ def wraps_syntreegenerator_generate() -> Tuple[Union[SyntheticTree, None], Union except NoReactionAvailableError as e: logger.error(e) return None, e - except NoReactionPossible as e: + except NoBiReactionAvailableError as e: + logger.error(e) + return None, e + except NoReactionPossibleError as e: logger.error(e) return None, e except TypeError as e: diff --git a/src/syn_net/data_generation/syntrees.py b/src/syn_net/data_generation/syntrees.py index 2cb3d7c5..e0983544 100644 --- a/src/syn_net/data_generation/syntrees.py +++ b/src/syn_net/data_generation/syntrees.py @@ -189,7 +189,7 @@ def _get_action_mask(self, syntree: SyntheticTree): canEnd = True elif nTrees == 2: canExpand = True - canMerge = any(self._get_action_mask(state)) + canMerge = any(self._get_rxn_mask(tuple(state))) else: raise ValueError From 239ce793051864d1fa89afb060e5b4021380eaaf Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Fri, 16 Sep 2022 12:23:37 -0400 Subject: [PATCH 131/302] wip: I/O for pickled `SynTreeGenerator` --- src/syn_net/data_generation/syntrees.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/syn_net/data_generation/syntrees.py b/src/syn_net/data_generation/syntrees.py index e0983544..62ca758d 100644 --- a/src/syn_net/data_generation/syntrees.py +++ b/src/syn_net/data_generation/syntrees.py @@ -276,3 +276,15 @@ def generate(self, max_depth: int = 15, retries: int = 3): logger.debug(f"🙌 SynTree completed.") return syntree + + +def load_syntreegenerator(file: str) -> SynTreeGenerator: + import pickle + with open(file,"rb") as f: + syntreegenerator = pickle.load(f) + return syntreegenerator + +def save_syntreegenerator(syntreegenerator: SynTreeGenerator,file: str) -> None: + import pickle + with open(file,"wb") as f: + pickle.dump(syntreegenerator,f) From 6c1faa0e55a77a9ec92ec94e6f7bff2641081c43 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Fri, 16 Sep 2022 12:25:24 -0400 Subject: [PATCH 132/302] fix: rename exception cls --- src/syn_net/data_generation/syntrees.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/syn_net/data_generation/syntrees.py b/src/syn_net/data_generation/syntrees.py index 62ca758d..d94b2f95 100644 --- a/src/syn_net/data_generation/syntrees.py +++ b/src/syn_net/data_generation/syntrees.py @@ -35,7 +35,7 @@ def __init__(self, message): super().__init__(message) -class NoReactionPossible(Exception): +class NoReactionPossibleError(Exception): """`rdkit` can not yield a valid reaction product.""" def __init__(self, message): @@ -238,7 +238,7 @@ def generate(self, max_depth: int = 15, retries: int = 3): break if p is None: # TODO: move to rxn.run_reaction? - raise NoReactionPossible( + raise NoReactionPossibleError( f"Reaction (ID: {idx_rxn}) not possible with: {r1} + {r2}." ) @@ -259,7 +259,7 @@ def generate(self, max_depth: int = 15, retries: int = 3): p = rxn.run_reaction((r1, r2)) if p is None: # TODO: move to rxn.run_reaction? - raise NoReactionPossible( + raise NoReactionPossibleError( f"Reaction (ID: {idx_rxn}) not possible with: {r1} + {r2}." ) From c9c6656d4d84ac60d63dab02a1220af1ef25b731 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Fri, 16 Sep 2022 15:15:18 -0400 Subject: [PATCH 133/302] fix: serialisation by using ints, not np.ndarray --- src/syn_net/data_generation/syntrees.py | 4 ++-- src/syn_net/utils/data_utils.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/syn_net/data_generation/syntrees.py b/src/syn_net/data_generation/syntrees.py index d94b2f95..d89e4d15 100644 --- a/src/syn_net/data_generation/syntrees.py +++ b/src/syn_net/data_generation/syntrees.py @@ -229,7 +229,7 @@ def generate(self, max_depth: int = 15, retries: int = 3): logger.debug(f" Sampled action: {action}") if action == "end": - r1, r2, p, idx_rxn = None, None, None, None + r1, r2, p, idx_rxn = None, None, None, -1 elif action == "expand": for j in range(retries): logger.debug(f" Try {j}") @@ -269,7 +269,7 @@ def generate(self, max_depth: int = 15, retries: int = 3): recent_mol = p # Update tree - syntree.update(act, rxn_id=idx_rxn, mol1=r1, mol2=r2, mol_product=p) + syntree.update(act, rxn_id=int(idx_rxn), mol1=r1, mol2=r2, mol_product=p) logger.debug(f"SynTree updated.") if action == "end": break diff --git a/src/syn_net/utils/data_utils.py b/src/syn_net/utils/data_utils.py index 247f8979..f609f9a9 100644 --- a/src/syn_net/utils/data_utils.py +++ b/src/syn_net/utils/data_utils.py @@ -668,8 +668,8 @@ def load(self, json_file): Args: json_file (str): The path to the stored synthetic tree file. """ - with gzip.open(json_file, 'r') as f: - data = json.loads(f.read().decode('utf-8')) + with gzip.open(json_file, 'rt') as f: + data = json.loads(f.read()) for st_dict in data['trees']: if st_dict is None: @@ -689,8 +689,8 @@ def save(self, json_file): st_list = { 'trees': [st.output_dict() if st is not None else None for st in self.sts] } - with gzip.open(json_file, 'w') as f: - f.write(json.dumps(st_list).encode('utf-8')) + with gzip.open(json_file, 'wt') as f: + f.write(json.dumps(st_list)) def _print(self, x=3): # For debugging From a78370a92b313a6ee70afb022ee92e7edb996f84 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Fri, 16 Sep 2022 15:31:33 -0400 Subject: [PATCH 134/302] fix import (see e5ef99) --- src/syn_net/data_generation/make_dataset_mp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/syn_net/data_generation/make_dataset_mp.py b/src/syn_net/data_generation/make_dataset_mp.py index 3185ab75..2cd6714a 100644 --- a/src/syn_net/data_generation/make_dataset_mp.py +++ b/src/syn_net/data_generation/make_dataset_mp.py @@ -8,7 +8,7 @@ import numpy as np from pathlib import Path -from syn_net.data_generation.make_dataset import synthetic_tree_generator +from syn_net.utils.prep_utils import synthetic_tree_generator from syn_net.utils.data_utils import ReactionSet, SyntheticTreeSet from syn_net.config import BUILDING_BLOCKS_RAW_DIR, DATA_PREPROCESS_DIR, MAX_PROCESSES from syn_net.data_generation.preprocessing import BuildingBlockFileHandler From c294c0d72020b57e2fc0bfa5be7181761f37f6be Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 19 Sep 2022 09:50:35 -0400 Subject: [PATCH 135/302] refactor + format --- scripts/03-generate-syntrees.py | 139 +++++++----------- .../data_generation/make_dataset_mp.py | 15 +- src/syn_net/data_generation/syntrees.py | 44 +++++- 3 files changed, 103 insertions(+), 95 deletions(-) diff --git a/scripts/03-generate-syntrees.py b/scripts/03-generate-syntrees.py index 6707ec30..a06c4a2d 100644 --- a/scripts/03-generate-syntrees.py +++ b/scripts/03-generate-syntrees.py @@ -1,109 +1,84 @@ import logging from collections import Counter from pathlib import Path -from pathos import multiprocessing as mp -import numpy as np -from rdkit import Chem, RDLogger -from syn_net.data_generation.preprocessing import (BuildingBlockFileHandler, - ReactionTemplateFileHandler) -from syn_net.data_generation.syntrees import (NoReactantAvailableError, NoReactionAvailableError, NoBiReactionAvailableError, - NoReactionPossibleError, SynTreeGenerator) -from syn_net.utils.data_utils import Reaction, SyntheticTree, SyntheticTreeSet +from rdkit import RDLogger + +from syn_net.config import DATA_PREPROCESS_DIR, MAX_PROCESSES +from syn_net.data_generation.preprocessing import ( + BuildingBlockFileHandler, + ReactionTemplateFileHandler, +) +from syn_net.data_generation.syntrees import SynTreeGenerator, wraps_syntreegenerator_generate +from syn_net.utils.data_utils import SyntheticTree, SyntheticTreeSet logger = logging.getLogger(__name__) -from typing import Tuple, Union +from typing import Union RDLogger.DisableLog("rdApp.*") +building_blocks_file = "data/pre-process/building-blocks/enamine-us-smiles.csv.gz" +rxn_templates_file = "data/assets/reaction-templates/hb.txt" +output_file = Path(DATA_PREPROCESS_DIR) / f"synthetic-trees.json.gz" -def __sanity_checks(): - """Sanity check some methods. Poor mans tests""" - out = stgen._sample_molecule() - assert isinstance(out, str) - assert Chem.MolFromSmiles(out) - - rxn_mask = stgen._find_rxn_candidates(out) - assert isinstance(rxn_mask, list) - assert isinstance(rxn_mask[0], bool) - - rxn, rxn_idx = stgen._sample_rxn() - assert isinstance(rxn, Reaction) - assert isinstance(rxn_idx, np.int64), print(f"{type(rxn_idx)=}") - - out = stgen._base_case() - assert isinstance(out, str) - assert Chem.MolFromSmiles(out) - - st = SyntheticTree() - mask = stgen._get_action_mask(st) - assert isinstance(mask, np.ndarray) - np.testing.assert_array_equal(mask, np.array([True, False, False, False])) - - -def wraps_syntreegenerator_generate() -> Tuple[Union[SyntheticTree, None], Union[Exception, None]]: - try: - st = stgen.generate() - except NoReactantAvailableError as e: - logger.error(e) - return None, e - except NoReactionAvailableError as e: - logger.error(e) - return None, e - except NoBiReactionAvailableError as e: - logger.error(e) - return None, e - except NoReactionPossibleError as e: - logger.error(e) - return None, e - except TypeError as e: - logger.error(e) - return None, e - except Exception as e: - logger.error(e, exc_info=e, stack_info=False) - return None, e - else: - return st, None +def get_args(): + import argparse + parser = argparse.ArgumentParser() + # File I/O + parser.add_argument( + "--building-blocks-file", + type=str, + help="Input file with SMILES strings (First row `SMILES`, then one per line).", + ) + parser.add_argument( + "--rxn-templates-file", + type=str, + help="Input file with reaction templates as SMARTS(No header, one per line).", + ) + parser.add_argument( + "--output-file", + type=str, + help="Output file for the generated synthetic trees (*.json.gz)", + ) + # Parameters + parser.add_argument("--number-syntrees", type=int, help="Number of SynTrees to generate.") + # Processing + parser.add_argument("--ncpu", type=int, default=MAX_PROCESSES, help="Number of cpus") + parser.add_argument("--verbose", default=False, action="store_true") + return parser.parse_args() if __name__ == "__main__": logger.info("Start.") + + # Parse input args + args = get_args() + logger.info(f"Arguments: {vars(args)}") + # Load assets - bblocks = BuildingBlockFileHandler().load( - "data/pre-process/building-blocks/enamine-us-smiles.csv.gz" - ) - rxn_templates = ReactionTemplateFileHandler().load("data/assets/reaction-templates/hb.txt") + bblocks = BuildingBlockFileHandler().load(args.building_blocks_file) + rxn_templates = ReactionTemplateFileHandler().load(args.rxn_templates_file) # Init SynTree Generator - import pickle - file = "stgen.pickle" - with open(file,"rb") as f: - stgen = pickle.load(f) - # stgen = SynTreeGenerator(building_blocks=bblocks, rxn_templates=rxn_templates, verbose=True) - - # Run some sanity tests - __sanity_checks() - - outcomes: dict[int, any] = dict() - syntrees = [] - for i in range(1_000): + stgen = SynTreeGenerator( + building_blocks=bblocks, rxn_templates=rxn_templates, verbose=args.verbose + ) + + # Generate synthetic trees + logger.info(f"Start generation of {args.number_syntrees} SynTrees...") + outcomes: dict[int, str] = dict() + syntrees: list[Union[SyntheticTree, None]] = [] + for i in range(args.number_syntrees): st, e = wraps_syntreegenerator_generate() outcomes[i] = e.__class__.__name__ if e is not None else "success" syntrees.append(st) + logger.info(f"SynTree generation completed. Results: {Counter(outcomes.values())}") - logger.info(Counter(outcomes.values())) - - # Store syntrees on disk - syntrees = [st for st in syntrees if st is not None] + # Save synthetic trees on disk syntree_collection = SyntheticTreeSet(syntrees) - import datetime - now = datetime.datetime.now().strftime("%Y%m%d_%H_%M") - file = f"data/{now}-syntrees.json.gz" - - syntree_collection.save(file) + syntree_collection.save(args.output_file) - print("completed at", now) logger.info(f"Completed.") diff --git a/src/syn_net/data_generation/make_dataset_mp.py b/src/syn_net/data_generation/make_dataset_mp.py index 2cd6714a..8ce861ba 100644 --- a/src/syn_net/data_generation/make_dataset_mp.py +++ b/src/syn_net/data_generation/make_dataset_mp.py @@ -4,26 +4,27 @@ Usage: python make_dataset_mp.py """ +import logging import multiprocessing as mp +from pathlib import Path import numpy as np -from pathlib import Path -from syn_net.utils.prep_utils import synthetic_tree_generator -from syn_net.utils.data_utils import ReactionSet, SyntheticTreeSet + from syn_net.config import BUILDING_BLOCKS_RAW_DIR, DATA_PREPROCESS_DIR, MAX_PROCESSES from syn_net.data_generation.preprocessing import BuildingBlockFileHandler -import logging +from syn_net.utils.data_utils import ReactionSet, SyntheticTreeSet +from syn_net.utils.prep_utils import synthetic_tree_generator logger = logging.getLogger(__name__) def func(_x): - np.random.seed(_x) # dummy input to generate "unique" seed + np.random.seed(_x) # dummy input to generate "unique" seed tree, action = synthetic_tree_generator(building_blocks, rxns) return tree, action -if __name__ == '__main__': +if __name__ == "__main__": reaction_template_id = "hb" # "pis" or "hb" building_blocks_id = "enamine_us-2021-smiles" @@ -35,7 +36,7 @@ def func(_x): # Load genearted reactions (matched reactions <=> building blocks) reactions_dir = Path(DATA_PREPROCESS_DIR) - reactions_file = f"reaction-sets_{reaction_template_id}_{building_blocks_id}.json.gz" + reactions_file = f"reaction-sets_{reaction_template_id}_{building_blocks_id}.json.gz" r_set = ReactionSet().load(reactions_dir / reactions_file) rxns = r_set.rxns diff --git a/src/syn_net/data_generation/syntrees.py b/src/syn_net/data_generation/syntrees.py index d89e4d15..5c985c75 100644 --- a/src/syn_net/data_generation/syntrees.py +++ b/src/syn_net/data_generation/syntrees.py @@ -1,11 +1,12 @@ """syntrees """ import logging -from typing import Tuple +from typing import Tuple, Union import numpy as np from rdkit import Chem from tqdm import tqdm + from syn_net.config import MAX_PROCESSES logger = logging.getLogger(__name__) @@ -47,7 +48,7 @@ class SynTreeGenerator: building_blocks: list[str] rxn_templates: list[Reaction] rxns: list[Reaction] - IDX_RXNS: np.ndarray # (nReactions,) + IDX_RXNS: np.ndarray # (nReactions,) ACTIONS: dict[int, str] = {i: action for i, action in enumerate("add expand merge end".split())} verbose: bool @@ -278,13 +279,44 @@ def generate(self, max_depth: int = 15, retries: int = 3): return syntree +def wraps_syntreegenerator_generate( + stgen: SynTreeGenerator, +) -> Tuple[Union[SyntheticTree, None], Union[Exception, None]]: + """Wrapper for `SynTreeGenerator().generate` that catches all Exceptions.""" + try: + st = stgen.generate() + except NoReactantAvailableError as e: + logger.error(e) + return None, e + except NoReactionAvailableError as e: + logger.error(e) + return None, e + except NoBiReactionAvailableError as e: + logger.error(e) + return None, e + except NoReactionPossibleError as e: + logger.error(e) + return None, e + except TypeError as e: + logger.error(e) + return None, e + except Exception as e: + logger.error(e, exc_info=e, stack_info=False) + return None, e + else: + return st, None + + def load_syntreegenerator(file: str) -> SynTreeGenerator: import pickle - with open(file,"rb") as f: + + with open(file, "rb") as f: syntreegenerator = pickle.load(f) return syntreegenerator -def save_syntreegenerator(syntreegenerator: SynTreeGenerator,file: str) -> None: + +def save_syntreegenerator(syntreegenerator: SynTreeGenerator, file: str) -> None: import pickle - with open(file,"wb") as f: - pickle.dump(syntreegenerator,f) + + with open(file, "wb") as f: + pickle.dump(syntreegenerator, f) From fd840093935d2c06a620487863862a0cd4cbcc95 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 19 Sep 2022 10:11:39 -0400 Subject: [PATCH 136/302] fix typo --- src/syn_net/utils/data_utils.py | 53 +++++++++++---------------------- 1 file changed, 18 insertions(+), 35 deletions(-) diff --git a/src/syn_net/utils/data_utils.py b/src/syn_net/utils/data_utils.py index f609f9a9..f81061dd 100644 --- a/src/syn_net/utils/data_utils.py +++ b/src/syn_net/utils/data_utils.py @@ -37,7 +37,7 @@ class Reaction: num_product: int reactant_template: Tuple[str,str] product_template: str - agent_templat: str + agent_template: str available_reactants: Tuple[list[str],Optional[list[str]]] rxnname: str smiles: Any @@ -97,6 +97,7 @@ def load(self, smirks, num_reactant, num_agent, num_product, reactant_template, self.smiles = smiles self.reference = reference self.rxn = self.__init_reaction(self.smirks) + return self @functools.lru_cache(maxsize=20) def get_mol(self, smi: Union[str,Chem.Mol]) -> Chem.Mol: @@ -267,46 +268,28 @@ def get_available_reactants(self) -> Set[str]: class ReactionSet: - """ - A class representing a set of reactions, for saving and loading purposes. - - Arritbutes: - rxns (list or None): Contains `Reaction` objects. One can initialize the - class with a list or None object, the latter of which is used to - define an empty list. - """ - def __init__(self, rxns=None): - if rxns is None: - self.rxns = [] - else: - self.rxns = rxns - - def load(self, json_file): - """ - A function that loads reactions from a JSON-formatted file. - - Args: - json_file (str): The path to the stored reaction file. - """ - - with gzip.open(json_file, 'r') as f: + """Represents a collection of reactions, for saving and loading purposes.""" + def __init__(self, rxns: Optional[list[Reaction]]=None): + self.rxns = rxns if rxns is not None else [] + + def load(self, file: str): + """Load a collection of reactions from a `*.json.gz` file.""" + assert str(file).endswith(".json.gz"), f"Incompatible file extension for file {file}" + with gzip.open(file, 'r') as f: data = json.loads(f.read().decode('utf-8')) - for r_dict in data['reactions']: - r = Reaction() - r.load(**r_dict) - self.rxns.append(r) + for r in data['reactions']: + rxn = Reaction().load(**r) + self.rxns.append(rxn) return self - def save(self, json_file): - """ - A function that saves the reaction set to a JSON-formatted file. + def save(self, file: str): + """Save a collection of reactions to a `*.json.gz` file.""" + + assert str(file).endswith(".json.gz"), f"Incompatible file extension for file {file}" - Args: - json_file (str): The path to the stored reaction file. - """ r_list = {'reactions': [r.__dict__ for r in self.rxns]} - with gzip.open(json_file, 'w') as f: + with gzip.open(file, 'w') as f: f.write(json.dumps(r_list).encode('utf-8')) def __len__(self): From 759f7f1fc8684b9453a335dc85221162b0c44bd4 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 19 Sep 2022 11:00:50 -0400 Subject: [PATCH 137/302] add: type hints --- src/syn_net/utils/data_utils.py | 63 ++++++++++++++++++++------------- 1 file changed, 39 insertions(+), 24 deletions(-) diff --git a/src/syn_net/utils/data_utils.py b/src/syn_net/utils/data_utils.py index f81061dd..cba18d01 100644 --- a/src/syn_net/utils/data_utils.py +++ b/src/syn_net/utils/data_utils.py @@ -305,32 +305,39 @@ def _print(self, x=3): # the definition of classes for defining synthetic trees below class NodeChemical: - """ - A class representing a chemical node in a synthetic tree. + """Represents a chemical node in a synthetic tree. Args: - smiles (None or str): SMILES string representing molecule. - parent (None or int): - child (None or int): Indicates reaction which molecule participates in. - is_leaf (bool): Indicates if this is a leaf node. - is_root (bool): Indicates if this is a root node. - depth (float): - index (int): Indicates the order of this chemical node in the tree. + smiles: Molecule represented as SMILES string. + parent: Parent molecule represented as SMILES string (i.e. the result of a reaction) + child: Index of the reaction this object participates in. + is_leaf: Is this a leaf node in a synthetic tree? + is_root: Is this a root node in a synthetic tree? + depth: Depth this node is in tree (+1 for an action, +.5 for a reaction) + index: Incremental index for all chemical nodes in the tree. """ - def __init__(self, smiles=None, parent=None, child=None, is_leaf=False, - is_root=False, depth=0, index=0): - self.smiles = smiles - self.parent = parent - self.child = child + def __init__( + self, + smiles: Union[str, None] = None, + parent: Union[int, None] = None, + child: Union[int, None] = None, + is_leaf: bool = False, + is_root: bool = False, + depth: float = 0, + index: int = 0, + ): + self.smiles = smiles + self.parent = parent + self.child = child self.is_leaf = is_leaf self.is_root = is_root - self.depth = depth - self.index = index + self.depth = depth + self.index = index class NodeRxn: - """ - A class representing a reaction node in a synthetic tree. + """Represents a chemical reaction in a synthetic tree. + Args: rxn_id (None or int): Index corresponding to reaction in a one-hot vector @@ -342,14 +349,22 @@ class NodeRxn: depth (float): index (int): Indicates the order of this reaction node in the tree. """ - def __init__(self, rxn_id=None, rtype=None, parent=[], - child=None, depth=0, index=0): + + def __init__( + self, + rxn_id: Union[int, None] = None, + rtype: Union[int, None] = None, + parent: Union[list, None] = [], + child: Union[list, None] = None, + depth: float = 0, + index: int = 0, + ): self.rxn_id = rxn_id - self.rtype = rtype + self.rtype = rtype self.parent = parent - self.child = child - self.depth = depth - self.index = index + self.child = child + self.depth = depth + self.index = index class SyntheticTree: From e466c7366aedf31905aef152308c68ab9792304b Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 19 Sep 2022 11:01:28 -0400 Subject: [PATCH 138/302] strip whitespace from reaction template --- src/syn_net/utils/data_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/syn_net/utils/data_utils.py b/src/syn_net/utils/data_utils.py index cba18d01..f3a1736f 100644 --- a/src/syn_net/utils/data_utils.py +++ b/src/syn_net/utils/data_utils.py @@ -47,7 +47,7 @@ def __init__(self, template=None, rxnname=None, smiles=None, reference=None): if template is not None: # define a few attributes based on the input - self.smirks = template + self.smirks = template.strip() self.rxnname = rxnname self.smiles = smiles self.reference = reference From 4d313b83f611fb2d4a514d88614c423aeb06682c Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 19 Sep 2022 11:02:54 -0400 Subject: [PATCH 139/302] fix: serialize `Reaction` --- src/syn_net/utils/data_utils.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/syn_net/utils/data_utils.py b/src/syn_net/utils/data_utils.py index f3a1736f..6fd9286c 100644 --- a/src/syn_net/utils/data_utils.py +++ b/src/syn_net/utils/data_utils.py @@ -265,7 +265,13 @@ def set_available_reactants(self, building_blocks: list[str],verbose: bool=False def get_available_reactants(self) -> Set[str]: return {x for reactants in self.available_reactants for x in reactants} - + def asdict(self) -> dict(): + """Returns serializable fields as new dictionary mapping. + *Excludes* Not-easily-serializable `self.rxn: rdkit.Chem.ChemicalReaction`.""" + import copy + out = copy.deepcopy(self.__dict__) # TODO: + _ = out.pop("rxn") + return out class ReactionSet: """Represents a collection of reactions, for saving and loading purposes.""" @@ -279,7 +285,7 @@ def load(self, file: str): data = json.loads(f.read().decode('utf-8')) for r in data['reactions']: - rxn = Reaction().load(**r) + rxn = Reaction().load(**r) # TODO: `load()` relies on postional args, hence we cannot load a reaction that has no `available_reactants` for extample (or no template) self.rxns.append(rxn) return self @@ -288,7 +294,7 @@ def save(self, file: str): assert str(file).endswith(".json.gz"), f"Incompatible file extension for file {file}" - r_list = {'reactions': [r.__dict__ for r in self.rxns]} + r_list = {'reactions': [r.asdict() for r in self.rxns]} with gzip.open(file, 'w') as f: f.write(json.dumps(r_list).encode('utf-8')) @@ -300,7 +306,7 @@ def _print(self, x=3): for i, r in enumerate(self.rxns): if i >= x: break - print(r.__dict__) + print(json.dumps(r.asdict(),indent=2)) # the definition of classes for defining synthetic trees below @@ -382,7 +388,7 @@ class SyntheticTree: """ def __init__(self, tree=None): self.chemicals: list[NodeChemical] = [] - self.reactions:list [Reaction] = [] + self.reactions: list[Reaction] = [] self.root = None self.depth: float= 0 self.actions = [] @@ -418,7 +424,7 @@ def output_dict(self): Returns: data (dict): A dictionary representing a synthetic tree. """ - return {'reactions': [r.__dict__ for r in self.reactions], + return {'reactions': [r.asdict() for r in self.reactions], 'chemicals': [m.__dict__ for m in self.chemicals], 'root': self.root.__dict__, 'depth': self.depth, From 587165baa3283411a7460e21b444b97c20b84758 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 19 Sep 2022 11:04:19 -0400 Subject: [PATCH 140/302] delete unused code --- src/syn_net/utils/data_utils.py | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/src/syn_net/utils/data_utils.py b/src/syn_net/utils/data_utils.py index 6fd9286c..fcf4abc7 100644 --- a/src/syn_net/utils/data_utils.py +++ b/src/syn_net/utils/data_utils.py @@ -705,19 +705,4 @@ def _print(self, x=3): if __name__ == '__main__': - """ - A test run to find available reactants for a set of reaction templates. - """ - path_to_building_blocks = '/home/whgao/shared/Data/scGen/enamine_5k.csv.gz' - # path_to_rxn_templates = '/home/whgao/shared/Data/scGen/rxn_set_hartenfeller.txt' - path_to_rxn_templates = '/home/whgao/shared/Data/scGen/rxn_set_pis_test.txt' - - building_blocks = pd.read_csv(path_to_building_blocks, compression='gzip')['SMILES'].tolist() - rxns = [] - for line in open(path_to_rxn_templates, 'rt'): - rxn = Reaction(line.split('|')[1].strip()) - rxn.set_available_reactants(building_blocks) - rxns.append(rxn) - - r = ReactionSet(rxns) - r.save('reactions_pis_test.json.gz') + pass From 93cff2b9a5292bafaebbb10393b6a8fe16fcb628 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 19 Sep 2022 11:12:30 -0400 Subject: [PATCH 141/302] refactor `SyntheticTreeSet` --- src/syn_net/utils/data_utils.py | 53 +++++++++++---------------------- 1 file changed, 17 insertions(+), 36 deletions(-) diff --git a/src/syn_net/utils/data_utils.py b/src/syn_net/utils/data_utils.py index fcf4abc7..9b0bef9b 100644 --- a/src/syn_net/utils/data_utils.py +++ b/src/syn_net/utils/data_utils.py @@ -289,7 +289,7 @@ def load(self, file: str): self.rxns.append(rxn) return self - def save(self, file: str): + def save(self, file: str) -> None: """Save a collection of reactions to a `*.json.gz` file.""" assert str(file).endswith(".json.gz"), f"Incompatible file extension for file {file}" @@ -644,19 +644,9 @@ def update(self, action: int, rxn_id:int, mol1: str, mol2: str, mol_product:str) class SyntheticTreeSet: - """ - A class representing a set of synthetic trees, for saving and loading purposes. - - Arritbute: - sts (list): Contains `SyntheticTree`s. One can initialize the class with - either a list of synthetic trees or None, in which case an empty - list is created. - """ - def __init__(self, sts=None): - if sts is None: - self.sts = [] - else: - self.sts = sts + """Represents a collection of synthetic trees, for saving and loading purposes.""" + def __init__(self, sts: Optional[list[SyntheticTree]]=None): + self.sts = sts if sts is not None else [] def __len__(self): return len(self.sts) @@ -665,42 +655,33 @@ def __getitem__(self,index): if self.sts is None: raise IndexError("No Synthetic Trees.") return self.sts[index] - def load(self, json_file): - """ - A function that loads a JSON-formatted synthetic tree file. + def load(self, file:str): + """Load a collection of synthetic trees from a `*.json.gz` file.""" + assert str(file).endswith(".json.gz"), f"Incompatible file extension for file {file}" - Args: - json_file (str): The path to the stored synthetic tree file. - """ - with gzip.open(json_file, 'rt') as f: + with gzip.open(file, 'rt') as f: data = json.loads(f.read()) for st_dict in data['trees']: - if st_dict is None: - self.sts.append(None) - else: - st = SyntheticTree(st_dict) - self.sts.append(st) + st = SyntheticTree(st_dict) if st is not None else None + self.sts.append(st) + return self - def save(self, json_file): - """ - A function that saves the synthetic tree set to a JSON-formatted file. + def save(self, file:str) -> None: + """Save a collection of synthetic trees to a `*.json.gz` file.""" + assert str(file).endswith(".json.gz"), f"Incompatible file extension for file {file}" - Args: - json_file (str): The path to the stored synthetic tree file. - """ st_list = { 'trees': [st.output_dict() if st is not None else None for st in self.sts] } - with gzip.open(json_file, 'wt') as f: + with gzip.open(file, 'wt') as f: f.write(json.dumps(st_list)) def _print(self, x=3): - # For debugging + """Helper function for debugging.""" for i, r in enumerate(self.sts): - if i >= x: - break + if i >= x: break print(r.output_dict()) From bca9d762c71e871ace82b22ca27f74d1082479d5 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 19 Sep 2022 11:13:06 -0400 Subject: [PATCH 142/302] black --- src/syn_net/utils/data_utils.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/syn_net/utils/data_utils.py b/src/syn_net/utils/data_utils.py index 9b0bef9b..843f9c85 100644 --- a/src/syn_net/utils/data_utils.py +++ b/src/syn_net/utils/data_utils.py @@ -645,45 +645,45 @@ def update(self, action: int, rxn_id:int, mol1: str, mol2: str, mol_product:str) class SyntheticTreeSet: """Represents a collection of synthetic trees, for saving and loading purposes.""" - def __init__(self, sts: Optional[list[SyntheticTree]]=None): + + def __init__(self, sts: Optional[list[SyntheticTree]] = None): self.sts = sts if sts is not None else [] def __len__(self): return len(self.sts) - def __getitem__(self,index): - if self.sts is None: raise IndexError("No Synthetic Trees.") + def __getitem__(self, index): + if self.sts is None: + raise IndexError("No Synthetic Trees.") return self.sts[index] - def load(self, file:str): + def load(self, file: str): """Load a collection of synthetic trees from a `*.json.gz` file.""" assert str(file).endswith(".json.gz"), f"Incompatible file extension for file {file}" - with gzip.open(file, 'rt') as f: + with gzip.open(file, "rt") as f: data = json.loads(f.read()) - for st_dict in data['trees']: + for st_dict in data["trees"]: st = SyntheticTree(st_dict) if st is not None else None self.sts.append(st) return self - def save(self, file:str) -> None: + def save(self, file: str) -> None: """Save a collection of synthetic trees to a `*.json.gz` file.""" assert str(file).endswith(".json.gz"), f"Incompatible file extension for file {file}" - st_list = { - 'trees': [st.output_dict() if st is not None else None for st in self.sts] - } - with gzip.open(file, 'wt') as f: + st_list = {"trees": [st.output_dict() if st is not None else None for st in self.sts]} + with gzip.open(file, "wt") as f: f.write(json.dumps(st_list)) def _print(self, x=3): """Helper function for debugging.""" for i, r in enumerate(self.sts): - if i >= x: break + if i >= x: + break print(r.output_dict()) - if __name__ == '__main__': pass From b69d648ab446f2e9cd913683d0935349c07ab307 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 19 Sep 2022 11:13:51 -0400 Subject: [PATCH 143/302] isort --- src/syn_net/utils/data_utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/syn_net/utils/data_utils.py b/src/syn_net/utils/data_utils.py index 843f9c85..b561eaa7 100644 --- a/src/syn_net/utils/data_utils.py +++ b/src/syn_net/utils/data_utils.py @@ -8,12 +8,11 @@ * `SyntheticTreeSet` """ import functools -import itertools import gzip +import itertools import json -from typing import Any, Optional, Tuple, Union, Set +from typing import Any, Optional, Set, Tuple, Union -import pandas as pd from rdkit import Chem from rdkit.Chem import AllChem, Draw, rdChemReactions from tqdm import tqdm From dd65464cbcbf67e82705744cb227f019ffca7344 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 19 Sep 2022 11:24:55 -0400 Subject: [PATCH 144/302] fix --- scripts/03-generate-syntrees.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/scripts/03-generate-syntrees.py b/scripts/03-generate-syntrees.py index a06c4a2d..7a73c75c 100644 --- a/scripts/03-generate-syntrees.py +++ b/scripts/03-generate-syntrees.py @@ -17,11 +17,6 @@ RDLogger.DisableLog("rdApp.*") -building_blocks_file = "data/pre-process/building-blocks/enamine-us-smiles.csv.gz" -rxn_templates_file = "data/assets/reaction-templates/hb.txt" -output_file = Path(DATA_PREPROCESS_DIR) / f"synthetic-trees.json.gz" - - def get_args(): import argparse @@ -30,16 +25,19 @@ def get_args(): parser.add_argument( "--building-blocks-file", type=str, + default="data/pre-process/building-blocks/enamine-us-smiles.csv.gz", # TODO: change help="Input file with SMILES strings (First row `SMILES`, then one per line).", ) parser.add_argument( "--rxn-templates-file", type=str, + default="data/assets/reaction-templates/hb.txt", # TODO: change help="Input file with reaction templates as SMARTS(No header, one per line).", ) parser.add_argument( "--output-file", type=str, + default=Path(DATA_PREPROCESS_DIR) / f"synthetic-trees.json.gz", help="Output file for the generated synthetic trees (*.json.gz)", ) # Parameters @@ -72,7 +70,7 @@ def get_args(): outcomes: dict[int, str] = dict() syntrees: list[Union[SyntheticTree, None]] = [] for i in range(args.number_syntrees): - st, e = wraps_syntreegenerator_generate() + st, e = wraps_syntreegenerator_generate(stgen) outcomes[i] = e.__class__.__name__ if e is not None else "success" syntrees.append(st) logger.info(f"SynTree generation completed. Results: {Counter(outcomes.values())}") From 5e20da6d674ccbebe2bc7fbb30e3e9be8f1b4c6f Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 19 Sep 2022 11:25:10 -0400 Subject: [PATCH 145/302] format --- scripts/03-generate-syntrees.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/scripts/03-generate-syntrees.py b/scripts/03-generate-syntrees.py index 7a73c75c..ac6c8005 100644 --- a/scripts/03-generate-syntrees.py +++ b/scripts/03-generate-syntrees.py @@ -17,6 +17,7 @@ RDLogger.DisableLog("rdApp.*") + def get_args(): import argparse @@ -25,13 +26,13 @@ def get_args(): parser.add_argument( "--building-blocks-file", type=str, - default="data/pre-process/building-blocks/enamine-us-smiles.csv.gz", # TODO: change + default="data/pre-process/building-blocks/enamine-us-smiles.csv.gz", # TODO: change help="Input file with SMILES strings (First row `SMILES`, then one per line).", ) parser.add_argument( "--rxn-templates-file", type=str, - default="data/assets/reaction-templates/hb.txt", # TODO: change + default="data/assets/reaction-templates/hb.txt", # TODO: change help="Input file with reaction templates as SMARTS(No header, one per line).", ) parser.add_argument( From d63e317dbc3235b98cf273917545db5405e213ac Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 19 Sep 2022 11:38:13 -0400 Subject: [PATCH 146/302] bug fix --- src/syn_net/utils/data_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/syn_net/utils/data_utils.py b/src/syn_net/utils/data_utils.py index b561eaa7..94372f0b 100644 --- a/src/syn_net/utils/data_utils.py +++ b/src/syn_net/utils/data_utils.py @@ -386,8 +386,8 @@ class SyntheticTree: type (uni- or bi-molecular). """ def __init__(self, tree=None): - self.chemicals: list[NodeChemical] = [] - self.reactions: list[Reaction] = [] + self.chemicals: list[NodeChemical] = [] + self.reactions: list[NodeRxn] = [] self.root = None self.depth: float= 0 self.actions = [] @@ -423,7 +423,7 @@ def output_dict(self): Returns: data (dict): A dictionary representing a synthetic tree. """ - return {'reactions': [r.asdict() for r in self.reactions], + return {'reactions': [r.__dict__ for r in self.reactions], 'chemicals': [m.__dict__ for m in self.chemicals], 'root': self.root.__dict__, 'depth': self.depth, From 0f6c7e91fe3008bd7b6dad7cd27b35cd52e401dd Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 19 Sep 2022 11:42:24 -0400 Subject: [PATCH 147/302] update with instruction for syntree generation --- INSTRUCTIONS.md | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/INSTRUCTIONS.md b/INSTRUCTIONS.md index 4007849b..f0dfc8a9 100644 --- a/INSTRUCTIONS.md +++ b/INSTRUCTIONS.md @@ -57,12 +57,20 @@ Let's start. Herein we generate the data used for training the networks. The data is generated by randomly selecting building blocks, reaction templates and directives to grow a synthetic tree. + + ```bash + # Generate synthetic trees + python scripts/03-generate-syntrees.py \ + --building-blocks-file "data/pre-process/building-blocks/enamine-us-smiles.csv.gz" \ + --rxn-templates-file "data/assets/reaction-templates/hb.txt" \ + --output-file "data/pre-process/synthetic-trees.json.gz" \ + --number-syntrees 600000 + ``` + In a second step, we filter out some synthetic trees to make the data pharmaceutically more interesting. That is, we filter out trees, whose root node molecule has a QED < 0.5, or randomly with a probability less than 1 - QED/0.5. ```bash - # Generate synthetic trees - python scripts/03-make_dataset_mp.py # Filter python scripts/04-sample_from_original.py ``` From 0eb71388d01f128937a556134210fce7f8a72d6b Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 19 Sep 2022 11:48:12 -0400 Subject: [PATCH 148/302] add documentation --- src/syn_net/data_generation/syntrees.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/syn_net/data_generation/syntrees.py b/src/syn_net/data_generation/syntrees.py index 5c985c75..823d60bf 100644 --- a/src/syn_net/data_generation/syntrees.py +++ b/src/syn_net/data_generation/syntrees.py @@ -298,6 +298,8 @@ def wraps_syntreegenerator_generate( logger.error(e) return None, e except TypeError as e: + # When converting an invalid molecule from SMILES to rdkit Molecule. + # This happens if the reaction template/rdkit produces an invalid product. logger.error(e) return None, e except Exception as e: From f38a5f1269518bccc13f0688584f7901472ad77c Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 19 Sep 2022 13:15:59 -0400 Subject: [PATCH 149/302] bug fix --- src/syn_net/utils/data_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/syn_net/utils/data_utils.py b/src/syn_net/utils/data_utils.py index 94372f0b..4f4f8ee4 100644 --- a/src/syn_net/utils/data_utils.py +++ b/src/syn_net/utils/data_utils.py @@ -663,8 +663,8 @@ def load(self, file: str): with gzip.open(file, "rt") as f: data = json.loads(f.read()) - for st_dict in data["trees"]: - st = SyntheticTree(st_dict) if st is not None else None + for st in data["trees"]: + st = SyntheticTree(st) if st is not None else None self.sts.append(st) return self From 2ea6c95e13285499f4da61c09fa1b9db2be77f40 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 19 Sep 2022 16:39:04 -0400 Subject: [PATCH 150/302] refactor syntree filter --- scripts/03-generate-syntrees.py | 4 +- scripts/04-filter-synthetic-trees.py | 120 +++++++++++++++++++++++++++ 2 files changed, 123 insertions(+), 1 deletion(-) create mode 100644 scripts/04-filter-synthetic-trees.py diff --git a/scripts/03-generate-syntrees.py b/scripts/03-generate-syntrees.py index ac6c8005..a4a388c5 100644 --- a/scripts/03-generate-syntrees.py +++ b/scripts/03-generate-syntrees.py @@ -1,6 +1,7 @@ import logging from collections import Counter from pathlib import Path +from tqdm import tqdm from rdkit import RDLogger @@ -70,7 +71,8 @@ def get_args(): logger.info(f"Start generation of {args.number_syntrees} SynTrees...") outcomes: dict[int, str] = dict() syntrees: list[Union[SyntheticTree, None]] = [] - for i in range(args.number_syntrees): + myrange = tqdm(range(args.number_syntrees)) if args.verbose else range(args.number_syntrees) + for i in myrange: st, e = wraps_syntreegenerator_generate(stgen) outcomes[i] = e.__class__.__name__ if e is not None else "success" syntrees.append(st) diff --git a/scripts/04-filter-synthetic-trees.py b/scripts/04-filter-synthetic-trees.py new file mode 100644 index 00000000..cbfc90ec --- /dev/null +++ b/scripts/04-filter-synthetic-trees.py @@ -0,0 +1,120 @@ +"""Filter Synthetic Trees. +""" + +import json +import logging +from collections import Counter +from pathlib import Path + +import numpy as np +from rdkit import Chem, RDLogger +from tqdm import tqdm + +from syn_net.config import DATA_PREPROCESS_DIR, MAX_PROCESSES +from syn_net.utils.data_utils import SyntheticTree, SyntheticTreeSet + +logger = logging.getLogger(__name__) + +RDLogger.DisableLog("rdApp.*") + + +class Filter: + def filter(self, st: SyntheticTree, **kwargs) -> bool: + ... + + +class ValidRootMolFilter(Filter): + def filter(self, st: SyntheticTree, **kwargs) -> bool: + return Chem.MolFromSmiles(st.root.smiles) is not None + + +class OracleFilter(Filter): + def __init__( + self, + name: str = "qed", + threshold: float = 0.5, + rng=np.random.default_rng(42), + ) -> None: + super().__init__() + from tdc import Oracle + + self.oracle_fct = Oracle(name=name) + self.threshold = threshold + self.rng = rng + + def _qed(self, st: SyntheticTree): + """Filter for molecules with a high qed.""" + return self.oracle_fct(st.root.smiles) > self.threshold + + def _random(self, st: SyntheticTree): + """Filter molecules that fail the `_qed` filter; i.e. randomly select low qed molecules.""" + return self.rng.random() < self.oracle_fct(st.root.smiles) / self.threshold + + def filter(self, st: SyntheticTree) -> bool: + return self._qed(st) or self._random(st) + + +def get_args(): + import argparse + + parser = argparse.ArgumentParser() + # File I/O + parser.add_argument( + "--input-file", + type=str, + default="data/pre-process/synthetic-trees.json.gz", + help="Input file for the filtered generated synthetic trees (*.json.gz)", + ) + parser.add_argument( + "--output-file", + type=str, + default="data/pre-process/synthetic-trees-filtered.json.gz", + help="Output file for the filtered generated synthetic trees (*.json.gz)", + ) + + # Processing + parser.add_argument("--ncpu", type=int, default=MAX_PROCESSES, help="Number of cpus") + parser.add_argument("--verbose", default=False, action="store_true") + return parser.parse_args() + + +if __name__ == "__main__": + logger.info("Start.") + + # Parse input args + args = get_args() + logger.info(f"Arguments: {json.dumps(vars(args),indent=2)}") + + # Load previously generated synthetic trees + syntree_collection = SyntheticTreeSet().load(args.input_file) + logger.info(f"Successfully loaded '{args.input_file}' with {len(syntree_collection)} syntrees.") + + # Filter trees + # TODO: Move to src/syn_net/data_generation/filters.py ? + valid_root_mol_filter = ValidRootMolFilter() + interesting_mol_filter = OracleFilter(threshold=0.5, rng=np.random.default_rng()) + + syntrees = [] + syntree_collection = [s for s in syntree_collection if s is not None] + syntree_collection = tqdm(syntree_collection) if args.verbose else syntree_collection + outcomes: dict[int, str] = dict() # TODO: think about what metrics to track here + for i, st in enumerate(syntree_collection): + + # Filter 1: Is root molecule valid? + keep_tree = valid_root_mol_filter.filter(st) + if not keep_tree: + continue + + # Filter 2: Is root molecule "pharmaceutically interesting?" + keep_tree = interesting_mol_filter.filter(st) + if not keep_tree: + continue + + # We passed all filters. This tree ascended to our dataset + syntrees.append(st) + + # Save filtered synthetic trees on disk + SyntheticTreeSet(syntrees).save(args.output_file) + logger.info(f"Successfully saved '{args.output_file}' with {len(syntrees)} syntrees.") + + logger.info(f"Completed.") From edc2091e327e54c8639ac13c4b7d7f0edd00a42f Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 19 Sep 2022 16:41:22 -0400 Subject: [PATCH 151/302] update instruction with syntree filters --- INSTRUCTIONS.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/INSTRUCTIONS.md b/INSTRUCTIONS.md index f0dfc8a9..429a3368 100644 --- a/INSTRUCTIONS.md +++ b/INSTRUCTIONS.md @@ -62,7 +62,7 @@ Let's start. # Generate synthetic trees python scripts/03-generate-syntrees.py \ --building-blocks-file "data/pre-process/building-blocks/enamine-us-smiles.csv.gz" \ - --rxn-templates-file "data/assets/reaction-templates/hb.txt" \ + --rxn-templates-file "data/assets/reaction-templates/hb.txt" \ --output-file "data/pre-process/synthetic-trees.json.gz" \ --number-syntrees 600000 ``` @@ -72,7 +72,9 @@ Let's start. ```bash # Filter - python scripts/04-sample_from_original.py + python scripts/04-filter-synthetic-trees.py \ + --input-file "data/pre-process/synthetic-trees.json.gz" \ + --output-file "data/pre-process/synthetic-trees-filtered.json.gz" ``` Each *synthetic tree* is serializable and so we save all trees in a compressed `.json` file. From 39ddc76511e1b5a54dfb2bc088e3c04db1716ee4 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 19 Sep 2022 16:42:09 -0400 Subject: [PATCH 152/302] delete old script (refactored in 2ea6c95) --- scripts/sample_from_original.py | 75 --------------------------------- 1 file changed, 75 deletions(-) delete mode 100644 scripts/sample_from_original.py diff --git a/scripts/sample_from_original.py b/scripts/sample_from_original.py deleted file mode 100644 index 9d32e79e..00000000 --- a/scripts/sample_from_original.py +++ /dev/null @@ -1,75 +0,0 @@ -""" -Filters the synthetic trees by the QEDs of the root molecules. -""" -from pathlib import Path - -import numpy as np -import pandas as pd -from rdkit import Chem -from tdc import Oracle -from tqdm import tqdm - -from syn_net.config import DATA_PREPROCESS_DIR -from syn_net.utils.data_utils import SyntheticTree, SyntheticTreeSet - -DATA_DIR = "pool001/whgao/data/synth_net" -SYNTHETIC_TREES_FILE = "abc-st_data.json.gz" - -def _is_valid_mol(mol: Chem.rdchem.Mol): - return mol is not None - -if __name__ == '__main__': - reaction_template_id = "hb" # "pis" or "hb" - building_blocks_id = "enamine_us-2021-smiles" - qed = Oracle(name='qed') - - # Load generated synthetic trees - file = Path(DATA_PREPROCESS_DIR) / f"synthetic-trees_{reaction_template_id}-{building_blocks_id}.json.gz" - st_set = SyntheticTreeSet() - st_set.load(file) - synthetic_trees = st_set.sts - print(f'Finish reading, in total {len(synthetic_trees)} synthetic trees.') - - # Filter synthetic trees - # .. based on validity of root molecule - # .. based on drug-like quality - filtered_data: list[SyntheticTree] = [] - original_qed: list[float] = [] - qeds: list[float] = [] - generated_smiles: list[str] = [] - - threshold = 0.5 - - for t in tqdm(synthetic_trees): - try: - smiles = t.root.smiles - mol = Chem.MolFromSmiles(smiles) - if not _is_valid_mol(mol): - continue - if smiles in generated_smiles: - continue - - qed_value = qed(smiles) - original_qed.append(qed_value) - - # filter the trees based on their QEDs - if qed_value > threshold or np.random.random() < (qed_value/threshold): - generated_smiles.append(smiles) - filtered_data.append(t) - qeds.append(qed_value) - - except Exception as e: - print(e) - - print(f'Finish sampling, remaining {len(filtered_data)} synthetic trees.') - - # Save to local disk - st_set = SyntheticTreeSet(filtered_data) - file = Path(DATA_PREPROCESS_DIR) / f"synthetic-trees_{reaction_template_id}-{building_blocks_id}-filtered.json.gz" - st_set.save(file) - - df = pd.DataFrame({'SMILES': generated_smiles, 'qed': qeds}) - file = Path(DATA_PREPROCESS_DIR) / f"filtered-smiles_{reaction_template_id}-{building_blocks_id}-filtered.csv.gz" - df.to_csv(file, compression='gzip', index=False) - - print('Finish!') From 429d4ff56436043ab86612fa88f892f065127e67 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 19 Sep 2022 16:51:00 -0400 Subject: [PATCH 153/302] do not save if syntree is `None` --- src/syn_net/utils/data_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/syn_net/utils/data_utils.py b/src/syn_net/utils/data_utils.py index 4f4f8ee4..217d13fb 100644 --- a/src/syn_net/utils/data_utils.py +++ b/src/syn_net/utils/data_utils.py @@ -673,7 +673,7 @@ def save(self, file: str) -> None: """Save a collection of synthetic trees to a `*.json.gz` file.""" assert str(file).endswith(".json.gz"), f"Incompatible file extension for file {file}" - st_list = {"trees": [st.output_dict() if st is not None else None for st in self.sts]} + st_list = {"trees": [st.output_dict() for st in self.sts if st is not None]} with gzip.open(file, "wt") as f: f.write(json.dumps(st_list)) From cda281cbfc24932e90d72874e3e4c1183fb13d66 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 19 Sep 2022 16:51:37 -0400 Subject: [PATCH 154/302] remove logging level --- src/syn_net/data_generation/syntrees.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/syn_net/data_generation/syntrees.py b/src/syn_net/data_generation/syntrees.py index 823d60bf..080ef666 100644 --- a/src/syn_net/data_generation/syntrees.py +++ b/src/syn_net/data_generation/syntrees.py @@ -10,7 +10,6 @@ from syn_net.config import MAX_PROCESSES logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) from syn_net.utils.data_utils import Reaction, SyntheticTree From 054f4b2a284827cb37b02e2f440f94819a294106 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Tue, 20 Sep 2022 16:09:08 -0400 Subject: [PATCH 155/302] bugfix --- src/syn_net/data_generation/syntrees.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/syn_net/data_generation/syntrees.py b/src/syn_net/data_generation/syntrees.py index 080ef666..de85531d 100644 --- a/src/syn_net/data_generation/syntrees.py +++ b/src/syn_net/data_generation/syntrees.py @@ -41,6 +41,12 @@ class NoReactionPossibleError(Exception): def __init__(self, message): super().__init__(message) +class MaxDepthError(Exception): + """Synthetic Tree has exceeded its maximum depth.""" + + def __init__(self, message): + super().__init__(message) + class SynTreeGenerator: @@ -56,7 +62,7 @@ def __init__( *, building_blocks: list[str], rxn_templates: list[str], - rng=np.random.default_rng(seed=42), + rng=np.random.default_rng(), processes: int = MAX_PROCESSES, verbose: bool = False, ) -> None: @@ -186,7 +192,7 @@ def _get_action_mask(self, syntree: SyntheticTree): elif nTrees == 1: canAdd = True canExpand = True - canEnd = True + canEnd = True # TODO: When syntree has reached max depth, only allow to end it. elif nTrees == 2: canExpand = True canMerge = any(self._get_rxn_mask(tuple(state))) @@ -241,12 +247,10 @@ def generate(self, max_depth: int = 15, retries: int = 3): raise NoReactionPossibleError( f"Reaction (ID: {idx_rxn}) not possible with: {r1} + {r2}." ) - elif action == "add": mol = self._sample_molecule() r1, r2, p, idx_rxn = self._expand(mol) # Expand this subtree: reactant, reaction, reactant2 - elif action == "merge": # merge two subtrees: sample reaction, run it. @@ -262,6 +266,8 @@ def generate(self, max_depth: int = 15, retries: int = 3): raise NoReactionPossibleError( f"Reaction (ID: {idx_rxn}) not possible with: {r1} + {r2}." ) + else: + raise ValueError(f"Invalid action {action}") # Prepare next iteration logger.debug(f" Ran reaction {r1} + {r2} -> {p}") @@ -274,6 +280,8 @@ def generate(self, max_depth: int = 15, retries: int = 3): if action == "end": break + if i==max_depth-1 and not action == "end": + raise MaxDepthError("Maximum depth {max_depth} exceeded.") logger.debug(f"🙌 SynTree completed.") return syntree From 2a117ae5a12b09fe7b1b8df55497528950fc0add Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 21 Sep 2022 11:32:18 -0400 Subject: [PATCH 156/302] allow mp for generation of syntrees --- scripts/03-generate-syntrees.py | 58 ++++++++++++++++++++++++--------- 1 file changed, 43 insertions(+), 15 deletions(-) diff --git a/scripts/03-generate-syntrees.py b/scripts/03-generate-syntrees.py index a4a388c5..2603bbe2 100644 --- a/scripts/03-generate-syntrees.py +++ b/scripts/03-generate-syntrees.py @@ -1,11 +1,10 @@ import logging from collections import Counter -from pathlib import Path -from tqdm import tqdm from rdkit import RDLogger +from tqdm import tqdm -from syn_net.config import DATA_PREPROCESS_DIR, MAX_PROCESSES +from syn_net.config import MAX_PROCESSES from syn_net.data_generation.preprocessing import ( BuildingBlockFileHandler, ReactionTemplateFileHandler, @@ -14,7 +13,7 @@ from syn_net.utils.data_utils import SyntheticTree, SyntheticTreeSet logger = logging.getLogger(__name__) -from typing import Union +from typing import Tuple, Union RDLogger.DisableLog("rdApp.*") @@ -39,11 +38,13 @@ def get_args(): parser.add_argument( "--output-file", type=str, - default=Path(DATA_PREPROCESS_DIR) / f"synthetic-trees.json.gz", + default="data/pre-precess/synthetic-trees.json.gz", help="Output file for the generated synthetic trees (*.json.gz)", ) # Parameters - parser.add_argument("--number-syntrees", type=int, help="Number of SynTrees to generate.") + parser.add_argument( + "--number-syntrees", type=int, default=1000, help="Number of SynTrees to generate." + ) # Processing parser.add_argument("--ncpu", type=int, default=MAX_PROCESSES, help="Number of cpus") @@ -51,9 +52,40 @@ def get_args(): return parser.parse_args() +def generate_mp() -> Tuple[dict[int, str], list[Union[SyntheticTree, None]]]: + from functools import partial + + import numpy as np + from pathos import multiprocessing as mp + + def wrapper(stgen, _): + stgen.rng = np.random.default_rng() + return wraps_syntreegenerator_generate(stgen) + + func = partial(wrapper, stgen) + with mp.Pool(processes=4) as pool: + results = pool.map(func, range(args.number_syntrees)) + outcomes = { + i: e.__class__.__name__ if e is not None else "success" for i, (_, e) in enumerate(results) + } + syntrees = [st for (st, e) in results if e is None] + return outcomes, syntrees + + +def generate() -> Tuple[dict[int, str], list[Union[SyntheticTree, None]]]: + outcomes: dict[int, str] = dict() + syntrees: list[Union[SyntheticTree, None]] = [] + myrange = tqdm(range(args.number_syntrees)) if args.verbose else range(args.number_syntrees) + for i in myrange: + st, e = wraps_syntreegenerator_generate(stgen) + outcomes[i] = e.__class__.__name__ if e is not None else "success" + syntrees.append(st) + + return outcomes, syntrees + + if __name__ == "__main__": logger.info("Start.") - # Parse input args args = get_args() logger.info(f"Arguments: {vars(args)}") @@ -66,16 +98,12 @@ def get_args(): stgen = SynTreeGenerator( building_blocks=bblocks, rxn_templates=rxn_templates, verbose=args.verbose ) - # Generate synthetic trees logger.info(f"Start generation of {args.number_syntrees} SynTrees...") - outcomes: dict[int, str] = dict() - syntrees: list[Union[SyntheticTree, None]] = [] - myrange = tqdm(range(args.number_syntrees)) if args.verbose else range(args.number_syntrees) - for i in myrange: - st, e = wraps_syntreegenerator_generate(stgen) - outcomes[i] = e.__class__.__name__ if e is not None else "success" - syntrees.append(st) + if args.ncpu > 1: + outcomes, syntrees = generate_mp() + else: + outcomes, syntrees = generate() logger.info(f"SynTree generation completed. Results: {Counter(outcomes.values())}") # Save synthetic trees on disk From b6b843395e6c8bce12dd5e15f6e235c3523ba93a Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 21 Sep 2022 11:43:53 -0400 Subject: [PATCH 157/302] add cli args to script --- scripts/st_split.py | 48 +++++++++++++++++++++++++++++++++++++-------- 1 file changed, 40 insertions(+), 8 deletions(-) diff --git a/scripts/st_split.py b/scripts/st_split.py index 5a4f2d75..b1de5539 100644 --- a/scripts/st_split.py +++ b/scripts/st_split.py @@ -3,15 +3,46 @@ """ from syn_net.utils.data_utils import SyntheticTreeSet from pathlib import Path -from syn_net.config import DATA_PREPROCESS_DIR, DATA_PREPARED_DIR +from syn_net.config import DATA_PREPROCESS_DIR, DATA_PREPARED_DIR, MAX_PROCESSES +import json +import logging +logger = logging.getLogger(__name__) + +def get_args(): + import argparse + + parser = argparse.ArgumentParser() + # File I/O + parser.add_argument( + "--input-file", + type=str, + default="data/pre-process/synthetic-trees.json.gz", + help="Input file for the filtered generated synthetic trees (*.json.gz)", + ) + parser.add_argument( + "--output-dir", + type=str, + default=str(Path(DATA_PREPROCESS_DIR) / "split"), + help="Output directory for the splitted synthetic trees (*.json.gz)", + ) + + # Processing + parser.add_argument("--ncpu", type=int, default=MAX_PROCESSES, help="Number of cpus") + parser.add_argument("--verbose", default=False, action="store_true") + return parser.parse_args() + if __name__ == "__main__": - reaction_template_id = "hb" # "pis" or "hb" - building_blocks_id = "enamine_us-2021-smiles" + logger.info("Start.") + + # Parse input args + args = get_args() + logger.info(f"Arguments: {json.dumps(vars(args),indent=2)}") + # Load filtered synthetic trees st_set = SyntheticTreeSet() - file = Path(DATA_PREPROCESS_DIR) / f"synthetic-trees_{reaction_template_id}-{building_blocks_id}-filtered.json.gz" + file = args.input_file print(f'Reading data from {file}') st_set.load(file) data = st_set.sts @@ -31,17 +62,18 @@ data_test = data[num_train + num_valid: ] # Save to local disk - + out_dir = Path(args.output_dir) + out_dir.mkdir(parents=True,exist_ok=True) print("Saving training dataset: ", len(data_train)) trees = SyntheticTreeSet(data_train) - trees.save(f'{DATA_PREPARED_DIR}/synthetic-trees-train.json.gz') + trees.save(out_dir / "synthetic-trees-train.json.gz") print("Saving validation dataset: ", len(data_valid)) trees = SyntheticTreeSet(data_valid) - trees.save(f'{DATA_PREPARED_DIR}/synthetic-trees-valid.json.gz') + trees.save(out_dir / "synthetic-trees-valid.json.gz") print("Saving testing dataset: ", len(data_test)) trees = SyntheticTreeSet(data_test) - trees.save(f'{DATA_PREPARED_DIR}/synthetic-trees-test.json.gz') + trees.save(out_dir / "synthetic-trees-test.json.gz") print("Finish!") From cce04baa34dd6422045b7eef7637881e7a740fdf Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 21 Sep 2022 11:50:34 -0400 Subject: [PATCH 158/302] tidy up + logging --- scripts/st_split.py | 42 +++++++++++++++++++----------------------- 1 file changed, 19 insertions(+), 23 deletions(-) diff --git a/scripts/st_split.py b/scripts/st_split.py index b1de5539..b62b22c6 100644 --- a/scripts/st_split.py +++ b/scripts/st_split.py @@ -3,7 +3,7 @@ """ from syn_net.utils.data_utils import SyntheticTreeSet from pathlib import Path -from syn_net.config import DATA_PREPROCESS_DIR, DATA_PREPARED_DIR, MAX_PROCESSES +from syn_net.config import DATA_PREPROCESS_DIR, MAX_PROCESSES import json import logging logger = logging.getLogger(__name__) @@ -39,16 +39,13 @@ def get_args(): args = get_args() logger.info(f"Arguments: {json.dumps(vars(args),indent=2)}") - # Load filtered synthetic trees - st_set = SyntheticTreeSet() - file = args.input_file - print(f'Reading data from {file}') - st_set.load(file) - data = st_set.sts - del st_set - num_total = len(data) - print(f"There are {len(data)} synthetic trees.") + logger.info(f'Reading data from {args.input_file}') + syntree_collection = SyntheticTreeSet().load(args.input_file) + syntrees = syntree_collection.sts + + num_total = len(syntrees) + logger.info(f"There are {len(syntrees)} synthetic trees.") # Split data SPLIT_RATIO = [0.6, 0.2, 0.2] @@ -57,23 +54,22 @@ def get_args(): num_valid = int(SPLIT_RATIO[1] * num_total) num_test = num_total - num_train - num_valid - data_train = data[:num_train] - data_valid = data[num_train: num_train + num_valid] - data_test = data[num_train + num_valid: ] + data_train = syntrees[:num_train] + data_valid = syntrees[num_train: num_train + num_valid] + data_test = syntrees[num_train + num_valid: ] # Save to local disk out_dir = Path(args.output_dir) out_dir.mkdir(parents=True,exist_ok=True) - print("Saving training dataset: ", len(data_train)) - trees = SyntheticTreeSet(data_train) - trees.save(out_dir / "synthetic-trees-train.json.gz") - print("Saving validation dataset: ", len(data_valid)) - trees = SyntheticTreeSet(data_valid) - trees.save(out_dir / "synthetic-trees-valid.json.gz") + logger.info(f"Saving training dataset. Number of syntrees: {len(data_train)}") + SyntheticTreeSet(data_train).save(out_dir / "synthetic-trees-train.json.gz") + + logger.info(f"Saving validation dataset. Number of syntrees: {len(data_valid)}") + SyntheticTreeSet(data_valid).save(out_dir / "synthetic-trees-valid.json.gz") + + logger.info(f"Saving testing dataset. Number of syntrees: {len(data_test)}") + SyntheticTreeSet(data_test).save(out_dir / "synthetic-trees-test.json.gz") - print("Saving testing dataset: ", len(data_test)) - trees = SyntheticTreeSet(data_test) - trees.save(out_dir / "synthetic-trees-test.json.gz") + logger.info(f"Completed.") - print("Finish!") From 3371263cc6915f287abad29767ece0fc41c1bfd2 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 21 Sep 2022 11:51:04 -0400 Subject: [PATCH 159/302] format --- scripts/st_split.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/scripts/st_split.py b/scripts/st_split.py index b62b22c6..9062374d 100644 --- a/scripts/st_split.py +++ b/scripts/st_split.py @@ -1,13 +1,16 @@ """ Reads synthetic tree data and splits it into training, validation and testing sets. """ -from syn_net.utils.data_utils import SyntheticTreeSet -from pathlib import Path -from syn_net.config import DATA_PREPROCESS_DIR, MAX_PROCESSES import json import logging +from pathlib import Path + +from syn_net.config import DATA_PREPROCESS_DIR, MAX_PROCESSES +from syn_net.utils.data_utils import SyntheticTreeSet + logger = logging.getLogger(__name__) + def get_args(): import argparse @@ -40,7 +43,7 @@ def get_args(): logger.info(f"Arguments: {json.dumps(vars(args),indent=2)}") # Load filtered synthetic trees - logger.info(f'Reading data from {args.input_file}') + logger.info(f"Reading data from {args.input_file}") syntree_collection = SyntheticTreeSet().load(args.input_file) syntrees = syntree_collection.sts @@ -55,12 +58,12 @@ def get_args(): num_test = num_total - num_train - num_valid data_train = syntrees[:num_train] - data_valid = syntrees[num_train: num_train + num_valid] - data_test = syntrees[num_train + num_valid: ] + data_valid = syntrees[num_train : num_train + num_valid] + data_test = syntrees[num_train + num_valid :] # Save to local disk out_dir = Path(args.output_dir) - out_dir.mkdir(parents=True,exist_ok=True) + out_dir.mkdir(parents=True, exist_ok=True) logger.info(f"Saving training dataset. Number of syntrees: {len(data_train)}") SyntheticTreeSet(data_train).save(out_dir / "synthetic-trees-train.json.gz") @@ -72,4 +75,3 @@ def get_args(): SyntheticTreeSet(data_test).save(out_dir / "synthetic-trees-test.json.gz") logger.info(f"Completed.") - From df8f336c9398c64b0bf0a436dccbf845dfa3be66 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 21 Sep 2022 11:51:54 -0400 Subject: [PATCH 160/302] rename --- scripts/{st_split.py => 05-split-syntrees.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename scripts/{st_split.py => 05-split-syntrees.py} (100%) diff --git a/scripts/st_split.py b/scripts/05-split-syntrees.py similarity index 100% rename from scripts/st_split.py rename to scripts/05-split-syntrees.py From be9e45e908b42c219cbade8a053e0443e6d14cdd Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 21 Sep 2022 11:53:33 -0400 Subject: [PATCH 161/302] update instructions up to 5 --- INSTRUCTIONS.md | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/INSTRUCTIONS.md b/INSTRUCTIONS.md index 429a3368..e7087405 100644 --- a/INSTRUCTIONS.md +++ b/INSTRUCTIONS.md @@ -63,7 +63,7 @@ Let's start. python scripts/03-generate-syntrees.py \ --building-blocks-file "data/pre-process/building-blocks/enamine-us-smiles.csv.gz" \ --rxn-templates-file "data/assets/reaction-templates/hb.txt" \ - --output-file "data/pre-process/synthetic-trees.json.gz" \ + --output-file "data/pre-process/synthetic-trees.json.gz" \ --number-syntrees 600000 ``` @@ -86,17 +86,19 @@ Let's start. The default split ratio is 6:2:2. ```bash - python scripts/05-st_split.py + python scripts/st_split.py \ + --input-file "data/pre-process/synthetic-trees-filtered.json.gz" + --output-dir "data/pre-process/split" ``` 4. Featurization and - > :bulb: All following steps depend on the representations for the data. Hence, you have to specify the parameters for the reprensations as input argument for most of the scripts so that it can operate on the right data. + > :bulb: All following steps depend on the representations for the data. Hence, you have to specify the parameters for the representations as input argument for most of the scripts so that it can operate on the right data. We organize each *synthetic tree* into states and actions. That is, we break down each tree to the action at each iteration ("Add", "Expand", "Extend", "End") and a corresponding "super state" vector. We call it "super state" here, as it contains all states for all networks. - However, recall that the input that the state vector is different for each network. + However, recall that the input state vector is different for each network. ```bash python scripts/06-st2steps.py From 79f2732f9cd2392d601af87c1b367dcc10705697 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 21 Sep 2022 13:18:17 -0400 Subject: [PATCH 162/302] fix filename --- INSTRUCTIONS.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/INSTRUCTIONS.md b/INSTRUCTIONS.md index e7087405..941662e8 100644 --- a/INSTRUCTIONS.md +++ b/INSTRUCTIONS.md @@ -86,7 +86,7 @@ Let's start. The default split ratio is 6:2:2. ```bash - python scripts/st_split.py \ + python scripts/05-split-syntrees.py \ --input-file "data/pre-process/synthetic-trees-filtered.json.gz" --output-dir "data/pre-process/split" ``` From 08befe043a339b684b36869ee0e0efd1711d6483 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 21 Sep 2022 13:23:07 -0400 Subject: [PATCH 163/302] separate argparse from main --- scripts/st2steps.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/scripts/st2steps.py b/scripts/st2steps.py index eb07ee62..68baeb55 100644 --- a/scripts/st2steps.py +++ b/scripts/st2steps.py @@ -3,6 +3,7 @@ """ from pathlib import Path from tqdm import tqdm +import json from scipy import sparse from syn_net.utils.data_utils import SyntheticTreeSet from syn_net.utils.prep_utils import organize @@ -12,8 +13,7 @@ from syn_net.config import DATA_PREPARED_DIR, DATA_FEATURIZED_DIR -if __name__ == '__main__': - +def get_args():# import argparse parser = argparse.ArgumentParser() parser.add_argument("-e", "--targetembedding", type=str, default='fp', @@ -28,8 +28,16 @@ help="Choose from ['train', 'valid', 'test']") parser.add_argument("-rxn", "--rxn_template", type=str, default='hb', choices=["hb","pis"], help="Choose from ['hb', 'pis']") - args = parser.parse_args() - logger.info(vars(args)) + return parser.parse_args() + +if __name__ == '__main__': + logger.info("Start.") + + # Parse input args + args = get_args() + logger.info(f"Arguments: {json.dumps(vars(args),indent=2)}") + + # Parse & set inputs reaction_template_id = args.rxn_template @@ -45,7 +53,7 @@ logger.info("Number of synthetic trees: {len(st_set.sts}") data: list = st_set.sts del st_set - + # Set output directory save_dir = Path(DATA_FEATURIZED_DIR) / f'{reaction_template_id}_{embedding}_{args.radius}_{args.nbits}_{args.outputembedding}/' Path(save_dir).mkdir(parents=1,exist_ok=1) @@ -67,7 +75,7 @@ steps.append(step) - # Finally, save. + # Finally, save. logger.info(f"Saving to {save_dir}") states = sparse.vstack(states) steps = sparse.vstack(steps) From 55108b840c13e0450ee3cc7bc849b431dea96f16 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 21 Sep 2022 13:43:26 -0400 Subject: [PATCH 164/302] add type hints --- src/syn_net/utils/prep_utils.py | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/src/syn_net/utils/prep_utils.py b/src/syn_net/utils/prep_utils.py index 9811dd27..97ce216b 100644 --- a/src/syn_net/utils/prep_utils.py +++ b/src/syn_net/utils/prep_utils.py @@ -1,7 +1,7 @@ """ This file contains various utils for data preparation and preprocessing. """ -from typing import Iterator, Union +from typing import Iterator, Union, Tuple import numpy as np from scipy import sparse from sklearn.preprocessing import OneHotEncoder @@ -45,23 +45,18 @@ def _fetch_gin_pretrained_model(model_name: str): return model -def organize(st, d_mol=300, target_embedding='fp', radius=2, nBits=4096, - output_embedding='gin'): +def organize(st: SyntheticTree, d_mol: int=300, target_embedding: str='fp', radius: int=2, nBits:int=4096, + output_embedding: str ='gin') -> Tuple(sparse.csc_matrix,sparse.csc_matrix): """ - Organizes the states and steps from the input synthetic tree into sparse - matrices. + Organizes synthetic trees into states and node states at each step into sparse matrices. Args: - st (SyntheticTree): The input synthetic tree to organize. - d_mol (int, optional): The molecular embedding size. Defaults to 300. - target_embedding (str, optional): Indicates what embedding type to use - for the input target (Morgan fingerprint --> 'fp' or GIN --> 'gin'). - Defaults to 'fp'. - radius (int, optional): Morgan fingerprint radius to use. Defaults to 2. - nBits (int, optional): Number of bits to use in the Morgan fingerprints. - Defaults to 4096. - output_embedding (str, optional): Indicates what type of embedding to - use for the output node states. Defaults to 'gin'. + st: Synthetic tree to organize + d_mol: The molecular embedding size. Defaults to 300 + target_embedding: Embedding for the input node states. + radius: (if Morgan fingerprint) radius + nBits: (if Morgan fingerprint) bits + output_embedding: Embedding for the output node states Raises: ValueError: Raised if target embedding not supported. From bc68f29baaaaf0ecb53474fadbe114887c792844 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 21 Sep 2022 13:43:39 -0400 Subject: [PATCH 165/302] remove unused imports --- scripts/04-filter-synthetic-trees.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/scripts/04-filter-synthetic-trees.py b/scripts/04-filter-synthetic-trees.py index cbfc90ec..486ce552 100644 --- a/scripts/04-filter-synthetic-trees.py +++ b/scripts/04-filter-synthetic-trees.py @@ -3,14 +3,13 @@ import json import logging -from collections import Counter -from pathlib import Path + import numpy as np from rdkit import Chem, RDLogger from tqdm import tqdm -from syn_net.config import DATA_PREPROCESS_DIR, MAX_PROCESSES +from syn_net.config import MAX_PROCESSES from syn_net.utils.data_utils import SyntheticTree, SyntheticTreeSet logger = logging.getLogger(__name__) From d0f45c57be07f37f20c1f1d249f6f7e317f1a87b Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 21 Sep 2022 13:44:02 -0400 Subject: [PATCH 166/302] fix: move import statement into if branch --- src/syn_net/utils/prep_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/syn_net/utils/prep_utils.py b/src/syn_net/utils/prep_utils.py index 97ce216b..2ff0e72e 100644 --- a/src/syn_net/utils/prep_utils.py +++ b/src/syn_net/utils/prep_utils.py @@ -9,7 +9,6 @@ from syn_net.utils.predict_utils import (can_react, get_action_mask, get_reaction_mask, mol_fp, ) -from syn_net.encoding.gins import get_mol_embedding from pathlib import Path from rdkit import Chem @@ -87,6 +86,7 @@ def organize(st: SyntheticTree, d_mol: int=300, target_embedding: str='fp', radi if target_embedding == 'fp': target = mol_fp(st.root.smiles, radius, nBits).tolist() elif target_embedding == 'gin': + from syn_net.encoding.gins import get_mol_embedding # define model to use for molecular embedding target = get_mol_embedding(st.root.smiles, model=model).tolist() else: From 4b7dbd31ca58dbb4586d1fd7e498926d73380fc3 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 21 Sep 2022 13:44:54 -0400 Subject: [PATCH 167/302] fix --- src/syn_net/utils/prep_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/syn_net/utils/prep_utils.py b/src/syn_net/utils/prep_utils.py index 2ff0e72e..d82023c2 100644 --- a/src/syn_net/utils/prep_utils.py +++ b/src/syn_net/utils/prep_utils.py @@ -45,7 +45,7 @@ def _fetch_gin_pretrained_model(model_name: str): def organize(st: SyntheticTree, d_mol: int=300, target_embedding: str='fp', radius: int=2, nBits:int=4096, - output_embedding: str ='gin') -> Tuple(sparse.csc_matrix,sparse.csc_matrix): + output_embedding: str ='gin') -> Tuple[sparse.csc_matrix,sparse.csc_matrix]: """ Organizes synthetic trees into states and node states at each step into sparse matrices. @@ -100,7 +100,7 @@ def organize(st: SyntheticTree, d_mol: int=300, target_embedding: str='fp', radi other_root_mol_embedding = mol_fp(other_root_mol, radius, nBits).tolist() state = most_recent_mol_embedding + other_root_mol_embedding + target # (3d,1) - if action == 3: + if action == 3: #end step = [3] + [0]*d_mol + [-1] + [0]*d_mol + [0]*nBits else: From b92e9329b3ea9fe02ae00a4af80409f22c9d5045 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 21 Sep 2022 13:47:42 -0400 Subject: [PATCH 168/302] bugfix --- scripts/st2steps.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/scripts/st2steps.py b/scripts/st2steps.py index 68baeb55..e72e2f61 100644 --- a/scripts/st2steps.py +++ b/scripts/st2steps.py @@ -37,8 +37,6 @@ def get_args():# args = get_args() logger.info(f"Arguments: {json.dumps(vars(args),indent=2)}") - - # Parse & set inputs reaction_template_id = args.rxn_template building_blocks_id = "enamine_us-2021-smiles" @@ -69,7 +67,7 @@ def get_args():# nBits=args.nbits, output_embedding=args.outputembedding) except Exception as e: - logger.exception(exc_info=e) + logger.exception(e,exc_info=e) continue states.append(state) steps.append(step) From 0e533ffd060ea78b34a822808cbea2978a4b84fb Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 21 Sep 2022 14:55:04 -0400 Subject: [PATCH 169/302] fix: direct import, not via 3rd file --- src/syn_net/utils/prep_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/syn_net/utils/prep_utils.py b/src/syn_net/utils/prep_utils.py index d82023c2..c6789fbb 100644 --- a/src/syn_net/utils/prep_utils.py +++ b/src/syn_net/utils/prep_utils.py @@ -7,8 +7,8 @@ from sklearn.preprocessing import OneHotEncoder from syn_net.utils.data_utils import Reaction, SyntheticTree from syn_net.utils.predict_utils import (can_react, get_action_mask, - get_reaction_mask, mol_fp, - ) + get_reaction_mask, ) +from syn_net.encoding.fingerprints import mol_fp from pathlib import Path from rdkit import Chem From bcd39d7cafcda7694608d463d5fff57f64843f2b Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Thu, 22 Sep 2022 11:13:46 -0400 Subject: [PATCH 170/302] rewrite `organize()` as class --- src/syn_net/data_generation/syntrees.py | 137 +++++++++++++++++++++++- 1 file changed, 135 insertions(+), 2 deletions(-) diff --git a/src/syn_net/data_generation/syntrees.py b/src/syn_net/data_generation/syntrees.py index de85531d..03cfe441 100644 --- a/src/syn_net/data_generation/syntrees.py +++ b/src/syn_net/data_generation/syntrees.py @@ -5,6 +5,7 @@ import numpy as np from rdkit import Chem +from scipy import sparse from tqdm import tqdm from syn_net.config import MAX_PROCESSES @@ -41,6 +42,7 @@ class NoReactionPossibleError(Exception): def __init__(self, message): super().__init__(message) + class MaxDepthError(Exception): """Synthetic Tree has exceeded its maximum depth.""" @@ -192,7 +194,7 @@ def _get_action_mask(self, syntree: SyntheticTree): elif nTrees == 1: canAdd = True canExpand = True - canEnd = True # TODO: When syntree has reached max depth, only allow to end it. + canEnd = True # TODO: When syntree has reached max depth, only allow to end it. elif nTrees == 2: canExpand = True canMerge = any(self._get_rxn_mask(tuple(state))) @@ -280,7 +282,7 @@ def generate(self, max_depth: int = 15, retries: int = 3): if action == "end": break - if i==max_depth-1 and not action == "end": + if i == max_depth - 1 and not action == "end": raise MaxDepthError("Maximum depth {max_depth} exceeded.") logger.debug(f"🙌 SynTree completed.") return syntree @@ -329,3 +331,134 @@ def save_syntreegenerator(syntreegenerator: SynTreeGenerator, file: str) -> None with open(file, "wb") as f: pickle.dump(syntreegenerator, f) + + +# TODO: Move all these encoders to "from syn_net.encoding/" +# TODO: Evaluate if One-Hot-Encoder can be replaced with encoder from sklearn +class OneHotEncoder: + def __init__(self, d: int) -> None: + self.d = d + + def encode(self, ind: int, datatype: np.dtype = np.float64) -> np.ndarray: + """Returns a (1,d)-array with zeros and a 1 at index `ind`.""" + onehot = np.zeros((1, self.d), dtype=datatype) # (1,d) + onehot[0, ind] = 1.0 + return onehot # (1,d) + + +class MorganFingerprintEncoder: + def __init__(self, radius: int, nbits: int) -> None: + self.radius = radius + self.nbits = nbits + + def encode(self, smi: str) -> np.ndarray: + if smi is None: + fp = np.zeros((1, self.nbits)) # (1,d) + else: + mol = Chem.MolFromSmiles(smi) # TODO: sanity check mol here or use datmol? + bv = Chem.AllChem.GetMorganFingerprintAsBitVect(mol, self.radius, self.nbits) + fp = np.empty(self.nbits) + Chem.DataStructs.ConvertToNumpyArray(bv, fp) + fp = fp[None, :] + return fp + + +class IdentityIntEncoder: + def __init__(self) -> None: + pass + + def encode(self, number: int): + return np.atleast_2d(number) + + +class SynTreeFeaturizer: + def __init__(self) -> None: + # Embedders + self.reactant_embedder = MorganFingerprintEncoder(2, 256) + self.mol_embedder = MorganFingerprintEncoder(2, 4096) + self.rxn_embedder = IdentityIntEncoder() + self.action_embedder = IdentityIntEncoder() + + def featurize(self, syntree: SyntheticTree): + """Featurize a synthetic tree at every state. + + Note: + - At each iteration of the syntree growth, an action is chosen + - Every action (except "end") comes with a reaction. + - For every action, we compute: + - a "state" + - a "step", a vector that encompasses all info we need for training the neural nets. + This step is: [action, z_rt1, reaction_id, z_rt2, z_root_mol_1] + """ + + states, steps = [], [] + + target_mol = syntree.root.smiles + z_target_mol = self.mol_embedder.encode(target_mol) + + # Recall: We can have at most 2 sub-trees, each with a root node. + root_mol_1 = None + root_mol_2 = None + for i, action in enumerate(syntree.actions): + + # 1. Encode "state" + z_root_mol_1 = self.mol_embedder.encode(root_mol_1) + z_root_mol_2 = self.mol_embedder.encode(root_mol_2) + state = np.concatenate((z_root_mol_1, z_root_mol_2, z_target_mol), axis=1) # (1,3d) + + # 2. Encode "super"-step + if action == 3: # end + step = np.concatenate( + ( + self.action_embedder.encode(action), + self.reactant_embedder.encode(mol1), + self.rxn_embedder.encode(rxn_node.rxn_id), + self.reactant_embedder.encode(mol2), + self.mol_embedder.encode(mol1), + ), + axis=1, + ) + else: + rxn_node = syntree.reactions[i] + + if len(rxn_node.child) == 1: + mol1 = rxn_node.child[0] + mol2 = None + elif len(rxn_node.child) == 2: + mol1 = rxn_node.child[0] + mol2 = rxn_node.child[1] + else: # TODO: Change `child` is stored in reaction node so we can just unpack via * + raise ValueError() + + step = np.concatenate( + ( + self.action_embedder.encode(action), + self.reactant_embedder.encode(mol1), + self.rxn_embedder.encode(rxn_node.rxn_id), + self.reactant_embedder.encode(mol2), + self.mol_embedder.encode(mol1), + ), + axis=1, + ) + + # 3. Prepare next iteration + if action == 2: # merge + root_mol_1 = rxn_node.parent + root_mol_2 = None + + elif action == 1: # expand + root_mol_1 = rxn_node.parent + + elif action == 0: # add + root_mol_2 = root_mol_1 + root_mol_1 = rxn_node.parent + + # 4. Keep track of data + states.append(state) + steps.append(step) + + # Some housekeeping on dimensions + states = np.atleast_2d(np.asarray(states).squeeze()) + steps = np.atleast_2d(np.asarray(steps).squeeze()) + + return sparse.csc_matrix(states), sparse.csc_matrix(steps) From 3c7dc278a01d466ae74a9e0cb34cbc948d88e79c Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 26 Sep 2022 11:16:49 -0400 Subject: [PATCH 171/302] replace `organize` with new `SynTreeFeaturizer` --- scripts/st2steps.py | 56 ++++++++++++++++++++++++--------------------- 1 file changed, 30 insertions(+), 26 deletions(-) diff --git a/scripts/st2steps.py b/scripts/st2steps.py index e72e2f61..7e21d4c2 100644 --- a/scripts/st2steps.py +++ b/scripts/st2steps.py @@ -6,14 +6,12 @@ import json from scipy import sparse from syn_net.utils.data_utils import SyntheticTreeSet -from syn_net.utils.prep_utils import organize - +from syn_net.data_generation.syntrees import SynTreeFeaturizer import logging logger = logging.getLogger(__file__) -from syn_net.config import DATA_PREPARED_DIR, DATA_FEATURIZED_DIR - -def get_args():# +from syn_net.config import DATA_FEATURIZED_DIR +def get_args(): import argparse parser = argparse.ArgumentParser() parser.add_argument("-e", "--targetembedding", type=str, default='fp', @@ -28,50 +26,55 @@ def get_args():# help="Choose from ['train', 'valid', 'test']") parser.add_argument("-rxn", "--rxn_template", type=str, default='hb', choices=["hb","pis"], help="Choose from ['hb', 'pis']") + # File I/O + parser.add_argument( + "--input-file", + type=str, + default="data/pre-process/split/synthetic-trees-valid.json.gz", # TODO think about filename vs dir + help="Input file for the splitted generated synthetic trees (*.json.gz)", + ) + parser.add_argument( + "--output-dir", + type=str, + default=str(Path(DATA_FEATURIZED_DIR)) + "debug-newversion", + help="Output directory for the splitted synthetic trees (*.json.gz)", + ) return parser.parse_args() +def _extract_dataset(filename: str) -> str: + stem = Path(filename).stem.split(".")[0] + return stem.split("-")[-1] # TODO: avoid hard coding + if __name__ == '__main__': logger.info("Start.") # Parse input args args = get_args() logger.info(f"Arguments: {json.dumps(vars(args),indent=2)}") + dataset_type = _extract_dataset(args.input_file) - # Parse & set inputs - reaction_template_id = args.rxn_template - building_blocks_id = "enamine_us-2021-smiles" - dataset_type = args.datasettype - embedding = args.targetembedding - assert dataset_type is not None, "Must specify which dataset to use." - - # Load synthetic trees subset {train,valid,test} - file = f'{DATA_PREPARED_DIR}/synthetic-trees-{dataset_type}.json.gz' - st_set = SyntheticTreeSet() - st_set.load(file) - logger.info("Number of synthetic trees: {len(st_set.sts}") + st_set = SyntheticTreeSet().load(args.input_file) + logger.info(f"Number of synthetic trees: {len(st_set.sts)}") data: list = st_set.sts del st_set - # Set output directory - save_dir = Path(DATA_FEATURIZED_DIR) / f'{reaction_template_id}_{embedding}_{args.radius}_{args.nbits}_{args.outputembedding}/' - Path(save_dir).mkdir(parents=1,exist_ok=1) - # Start splitting synthetic trees in states and steps states = [] steps = [] - + stf = SynTreeFeaturizer() for st in tqdm(data): try: - state, step = organize(st, target_embedding=embedding, - radius=args.radius, - nBits=args.nbits, - output_embedding=args.outputembedding) + state, step = stf.featurize(st) except Exception as e: logger.exception(e,exc_info=e) continue states.append(state) steps.append(step) + # Set output directory + save_dir = Path(args.output_dir) / "hb_fp_2_4096_fp_256" # TODO: Save info as json in dir? + Path(save_dir).mkdir(parents=1,exist_ok=1) + dataset_type = _extract_dataset(args.input_file) # Finally, save. logger.info(f"Saving to {save_dir}") @@ -81,4 +84,5 @@ def get_args():# sparse.save_npz(save_dir / f"steps_{dataset_type}.npz", steps) logger.info("Save successful.") + logger.info("Completed.") From 5478827abdcfed36d8b510f566e0eaec10025d99 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 26 Sep 2022 11:17:41 -0400 Subject: [PATCH 172/302] format --- scripts/st2steps.py | 70 ++++++++++++++++++++++++++++++--------------- 1 file changed, 47 insertions(+), 23 deletions(-) diff --git a/scripts/st2steps.py b/scripts/st2steps.py index 7e21d4c2..f52c1722 100644 --- a/scripts/st2steps.py +++ b/scripts/st2steps.py @@ -1,36 +1,59 @@ """ Splits a synthetic tree into states and steps. """ -from pathlib import Path -from tqdm import tqdm import json +import logging +from pathlib import Path + from scipy import sparse -from syn_net.utils.data_utils import SyntheticTreeSet +from tqdm import tqdm + from syn_net.data_generation.syntrees import SynTreeFeaturizer -import logging +from syn_net.utils.data_utils import SyntheticTreeSet + logger = logging.getLogger(__file__) from syn_net.config import DATA_FEATURIZED_DIR + + def get_args(): import argparse + parser = argparse.ArgumentParser() - parser.add_argument("-e", "--targetembedding", type=str, default='fp', - help="Choose from ['fp', 'gin']") - parser.add_argument("-o", "--outputembedding", type=str, default='fp_256', - help="Choose from ['fp_4096', 'fp_256', 'gin', 'rdkit2d']") - parser.add_argument("-r", "--radius", type=int, default=2, - help="Radius for Morgan Fingerprint") - parser.add_argument("-b", "--nbits", type=int, default=4096, - help="Number of Bits for Morgan Fingerprint") - parser.add_argument("-d", "--datasettype", type=str, choices=["train","valid","test"], - help="Choose from ['train', 'valid', 'test']") - parser.add_argument("-rxn", "--rxn_template", type=str, default='hb', choices=["hb","pis"], - help="Choose from ['hb', 'pis']") + parser.add_argument( + "-e", "--targetembedding", type=str, default="fp", help="Choose from ['fp', 'gin']" + ) + parser.add_argument( + "-o", + "--outputembedding", + type=str, + default="fp_256", + help="Choose from ['fp_4096', 'fp_256', 'gin', 'rdkit2d']", + ) + parser.add_argument("-r", "--radius", type=int, default=2, help="Radius for Morgan Fingerprint") + parser.add_argument( + "-b", "--nbits", type=int, default=4096, help="Number of Bits for Morgan Fingerprint" + ) + parser.add_argument( + "-d", + "--datasettype", + type=str, + choices=["train", "valid", "test"], + help="Choose from ['train', 'valid', 'test']", + ) + parser.add_argument( + "-rxn", + "--rxn_template", + type=str, + default="hb", + choices=["hb", "pis"], + help="Choose from ['hb', 'pis']", + ) # File I/O parser.add_argument( "--input-file", type=str, - default="data/pre-process/split/synthetic-trees-valid.json.gz", # TODO think about filename vs dir + default="data/pre-process/split/synthetic-trees-valid.json.gz", # TODO think about filename vs dir help="Input file for the splitted generated synthetic trees (*.json.gz)", ) parser.add_argument( @@ -41,11 +64,13 @@ def get_args(): ) return parser.parse_args() + def _extract_dataset(filename: str) -> str: stem = Path(filename).stem.split(".")[0] - return stem.split("-")[-1] # TODO: avoid hard coding + return stem.split("-")[-1] # TODO: avoid hard coding + -if __name__ == '__main__': +if __name__ == "__main__": logger.info("Start.") # Parse input args @@ -66,14 +91,14 @@ def _extract_dataset(filename: str) -> str: try: state, step = stf.featurize(st) except Exception as e: - logger.exception(e,exc_info=e) + logger.exception(e, exc_info=e) continue states.append(state) steps.append(step) # Set output directory - save_dir = Path(args.output_dir) / "hb_fp_2_4096_fp_256" # TODO: Save info as json in dir? - Path(save_dir).mkdir(parents=1,exist_ok=1) + save_dir = Path(args.output_dir) / "hb_fp_2_4096_fp_256" # TODO: Save info as json in dir? + Path(save_dir).mkdir(parents=1, exist_ok=1) dataset_type = _extract_dataset(args.input_file) # Finally, save. @@ -85,4 +110,3 @@ def _extract_dataset(filename: str) -> str: logger.info("Save successful.") logger.info("Completed.") - From 26285ce5f5fb16f0153fc23c05ddb903c472580d Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 26 Sep 2022 11:19:29 -0400 Subject: [PATCH 173/302] rename --- scripts/{st2steps.py => 06-featurize-syntrees.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename scripts/{st2steps.py => 06-featurize-syntrees.py} (100%) diff --git a/scripts/st2steps.py b/scripts/06-featurize-syntrees.py similarity index 100% rename from scripts/st2steps.py rename to scripts/06-featurize-syntrees.py From c3b123510301c027379acbd711f3cc747599e2d1 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 26 Sep 2022 11:28:55 -0400 Subject: [PATCH 174/302] update instructions --- INSTRUCTIONS.md | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/INSTRUCTIONS.md b/INSTRUCTIONS.md index 941662e8..2b58f217 100644 --- a/INSTRUCTIONS.md +++ b/INSTRUCTIONS.md @@ -91,24 +91,28 @@ Let's start. --output-dir "data/pre-process/split" ``` -4. Featurization and +4. Featurization > :bulb: All following steps depend on the representations for the data. Hence, you have to specify the parameters for the representations as input argument for most of the scripts so that it can operate on the right data. - We organize each *synthetic tree* into states and actions. - That is, we break down each tree to the action at each iteration ("Add", "Expand", "Extend", "End") and a corresponding "super state" vector. - We call it "super state" here, as it contains all states for all networks. - However, recall that the input state vector is different for each network. + We featurize each *synthetic tree*. + That is, we break down each tree to each iteration step ("Add", "Expand", "Extend", "End") and featurize it. + This results in a "state" vector and a a corresponding "super step" vector. + We call it "super step" here, as it contains all (featurized) data for all networks. ```bash - python scripts/06-st2steps.py + python scripts/06-featurize-syntrees.py \ + --input-file "data/pre-process/split/synthetic-trees-train.json.gz" # or {train,valid,test} + --output-dir "data/featurized" ``` + This script will load the `input-file`, featurize it, and it in `/states_{train,valid,test}.np` and `/steps_{train,valid,test}.np`. + 5. Split features Up to this point, we worked with a (featurized) *synthetic tree* as a whole, - now we split it up to into "consumable" input data for each of the four networks. - This includes picking the right state(s) from the "super state" vector from the previous step. + now we split it up to into "consumable" input/output data for each of the four networks. + This includes picking the right featurized data from the "super step" vector from the previous step. ```bash python scripts/07-prepare_data.py From 4c40b560f0af8341a2742251baa6a30ebfa06169 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 26 Sep 2022 11:31:45 -0400 Subject: [PATCH 175/302] fix numbering --- INSTRUCTIONS.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/INSTRUCTIONS.md b/INSTRUCTIONS.md index 2b58f217..a5c5b808 100644 --- a/INSTRUCTIONS.md +++ b/INSTRUCTIONS.md @@ -79,7 +79,7 @@ Let's start. Each *synthetic tree* is serializable and so we save all trees in a compressed `.json` file. -3. Split *synthetic trees* into train,valid,test-data +5. Split *synthetic trees* into train,valid,test-data We load the `.json`-file with all *synthetic trees* and straightforward split it into three files: `{train,test,valid}.json`. @@ -91,7 +91,7 @@ Let's start. --output-dir "data/pre-process/split" ``` -4. Featurization +6. Featurization > :bulb: All following steps depend on the representations for the data. Hence, you have to specify the parameters for the representations as input argument for most of the scripts so that it can operate on the right data. @@ -108,7 +108,7 @@ Let's start. This script will load the `input-file`, featurize it, and it in `/states_{train,valid,test}.np` and `/steps_{train,valid,test}.np`. -5. Split features +7. Split features Up to this point, we worked with a (featurized) *synthetic tree* as a whole, now we split it up to into "consumable" input/output data for each of the four networks. @@ -118,7 +118,7 @@ Let's start. python scripts/07-prepare_data.py ``` -6. Train the networks +8. Train the networks Finally, we can train each of the four networks in `src/syn_net/models/` separately: From a03b852727f4b6629e259a6769807936081bbb25 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 26 Sep 2022 11:32:45 -0400 Subject: [PATCH 176/302] delete unused cli args --- scripts/06-featurize-syntrees.py | 29 ----------------------------- 1 file changed, 29 deletions(-) diff --git a/scripts/06-featurize-syntrees.py b/scripts/06-featurize-syntrees.py index f52c1722..9f1a93ce 100644 --- a/scripts/06-featurize-syntrees.py +++ b/scripts/06-featurize-syntrees.py @@ -20,35 +20,6 @@ def get_args(): import argparse parser = argparse.ArgumentParser() - parser.add_argument( - "-e", "--targetembedding", type=str, default="fp", help="Choose from ['fp', 'gin']" - ) - parser.add_argument( - "-o", - "--outputembedding", - type=str, - default="fp_256", - help="Choose from ['fp_4096', 'fp_256', 'gin', 'rdkit2d']", - ) - parser.add_argument("-r", "--radius", type=int, default=2, help="Radius for Morgan Fingerprint") - parser.add_argument( - "-b", "--nbits", type=int, default=4096, help="Number of Bits for Morgan Fingerprint" - ) - parser.add_argument( - "-d", - "--datasettype", - type=str, - choices=["train", "valid", "test"], - help="Choose from ['train', 'valid', 'test']", - ) - parser.add_argument( - "-rxn", - "--rxn_template", - type=str, - default="hb", - choices=["hb", "pis"], - help="Choose from ['hb', 'pis']", - ) # File I/O parser.add_argument( "--input-file", From c575b5b739f6ebf0c6c3f3ab666f38a4bc568fa9 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 26 Sep 2022 11:32:55 -0400 Subject: [PATCH 177/302] add comments --- src/syn_net/utils/prep_utils.py | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/src/syn_net/utils/prep_utils.py b/src/syn_net/utils/prep_utils.py index c6789fbb..b32133c5 100644 --- a/src/syn_net/utils/prep_utils.py +++ b/src/syn_net/utils/prep_utils.py @@ -243,12 +243,8 @@ def synthetic_tree_generator( return tree, action -def prep_data(main_dir, num_rxn, out_dim, datasets=None): - """ - Loads the states and steps from preprocessed *.npz files and saves data - specific to the Action, Reactant 1, Reaction, and Reactant 2 networks in - their own *.npz files. - +def prep_data(main_dir: str, num_rxn: int, out_dim: int, datasets=None): + """Split the featurized data into X,y-chunks for the {act,rt1,rxn,rt2}-networks. Args: main_dir (str): The path to the directory containing the *.npz files. num_rxn (int): Number of reactions in the dataset. @@ -270,17 +266,24 @@ def prep_data(main_dir, num_rxn, out_dim, datasets=None): states = sparse.csc_matrix(sparse.vstack(states_list)) steps = sparse.csc_matrix(sparse.vstack(steps_list)) - # extract Action data + # Extract data for each network... + + # ... action data + # X: [z_state] + # y: [action id] (int) X = states y = steps[:, 0] sparse.save_npz(main_dir / f'X_act_{dataset}.npz', X) sparse.save_npz(main_dir / f'y_act_{dataset}.npz', y) + print(f' saved data for "Action"') + # Delete all data where tree was ended (i.e. tree expansion did not trigger reaction) states = sparse.csc_matrix(states.A[(steps[:, 0].A != 3).reshape(-1, )]) steps = sparse.csc_matrix(steps.A[(steps[:, 0].A != 3).reshape(-1, )]) - print(f' saved data for "Action"') - # extract Reaction data + # ... reaction data + # X: [state, z_reactant_1] + # y: [reaction_id] (int) X = sparse.hstack([states, steps[:, (2 * out_dim + 2):]]) y = steps[:, out_dim + 1] sparse.save_npz(main_dir / f'X_rxn_{dataset}.npz', X) @@ -292,9 +295,10 @@ def prep_data(main_dir, num_rxn, out_dim, datasets=None): enc = OneHotEncoder(handle_unknown='ignore') enc.fit([[i] for i in range(num_rxn)]) - # import ipdb; ipdb.set_trace(context=9) - # extract Reactant 2 data + # ... reactant 2 data + # X: [z_state, z_reactant_1, reaction_id] + # y: [z'_reactant_2] X = sparse.hstack( [states, steps[:, (2 * out_dim + 2):], @@ -308,7 +312,9 @@ def prep_data(main_dir, num_rxn, out_dim, datasets=None): states = sparse.csc_matrix(states.A[(steps[:, 0].A != 1).reshape(-1, )]) steps = sparse.csc_matrix(steps.A[(steps[:, 0].A != 1).reshape(-1, )]) - # extract Reactant 1 data + # ... reactant 1 data + # X: [z_state] + # y: [z'_reactant_1] X = states y = steps[:, 1: (out_dim+1)] sparse.save_npz(main_dir / f'X_rt1_{dataset}.npz', X) From 7ef7cbc155e5aba1384e9c6adf2aca72447cc5e2 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 26 Sep 2022 11:43:06 -0400 Subject: [PATCH 178/302] move script to `scripts/` --- {src/syn_net/models => scripts}/prepare_data.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename {src/syn_net/models => scripts}/prepare_data.py (100%) diff --git a/src/syn_net/models/prepare_data.py b/scripts/prepare_data.py similarity index 100% rename from src/syn_net/models/prepare_data.py rename to scripts/prepare_data.py From 1036ffbd4f7f390d2975bb6dcc9a61c9e2e5fead Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 26 Sep 2022 11:53:54 -0400 Subject: [PATCH 179/302] move argparse stuff to fct --- scripts/prepare_data.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/scripts/prepare_data.py b/scripts/prepare_data.py index d34cdb57..80caa694 100644 --- a/scripts/prepare_data.py +++ b/scripts/prepare_data.py @@ -11,10 +11,8 @@ logger = logging.getLogger(__file__) -if __name__ == "__main__": - +def get_args(): import argparse - parser = argparse.ArgumentParser() parser.add_argument( "-e", "--targetembedding", type=str, default="fp", help="Choose from ['fp', 'gin']" @@ -38,8 +36,16 @@ choices=["hb", "pis"], help="Choose from ['hb', 'pis']", ) + return parser.parse_args() + +if __name__ == "__main__": + logger.info("Start.") + # Parse input args + args = get_args() + logger.info(f"Arguments: {vars(args)}") + import argparse + - args = parser.parse_args() reaction_template_id = args.rxn_template embedding = args.targetembedding output_emb = args.outputembedding From d14f88f139836c921eeaefaba9650879fe9668f0 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 26 Sep 2022 11:55:01 -0400 Subject: [PATCH 180/302] delete unused cli args, dumb-down for now --- scripts/prepare_data.py | 54 +++++++---------------------------------- 1 file changed, 9 insertions(+), 45 deletions(-) diff --git a/scripts/prepare_data.py b/scripts/prepare_data.py index 80caa694..d02c894a 100644 --- a/scripts/prepare_data.py +++ b/scripts/prepare_data.py @@ -15,26 +15,10 @@ def get_args(): import argparse parser = argparse.ArgumentParser() parser.add_argument( - "-e", "--targetembedding", type=str, default="fp", help="Choose from ['fp', 'gin']" - ) - parser.add_argument( - "-o", - "--outputembedding", - type=str, - default="fp_256", - help="Choose from ['fp_4096', 'fp_256', 'gin', 'rdkit2d']", - ) - parser.add_argument("-r", "--radius", type=int, default=2, help="Radius for Morgan Fingerprint") - parser.add_argument( - "-b", "--nbits", type=int, default=4096, help="Number of Bits for Morgan Fingerprint" - ) - parser.add_argument( - "-rxn", - "--rxn_template", + "--input-dir", type=str, - default="hb", - choices=["hb", "pis"], - help="Choose from ['hb', 'pis']", + default=str(Path(DATA_FEATURIZED_DIR)) + "hb_fp_2_4096_fp_256", # TODO: dont hardcode + help="Input directory for the featurized synthetic trees (with {train,valid,test}-data).", ) return parser.parse_args() @@ -43,33 +27,13 @@ def get_args(): # Parse input args args = get_args() logger.info(f"Arguments: {vars(args)}") - import argparse - - - reaction_template_id = args.rxn_template - embedding = args.targetembedding - output_emb = args.outputembedding - main_dir = ( - Path(DATA_FEATURIZED_DIR) - / f"{reaction_template_id}_{embedding}_{args.radius}_{args.nbits}_{args.outputembedding}/" - ) # must match with dir in `st2steps.py` - if reaction_template_id == "hb": - num_rxn = 91 - elif reaction_template_id == "pis": - num_rxn = 4700 + featurized_data_dir = args.input_dir - # Get dimension of output embedding - OUTPUT_EMBEDDINGS = { - "gin": 300, - "fp_4096": 4096, - "fp_256": 256, - "rdkit2d": 200, - } - out_dim = OUTPUT_EMBEDDINGS[output_emb] - - logger.info("Start splitting data.") # Split datasets for each MLP - prep_data(main_dir, num_rxn, out_dim) + logger.info("Start splitting data.") + num_rxn = 91 # Auxiliary var for indexing TODO: Dont hardcode + out_dim = 256 # Auxiliary var for indexing TODO: Dont hardcode + prep_data(featurized_data_dir, num_rxn, out_dim) - logger.info("Successfully splitted data.") + logger.info(f"Completed.") From 80f3901ec3688aca65579eb137bdbf4c1462f914 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 26 Sep 2022 11:55:11 -0400 Subject: [PATCH 181/302] delete debug stuff --- scripts/06-featurize-syntrees.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/06-featurize-syntrees.py b/scripts/06-featurize-syntrees.py index 9f1a93ce..7adf58ad 100644 --- a/scripts/06-featurize-syntrees.py +++ b/scripts/06-featurize-syntrees.py @@ -30,7 +30,7 @@ def get_args(): parser.add_argument( "--output-dir", type=str, - default=str(Path(DATA_FEATURIZED_DIR)) + "debug-newversion", + default=str(Path(DATA_FEATURIZED_DIR)), help="Output directory for the splitted synthetic trees (*.json.gz)", ) return parser.parse_args() From ce1e72a902dd36df8dafde063efebdc05f3b8402 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 26 Sep 2022 11:58:54 -0400 Subject: [PATCH 182/302] fix path --- scripts/prepare_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/prepare_data.py b/scripts/prepare_data.py index d02c894a..c16929a9 100644 --- a/scripts/prepare_data.py +++ b/scripts/prepare_data.py @@ -17,7 +17,7 @@ def get_args(): parser.add_argument( "--input-dir", type=str, - default=str(Path(DATA_FEATURIZED_DIR)) + "hb_fp_2_4096_fp_256", # TODO: dont hardcode + default=str(Path(DATA_FEATURIZED_DIR)) + "/hb_fp_2_4096_fp_256", # TODO: dont hardcode help="Input directory for the featurized synthetic trees (with {train,valid,test}-data).", ) return parser.parse_args() From f128b55dbdb9f28a2bd305eb06b2dec443880e37 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 26 Sep 2022 12:02:46 -0400 Subject: [PATCH 183/302] rename --- scripts/{prepare_data.py => 08-split-data-for-networks.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename scripts/{prepare_data.py => 08-split-data-for-networks.py} (100%) diff --git a/scripts/prepare_data.py b/scripts/08-split-data-for-networks.py similarity index 100% rename from scripts/prepare_data.py rename to scripts/08-split-data-for-networks.py From 72b2214fb25310e85a78367bef29a890eff982f4 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 26 Sep 2022 12:03:08 -0400 Subject: [PATCH 184/302] update instructions --- INSTRUCTIONS.md | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/INSTRUCTIONS.md b/INSTRUCTIONS.md index a5c5b808..22c21a40 100644 --- a/INSTRUCTIONS.md +++ b/INSTRUCTIONS.md @@ -106,7 +106,9 @@ Let's start. --output-dir "data/featurized" ``` - This script will load the `input-file`, featurize it, and it in `/states_{train,valid,test}.np` and `/steps_{train,valid,test}.np`. + This script will load the `input-file`, featurize it, and it in + - `/hb_fp_2_4096_fp_256/states_{train,valid,test}.np` and + - `/hb_fp_2_4096_fp_256/steps_{train,valid,test}.np`. 7. Split features @@ -115,7 +117,8 @@ Let's start. This includes picking the right featurized data from the "super step" vector from the previous step. ```bash - python scripts/07-prepare_data.py + python scripts/08-split-data-for-networks.py \ + --input-dir "data/featurized/hb_fp_2_4096_fp_256" ``` 8. Train the networks From 7ef0e9bef73b565943a1f1d0928395e23c98753f Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 26 Sep 2022 12:14:25 -0400 Subject: [PATCH 185/302] document how to visualize a syntree --- INSTRUCTIONS.md | 14 +++++++++++++- src/syn_net/visualize/visualizer.py | 12 ++++++++---- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/INSTRUCTIONS.md b/INSTRUCTIONS.md index 22c21a40..ef96127e 100644 --- a/INSTRUCTIONS.md +++ b/INSTRUCTIONS.md @@ -139,7 +139,19 @@ Please refer to the [README.md](./README.md) for inference instructions. ### Visualizing trees -To be added. +To visualize trees, there is a hacky script that represents *Synthetic Trees* as [mermaid](https://github.com/mermaid-js/mermaid) diagrams. + +To demo it: + +```bash +python src/syn_net/visualize/visualizer.py +``` + +Still to be implemented: i) target molecule, ii) "end" action + +To render the markdown file incl. the diagram directly in VS Code, install the extension [vscode-markdown-mermaid](https://github.com/mjbvz/vscode-markdown-mermaid) and use the built-in markdown preview. + +*Info*: If the images of the molecules do not load, edit + save the markdown file anywhere. For example add and delete a character with the preview open. Not sure why this happens. ### Mean reciprocal rank diff --git a/src/syn_net/visualize/visualizer.py b/src/syn_net/visualize/visualizer.py index c629c741..d3135718 100644 --- a/src/syn_net/visualize/visualizer.py +++ b/src/syn_net/visualize/visualizer.py @@ -142,11 +142,12 @@ def __printer(): return text -def main(): +def demo(): """Demo syntree visualisation""" # 1. Load syntree import json - with open("tests/assets/syntree-small.json","rt") as f: + infile = "tests/assets/syntree-small.json" + with open(infile,"rt") as f: data = json.load(f) st = SyntheticTree() @@ -156,7 +157,7 @@ def main(): from syn_net.visualize.visualizer import SynTreeVisualizer from syn_net.visualize.writers import SynTreeWriter - outpath = Path("./0-figures/syntrees/generation/st") + outpath = Path("./figures/syntrees/generation/st") outpath.mkdir(parents=True, exist_ok=True) # 2. Plot & Write mermaid markup diagram @@ -165,7 +166,10 @@ def main(): # 3. Write everything to a markdown doc outfile = stviz.path / "syntree.md" SynTreeWriter().write(mermaid_txt).to_file(outfile) + print(f"Generated markdown file.") + print(f" Input file:", infile) + print(f" Output file:", outfile) return None if __name__ == "__main__": - main() + demo() From 201657f7bf0f68c6a8aa7025305c9c0a79e969a8 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 26 Sep 2022 12:51:37 -0400 Subject: [PATCH 186/302] rename (fix order) --- ...8-split-data-for-networks.py => 07-split-data-for-networks.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename scripts/{08-split-data-for-networks.py => 07-split-data-for-networks.py} (100%) diff --git a/scripts/08-split-data-for-networks.py b/scripts/07-split-data-for-networks.py similarity index 100% rename from scripts/08-split-data-for-networks.py rename to scripts/07-split-data-for-networks.py From 45c8567250f59fd911282f256857acf3cd202280 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 26 Sep 2022 12:57:51 -0400 Subject: [PATCH 187/302] rename (fix order) --- INSTRUCTIONS.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/INSTRUCTIONS.md b/INSTRUCTIONS.md index ef96127e..beedcfef 100644 --- a/INSTRUCTIONS.md +++ b/INSTRUCTIONS.md @@ -117,7 +117,7 @@ Let's start. This includes picking the right featurized data from the "super step" vector from the previous step. ```bash - python scripts/08-split-data-for-networks.py \ + python scripts/07-split-data-for-networks.py \ --input-dir "data/featurized/hb_fp_2_4096_fp_256" ``` From bec4eacb3e77a7587d7a9e0d4f409459f6cd4ba7 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 26 Sep 2022 13:08:38 -0400 Subject: [PATCH 188/302] update readme --- README.md | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 1d1063a7..e1bb4946 100644 --- a/README.md +++ b/README.md @@ -115,14 +115,8 @@ Use `python some_script.py --help` or check the source code to see the instructi ### Prerequisites -In addition to the necessary data, see [Data](#data), we pre-compute an embedding of the building blocks. Please double-check the filename of your building blocks. - -```bash -python scripts/compute_embedding_mp.py \ - --feature "fp_256" \ - --rxn-template "hb" \ - --ncpu 10 -``` +In addition to the necessary data, we will need to pre-compute an embedding of the building blocks. +To do so, please follow steps 0-2 from the [INSTRUCTIONS.md](INSTRUCTIONS.md). #### Synthesis Planning @@ -139,7 +133,7 @@ This script will feed a list of ten randomly selected molecules from the validat The decoded results, i.e. the predicted synthesis trees, are saved to `DATA_RESULT_DIR`. (Paths are defined in [src/syn_net/config.py](src/syn_net/config.py).) -*Note*: To do synthesis planning, you will need a list of target molecules, building blocks and compute their embedding. As mentioned, we cannot share the building blocks from enamine.net and you will have to request access yourselfs. +*Note*: To do synthesis planning, you will need a list of target molecules (provided), building blocks (need to download) and embeddings (need to compute). #### Synthesizable Molecular Design From 78c520f1e49e7ccb377978887a7bae517dd0c432 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 26 Sep 2022 13:17:34 -0400 Subject: [PATCH 189/302] clean up --- scripts/predict_multireactant_mp.py | 57 ++++++++--------------------- 1 file changed, 16 insertions(+), 41 deletions(-) diff --git a/scripts/predict_multireactant_mp.py b/scripts/predict_multireactant_mp.py index 0bacfc7a..335253b5 100644 --- a/scripts/predict_multireactant_mp.py +++ b/scripts/predict_multireactant_mp.py @@ -3,7 +3,7 @@ """ import multiprocessing as mp from pathlib import Path -from typing import Union +from typing import Union, Tuple import logging logger = logging.getLogger(__name__) @@ -18,8 +18,9 @@ DATA_RESULT_DIR, ) from syn_net.models.chkpt_loader import load_modules_from_checkpoint -from syn_net.utils.data_utils import ReactionSet, SyntheticTreeSet +from syn_net.utils.data_utils import SyntheticTreeSet, SyntheticTree from syn_net.utils.predict_utils import mol_fp, synthetic_tree_decoder_multireactant +from syn_net.data_generation.preprocessing import ReactionTemplateFileHandler, BuildingBlockFileHandler Path(DATA_RESULT_DIR).mkdir(exist_ok=True) @@ -51,22 +52,10 @@ def _fetch_data(name: str) -> list[str]: return smis_query -def _fetch_reaction_templates(file: str): - # Load reaction templates - rxn_set = ReactionSet().load(file) - return rxn_set.rxns - - def _fetch_building_blocks_embeddings(file: str): """Load the purchasable building block embeddings.""" return np.load(file) - -def _fetch_building_blocks(file: str): - """Load the building blocks""" - return pd.read_csv(file, compression="gzip")["SMILES"].tolist() - - def find_best_model_ckpt(path: str) -> Union[Path, None]: # TODO: move to utils.py """Find checkpoint with lowest val_loss. @@ -97,28 +86,17 @@ def _load_pretrained_model(path_to_checkpoints: list[Path]): path_to_rt1=path_to_rt1, path_to_rxn=path_to_rxn, path_to_rt2=path_to_rt2, - featurize=featurize, - rxn_template=rxn_template, + featurize=args.featurize, + rxn_template=args.rxn_template, out_dim=out_dim, nbits=nbits, - ncpu=ncpu, + ncpu=args.ncpu, ) return act_net, rt1_net, rxn_net, rt2_net -def func(smiles: str): - """ - Generates the synthetic tree for the input molecular embedding. - - Args: - smi (str): SMILES string corresponding to the molecule to decode. - - Returns: - smi (str): SMILES for the final chemical node in the tree. - similarity (float): Similarity measure between the final chemical node - and the input molecule. - tree (SyntheticTree): The generated synthetic tree. - """ +def func(smiles: str) -> Tuple[str,float,SyntheticTree]: + """Generate a synthetic tree for the input molecular embedding.""" emb = mol_fp(smiles) try: smi, similarity, tree, action = synthetic_tree_decoder_multireactant( @@ -132,7 +110,7 @@ def func(smiles: str): rxn_net=rxn_net, reactant2_net=rt2_net, bb_emb=bb_emb, - rxn_template=rxn_template, + rxn_template=args.rxn_template, n_bits=nbits, beam_width=3, max_step=15, @@ -188,12 +166,8 @@ def get_args(): nbits = args.nbits out_dim = args.outputembedding.split("_")[-1] # <=> morgan fingerprint with 256 bits - rxn_template = args.rxn_template building_blocks_id = "enamine_us-2021-smiles" - featurize = args.featurize - radius = args.radius - ncpu = args.ncpu - param_dir = f"{rxn_template}_{featurize}_{radius}_{nbits}_{out_dim}" + param_dir = f"{args.rxn_template}_{args.featurize}_{args.radius}_{nbits}_{out_dim}" # Load data ... logger.info("Stat loading data...") @@ -203,18 +177,19 @@ def get_args(): smiles_queries = smiles_queries[:args.num] # ... building blocks - file = Path(DATA_PREPROCESS_DIR) / f"{rxn_template}-{building_blocks_id}-matched.csv.gz" - building_blocks = _fetch_building_blocks(file) + file = Path(DATA_PREPROCESS_DIR) / f"{args.rxn_template}-{building_blocks_id}-matched.csv.gz" + + building_blocks = BuildingBlockFileHandler().load(file) building_blocks_dict = { block: i for i, block in enumerate(building_blocks) } # dict is used as lookup table for 2nd reactant during inference # ... reaction templates - file = Path(DATA_PREPROCESS_DIR) / f"reaction-sets_{rxn_template}_{building_blocks_id}.json.gz" - rxns = _fetch_reaction_templates(file) + file = Path(DATA_PREPROCESS_DIR) / f"reaction-sets_{args.rxn_template}_{building_blocks_id}.json.gz" + rxns = ReactionTemplateFileHandler().load(file) # ... building block embedding - file = Path(DATA_EMBEDDINGS_DIR) / f"{rxn_template}-{building_blocks_id}-embeddings.npy" + file = Path(DATA_EMBEDDINGS_DIR) / f"{args.rxn_template}-{building_blocks_id}-embeddings.npy" bb_emb = _fetch_building_blocks_embeddings(file) logger.info("...loading data completed.") From b399f69adf74c417925cb9a66fb2067877b4951c Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 26 Sep 2022 13:19:57 -0400 Subject: [PATCH 190/302] clean up --- scripts/predict_multireactant_mp.py | 26 +++++++++----------------- 1 file changed, 9 insertions(+), 17 deletions(-) diff --git a/scripts/predict_multireactant_mp.py b/scripts/predict_multireactant_mp.py index 335253b5..7dd51142 100644 --- a/scripts/predict_multireactant_mp.py +++ b/scripts/predict_multireactant_mp.py @@ -1,28 +1,25 @@ """ Generate synthetic trees for a set of specified query molecules. Multiprocessing. """ +import logging import multiprocessing as mp from pathlib import Path -from typing import Union, Tuple -import logging +from typing import Tuple, Union logger = logging.getLogger(__name__) import numpy as np import pandas as pd -from syn_net.config import ( - CHECKPOINTS_DIR, - DATA_EMBEDDINGS_DIR, - DATA_PREPARED_DIR, - DATA_PREPROCESS_DIR, - DATA_RESULT_DIR, -) +from syn_net.config import (CHECKPOINTS_DIR, DATA_EMBEDDINGS_DIR, DATA_PREPARED_DIR, + DATA_PREPROCESS_DIR, DATA_RESULT_DIR) +from syn_net.data_generation.preprocessing import (BuildingBlockFileHandler, + ReactionTemplateFileHandler) from syn_net.models.chkpt_loader import load_modules_from_checkpoint -from syn_net.utils.data_utils import SyntheticTreeSet, SyntheticTree +from syn_net.utils.data_utils import SyntheticTree, SyntheticTreeSet from syn_net.utils.predict_utils import mol_fp, synthetic_tree_decoder_multireactant -from syn_net.data_generation.preprocessing import ReactionTemplateFileHandler, BuildingBlockFileHandler Path(DATA_RESULT_DIR).mkdir(exist_ok=True) +from syn_net.MolEmbedder import MolEmbedder def _fetch_data_chembl(name: str) -> list[str]: @@ -51,11 +48,6 @@ def _fetch_data(name: str) -> list[str]: smis_query = _fetch_data_from_file(name) return smis_query - -def _fetch_building_blocks_embeddings(file: str): - """Load the purchasable building block embeddings.""" - return np.load(file) - def find_best_model_ckpt(path: str) -> Union[Path, None]: # TODO: move to utils.py """Find checkpoint with lowest val_loss. @@ -190,7 +182,7 @@ def get_args(): # ... building block embedding file = Path(DATA_EMBEDDINGS_DIR) / f"{args.rxn_template}-{building_blocks_id}-embeddings.npy" - bb_emb = _fetch_building_blocks_embeddings(file) + bb_emb = MolEmbedder.load(file) logger.info("...loading data completed.") # ... models From c06d86eeaf23c161832deb5ef3cebefda9fa18eb Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 26 Sep 2022 14:31:37 -0400 Subject: [PATCH 191/302] add comments --- src/syn_net/utils/predict_utils.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/src/syn_net/utils/predict_utils.py b/src/syn_net/utils/predict_utils.py index 1054e1ec..7039d0b2 100644 --- a/src/syn_net/utils/predict_utils.py +++ b/src/syn_net/utils/predict_utils.py @@ -545,20 +545,25 @@ def synthetic_tree_decoder_multireactant( rt1_index=i, ) - similarities_ = np.array( + # Find the chemical in this tree that is most similar to the target. + # Note: This does not have to be the final root mol, but any, as we can truncate tree to our liking. + similarities_in_tree = np.array( tanimoto_similarity(z_target, [node.smiles for node in tree.chemicals]) ) - max_simi_idx = np.where(similarities_ == np.max(similarities_))[0][0] + max_similar_idx = np.argmax(similarities_in_tree) + max_similarity = similarities_in_tree[max_similar_idx] - similarities.append(np.max(similarities_)) - smiles.append(tree.chemicals[max_simi_idx].smiles) + similarities.append(max_similarity) + # Keep track of generated trees + smiles.append(tree.chemicals[max_similar_idx].smiles) trees.append(tree) acts.append(act) - max_simi_idx = np.where(similarities == np.max(similarities))[0][0] - similarity = similarities[max_simi_idx] - tree = trees[max_simi_idx] - smi = smiles[max_simi_idx] - act = acts[max_simi_idx] + # Identify most similar among all trees + max_similar_idx = np.argmax(similarities) + similarity = similarities[max_similar_idx] + tree = trees[max_similar_idx] + smi = smiles[max_similar_idx] + act = acts[max_similar_idx] return smi, similarity, tree, act From 11694d84c550b6bf0f9a6db6cb1f45da6bf1b220 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 26 Sep 2022 14:49:07 -0400 Subject: [PATCH 192/302] clean up, rename, and align function inputs --- scripts/predict_multireactant_mp.py | 4 +- src/syn_net/utils/predict_utils.py | 63 +++++++---------------------- tests/test_Predict.py | 4 +- 3 files changed, 19 insertions(+), 52 deletions(-) diff --git a/scripts/predict_multireactant_mp.py b/scripts/predict_multireactant_mp.py index 7dd51142..cc999585 100644 --- a/scripts/predict_multireactant_mp.py +++ b/scripts/predict_multireactant_mp.py @@ -16,7 +16,7 @@ ReactionTemplateFileHandler) from syn_net.models.chkpt_loader import load_modules_from_checkpoint from syn_net.utils.data_utils import SyntheticTree, SyntheticTreeSet -from syn_net.utils.predict_utils import mol_fp, synthetic_tree_decoder_multireactant +from syn_net.utils.predict_utils import mol_fp, synthetic_tree_decoder_beam_search Path(DATA_RESULT_DIR).mkdir(exist_ok=True) from syn_net.MolEmbedder import MolEmbedder @@ -91,7 +91,7 @@ def func(smiles: str) -> Tuple[str,float,SyntheticTree]: """Generate a synthetic tree for the input molecular embedding.""" emb = mol_fp(smiles) try: - smi, similarity, tree, action = synthetic_tree_decoder_multireactant( + smi, similarity, tree, action = synthetic_tree_decoder_beam_search( z_target=emb, building_blocks=building_blocks, bb_dict=building_blocks_dict, diff --git a/src/syn_net/utils/predict_utils.py b/src/syn_net/utils/predict_utils.py index 7039d0b2..7393cbc0 100644 --- a/src/syn_net/utils/predict_utils.py +++ b/src/syn_net/utils/predict_utils.py @@ -482,66 +482,31 @@ def synthetic_tree_decoder_rt1( return tree, act -def synthetic_tree_decoder_multireactant( - z_target, - building_blocks, - bb_dict, - reaction_templates, - mol_embedder, - action_net, - reactant1_net, - rxn_net, - reactant2_net, - bb_emb, - rxn_template, - n_bits, +def synthetic_tree_decoder_beam_search( beam_width: int = 3, - max_step: int = 15, -): + **kwargs +) -> Tuple[str, float, SyntheticTree, int]: """ - Computes the synthetic tree given an input molecule embedding, using the - Action, Reaction, Reactant1, and Reactant2 networks and a greedy search. + Wrapper around `synthetic_tree_decoder_rt1` with a beam search. + Selects the k-th first reactant in the k-NN search and expands in a greedy manner. Args: - z_target (np.ndarray): Embedding for the target molecule - building_blocks (list of str): Contains available building blocks - bb_dict (dict): Building block dictionary - reaction_templates (list of Reactions): Contains reaction templates - mol_embedder (dgllife.model.gnn.gin.GIN): GNN to use for obtaining molecular embeddings - action_net (synth_net.models.mlp.MLP): The action network - reactant1_net (synth_net.models.mlp.MLP): The reactant1 network - rxn_net (synth_net.models.mlp.MLP): The reaction network - reactant2_net (synth_net.models.mlp.MLP): The reactant2 network - bb_emb (list): Contains purchasable building block embeddings. - rxn_template (str): Specifies the set of reaction templates to use. - n_bits (int): Length of fingerprint. beam_width (int): The beam width to use for Reactant 1 search. Defaults to 3. - max_step (int, optional): Maximum number of steps to include in the synthetic tree + kwargs: Identical to wrapped function. Returns: tree (SyntheticTree): The final synthetic tree act (int): The final action (to know if the tree was "properly" terminated) """ - trees = [] - smiles = [] - similarities = [] - acts = [] + z_target = kwargs["z_target"] + trees: list[SyntheticTree] = [] + smiles: list[str] = [] + similarities: list[float] = [] + acts: list[int] = [] for i in range(beam_width): tree, act = synthetic_tree_decoder_rt1( - z_target=z_target, - building_blocks=building_blocks, - bb_dict=bb_dict, - reaction_templates=reaction_templates, - mol_embedder=mol_embedder, - action_net=action_net, - reactant1_net=reactant1_net, - rxn_net=rxn_net, - reactant2_net=reactant2_net, - bb_emb=bb_emb, - rxn_template=rxn_template, - n_bits=n_bits, - max_step=max_step, + **kwargs, rt1_index=i, ) @@ -553,8 +518,10 @@ def synthetic_tree_decoder_multireactant( max_similar_idx = np.argmax(similarities_in_tree) max_similarity = similarities_in_tree[max_similar_idx] + # Keep track of max similarities (across syntrees) similarities.append(max_similarity) - # Keep track of generated trees + + # Keep track of generated syntrees smiles.append(tree.chemicals[max_similar_idx].smiles) trees.append(tree) acts.append(act) diff --git a/tests/test_Predict.py b/tests/test_Predict.py index 1aaa564a..4a0c2f0f 100644 --- a/tests/test_Predict.py +++ b/tests/test_Predict.py @@ -8,7 +8,7 @@ import pandas as pd from syn_net.utils.predict_utils import ( - synthetic_tree_decoder_multireactant, + synthetic_tree_decoder_beam_search, mol_fp, ) from syn_net.utils.data_utils import SyntheticTreeSet, ReactionSet @@ -83,7 +83,7 @@ def test_predict(self): trees = [] for smi in smis_query: emb = mol_fp(smi) - smi, similarity, tree, action = synthetic_tree_decoder_multireactant( + smi, similarity, tree, action = synthetic_tree_decoder_beam_search( z_target=emb, building_blocks=building_blocks, bb_dict=bb_dict, From 4cc66e654a158bff03ff8d6af48b7970e000d152 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 26 Sep 2022 14:49:47 -0400 Subject: [PATCH 193/302] bug fix --- scripts/predict_multireactant_mp.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/scripts/predict_multireactant_mp.py b/scripts/predict_multireactant_mp.py index cc999585..6f09e4d9 100644 --- a/scripts/predict_multireactant_mp.py +++ b/scripts/predict_multireactant_mp.py @@ -5,6 +5,7 @@ import multiprocessing as mp from pathlib import Path from typing import Tuple, Union +import json logger = logging.getLogger(__name__) import numpy as np @@ -15,7 +16,7 @@ from syn_net.data_generation.preprocessing import (BuildingBlockFileHandler, ReactionTemplateFileHandler) from syn_net.models.chkpt_loader import load_modules_from_checkpoint -from syn_net.utils.data_utils import SyntheticTree, SyntheticTreeSet +from syn_net.utils.data_utils import SyntheticTree, SyntheticTreeSet, ReactionSet from syn_net.utils.predict_utils import mol_fp, synthetic_tree_decoder_beam_search Path(DATA_RESULT_DIR).mkdir(exist_ok=True) @@ -53,7 +54,7 @@ def find_best_model_ckpt(path: str) -> Union[Path, None]: # TODO: move to utils Poor man's regex: somepath/act/ckpts.epoch=70-val_loss=0.03.ckpt - ^^^^--extract this as float + ^^^^--extract this as float """ ckpts = Path(path).rglob("*.ckpt") best_model_ckpt = None @@ -132,7 +133,7 @@ def get_args(): parser.add_argument( "-r", "--rxn_template", type=str, default="hb", help="Choose from ['hb', 'pis']" ) - parser.add_argument("--ncpu", type=int, default=32, help="Number of cpus") + parser.add_argument("--ncpu", type=int, default=1, help="Number of cpus") parser.add_argument("-n", "--num", type=int, default=1, help="Number of molecules to predict.") parser.add_argument( "-d", @@ -154,7 +155,7 @@ def get_args(): if __name__ == "__main__": args = get_args() - logger.info(f"Args: {vars(args)}") + logger.info(f"Arguments: {json.dumps(vars(args),indent=2)}") nbits = args.nbits out_dim = args.outputembedding.split("_")[-1] # <=> morgan fingerprint with 256 bits @@ -170,7 +171,6 @@ def get_args(): # ... building blocks file = Path(DATA_PREPROCESS_DIR) / f"{args.rxn_template}-{building_blocks_id}-matched.csv.gz" - building_blocks = BuildingBlockFileHandler().load(file) building_blocks_dict = { block: i for i, block in enumerate(building_blocks) @@ -178,11 +178,11 @@ def get_args(): # ... reaction templates file = Path(DATA_PREPROCESS_DIR) / f"reaction-sets_{args.rxn_template}_{building_blocks_id}.json.gz" - rxns = ReactionTemplateFileHandler().load(file) + rxns = ReactionSet().load(file).rxns # ... building block embedding file = Path(DATA_EMBEDDINGS_DIR) / f"{args.rxn_template}-{building_blocks_id}-embeddings.npy" - bb_emb = MolEmbedder.load(file) + bb_emb = MolEmbedder().load_precomputed(file).embeddings logger.info("...loading data completed.") # ... models From 54c56b49cfc5e7657016b27b64b0901dd6d84059 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 26 Sep 2022 16:52:30 -0400 Subject: [PATCH 194/302] dumb but effective way to cache kdtree computation --- src/syn_net/utils/predict_utils.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/syn_net/utils/predict_utils.py b/src/syn_net/utils/predict_utils.py index 7393cbc0..826bf096 100644 --- a/src/syn_net/utils/predict_utils.py +++ b/src/syn_net/utils/predict_utils.py @@ -2,6 +2,7 @@ This file contains various utils for creating molecular embeddings and for decoding synthetic trees. """ +import functools from typing import Callable, Tuple import numpy as np @@ -323,6 +324,14 @@ def synthetic_tree_decoder( return tree, act +@functools.lru_cache(maxsize=1) +def _fetch_bb_embeddings_as_balltree(filename: str): # TODO: find more elegant way / use MolEmbedder-cls + """Helper function to cache computing BallTree. + Can hash string, but not numpy array easily, hence this workaround. + """ + from syn_net.MolEmbedder import MolEmbedder + molemedder = MolEmbedder().load_precomputed(filename) + return BallTree(molemedder.get_embeddings(), metric=cosine_distance) def synthetic_tree_decoder_rt1( z_target: np.ndarray, @@ -373,7 +382,10 @@ def synthetic_tree_decoder_rt1( # Initialization tree = SyntheticTree() mol_recent = None - kdtree = BallTree(bb_emb, metric=cosine_distance) # TODO: cache this or use class + if isinstance(bb_emb,str): + kdtree = _fetch_bb_embeddings_as_balltree(bb_emb) + else: + kdtree = BallTree(bb_emb, metric=cosine_distance) # Start iteration for i in range(max_step): From 5177593bd2d77905567ab5ed7f13619c65f53b87 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 26 Sep 2022 16:53:22 -0400 Subject: [PATCH 195/302] wip: consolidate functions --- src/syn_net/utils/predict_utils.py | 29 +++++++++++++---------------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/src/syn_net/utils/predict_utils.py b/src/syn_net/utils/predict_utils.py index 826bf096..e5452510 100644 --- a/src/syn_net/utils/predict_utils.py +++ b/src/syn_net/utils/predict_utils.py @@ -347,7 +347,7 @@ def synthetic_tree_decoder_rt1( rxn_template: str, n_bits: int, max_step: int = 15, - rt1_index=0, + k_reactant1=1, ) -> Tuple[SyntheticTree, int]: """ Computes the synthetic tree given an input molecule embedding, using the @@ -371,7 +371,7 @@ def synthetic_tree_decoder_rt1( to 3. max_step (int, optional): Maximum number of steps to include in the synthetic tree - rt1_index (int, optional): Index for molecule in the building blocks + k_reactant1 (int, optional): Index for molecule in the building blocks corresponding to reactant 1. Returns: @@ -409,14 +409,14 @@ def synthetic_tree_decoder_rt1( # Select first molecule if act == 0: # Add - if mol_recent is not None: - dist, ind = nn_search(z_mol1,_tree=kdtree) - mol1 = building_blocks[ind] - else: # no recent mol - dist, ind = nn_search_rt1( - z_mol1, _tree=kdtree, _k=rt1_index + 1 - ) # TODO: why is there an option to select the k-th? rt1_index (???) - mol1 = building_blocks[ind[rt1_index]] + if mol_recent is None: # proxy to determine if we have an empty syntree (<=> i==0) + k = k_reactant1 + else: + k = 1 + + _, idxs = kdtree.query(z_mol1,k=k) # idxs.shape = (1,k) + mol1 = building_blocks[idxs[0][k]] + elif act == 1 or act == 2: # Expand or Merge mol1 = mol_recent @@ -499,8 +499,8 @@ def synthetic_tree_decoder_beam_search( **kwargs ) -> Tuple[str, float, SyntheticTree, int]: """ - Wrapper around `synthetic_tree_decoder_rt1` with a beam search. - Selects the k-th first reactant in the k-NN search and expands in a greedy manner. + Wrapper around `synthetic_tree_decoder_rt1` with variable `k` for kNN search of 1st reactant. + Will keep the syntree that comprises of a molecule most similar to the target mol. Args: beam_width (int): The beam width to use for Reactant 1 search. Defaults to 3. @@ -517,10 +517,7 @@ def synthetic_tree_decoder_beam_search( acts: list[int] = [] for i in range(beam_width): - tree, act = synthetic_tree_decoder_rt1( - **kwargs, - rt1_index=i, - ) + tree, act = synthetic_tree_decoder_rt1(k_reactant1=i, **kwargs) # Find the chemical in this tree that is most similar to the target. # Note: This does not have to be the final root mol, but any, as we can truncate tree to our liking. From 3c0191519fbb5ae6a8ad1d00027c56449041b650 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 26 Sep 2022 17:36:48 -0400 Subject: [PATCH 196/302] align comments from almost identical fcts --- src/syn_net/utils/predict_utils.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/syn_net/utils/predict_utils.py b/src/syn_net/utils/predict_utils.py index e5452510..eba45b2a 100644 --- a/src/syn_net/utils/predict_utils.py +++ b/src/syn_net/utils/predict_utils.py @@ -228,7 +228,7 @@ def synthetic_tree_decoder( # Predict action type, masked selection # Action: (Add: 0, Expand: 1, Merge: 2, End: 3) - action_proba = action_net(torch.Tensor(z_state)) + action_proba = action_net(torch.Tensor(z_state)) # (1,4) action_proba = action_proba.squeeze().detach().numpy() + 1e-10 action_mask = get_action_mask(tree.get_state(), reaction_templates) act = np.argmax(action_proba * action_mask) @@ -238,7 +238,7 @@ def synthetic_tree_decoder( break z_mol1 = reactant1_net(torch.Tensor(z_state)) - z_mol1 = z_mol1.detach().numpy() + z_mol1 = z_mol1.detach().numpy() # (1,dimension_output_embedding), default: (1,256) # Select first molecule if act == 0: @@ -257,21 +257,21 @@ def synthetic_tree_decoder( # Select reaction z = np.concatenate([z_state, z_mol1], axis=1) reaction_proba = rxn_net(torch.Tensor(z)) - reaction_proba = reaction_proba.squeeze().detach().numpy() + 1e-10 # (nReactionTemplate) + reaction_proba = reaction_proba.squeeze().detach().numpy() + 1e-10 # (nReactionTemplate,) if act != 2: # add or expand reaction_mask, available_list = get_reaction_mask(mol1, reaction_templates) else: # merge _, reaction_mask = can_react(tree.get_state(), reaction_templates) - available_list = [[] for rxn in reaction_templates] + available_list = [[] for rxn in reaction_templates] # TODO: if act=merge, this is not used at all # If we ended up in a state where no reaction is possible, end this iteration. if reaction_mask is None: - if len(state) == 1: + if len(state) == 1: # only a single root mol, so this syntree is valid act = 3 break else: - break + break # action != 3, so in our analysis we will see this tree as "invalid" # Select reaction template rxn_id = np.argmax(reaction_proba * reaction_mask) @@ -307,11 +307,11 @@ def synthetic_tree_decoder( # Run reaction mol_product = rxn.run_reaction((mol1, mol2)) if mol_product is None or Chem.MolFromSmiles(mol_product) is None: - if len(tree.get_state()) == 1: + if len(state) == 1: # only a single root mol, so this syntree is valid act = 3 break else: - break + break # action != 3, so in our analysis we will see this tree as "invalid" # Update tree.update(act, int(rxn_id), mol1, mol2, mol_product) From 84574c5b4e33fe718949fb980cb0cc9d96ac1157 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 26 Sep 2022 17:37:57 -0400 Subject: [PATCH 197/302] more explicit ifs --- src/syn_net/utils/predict_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/syn_net/utils/predict_utils.py b/src/syn_net/utils/predict_utils.py index eba45b2a..3b9c17db 100644 --- a/src/syn_net/utils/predict_utils.py +++ b/src/syn_net/utils/predict_utils.py @@ -259,7 +259,7 @@ def synthetic_tree_decoder( reaction_proba = rxn_net(torch.Tensor(z)) reaction_proba = reaction_proba.squeeze().detach().numpy() + 1e-10 # (nReactionTemplate,) - if act != 2: # add or expand + if act==0 or act==1: # add or expand reaction_mask, available_list = get_reaction_mask(mol1, reaction_templates) else: # merge _, reaction_mask = can_react(tree.get_state(), reaction_templates) From daddfd5bf200463e5a8f0fe2ffebc0d028f3e77d Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 26 Sep 2022 17:38:24 -0400 Subject: [PATCH 198/302] also use stupid but working caching --- src/syn_net/utils/predict_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/syn_net/utils/predict_utils.py b/src/syn_net/utils/predict_utils.py index 3b9c17db..c890f5c0 100644 --- a/src/syn_net/utils/predict_utils.py +++ b/src/syn_net/utils/predict_utils.py @@ -218,7 +218,10 @@ def synthetic_tree_decoder( # Initialization tree = SyntheticTree() mol_recent = None - kdtree = BallTree(bb_emb, metric=cosine_distance) # TODO: cache this or use class + if isinstance(bb_emb,str): + kdtree = _fetch_bb_embeddings_as_balltree(bb_emb) + else: + kdtree = BallTree(bb_emb, metric=cosine_distance) # Start iteration for i in range(max_step): From 883f5781d02b315743b6447c9f9950083ca4c4fc Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 26 Sep 2022 17:39:58 -0400 Subject: [PATCH 199/302] `k` as variable for very first reactant 1 --- src/syn_net/utils/predict_utils.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/syn_net/utils/predict_utils.py b/src/syn_net/utils/predict_utils.py index c890f5c0..fe6d41f4 100644 --- a/src/syn_net/utils/predict_utils.py +++ b/src/syn_net/utils/predict_utils.py @@ -188,6 +188,7 @@ def synthetic_tree_decoder( rxn_template: str, n_bits: int, max_step: int = 15, + k_reactant1: int = 1, ) -> Tuple[SyntheticTree, int]: """ Computes a synthetic tree given an input molecule embedding. @@ -245,9 +246,13 @@ def synthetic_tree_decoder( # Select first molecule if act == 0: - # Add - dist, ind = nn_search(z_mol1, _tree=kdtree) - mol1 = building_blocks[ind] + # Select `k` for kNN search of 1st reactant + # Use k>1 for the first action, and k==1 for all others. + # Idea: Increase the chances of generating a better tree. + k = k_reactant1 if mol_recent is None else 1 + + _, idxs = kdtree.query(z_mol1,k=k) # idxs.shape = (1,k) + mol1 = building_blocks[idxs[0][k]] elif act == 1 or act == 2: # Expand or Merge mol1 = mol_recent From 0ff8984f09406bb8e7c54bba7cdfef1d10fbfb6f Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 26 Sep 2022 17:41:06 -0400 Subject: [PATCH 200/302] fix bug: indices --- src/syn_net/utils/predict_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/syn_net/utils/predict_utils.py b/src/syn_net/utils/predict_utils.py index fe6d41f4..ebf42546 100644 --- a/src/syn_net/utils/predict_utils.py +++ b/src/syn_net/utils/predict_utils.py @@ -252,7 +252,7 @@ def synthetic_tree_decoder( k = k_reactant1 if mol_recent is None else 1 _, idxs = kdtree.query(z_mol1,k=k) # idxs.shape = (1,k) - mol1 = building_blocks[idxs[0][k]] + mol1 = building_blocks[idxs[0][k-1]] elif act == 1 or act == 2: # Expand or Merge mol1 = mol_recent @@ -423,7 +423,7 @@ def synthetic_tree_decoder_rt1( k = 1 _, idxs = kdtree.query(z_mol1,k=k) # idxs.shape = (1,k) - mol1 = building_blocks[idxs[0][k]] + mol1 = building_blocks[idxs[0][k-1]] elif act == 1 or act == 2: # Expand or Merge From c747ac7564152e2d58ac09ab4dc5c86a0f6fa8ef Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 26 Sep 2022 17:41:37 -0400 Subject: [PATCH 201/302] finally: change which function is called --- src/syn_net/utils/predict_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/syn_net/utils/predict_utils.py b/src/syn_net/utils/predict_utils.py index ebf42546..5474c11d 100644 --- a/src/syn_net/utils/predict_utils.py +++ b/src/syn_net/utils/predict_utils.py @@ -525,7 +525,7 @@ def synthetic_tree_decoder_beam_search( acts: list[int] = [] for i in range(beam_width): - tree, act = synthetic_tree_decoder_rt1(k_reactant1=i, **kwargs) + tree, act = synthetic_tree_decoder(k_reactant1=i, **kwargs) # Find the chemical in this tree that is most similar to the target. # Note: This does not have to be the final root mol, but any, as we can truncate tree to our liking. From 10da75748d199c4ba6f9097b1651831cd6ed2201 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 26 Sep 2022 17:42:16 -0400 Subject: [PATCH 202/302] delete now obsolete function --- src/syn_net/utils/predict_utils.py | 159 ----------------------------- 1 file changed, 159 deletions(-) diff --git a/src/syn_net/utils/predict_utils.py b/src/syn_net/utils/predict_utils.py index 5474c11d..4a103c27 100644 --- a/src/syn_net/utils/predict_utils.py +++ b/src/syn_net/utils/predict_utils.py @@ -341,165 +341,6 @@ def _fetch_bb_embeddings_as_balltree(filename: str): # TODO: find more elegant w molemedder = MolEmbedder().load_precomputed(filename) return BallTree(molemedder.get_embeddings(), metric=cosine_distance) -def synthetic_tree_decoder_rt1( - z_target: np.ndarray, - building_blocks: list[str], - bb_dict: dict[str, int], - reaction_templates: list[Reaction], - mol_embedder, - action_net: pl.LightningModule, - reactant1_net: pl.LightningModule, - rxn_net: pl.LightningModule, - reactant2_net: pl.LightningModule, - bb_emb: np.ndarray, - rxn_template: str, - n_bits: int, - max_step: int = 15, - k_reactant1=1, -) -> Tuple[SyntheticTree, int]: - """ - Computes the synthetic tree given an input molecule embedding, using the - Action, Reaction, Reactant1, and Reactant2 networks and a greedy search. - - Args: - z_target (np.ndarray): Embedding for the target molecule - building_blocks (list of str): Contains available building blocks - bb_dict (dict): Building block dictionary - reaction_templates (list of Reactions): Contains reaction templates - mol_embedder (dgllife.model.gnn.gin.GIN): GNN to use for obtaining - molecular embeddings - action_net (synth_net.models.mlp.MLP): The action network - reactant1_net (synth_net.models.mlp.MLP): The reactant1 network - rxn_net (synth_net.models.mlp.MLP): The reaction network - reactant2_net (synth_net.models.mlp.MLP): The reactant2 network - bb_emb (list): Contains purchasable building block embeddings. - rxn_template (str): Specifies the set of reaction templates to use. - n_bits (int): Length of fingerprint. - beam_width (int): The beam width to use for Reactant 1 search. Defaults - to 3. - max_step (int, optional): Maximum number of steps to include in the - synthetic tree - k_reactant1 (int, optional): Index for molecule in the building blocks - corresponding to reactant 1. - - Returns: - tree (SyntheticTree): The final synthetic tree - act (int): The final action (to know if the tree was "properly" - terminated). - """ - # Initialization - tree = SyntheticTree() - mol_recent = None - if isinstance(bb_emb,str): - kdtree = _fetch_bb_embeddings_as_balltree(bb_emb) - else: - kdtree = BallTree(bb_emb, metric=cosine_distance) - - # Start iteration - for i in range(max_step): - # Encode current state - state = tree.get_state() # a list - z_state = set_embedding(z_target, state, nbits=n_bits, _mol_embedding=mol_fp) - - # Predict action type, masked selection - # Action: (Add: 0, Expand: 1, Merge: 2, End: 3) - action_proba = action_net(torch.Tensor(z_state)) # (1,4) - action_proba = action_proba.squeeze().detach().numpy() + 1e-10 - action_mask = get_action_mask(tree.get_state(), reaction_templates) - act = np.argmax(action_proba * action_mask) - - # Continue growing tree? - if act == 3: # End - break - - z_mol1 = reactant1_net(torch.Tensor(z_state)) - z_mol1 = z_mol1.detach().numpy() # (1,dimension_output_embedding), default: (1,256) - - # Select first molecule - if act == 0: # Add - if mol_recent is None: # proxy to determine if we have an empty syntree (<=> i==0) - k = k_reactant1 - else: - k = 1 - - _, idxs = kdtree.query(z_mol1,k=k) # idxs.shape = (1,k) - mol1 = building_blocks[idxs[0][k-1]] - - elif act == 1 or act == 2: - # Expand or Merge - mol1 = mol_recent - else: - raise ValueError(f"Unexpected action {act}.") - - z_mol1 = mol_fp(mol1) # (dimension_input_embedding=d), default (4096,) - z_mol1 = np.atleast_2d(z_mol1) # (1,4096) - - # Select reaction - z = np.concatenate([z_state, z_mol1], axis=1) # (1,4d) - reaction_proba = rxn_net(torch.Tensor(z)) - reaction_proba = reaction_proba.squeeze().detach().numpy() + 1e-10 # (nReactionTemplate) - - if act != 2: # add or expand - reaction_mask, available_list = get_reaction_mask(mol1, reaction_templates) - else: # merge - _, reaction_mask = can_react(tree.get_state(), reaction_templates) - available_list = [[] for rxn in reaction_templates] # TODO: if act=merge, this is not used at all - - # If we ended up in a state where no reaction is possible, end this iteration. - if reaction_mask is None: - if len(state) == 1: - act = 3 - break - else: - break - - # Select reaction template - rxn_id = np.argmax(reaction_proba * reaction_mask) - rxn = reaction_templates[rxn_id] - - NUMBER_OF_REACTION_TEMPLATES = { - "hb": 91, - "pis": 4700, - "unittest": 3, - } # TODO: Refactor / use class - - # Select 2nd reactant - if rxn.num_reactant == 2: - if act == 2: # Merge - temp = set(state) - set([mol1]) - mol2 = temp.pop() - else: # Add or Expand - x_rxn = one_hot_encoder(rxn_id, NUMBER_OF_REACTION_TEMPLATES[rxn_template]) - x_rct2 = np.concatenate([z_state, z_mol1, x_rxn], axis=1) - z_mol2 = reactant2_net(torch.Tensor(x_rct2)) - z_mol2 = z_mol2.detach().numpy() - available = available_list[rxn_id] # list[str], list of reactants for this rxn - available = [bb_dict[smiles] for smiles in available] # list[int] - temp_emb = bb_emb[available] - available_tree = BallTree( - temp_emb, metric=cosine_distance - ) # TODO: evaluate if distance matrix is faster/feasible as this BallTree is discarded immediately. - dist, ind = nn_search(z_mol2, _tree=available_tree) - mol2 = building_blocks[available[ind]] - else: - mol2 = None - - # Run reaction - mol_product = rxn.run_reaction((mol1, mol2)) - if mol_product is None or Chem.MolFromSmiles(mol_product) is None: - act = 3 - break - - # Update - tree.update(act, int(rxn_id), mol1, mol2, mol_product) - mol_recent = mol_product - - if act != 3: - tree = tree - else: - tree.update(act, None, None, None, None) - - return tree, act def synthetic_tree_decoder_beam_search( From 7c1a9c99ad153f8e79d07689f76d0a9637caf688 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 26 Sep 2022 17:51:06 -0400 Subject: [PATCH 203/302] add some todos --- src/syn_net/data_generation/syntrees.py | 2 +- src/syn_net/encoding/fingerprints.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/syn_net/data_generation/syntrees.py b/src/syn_net/data_generation/syntrees.py index 03cfe441..39776b1e 100644 --- a/src/syn_net/data_generation/syntrees.py +++ b/src/syn_net/data_generation/syntrees.py @@ -374,7 +374,7 @@ def encode(self, number: int): class SynTreeFeaturizer: def __init__(self) -> None: # Embedders - self.reactant_embedder = MorganFingerprintEncoder(2, 256) + self.reactant_embedder = MorganFingerprintEncoder(2, 256) # TODO: pass these in script, not here self.mol_embedder = MorganFingerprintEncoder(2, 4096) self.rxn_embedder = IdentityIntEncoder() self.action_embedder = IdentityIntEncoder() diff --git a/src/syn_net/encoding/fingerprints.py b/src/syn_net/encoding/fingerprints.py index 4c64e223..74659fbc 100644 --- a/src/syn_net/encoding/fingerprints.py +++ b/src/syn_net/encoding/fingerprints.py @@ -21,7 +21,7 @@ def mol_fp(smi, _radius=2, _nBits=4096): else: mol = Chem.MolFromSmiles(smi) features_vec = Chem.AllChem.GetMorganFingerprintAsBitVect(mol, _radius, _nBits) - return np.array(features_vec) + return np.array(features_vec) # TODO: much slower compared to `DataStructs.ConvertToNumpyArray` (20x?) so deprecates def fp_embedding(smi, _radius=2, _nBits=4096): """ From 39683fb8aee32458bfc1cc6a7b1a43980760cabe Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Mon, 26 Sep 2022 17:51:20 -0400 Subject: [PATCH 204/302] bug fix: indices.. --- src/syn_net/utils/predict_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/syn_net/utils/predict_utils.py b/src/syn_net/utils/predict_utils.py index 4a103c27..b167555b 100644 --- a/src/syn_net/utils/predict_utils.py +++ b/src/syn_net/utils/predict_utils.py @@ -366,7 +366,7 @@ def synthetic_tree_decoder_beam_search( acts: list[int] = [] for i in range(beam_width): - tree, act = synthetic_tree_decoder(k_reactant1=i, **kwargs) + tree, act = synthetic_tree_decoder(k_reactant1=i+1, **kwargs) # Find the chemical in this tree that is most similar to the target. # Note: This does not have to be the final root mol, but any, as we can truncate tree to our liking. From 03cd457b73f4c156025c1b1bda1bb0943d978888 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Tue, 27 Sep 2022 15:30:49 -0400 Subject: [PATCH 205/302] New: Clamp cosine distance to [0,2] --- src/syn_net/encoding/distances.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/src/syn_net/encoding/distances.py b/src/syn_net/encoding/distances.py index 55c6cf36..74045114 100644 --- a/src/syn_net/encoding/distances.py +++ b/src/syn_net/encoding/distances.py @@ -1,20 +1,14 @@ import numpy as np from syn_net.encoding.fingerprints import mol_fp -def cosine_distance(v1, v2, eps=1e-15): - """Computes the cosine similarity between two vectors. +def cosine_distance(v1, v2): + """Compute the cosine distance between two 1d-vectors. - Args: - v1 (np.ndarray): First vector. - v2 (np.ndarray): Second vector. - eps (float, optional): Small value, for numerical stability. Defaults - to 1e-15. - - Returns: - float: The cosine similarity. + Note: + cosine_similarity = x'y / (||x|| ||y||) in [-1,1] + cosine_distance = 1 - cosine_similarity in [0,2] """ - return (1 - np.dot(v1, v2) - / (np.linalg.norm(v1, ord=2) * np.linalg.norm(v2, ord=2) + eps)) + return max(0,min( 1-np.dot(v1,v2)/(np.linalg.norm(v1)*np.linalg.norm(v2)),2)) def ce_distance(y, y_pred, eps=1e-15): """Computes the cross-entropy between two vectors. From 945ee474b983d13d0482e3b58d9fb518cb4cb483 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Tue, 27 Sep 2022 15:31:14 -0400 Subject: [PATCH 206/302] bug fix --- src/syn_net/utils/data_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/syn_net/utils/data_utils.py b/src/syn_net/utils/data_utils.py index 217d13fb..33b7b0fa 100644 --- a/src/syn_net/utils/data_utils.py +++ b/src/syn_net/utils/data_utils.py @@ -188,6 +188,8 @@ def run_reaction(self, reactants: Tuple[Union[str,Chem.Mol,None]], keep_main: bo if self.num_reactant == 1: + if len(r)==2: # Provided two reactants for unimolecular reaction -> no rxn possible + return None if not self.is_reactant(r[0]): return None elif self.num_reactant == 2: @@ -200,6 +202,7 @@ def run_reaction(self, reactants: Tuple[Union[str,Chem.Mol,None]], keep_main: bo return None else: raise ValueError('This reaction is neither uni- nor bi-molecular.') + # Run reaction with rdkit magic ps = rxn.RunReactants(r) From 14c998e2d84f36d0e6fb6dcbb601f693d24383ad Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Tue, 27 Sep 2022 15:32:20 -0400 Subject: [PATCH 207/302] fix typo --- src/syn_net/MolEmbedder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/syn_net/MolEmbedder.py b/src/syn_net/MolEmbedder.py index cd8159cc..1b335a4b 100644 --- a/src/syn_net/MolEmbedder.py +++ b/src/syn_net/MolEmbedder.py @@ -78,7 +78,7 @@ def init_balltree(self, metric: Union[Callable, str]): Note: Can take a couple of minutes.""" if self.embeddings is None: - raise ValueError("Neeed emebddings to compute kdtree.") + raise ValueError("Need emebddings to compute kdtree.") X = self.embeddings self.kdtree_metric = metric.__name__ if not isinstance(metric,str) else metric self.kdtree = BallTree(X, metric=metric) From 329b957e31c17b3bbad538cd24f4c1694806194a Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Tue, 27 Sep 2022 15:38:16 -0400 Subject: [PATCH 208/302] update gitignore --- .gitignore | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/.gitignore b/.gitignore index 9f6319f1..f1e406b5 100644 --- a/.gitignore +++ b/.gitignore @@ -181,3 +181,10 @@ synth_net/params tmp/ scripts/oracle temp.py + +.dev/ +.old/ +.notes/ +.aliases +figures/ +*.html \ No newline at end of file From 8064c7a1256aa7c460856802dc996900f9fe4ec9 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Tue, 27 Sep 2022 15:45:51 -0400 Subject: [PATCH 209/302] add: empty folder with gitkeep as placeholder for building blocks --- data/assets/building-blocks/.gitkeep | 1 + 1 file changed, 1 insertion(+) create mode 100644 data/assets/building-blocks/.gitkeep diff --git a/data/assets/building-blocks/.gitkeep b/data/assets/building-blocks/.gitkeep new file mode 100644 index 00000000..ef7c1b53 --- /dev/null +++ b/data/assets/building-blocks/.gitkeep @@ -0,0 +1 @@ +Placeholder for building block molecules. From efcf95df8d4f07d1f02341dda64dc9bcc350de52 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Tue, 27 Sep 2022 15:56:21 -0400 Subject: [PATCH 210/302] clean up --- INSTRUCTIONS.md | 3 ++- scripts/00-extract-smiles-from-sdf.py | 33 +++++++++++++++++++-------- 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/INSTRUCTIONS.md b/INSTRUCTIONS.md index beedcfef..b7290195 100644 --- a/INSTRUCTIONS.md +++ b/INSTRUCTIONS.md @@ -19,7 +19,8 @@ Let's start. Extract SMILES from the `.sdf` file from enamine.net. ```shell - python scripts/00-extract-smiles-from-sdf.py --file="data/assets/building-blocks/enamine-us.sdf" + python scripts/00-extract-smiles-from-sdf.py \ + --input-file="data/assets/building-blocks/enamine-us.sdf" ``` 1. Filter building blocks. diff --git a/scripts/00-extract-smiles-from-sdf.py b/scripts/00-extract-smiles-from-sdf.py index 5c37729a..c9c52a77 100644 --- a/scripts/00-extract-smiles-from-sdf.py +++ b/scripts/00-extract-smiles-from-sdf.py @@ -1,24 +1,37 @@ -from syn_net.utils.prep_utils import Sdf2SmilesExtractor -from pathlib import Path +import json import logging +from pathlib import Path + +from syn_net.utils.prep_utils import Sdf2SmilesExtractor logger = logging.getLogger(__name__) -def main(file): +def main(file): file = Path(file) if not file.exists(): raise FileNotFoundError(file) + logger.info(f"Start parsing {file}") outfile = file.with_name(f"{file.name}-smiles").with_suffix(".csv.gz") Sdf2SmilesExtractor().from_sdf(file).to_file(outfile) + logger.info(f"Parsed file. Output written to {outfile}.") + -if __name__=="__main__": +def get_args(): import argparse + parser = argparse.ArgumentParser() - parser.add_argument("-f", "--file", type=str, help="An *.sdf file") - args = parser.parse_args() - logger.info(f"Arguments: {vars(args)}") - file = args.file - main(file) - logger.info(f"Success.") + parser.add_argument("--input-file", type=str, help="An *.sdf file") + return parser.parse_args() + + +if __name__ == "__main__": + logger.info("Start.") + + # Parse input args + args = get_args() + logger.info(f"Arguments: {json.dumps(vars(args),indent=2)}") + + main(args.input_file) + logger.info(f"Complete.") From c7d6d7bf32fa1c0809323a6219d62be549111770 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 28 Sep 2022 09:55:41 -0400 Subject: [PATCH 211/302] clean up scripts --- INSTRUCTIONS.md | 5 +--- scripts/00-extract-smiles-from-sdf.py | 23 +++++++++++------- scripts/01-filter-building-blocks.py | 19 +++++++++++---- scripts/02-compute-embeddings.py | 34 +++++++++++++-------------- scripts/03-generate-syntrees.py | 26 ++++++++++++++++---- scripts/04-filter-synthetic-trees.py | 5 ++-- 6 files changed, 71 insertions(+), 41 deletions(-) diff --git a/INSTRUCTIONS.md b/INSTRUCTIONS.md index b7290195..41ff4834 100644 --- a/INSTRUCTIONS.md +++ b/INSTRUCTIONS.md @@ -36,8 +36,7 @@ Let's start. python scripts/01-filter-building-blocks.py \ --building-blocks-file "data/assets/building-blocks/enamine-us-smiles.csv.gz" \ --rxn-templates-file "data/assets/reaction-templates/hb.txt" \ - --output-file "data/pre-process/building-blocks/enamine-us-smiles.csv.gz" \ - --verbose + --output-file "data/pre-process/building-blocks/enamine-us-smiles.csv.gz" --verbose ``` > :bulb: All following steps use this matched building blocks <-> reaction template data. You have to specify the correct files for every script to that it can load the right data. It can save some time to store these as environment variables. @@ -94,8 +93,6 @@ Let's start. 6. Featurization - > :bulb: All following steps depend on the representations for the data. Hence, you have to specify the parameters for the representations as input argument for most of the scripts so that it can operate on the right data. - We featurize each *synthetic tree*. That is, we break down each tree to each iteration step ("Add", "Expand", "Extend", "End") and featurize it. This results in a "state" vector and a a corresponding "super step" vector. diff --git a/scripts/00-extract-smiles-from-sdf.py b/scripts/00-extract-smiles-from-sdf.py index c9c52a77..c6ac4ee9 100644 --- a/scripts/00-extract-smiles-from-sdf.py +++ b/scripts/00-extract-smiles-from-sdf.py @@ -1,3 +1,5 @@ +"""Extract chemicals as SMILES from a downloaded `*.sdf*` file. +""" import json import logging from pathlib import Path @@ -7,13 +9,11 @@ logger = logging.getLogger(__name__) -def main(file): - file = Path(file) - if not file.exists(): - raise FileNotFoundError(file) - logger.info(f"Start parsing {file}") - outfile = file.with_name(f"{file.name}-smiles").with_suffix(".csv.gz") - Sdf2SmilesExtractor().from_sdf(file).to_file(outfile) +def main(): + if not input_file.exists(): + raise FileNotFoundError(input_file) + logger.info(f"Start parsing {input_file}") + Sdf2SmilesExtractor().from_sdf(input_file).to_file(outfile) logger.info(f"Parsed file. Output written to {outfile}.") @@ -22,6 +22,11 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--input-file", type=str, help="An *.sdf file") + parser.add_argument( + "--output-file", + type=str, + help="File with SMILES strings (First row `SMILES`, then one per line).", + ) return parser.parse_args() @@ -32,6 +37,8 @@ def get_args(): args = get_args() logger.info(f"Arguments: {json.dumps(vars(args),indent=2)}") - main(args.input_file) + input_file = Path(args.input_file) + outfile = Path(args.output_file) + main() logger.info(f"Complete.") diff --git a/scripts/01-filter-building-blocks.py b/scripts/01-filter-building-blocks.py index b85ee1e6..c36d86a8 100644 --- a/scripts/01-filter-building-blocks.py +++ b/scripts/01-filter-building-blocks.py @@ -3,10 +3,17 @@ import logging from rdkit import RDLogger -from syn_net.data_generation.preprocessing import BuildingBlockFileHandler, BuildingBlockFilter + from syn_net.config import MAX_PROCESSES +from syn_net.data_generation.preprocessing import ( + BuildingBlockFileHandler, + BuildingBlockFilter, + ReactionTemplateFileHandler, +) + RDLogger.DisableLog("rdApp.*") logger = logging.getLogger(__name__) +import json def get_args(): @@ -17,7 +24,7 @@ def get_args(): parser.add_argument( "--building-blocks-file", type=str, - help="Input file with SMILES strings (First row `SMILES`, then one per line).", + help="File with SMILES strings (First row `SMILES`, then one per line).", ) parser.add_argument( "--rxn-templates-file", @@ -36,13 +43,15 @@ def get_args(): if __name__ == "__main__": - args = get_args() logger.info("Start.") + # Parse input args + args = get_args() + logger.info(f"Arguments: {json.dumps(vars(args),indent=2)}") + # Load assets bblocks = BuildingBlockFileHandler().load(args.building_blocks_file) - with open(args.rxn_templates_file, "rt") as f: - rxn_templates = f.readlines() + rxn_templates = ReactionTemplateFileHandler().load(args.rxn_templates_file) bbf = BuildingBlockFilter( building_blocks=bblocks, diff --git a/scripts/02-compute-embeddings.py b/scripts/02-compute-embeddings.py index 53d8c796..e274800e 100644 --- a/scripts/02-compute-embeddings.py +++ b/scripts/02-compute-embeddings.py @@ -5,28 +5,24 @@ In the embedding space, a kNN-search will identify the 1st or 2nd reactant. """ +import json import logging +from functools import partial +from syn_net.config import MAX_PROCESSES from syn_net.data_generation.preprocessing import BuildingBlockFileHandler -from syn_net.encoding.fingerprints import fp_256, fp_512, fp_1024, fp_2048, fp_4096 +from syn_net.encoding.fingerprints import mol_fp from syn_net.MolEmbedder import MolEmbedder -from syn_net.config import MAX_PROCESSES -# from syn_net.encoding.gins import mol_embedding -# from syn_net.utils.prep_utils import rdkit2d_embedding - logger = logging.getLogger(__file__) - FUNCTIONS = { - # "gin": mol_embedding, - "fp_4096": fp_4096, - "fp_2048": fp_2048, - "fp_1024": fp_1024, - "fp_512": fp_512, - "fp_256": fp_256, - # "rdkit2d": rdkit2d_embedding, -} + "fp_4096": partial(mol_fp, _radius=2, _nBits=4096), + "fp_2048": partial(mol_fp, _radius=2, _nBits=2048), + "fp_1024": partial(mol_fp, _radius=2, _nBits=1024), + "fp_512": partial(mol_fp, _radius=2, _nBits=512), + "fp_256": partial(mol_fp, _radius=2, _nBits=256), +} # TODO: think about refactor/merge with `MorganFingerprintEncoder` def get_args(): @@ -42,7 +38,7 @@ def get_args(): parser.add_argument( "--rxn-templates-file", type=str, - help="Input file with reaction templates as SMARTS(No header, one per line).", + help="Input file with reaction templates as SMARTS (No header, one per line).", ) parser.add_argument( "--output-file", @@ -52,9 +48,8 @@ def get_args(): parser.add_argument( "--featurization-fct", type=str, - default="fp_256", choices=FUNCTIONS.keys(), - help="Objective function to optimize", + help="Featurization function applied to each molecule.", ) # Processing parser.add_argument("--ncpu", type=int, default=MAX_PROCESSES, help="Number of cpus") @@ -63,8 +58,11 @@ def get_args(): if __name__ == "__main__": + logger.info("Start.") + # Parse input args args = get_args() + logger.info(f"Arguments: {json.dumps(vars(args),indent=2)}") # Load building blocks bblocks = BuildingBlockFileHandler().load(args.building_blocks_file) @@ -77,3 +75,5 @@ def get_args(): # Save? molembedder.save_precomputed(args.output_file) + + logger.info("Completed.") diff --git a/scripts/03-generate-syntrees.py b/scripts/03-generate-syntrees.py index 2603bbe2..53986643 100644 --- a/scripts/03-generate-syntrees.py +++ b/scripts/03-generate-syntrees.py @@ -1,5 +1,9 @@ +"""Generate synthetic trees. +""" # TODO: clean up this mess +import json import logging from collections import Counter +from pathlib import Path from rdkit import RDLogger from tqdm import tqdm @@ -43,7 +47,7 @@ def get_args(): ) # Parameters parser.add_argument( - "--number-syntrees", type=int, default=1000, help="Number of SynTrees to generate." + "--number-syntrees", type=int, default=100, help="Number of SynTrees to generate." ) # Processing @@ -59,12 +63,14 @@ def generate_mp() -> Tuple[dict[int, str], list[Union[SyntheticTree, None]]]: from pathos import multiprocessing as mp def wrapper(stgen, _): - stgen.rng = np.random.default_rng() + stgen.rng = np.random.default_rng() # TODO: Think about this... return wraps_syntreegenerator_generate(stgen) func = partial(wrapper, stgen) - with mp.Pool(processes=4) as pool: + + with mp.Pool(processes=args.ncpu) as pool: results = pool.map(func, range(args.number_syntrees)) + outcomes = { i: e.__class__.__name__ if e is not None else "success" for i, (_, e) in enumerate(results) } @@ -86,25 +92,35 @@ def generate() -> Tuple[dict[int, str], list[Union[SyntheticTree, None]]]: if __name__ == "__main__": logger.info("Start.") + # Parse input args args = get_args() - logger.info(f"Arguments: {vars(args)}") + logger.info(f"Arguments: {json.dumps(vars(args),indent=2)}") # Load assets bblocks = BuildingBlockFileHandler().load(args.building_blocks_file) rxn_templates = ReactionTemplateFileHandler().load(args.rxn_templates_file) + logger.info("Loaded building block & rxn-template assets.") # Init SynTree Generator + logger.info("Start initializing SynTreeGenerator...") stgen = SynTreeGenerator( building_blocks=bblocks, rxn_templates=rxn_templates, verbose=args.verbose ) + logger.info("Successfully initialized SynTreeGenerator.") + # Generate synthetic trees logger.info(f"Start generation of {args.number_syntrees} SynTrees...") if args.ncpu > 1: outcomes, syntrees = generate_mp() else: outcomes, syntrees = generate() - logger.info(f"SynTree generation completed. Results: {Counter(outcomes.values())}") + result_summary = Counter(outcomes.values()) + logger.info(f"SynTree generation completed. Results: {result_summary}") + + summary_file = Path(args.output_file).parent / "results-summary.txt" + summary_file.parent.mkdir(parents=True, exist_ok=True) + summary_file.write_text(json.dumps(result_summary, indent=2)) # Save synthetic trees on disk syntree_collection = SyntheticTreeSet(syntrees) diff --git a/scripts/04-filter-synthetic-trees.py b/scripts/04-filter-synthetic-trees.py index 486ce552..012a45ff 100644 --- a/scripts/04-filter-synthetic-trees.py +++ b/scripts/04-filter-synthetic-trees.py @@ -4,7 +4,6 @@ import json import logging - import numpy as np from rdkit import Chem, RDLogger from tqdm import tqdm @@ -47,7 +46,7 @@ def _qed(self, st: SyntheticTree): def _random(self, st: SyntheticTree): """Filter molecules that fail the `_qed` filter; i.e. randomly select low qed molecules.""" - return self.rng.random() < self.oracle_fct(st.root.smiles) / self.threshold + return self.rng.random() < (self.oracle_fct(st.root.smiles) / self.threshold) def filter(self, st: SyntheticTree) -> bool: return self._qed(st) or self._random(st) @@ -93,6 +92,7 @@ def get_args(): valid_root_mol_filter = ValidRootMolFilter() interesting_mol_filter = OracleFilter(threshold=0.5, rng=np.random.default_rng()) + logger.info(f"Start filtering syntrees...") syntrees = [] syntree_collection = [s for s in syntree_collection if s is not None] syntree_collection = tqdm(syntree_collection) if args.verbose else syntree_collection @@ -111,6 +111,7 @@ def get_args(): # We passed all filters. This tree ascended to our dataset syntrees.append(st) + logger.info(f"Successfully filtered syntrees.") # Save filtered synthetic trees on disk SyntheticTreeSet(syntrees).save(args.output_file) From fbefdfc6773cb0e541ac0da47cdcaf99dc95bcfc Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 28 Sep 2022 11:35:08 -0400 Subject: [PATCH 212/302] refactor + cleanup --- INSTRUCTIONS.md | 19 +++-- scripts/05-split-syntrees.py | 13 ++- scripts/06-featurize-syntrees.py | 103 ++++++++++++++++-------- src/syn_net/data_generation/syntrees.py | 39 ++++++--- 4 files changed, 114 insertions(+), 60 deletions(-) diff --git a/INSTRUCTIONS.md b/INSTRUCTIONS.md index 41ff4834..81fc3389 100644 --- a/INSTRUCTIONS.md +++ b/INSTRUCTIONS.md @@ -87,8 +87,8 @@ Let's start. ```bash python scripts/05-split-syntrees.py \ - --input-file "data/pre-process/synthetic-trees-filtered.json.gz" - --output-dir "data/pre-process/split" + --input-file "data/pre-process/syntrees/synthetic-trees-filtered.json.gz" \ + --output-dir "data/pre-process/syntrees/" ``` 6. Featurization @@ -96,17 +96,20 @@ Let's start. We featurize each *synthetic tree*. That is, we break down each tree to each iteration step ("Add", "Expand", "Extend", "End") and featurize it. This results in a "state" vector and a a corresponding "super step" vector. - We call it "super step" here, as it contains all (featurized) data for all networks. + We call it "super step" here, as it contains all featurized data for all networks. ```bash python scripts/06-featurize-syntrees.py \ - --input-file "data/pre-process/split/synthetic-trees-train.json.gz" # or {train,valid,test} - --output-dir "data/featurized" + --input-dir "data/pre-process/syntrees/" + --output-dir "data/featurized" --verbose ``` - This script will load the `input-file`, featurize it, and it in - - `/hb_fp_2_4096_fp_256/states_{train,valid,test}.np` and - - `/hb_fp_2_4096_fp_256/steps_{train,valid,test}.np`. + This script will load the `{train,valid,test}` data, featurize it, and save it in + - `/{train,valid,test}_states.npz` and + - `/{train,valid,test}_steps.npz`. + + The encoders for the molecules must be provided in the script. + A short text summary of the encoders will be saved as well. 7. Split features diff --git a/scripts/05-split-syntrees.py b/scripts/05-split-syntrees.py index 9062374d..9fbef7b5 100644 --- a/scripts/05-split-syntrees.py +++ b/scripts/05-split-syntrees.py @@ -1,11 +1,10 @@ -""" -Reads synthetic tree data and splits it into training, validation and testing sets. +"""Reads synthetic tree data and splits it into training, validation and testing sets. """ import json import logging from pathlib import Path -from syn_net.config import DATA_PREPROCESS_DIR, MAX_PROCESSES +from syn_net.config import MAX_PROCESSES from syn_net.utils.data_utils import SyntheticTreeSet logger = logging.getLogger(__name__) @@ -19,13 +18,11 @@ def get_args(): parser.add_argument( "--input-file", type=str, - default="data/pre-process/synthetic-trees.json.gz", help="Input file for the filtered generated synthetic trees (*.json.gz)", ) parser.add_argument( "--output-dir", type=str, - default=str(Path(DATA_PREPROCESS_DIR) / "split"), help="Output directory for the splitted synthetic trees (*.json.gz)", ) @@ -66,12 +63,12 @@ def get_args(): out_dir.mkdir(parents=True, exist_ok=True) logger.info(f"Saving training dataset. Number of syntrees: {len(data_train)}") - SyntheticTreeSet(data_train).save(out_dir / "synthetic-trees-train.json.gz") + SyntheticTreeSet(data_train).save(out_dir / "synthetic-trees-filtered-train.json.gz") logger.info(f"Saving validation dataset. Number of syntrees: {len(data_valid)}") - SyntheticTreeSet(data_valid).save(out_dir / "synthetic-trees-valid.json.gz") + SyntheticTreeSet(data_valid).save(out_dir / "synthetic-trees-filtered-valid.json.gz") logger.info(f"Saving testing dataset. Number of syntrees: {len(data_test)}") - SyntheticTreeSet(data_test).save(out_dir / "synthetic-trees-test.json.gz") + SyntheticTreeSet(data_test).save(out_dir / "synthetic-trees-filtered-test.json.gz") logger.info(f"Completed.") diff --git a/scripts/06-featurize-syntrees.py b/scripts/06-featurize-syntrees.py index 7adf58ad..4a641df0 100644 --- a/scripts/06-featurize-syntrees.py +++ b/scripts/06-featurize-syntrees.py @@ -1,5 +1,4 @@ -""" -Splits a synthetic tree into states and steps. +"""Splits a synthetic tree into states and steps. """ import json import logging @@ -8,12 +7,16 @@ from scipy import sparse from tqdm import tqdm -from syn_net.data_generation.syntrees import SynTreeFeaturizer +from syn_net.data_generation.syntrees import ( + IdentityIntEncoder, + MorganFingerprintEncoder, + SynTreeFeaturizer, +) from syn_net.utils.data_utils import SyntheticTreeSet logger = logging.getLogger(__file__) -from syn_net.config import DATA_FEATURIZED_DIR +from syn_net.config import MAX_PROCESSES def get_args(): @@ -22,62 +25,92 @@ def get_args(): parser = argparse.ArgumentParser() # File I/O parser.add_argument( - "--input-file", + "--input-dir", type=str, - default="data/pre-process/split/synthetic-trees-valid.json.gz", # TODO think about filename vs dir - help="Input file for the splitted generated synthetic trees (*.json.gz)", + help="Directory with `*{train,valid,test}*.json.gz`-data of synthetic trees", ) parser.add_argument( "--output-dir", type=str, - default=str(Path(DATA_FEATURIZED_DIR)), - help="Output directory for the splitted synthetic trees (*.json.gz)", + help="Directory for the splitted synthetic trees ({train,valid,test}_{steps,states}.npz", ) + # Processing + parser.add_argument("--ncpu", type=int, default=MAX_PROCESSES, help="Number of cpus") + parser.add_argument("--verbose", default=False, action="store_true") return parser.parse_args() -def _extract_dataset(filename: str) -> str: - stem = Path(filename).stem.split(".")[0] - return stem.split("-")[-1] # TODO: avoid hard coding - +def _match_dataset_filename(path: str, dataset_type: str) -> Path: + """Helper to find the exact filename for {train,valid,test} file.""" + files = list(Path(path).glob(f"*{dataset_type}*.json.gz")) + if len(files) != 1: + raise ValueError(f"Can not find unique '{dataset_type} 'file, got {files}") + return files[0] -if __name__ == "__main__": - logger.info("Start.") - # Parse input args - args = get_args() - logger.info(f"Arguments: {json.dumps(vars(args),indent=2)}") - dataset_type = _extract_dataset(args.input_file) +def featurize_data( + syntree_featurizer: SynTreeFeaturizer, input_dir: str, output_dir: Path, verbose: bool = False +): + """Wrapper method to featurize synthetic tree data.""" - st_set = SyntheticTreeSet().load(args.input_file) - logger.info(f"Number of synthetic trees: {len(st_set.sts)}") - data: list = st_set.sts - del st_set + # Load syntree data + logger.info(f"Start loading {input_dir}") + syntree_collection = SyntheticTreeSet().load(input_dir) + logger.info(f"Successfully loaded synthetic trees.") + logger.info(f" Number of trees: {len(syntree_collection.sts)}") # Start splitting synthetic trees in states and steps states = [] steps = [] - stf = SynTreeFeaturizer() - for st in tqdm(data): + unsuccessfuls = [] + it = tqdm(syntree_collection) if verbose else syntree_collection + for i, syntree in enumerate(it): try: - state, step = stf.featurize(st) + state, step = syntree_featurizer.featurize(syntree) except Exception as e: logger.exception(e, exc_info=e) + unsuccessfuls += [i] continue states.append(state) steps.append(step) - - # Set output directory - save_dir = Path(args.output_dir) / "hb_fp_2_4096_fp_256" # TODO: Save info as json in dir? - Path(save_dir).mkdir(parents=1, exist_ok=1) - dataset_type = _extract_dataset(args.input_file) + logger.info(f"Completed featurizing syntrees.") + if len(unsuccessfuls) > 0: + logger.warning(f"Unsuccessfully attempted to featurize syntrees: {unsuccessfuls}.") # Finally, save. - logger.info(f"Saving to {save_dir}") + logger.info(f"Saving to directory {output_dir}") states = sparse.vstack(states) steps = sparse.vstack(steps) - sparse.save_npz(save_dir / f"states_{dataset_type}.npz", states) - sparse.save_npz(save_dir / f"steps_{dataset_type}.npz", steps) - + sparse.save_npz(output_dir / f"{dataset_type}_states.npz", states) + sparse.save_npz(output_dir / f"{dataset_type}_steps.npz", steps) logger.info("Save successful.") + return None + + +if __name__ == "__main__": + logger.info("Start.") + + # Parse input args + args = get_args() + logger.info(f"Arguments: {json.dumps(vars(args),indent=2)}") + + stfeat = SynTreeFeaturizer( + reactant_embedder=MorganFingerprintEncoder(2, 256), + mol_embedder=MorganFingerprintEncoder(2, 4096), + rxn_embedder=IdentityIntEncoder(), + action_embedder=IdentityIntEncoder(), + ) + + # Ensure output dir exists + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=1, exist_ok=1) + + for dataset_type in "train valid test".split(): + + input_file = _match_dataset_filename(args.input_dir, dataset_type) + featurize_data(stfeat, input_file, output_dir=output_dir, verbose=args.verbose) + + # Save information + (output_dir / "summary.txt").write_text(f"{stfeat}") # TODO: Parse as proper json? + logger.info("Completed.") diff --git a/src/syn_net/data_generation/syntrees.py b/src/syn_net/data_generation/syntrees.py index 39776b1e..f29fdaf8 100644 --- a/src/syn_net/data_generation/syntrees.py +++ b/src/syn_net/data_generation/syntrees.py @@ -335,7 +335,19 @@ def save_syntreegenerator(syntreegenerator: SynTreeGenerator, file: str) -> None # TODO: Move all these encoders to "from syn_net.encoding/" # TODO: Evaluate if One-Hot-Encoder can be replaced with encoder from sklearn -class OneHotEncoder: + +from abc import abstractmethod, ABC + +class Encoder(ABC): + + @abstractmethod + def encode(self,*args,**kwargs): + ... + + def __repr__(self) -> str: + return f"'{self.__class__.__name__}': {self.__dict__}" + +class OneHotEncoder(Encoder): def __init__(self, d: int) -> None: self.d = d @@ -346,7 +358,7 @@ def encode(self, ind: int, datatype: np.dtype = np.float64) -> np.ndarray: return onehot # (1,d) -class MorganFingerprintEncoder: +class MorganFingerprintEncoder(Encoder): def __init__(self, radius: int, nbits: int) -> None: self.radius = radius self.nbits = nbits @@ -360,10 +372,11 @@ def encode(self, smi: str) -> np.ndarray: fp = np.empty(self.nbits) Chem.DataStructs.ConvertToNumpyArray(bv, fp) fp = fp[None, :] - return fp + return fp # (1,d) -class IdentityIntEncoder: + +class IdentityIntEncoder(Encoder): def __init__(self) -> None: pass @@ -372,12 +385,20 @@ def encode(self, number: int): class SynTreeFeaturizer: - def __init__(self) -> None: + def __init__(self, *, + reactant_embedder: Encoder, + mol_embedder: Encoder, + rxn_embedder: Encoder, + action_embedder: Encoder, + ) -> None: # Embedders - self.reactant_embedder = MorganFingerprintEncoder(2, 256) # TODO: pass these in script, not here - self.mol_embedder = MorganFingerprintEncoder(2, 4096) - self.rxn_embedder = IdentityIntEncoder() - self.action_embedder = IdentityIntEncoder() + self.reactant_embedder = reactant_embedder + self.mol_embedder = mol_embedder + self.rxn_embedder = rxn_embedder + self.action_embedder = action_embedder + + def __repr__(self) -> str: + return f"{self.__dict__}" def featurize(self, syntree: SyntheticTree): """Featurize a synthetic tree at every state. From 71d6744f5306c602ef7fcd4b01fb51ee4975d37a Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 28 Sep 2022 15:33:06 -0400 Subject: [PATCH 213/302] refactor + cleanup --- INSTRUCTIONS.md | 5 +- scripts/07-split-data-for-networks.py | 28 +++-- src/syn_net/utils/prep_utils.py | 147 +++++++++++++------------- 3 files changed, 94 insertions(+), 86 deletions(-) diff --git a/INSTRUCTIONS.md b/INSTRUCTIONS.md index 81fc3389..6cb75bbe 100644 --- a/INSTRUCTIONS.md +++ b/INSTRUCTIONS.md @@ -119,9 +119,12 @@ Let's start. ```bash python scripts/07-split-data-for-networks.py \ - --input-dir "data/featurized/hb_fp_2_4096_fp_256" + --input-dir "data/featurized/" ``` + This will create 24 new files (3 splits, 4 networks, X + y). + All new files will be saved in `/Xy`. + 8. Train the networks Finally, we can train each of the four networks in `src/syn_net/models/` separately: diff --git a/scripts/07-split-data-for-networks.py b/scripts/07-split-data-for-networks.py index c16929a9..0e494519 100644 --- a/scripts/07-split-data-for-networks.py +++ b/scripts/07-split-data-for-networks.py @@ -1,13 +1,10 @@ -""" -Prepares the training, testing, and validation data by reading in the states -and steps for the reaction data and re-writing it as separate one-hot encoded -Action, Reactant 1, Reactant 2, and Reaction files. +"""Split the featurized data into X,y-chunks for the {act,rt1,rxn,rt2}-networks """ import logging from pathlib import Path +import json -from syn_net.config import DATA_FEATURIZED_DIR -from syn_net.utils.prep_utils import prep_data +from syn_net.utils.prep_utils import split_data_into_Xy logger = logging.getLogger(__file__) @@ -17,23 +14,32 @@ def get_args(): parser.add_argument( "--input-dir", type=str, - default=str(Path(DATA_FEATURIZED_DIR)) + "/hb_fp_2_4096_fp_256", # TODO: dont hardcode help="Input directory for the featurized synthetic trees (with {train,valid,test}-data).", ) return parser.parse_args() if __name__ == "__main__": logger.info("Start.") + # Parse input args args = get_args() - logger.info(f"Arguments: {vars(args)}") - - featurized_data_dir = args.input_dir + logger.info(f"Arguments: {json.dumps(vars(args),indent=2)}") # Split datasets for each MLP logger.info("Start splitting data.") num_rxn = 91 # Auxiliary var for indexing TODO: Dont hardcode out_dim = 256 # Auxiliary var for indexing TODO: Dont hardcode - prep_data(featurized_data_dir, num_rxn, out_dim) + input_dir = Path(args.input_dir) + output_dir = input_dir / "Xy" + for dataset_type in "train valid test".split(): + logger.info("Split {dataset_type}-data...") + split_data_into_Xy( + dataset_type=dataset_type, + steps_file=input_dir / f"{dataset_type}_steps.npz", + states_file=input_dir / f"{dataset_type}_states.npz", + output_dir=input_dir / "Xy", + num_rxn=num_rxn, + out_dim=out_dim, + ) logger.info(f"Completed.") diff --git a/src/syn_net/utils/prep_utils.py b/src/syn_net/utils/prep_utils.py index b32133c5..30eb903e 100644 --- a/src/syn_net/utils/prep_utils.py +++ b/src/syn_net/utils/prep_utils.py @@ -243,83 +243,82 @@ def synthetic_tree_generator( return tree, action -def prep_data(main_dir: str, num_rxn: int, out_dim: int, datasets=None): +def split_data_into_Xy( + dataset_type: str, + steps_file: str, + states_file: str, + output_dir: Path, + num_rxn: int, + out_dim: int, + ) -> None: """Split the featurized data into X,y-chunks for the {act,rt1,rxn,rt2}-networks. + Args: - main_dir (str): The path to the directory containing the *.npz files. num_rxn (int): Number of reactions in the dataset. - out_dim (int): Size of the output feature vectors. + out_dim (int): Size of the output feature vectors (used in kNN-search for rt1,rt2) """ - if datasets is None: - datasets = ['train', 'valid', 'test'] - main_dir = Path(main_dir) - - for dataset in datasets: - - print(f'Reading {dataset} data ...') - states_list = [] - steps_list = [] - - states_list.append(sparse.load_npz(main_dir / f'states_{dataset}.npz')) - steps_list.append(sparse.load_npz(main_dir / f'steps_{dataset}.npz')) - - states = sparse.csc_matrix(sparse.vstack(states_list)) - steps = sparse.csc_matrix(sparse.vstack(steps_list)) - - # Extract data for each network... - - # ... action data - # X: [z_state] - # y: [action id] (int) - X = states - y = steps[:, 0] - sparse.save_npz(main_dir / f'X_act_{dataset}.npz', X) - sparse.save_npz(main_dir / f'y_act_{dataset}.npz', y) - print(f' saved data for "Action"') - - # Delete all data where tree was ended (i.e. tree expansion did not trigger reaction) - states = sparse.csc_matrix(states.A[(steps[:, 0].A != 3).reshape(-1, )]) - steps = sparse.csc_matrix(steps.A[(steps[:, 0].A != 3).reshape(-1, )]) - - # ... reaction data - # X: [state, z_reactant_1] - # y: [reaction_id] (int) - X = sparse.hstack([states, steps[:, (2 * out_dim + 2):]]) - y = steps[:, out_dim + 1] - sparse.save_npz(main_dir / f'X_rxn_{dataset}.npz', X) - sparse.save_npz(main_dir / f'y_rxn_{dataset}.npz', y) - print(f' saved data for "Reaction"') - - states = sparse.csc_matrix(states.A[(steps[:, 0].A != 2).reshape(-1, )]) - steps = sparse.csc_matrix(steps.A[(steps[:, 0].A != 2).reshape(-1, )]) - - enc = OneHotEncoder(handle_unknown='ignore') - enc.fit([[i] for i in range(num_rxn)]) - - # ... reactant 2 data - # X: [z_state, z_reactant_1, reaction_id] - # y: [z'_reactant_2] - X = sparse.hstack( - [states, - steps[:, (2 * out_dim + 2):], - sparse.csc_matrix(enc.transform(steps[:, out_dim+1].A.reshape((-1, 1))).toarray())] - ) - y = steps[:, (out_dim+2): (2 * out_dim + 2)] - sparse.save_npz(main_dir / f'X_rt2_{dataset}.npz', X) - sparse.save_npz(main_dir / f'y_rt2_{dataset}.npz', y) - print(f' saved data for "Reactant 2"') - - states = sparse.csc_matrix(states.A[(steps[:, 0].A != 1).reshape(-1, )]) - steps = sparse.csc_matrix(steps.A[(steps[:, 0].A != 1).reshape(-1, )]) - - # ... reactant 1 data - # X: [z_state] - # y: [z'_reactant_1] - X = states - y = steps[:, 1: (out_dim+1)] - sparse.save_npz(main_dir / f'X_rt1_{dataset}.npz', X) - sparse.save_npz(main_dir / f'y_rt1_{dataset}.npz', y) - print(f' saved data for "Reactant 1"') + output_dir = Path(output_dir) + output_dir.mkdir(exist_ok=True,parents=True) + + # Load data # TODO: separate functionality? + states = sparse.load_npz(states_file) + steps = sparse.load_npz(steps_file) + + # Extract data for each network... + + # ... action data + # X: [z_state] + # y: [action id] (int) + X = states + y = steps[:, 0] + sparse.save_npz(output_dir / f'X_act_{dataset_type}.npz', X) + sparse.save_npz(output_dir / f'y_act_{dataset_type}.npz', y) + logger.info(f' saved data for "Action" to {output_dir}') + + # Delete all data where tree was ended (i.e. tree expansion did not trigger reaction) + # TODO: Look into simpler slicing with boolean indices, perhabs consider CSR for row slicing + states = sparse.csc_matrix(states.A[(steps[:, 0].A != 3).reshape(-1, )]) + steps = sparse.csc_matrix(steps.A[(steps[:, 0].A != 3).reshape(-1, )]) + + # ... reaction data + # X: [state, z_reactant_1] + # y: [reaction_id] (int) + X = sparse.hstack([states, steps[:, (2 * out_dim + 2):]]) + y = steps[:, out_dim + 1] + sparse.save_npz(output_dir / f'X_rxn_{dataset_type}.npz', X) + sparse.save_npz(output_dir / f'y_rxn_{dataset_type}.npz', y) + logger.info(f' saved data for "Reaction" to {output_dir}') + + states = sparse.csc_matrix(states.A[(steps[:, 0].A != 2).reshape(-1, )]) + steps = sparse.csc_matrix(steps.A[(steps[:, 0].A != 2).reshape(-1, )]) + + enc = OneHotEncoder(handle_unknown='ignore') + enc.fit([[i] for i in range(num_rxn)]) + + # ... reactant 2 data + # X: [z_state, z_reactant_1, reaction_id] + # y: [z'_reactant_2] + X = sparse.hstack( + [states, + steps[:, (2 * out_dim + 2):], + sparse.csc_matrix(enc.transform(steps[:, out_dim+1].A.reshape((-1, 1))).toarray())] + ) + y = steps[:, (out_dim+2): (2 * out_dim + 2)] + sparse.save_npz(output_dir / f'X_rt2_{dataset_type}.npz', X) + sparse.save_npz(output_dir / f'y_rt2_{dataset_type}.npz', y) + logger.info(f' saved data for "Reactant 2" to {output_dir}') + + states = sparse.csc_matrix(states.A[(steps[:, 0].A != 1).reshape(-1, )]) + steps = sparse.csc_matrix(steps.A[(steps[:, 0].A != 1).reshape(-1, )]) + + # ... reactant 1 data + # X: [z_state] + # y: [z'_reactant_1] + X = states + y = steps[:, 1: (out_dim+1)] + sparse.save_npz(output_dir / f'X_rt1_{dataset_type}.npz', X) + sparse.save_npz(output_dir / f'y_rt1_{dataset_type}.npz', y) + logger.info(f' saved data for "Reactant 1" to {output_dir}') return None @@ -349,7 +348,7 @@ def _to_csv_gz(self, file: Path) -> None: f.writelines("SMILES\n") f.writelines((s + "\n" for s in self.smiles)) - def _to_csv_gz(self, file: Path) -> None: + def _to_txt(self, file: Path) -> None: with open(file, "wt") as f: f.writelines("SMILES\n") f.writelines((s + "\n" for s in self.smiles)) From 38b048fa156bd931ae4387f045c12135ec63b69f Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 28 Sep 2022 15:50:16 -0400 Subject: [PATCH 214/302] bug fix --- src/syn_net/data_generation/preprocessing.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/syn_net/data_generation/preprocessing.py b/src/syn_net/data_generation/preprocessing.py index 0ac55af4..840de40c 100644 --- a/src/syn_net/data_generation/preprocessing.py +++ b/src/syn_net/data_generation/preprocessing.py @@ -1,6 +1,8 @@ +from pathlib import Path + from tqdm import tqdm -from syn_net.config import MAX_PROCESSES +from syn_net.config import MAX_PROCESSES from syn_net.utils.data_utils import Reaction @@ -25,7 +27,7 @@ def __init__( self.rxn_templates = rxn_templates # Init reactions - self.rxns = [Reaction(template=template.strip()) for template in self.rxn_templates] + self.rxns = [Reaction(template=template) for template in self.rxn_templates] # Init other stuff self.processes = processes self.verbose = verbose @@ -66,9 +68,6 @@ def filter(self): return self -from pathlib import Path - - class BuildingBlockFileHandler: def _load_csv(self, file: str) -> list[str]: """Load building blocks as smiles from `*.csv` or `*.csv.gz`.""" @@ -85,7 +84,7 @@ def load(self, file: str) -> list[str]: raise NotImplementedError def _save_csv(self, file: Path, building_blocks: list[str]): - """Save building blocks to `*.csv`""" + """Save building blocks to `*.csv.gz`""" import pandas as pd # remove possible 1 or more extensions, i.e. @@ -100,18 +99,21 @@ def _save_csv(self, file: Path, building_blocks: list[str]): def save(self, file: str, building_blocks: list[str]): """Save building blocks to file.""" file = Path(file) + file.parent.mkdir(parents=True, exist_ok=True) if ".csv" in file.suffixes: self._save_csv(file, building_blocks) else: raise NotImplementedError -class ReactionTemplateFileHandler: +class ReactionTemplateFileHandler: def load(self, file: str) -> list[str]: """Load reaction templates from file.""" with open(file, "rt") as f: rxn_templates = f.readlines() + rxn_templates = [tmplt.strip() for tmplt in rxn_templates] + if not all([self._validate(t)] for t in rxn_templates): raise ValueError("Not all reaction templates are valid.") From 1c7a192de136509fa296ddd6da67a7c8997888cf Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 28 Sep 2022 15:52:16 -0400 Subject: [PATCH 215/302] delete code with bug, leaves TODO --- src/syn_net/utils/predict_utils.py | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/src/syn_net/utils/predict_utils.py b/src/syn_net/utils/predict_utils.py index b167555b..11125486 100644 --- a/src/syn_net/utils/predict_utils.py +++ b/src/syn_net/utils/predict_utils.py @@ -219,10 +219,7 @@ def synthetic_tree_decoder( # Initialization tree = SyntheticTree() mol_recent = None - if isinstance(bb_emb,str): - kdtree = _fetch_bb_embeddings_as_balltree(bb_emb) - else: - kdtree = BallTree(bb_emb, metric=cosine_distance) + kdtree = mol_embedder # TODO: dont mis-use this arg # Start iteration for i in range(max_step): @@ -332,15 +329,6 @@ def synthetic_tree_decoder( return tree, act -@functools.lru_cache(maxsize=1) -def _fetch_bb_embeddings_as_balltree(filename: str): # TODO: find more elegant way / use MolEmbedder-cls - """Helper function to cache computing BallTree. - Can hash string, but not numpy array easily, hence this workaround. - """ - from syn_net.MolEmbedder import MolEmbedder - molemedder = MolEmbedder().load_precomputed(filename) - return BallTree(molemedder.get_embeddings(), metric=cosine_distance) - def synthetic_tree_decoder_beam_search( From c71cb43e76d2a02478dcc6e6d8f3303286ad8c65 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 28 Sep 2022 15:53:07 -0400 Subject: [PATCH 216/302] fix --- src/syn_net/MolEmbedder.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/syn_net/MolEmbedder.py b/src/syn_net/MolEmbedder.py index 1b335a4b..ee9c61fc 100644 --- a/src/syn_net/MolEmbedder.py +++ b/src/syn_net/MolEmbedder.py @@ -30,7 +30,6 @@ def _compute_mp(self, data): def compute_embeddings(self, func: Callable, building_blocks: list[str]): logger.info(f"Will compute embedding with {self.processes} processes.") - logger.info(f"Embedding function: {func.__name__}") self.func = func if self.processes == 1: embeddings = list(map(self.func, building_blocks)) @@ -46,7 +45,7 @@ def _save_npy(self, file: str): embeddings = np.asarray(self.embeddings) # assume at least 2d np.save(file, embeddings) - logger.info(f"Successfully saved data (shape={embeddings.shape}) to {file}.") + logger.info(f"Successfully saved data (shape={embeddings.shape}) to {file} .") return self def save_precomputed(self, file: str): @@ -56,7 +55,7 @@ def save_precomputed(self, file: str): if file.suffixes == [".npy"]: self._save_npy(file) else: - raise NotImplementedError(f"File have 'npy' extension, not {file.suffixes}") + raise NotImplementedError(f"File must have 'npy' extension, not {file.suffixes}") return self def _load_npy(self, file: Path): From 4571c649a4a1a74012d1b9c7803e2e3b7ec25268 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 28 Sep 2022 15:53:31 -0400 Subject: [PATCH 217/302] add TODOs --- src/syn_net/data_generation/syntrees.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/syn_net/data_generation/syntrees.py b/src/syn_net/data_generation/syntrees.py index f29fdaf8..c7099add 100644 --- a/src/syn_net/data_generation/syntrees.py +++ b/src/syn_net/data_generation/syntrees.py @@ -64,7 +64,7 @@ def __init__( *, building_blocks: list[str], rxn_templates: list[str], - rng=np.random.default_rng(), + rng=np.random.default_rng(), # TODO: Think about this... processes: int = MAX_PROCESSES, verbose: bool = False, ) -> None: From 862eeb2fd61da5af2e8dcc4c3fd4b5f4fd613a9e Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 28 Sep 2022 15:55:31 -0400 Subject: [PATCH 218/302] "cache" balltree, clean up --- scripts/predict_multireactant_mp.py | 61 +++++++++++++++++++---------- 1 file changed, 41 insertions(+), 20 deletions(-) diff --git a/scripts/predict_multireactant_mp.py b/scripts/predict_multireactant_mp.py index 6f09e4d9..6afee6a0 100644 --- a/scripts/predict_multireactant_mp.py +++ b/scripts/predict_multireactant_mp.py @@ -1,22 +1,28 @@ """ Generate synthetic trees for a set of specified query molecules. Multiprocessing. -""" +""" # TODO: Clean up +import json import logging import multiprocessing as mp from pathlib import Path from typing import Tuple, Union -import json + +from syn_net.encoding.distances import cosine_distance logger = logging.getLogger(__name__) import numpy as np import pandas as pd -from syn_net.config import (CHECKPOINTS_DIR, DATA_EMBEDDINGS_DIR, DATA_PREPARED_DIR, - DATA_PREPROCESS_DIR, DATA_RESULT_DIR) -from syn_net.data_generation.preprocessing import (BuildingBlockFileHandler, - ReactionTemplateFileHandler) +from syn_net.config import ( + CHECKPOINTS_DIR, + DATA_EMBEDDINGS_DIR, + DATA_PREPARED_DIR, + DATA_PREPROCESS_DIR, + DATA_RESULT_DIR, +) +from syn_net.data_generation.preprocessing import BuildingBlockFileHandler from syn_net.models.chkpt_loader import load_modules_from_checkpoint -from syn_net.utils.data_utils import SyntheticTree, SyntheticTreeSet, ReactionSet +from syn_net.utils.data_utils import ReactionSet, SyntheticTree, SyntheticTreeSet from syn_net.utils.predict_utils import mol_fp, synthetic_tree_decoder_beam_search Path(DATA_RESULT_DIR).mkdir(exist_ok=True) @@ -49,6 +55,7 @@ def _fetch_data(name: str) -> list[str]: smis_query = _fetch_data_from_file(name) return smis_query + def find_best_model_ckpt(path: str) -> Union[Path, None]: # TODO: move to utils.py """Find checkpoint with lowest val_loss. @@ -88,7 +95,7 @@ def _load_pretrained_model(path_to_checkpoints: list[Path]): return act_net, rt1_net, rxn_net, rt2_net -def func(smiles: str) -> Tuple[str,float,SyntheticTree]: +def func(smiles: str) -> Tuple[str, float, SyntheticTree]: """Generate a synthetic tree for the input molecular embedding.""" emb = mol_fp(smiles) try: @@ -97,7 +104,7 @@ def func(smiles: str) -> Tuple[str,float,SyntheticTree]: building_blocks=building_blocks, bb_dict=building_blocks_dict, reaction_templates=rxns, - mol_embedder=mol_fp, + mol_embedder=bblocks_molembedder.kdtree, # TODO: fix this, currently misused action_net=act_net, reactant1_net=rt1_net, rxn_net=rxn_net, @@ -109,7 +116,7 @@ def func(smiles: str) -> Tuple[str,float,SyntheticTree]: max_step=15, ) except Exception as e: - logger.error(e,exc_info=e) + logger.error(e, exc_info=e) action = -1 if action != 3: # aka tree has not been properly ended @@ -119,6 +126,7 @@ def func(smiles: str) -> Tuple[str,float,SyntheticTree]: return smi, similarity, tree + def get_args(): import argparse @@ -134,7 +142,7 @@ def get_args(): "-r", "--rxn_template", type=str, default="hb", help="Choose from ['hb', 'pis']" ) parser.add_argument("--ncpu", type=int, default=1, help="Number of cpus") - parser.add_argument("-n", "--num", type=int, default=1, help="Number of molecules to predict.") + parser.add_argument("-n", "--num", type=int, default=-1, help="Number of molecules to predict.") parser.add_argument( "-d", "--data", @@ -163,11 +171,11 @@ def get_args(): param_dir = f"{args.rxn_template}_{args.featurize}_{args.radius}_{nbits}_{out_dim}" # Load data ... - logger.info("Stat loading data...") + logger.info("Start loading data...") # ... query molecules (i.e. molecules to decode) smiles_queries = _fetch_data(args.data) if args.num > 0: # Select only n queries - smiles_queries = smiles_queries[:args.num] + smiles_queries = smiles_queries[: args.num] # ... building blocks file = Path(DATA_PREPROCESS_DIR) / f"{args.rxn_template}-{building_blocks_id}-matched.csv.gz" @@ -175,14 +183,22 @@ def get_args(): building_blocks_dict = { block: i for i, block in enumerate(building_blocks) } # dict is used as lookup table for 2nd reactant during inference + logger.info("...loading building blocks completed.") # ... reaction templates - file = Path(DATA_PREPROCESS_DIR) / f"reaction-sets_{args.rxn_template}_{building_blocks_id}.json.gz" + file = ( + Path(DATA_PREPROCESS_DIR) + / f"reaction-sets_{args.rxn_template}_{building_blocks_id}.json.gz" + ) rxns = ReactionSet().load(file).rxns + logger.info("...loading reaction collection completed.") # ... building block embedding file = Path(DATA_EMBEDDINGS_DIR) / f"{args.rxn_template}-{building_blocks_id}-embeddings.npy" - bb_emb = MolEmbedder().load_precomputed(file).embeddings + bblocks_molembedder = MolEmbedder().load_precomputed(file).init_balltree(cosine_distance) + bb_emb = bblocks_molembedder.get_embeddings() + + logger.info("...loading building block embeddings completed.") logger.info("...loading data completed.") # ... models @@ -197,8 +213,12 @@ def get_args(): # Decode queries, i.e. the target molecules. logger.info(f"Start to decode {len(smiles_queries)} target molecules.") - with mp.Pool(processes=args.ncpu) as pool: - results = pool.map(func, smiles_queries) + if args.ncpu == 1: + results = [func(smi) for smi in smiles_queries] + else: + with mp.Pool(processes=args.ncpu) as pool: + logger.info(f"Starting MP with ncpu={args.ncpu}") + results = pool.map(func, smiles_queries) logger.info("Finished decoding.") # Print some results from the prediction @@ -209,13 +229,14 @@ def get_args(): recovery_rate = (np.asfarray(similarities) == 1.0).sum() / len(similarities) avg_similarity = np.mean(similarities) logger.info(f"For {args.data}:") - logger.info(f" {len(smiles_queries)=}") + logger.info(f" Total number of attempted reconstructions: {len(smiles_queries)}") + logger.info(f" Total number of successful reconstructions: {len(smis_decoded)}") logger.info(f" {recovery_rate=}") logger.info(f" {avg_similarity=}") # Save to local dir output_dir = DATA_RESULT_DIR if args.output_dir is None else args.output_dir - logger.info("Saving results to {output_dir} ...") + logger.info(f"Saving results to {output_dir} ...") df = pd.DataFrame( {"query SMILES": smiles_queries, "decode SMILES": smis_decoded, "similarity": similarities} ) @@ -224,4 +245,4 @@ def get_args(): synthetic_tree_set = SyntheticTreeSet(sts=trees) synthetic_tree_set.save(f"{output_dir}/decoded_st_{args.data}.json.gz") - logger.info("Finish!") + logger.info("Completed.") From 3900a9cbbc59045ed8149d2e23222877fc4499b0 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 28 Sep 2022 16:01:51 -0400 Subject: [PATCH 219/302] fix typo --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index e1bb4946..bff53731 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ The model consists of four modules, each containing a multi-layer perceptron (ML 2. A *First Reactant* selection function that selects the first reactant. A MLP predicts a molecular embedding and a first reactant is identified from the pool of building blocks through a k-nearest neighbors (k-NN) search. -3. A *Reaction* selection function that select reaction. The whose output is a probability distribution over available reaction templates, from which inapplicable reactions are masked (based on reactant 1) and a suitable template is then sampled using a greedy search. +3. A *Reaction* selection function whose output is a probability distribution over available reaction templates. Inapplicable reactions are masked based on reactant 1. A suitable template is then sampled using a greedy search. 4. A *Second Reactant* selection function that identifies the second reactant if the sampled template is bi-molecular. The model predicts an embedding for the second reactant, and a candidate is then sampled via a k-NN search from the masked set of building blocks. From 8f84e486985cfea7f01f901d7d3c31b94582f60b Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 28 Sep 2022 16:25:11 -0400 Subject: [PATCH 220/302] log to `sys.stderr` with level INFO --- src/syn_net/__init__.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/syn_net/__init__.py b/src/syn_net/__init__.py index e69de29b..2d371afe 100644 --- a/src/syn_net/__init__.py +++ b/src/syn_net/__init__.py @@ -0,0 +1,11 @@ +import logging + +logging.basicConfig( + format="%(asctime)s %(name)s %(levelname)s: %(message)s", + datefmt="%H:%M:%S", + handlers=[logging.StreamHandler()], + # handlers=[logging.FileHandler(".log"),logging.StreamHandler()], + level=logging.INFO, +) + +logger = logging.getLogger(__name__) From b33be4c2bc33479e09657228256171807690e6de Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 28 Sep 2022 17:01:41 -0400 Subject: [PATCH 221/302] remove code; new `SynTreeGenerator` --- src/syn_net/utils/prep_utils.py | 91 --------------------------------- 1 file changed, 91 deletions(-) diff --git a/src/syn_net/utils/prep_utils.py b/src/syn_net/utils/prep_utils.py index 30eb903e..5280a0fd 100644 --- a/src/syn_net/utils/prep_utils.py +++ b/src/syn_net/utils/prep_utils.py @@ -152,97 +152,6 @@ def organize(st: SyntheticTree, d_mol: int=300, target_embedding: str='fp', radi return sparse.csc_matrix(np.array(states)), sparse.csc_matrix(np.array(steps)) -def synthetic_tree_generator( - building_blocks: list[str], reaction_templates: list[Reaction], max_step: int = 15 -) -> tuple[SyntheticTree, int]: - """ - Generates a synthetic tree from the available building blocks and reaction - templates. Used in preparing the training/validation/testing data. - - Args: - building_blocks (list): Contains SMILES strings for purchasable building - blocks. - reaction_templates (list): Contains `Reaction` objects. - max_step (int, optional): Indicates the maximum number of reaction steps - to use for building the synthetic tree data. Defaults to 15. - - Returns: - tree (SyntheticTree): The built up synthetic tree. - action (int): Index corresponding to a specific action. - """ - # Initialization - tree = SyntheticTree() - mol_recent = None - building_blocks = np.asarray(building_blocks) - - try: - for i in range(max_step): - # Encode current state - state = tree.get_state() - - # Predict action type, masked selection - # Action: (Add: 0, Expand: 1, Merge: 2, End: 3) - action_proba = np.random.rand(4) - action_mask = get_action_mask(tree.get_state(), reaction_templates) - action = np.argmax(action_proba * action_mask) - - # Select first molecule - if action == 3: # End - break - elif action == 0: # Add - mol1 = np.random.choice(building_blocks) - else: # Expand or Merge - mol1 = mol_recent - - # Select reaction - reaction_proba = np.random.rand(len(reaction_templates)) - - if action != 2: # = action == 0 or action == 1 - rxn_mask, available = get_reaction_mask(smi=mol1, - rxns=reaction_templates) - else: # merge tree - _, rxn_mask = can_react(tree.get_state(), reaction_templates) - available = [[] for rxn in reaction_templates] - - if rxn_mask is None: - if len(state) == 1: - action = 3 - break - else: - break - - rxn_id = np.argmax(reaction_proba * rxn_mask) - rxn = reaction_templates[rxn_id] - - # Select second molecule - if rxn.num_reactant == 2: - if action == 2: # Merge - temp = set(state) - set([mol1]) - mol2 = temp.pop() - else: # Add or Expand - mol2 = np.random.choice(available[rxn_id]) - else: - mol2 = None - - # Run reaction - mol_product = rxn.run_reaction([mol1, mol2]) - - # Update - tree.update(action, int(rxn_id), mol1, mol2, mol_product) - mol_recent = mol_product - - except Exception as e: - print(e) - action = -1 - tree = None - - if action != 3: - tree = None - else: - tree.update(action, None, None, None, None) - - return tree, action - def split_data_into_Xy( dataset_type: str, steps_file: str, From 51e85ba49dc6ce6ae451567162c8de6c1bd02bd1 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 28 Sep 2022 17:03:26 -0400 Subject: [PATCH 222/302] remove code; new `SynTreeFeaturizer` --- src/syn_net/utils/prep_utils.py | 108 -------------------------------- 1 file changed, 108 deletions(-) diff --git a/src/syn_net/utils/prep_utils.py b/src/syn_net/utils/prep_utils.py index 5280a0fd..b49c0d64 100644 --- a/src/syn_net/utils/prep_utils.py +++ b/src/syn_net/utils/prep_utils.py @@ -44,114 +44,6 @@ def _fetch_gin_pretrained_model(model_name: str): return model -def organize(st: SyntheticTree, d_mol: int=300, target_embedding: str='fp', radius: int=2, nBits:int=4096, - output_embedding: str ='gin') -> Tuple[sparse.csc_matrix,sparse.csc_matrix]: - """ - Organizes synthetic trees into states and node states at each step into sparse matrices. - - Args: - st: Synthetic tree to organize - d_mol: The molecular embedding size. Defaults to 300 - target_embedding: Embedding for the input node states. - radius: (if Morgan fingerprint) radius - nBits: (if Morgan fingerprint) bits - output_embedding: Embedding for the output node states - - Raises: - ValueError: Raised if target embedding not supported. - - Returns: - sparse.csc_matrix: Node states pulled from the tree. - sparse.csc_matrix: Actions pulled from the tree. - """ - - - states = [] - steps = [] - - OUTPUT_EMBEDDINGS_DIMS = { - "gin": 300, - "fp_4096": 4096, - "fp_256": 256, - "rdkit2d": 200, - } - - d_mol = OUTPUT_EMBEDDINGS_DIMS[output_embedding] - - # Do we need a gin embedder? - if output_embedding == "gin" or target_embedding == "gin": - model = _fetch_gin_pretrained_model("gin_supervised_contextpred") - - # Compute embedding of target molecule, i.e. the root of the synthetic tree - if target_embedding == 'fp': - target = mol_fp(st.root.smiles, radius, nBits).tolist() - elif target_embedding == 'gin': - from syn_net.encoding.gins import get_mol_embedding - # define model to use for molecular embedding - target = get_mol_embedding(st.root.smiles, model=model).tolist() - else: - raise ValueError('Target embedding only supports fp and gin.') - - most_recent_mol = None - other_root_mol = None - for i, action in enumerate(st.actions): - - most_recent_mol_embedding = mol_fp(most_recent_mol, radius, nBits).tolist() - other_root_mol_embedding = mol_fp(other_root_mol, radius, nBits).tolist() - state = most_recent_mol_embedding + other_root_mol_embedding + target # (3d,1) - - if action == 3: #end - step = [3] + [0]*d_mol + [-1] + [0]*d_mol + [0]*nBits - - else: - r = st.reactions[i] - mol1 = r.child[0] - if len(r.child) == 2: - mol2 = r.child[1] - else: - mol2 = None - - if output_embedding == 'gin': - step = ([action] - + get_mol_embedding(mol1, model=model).tolist() - + [r.rxn_id] - + get_mol_embedding(mol2, model=model).tolist() - + mol_fp(mol1, radius, nBits).tolist()) - elif output_embedding == 'fp_4096': - step = ([action] - + mol_fp(mol1, 2, 4096).tolist() - + [r.rxn_id] - + mol_fp(mol2, 2, 4096).tolist() - + mol_fp(mol1, radius, nBits).tolist()) - elif output_embedding == 'fp_256': - step = ([action] - + mol_fp(mol1, 2, 256).tolist() - + [r.rxn_id] - + mol_fp(mol2, 2, 256).tolist() - + mol_fp(mol1, radius, nBits).tolist()) - elif output_embedding == 'rdkit2d': - step = ([action] - + rdkit2d_embedding(mol1).tolist() - + [r.rxn_id] - + rdkit2d_embedding(mol2).tolist() - + mol_fp(mol1, radius, nBits).tolist()) - - if action == 2: - most_recent_mol = r.parent - other_root_mol = None - - elif action == 1: - most_recent_mol = r.parent - - elif action == 0: - other_root_mol = most_recent_mol - most_recent_mol = r.parent - - states.append(state) - steps.append(step) - - return sparse.csc_matrix(np.array(states)), sparse.csc_matrix(np.array(steps)) - def split_data_into_Xy( dataset_type: str, steps_file: str, From 93c9ce97a0bf70dfd3903b4e00463bbc5849ad1e Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 28 Sep 2022 17:04:11 -0400 Subject: [PATCH 223/302] remove code; new `scripts/03-generate-syntrees.py` --- .../data_generation/make_dataset_mp.py | 64 ------------------- 1 file changed, 64 deletions(-) delete mode 100644 src/syn_net/data_generation/make_dataset_mp.py diff --git a/src/syn_net/data_generation/make_dataset_mp.py b/src/syn_net/data_generation/make_dataset_mp.py deleted file mode 100644 index 8ce861ba..00000000 --- a/src/syn_net/data_generation/make_dataset_mp.py +++ /dev/null @@ -1,64 +0,0 @@ -""" -This file generates synthetic tree data in a multi-thread fashion. - -Usage: - python make_dataset_mp.py -""" -import logging -import multiprocessing as mp -from pathlib import Path - -import numpy as np - -from syn_net.config import BUILDING_BLOCKS_RAW_DIR, DATA_PREPROCESS_DIR, MAX_PROCESSES -from syn_net.data_generation.preprocessing import BuildingBlockFileHandler -from syn_net.utils.data_utils import ReactionSet, SyntheticTreeSet -from syn_net.utils.prep_utils import synthetic_tree_generator - -logger = logging.getLogger(__name__) - - -def func(_x): - np.random.seed(_x) # dummy input to generate "unique" seed - tree, action = synthetic_tree_generator(building_blocks, rxns) - return tree, action - - -if __name__ == "__main__": - - reaction_template_id = "hb" # "pis" or "hb" - building_blocks_id = "enamine_us-2021-smiles" - NUM_TREES = 600_000 - - # Load building blocks - building_blocks_file = Path(BUILDING_BLOCKS_RAW_DIR) / f"{building_blocks_id}.csv.gz" - building_blocks = BuildingBlockFileHandler.load(building_blocks_file) - - # Load genearted reactions (matched reactions <=> building blocks) - reactions_dir = Path(DATA_PREPROCESS_DIR) - reactions_file = f"reaction-sets_{reaction_template_id}_{building_blocks_id}.json.gz" - r_set = ReactionSet().load(reactions_dir / reactions_file) - rxns = r_set.rxns - - # Generate synthetic trees - with mp.Pool(processes=MAX_PROCESSES) as pool: - results = pool.map(func, np.arange(NUM_TREES).tolist()) - - # Filter out trees that were completed with action="end" - trees = [r[0] for r in results if r[1] == 3] - actions = [r[1] for r in results] - - num_finish = actions.count(3) - num_error = actions.count(-1) - num_unfinish = NUM_TREES - num_finish - num_error - - logging.info(f"Total trial {NUM_TREES}") - logging.info(f"Number of finished trees: {num_finish}") - logging.info(f"Number of of unfinished tree: {num_unfinish}") - logging.info(f"Number of error processes: {num_error}") - - # Save to local disk - tree_set = SyntheticTreeSet(trees) - outfile = f"synthetic-trees_{reaction_template_id}-{building_blocks_id}.json.gz" - file = Path(DATA_PREPROCESS_DIR) / outfile - tree_set.save(file) From b05b0a06ae8f34b4aa3d32fc44c348c115b9a9e5 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 28 Sep 2022 17:04:26 -0400 Subject: [PATCH 224/302] add TODO --- src/syn_net/utils/predict_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/syn_net/utils/predict_utils.py b/src/syn_net/utils/predict_utils.py index 11125486..a78745aa 100644 --- a/src/syn_net/utils/predict_utils.py +++ b/src/syn_net/utils/predict_utils.py @@ -2,7 +2,6 @@ This file contains various utils for creating molecular embeddings and for decoding synthetic trees. """ -import functools from typing import Callable, Tuple import numpy as np @@ -121,7 +120,7 @@ def get_reaction_mask(smi: str, rxns: list[Reaction]): return reaction_mask, available_list -def nn_search(_e: np.ndarray, _tree: BallTree, _k: int = 1) -> Tuple[float, float]: +def nn_search(_e: np.ndarray, _tree: BallTree, _k: int = 1) -> Tuple[float, float]: # TODO: merge w `nn_search_rt1` """ Conducts a nearest neighbor search to find the molecule from the tree most simimilar to the input embedding. From 4684224ff0b2bd2f16cf2d59a126813bffdee837 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 28 Sep 2022 17:08:31 -0400 Subject: [PATCH 225/302] delete unused imports --- src/syn_net/utils/prep_utils.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/syn_net/utils/prep_utils.py b/src/syn_net/utils/prep_utils.py index b49c0d64..c675c0e4 100644 --- a/src/syn_net/utils/prep_utils.py +++ b/src/syn_net/utils/prep_utils.py @@ -1,14 +1,10 @@ """ This file contains various utils for data preparation and preprocessing. """ -from typing import Iterator, Union, Tuple +from typing import Iterator, Union import numpy as np from scipy import sparse from sklearn.preprocessing import OneHotEncoder -from syn_net.utils.data_utils import Reaction, SyntheticTree -from syn_net.utils.predict_utils import (can_react, get_action_mask, - get_reaction_mask, ) -from syn_net.encoding.fingerprints import mol_fp from pathlib import Path from rdkit import Chem From 4bfc0988f1affb66ac4b98ac498e27d7754b1ae6 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 28 Sep 2022 17:12:13 -0400 Subject: [PATCH 226/302] add todo --- src/syn_net/config.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/syn_net/config.py b/src/syn_net/config.py index 424af0bf..6c62884c 100644 --- a/src/syn_net/config.py +++ b/src/syn_net/config.py @@ -1,8 +1,9 @@ """Central place for all configuration, paths, and parameter.""" import multiprocessing # Multiprocessing -MAX_PROCESSES = min(32,multiprocessing.cpu_count()-1) +MAX_PROCESSES = min(32,multiprocessing.cpu_count())-1 +# TODO: Remove these paths bit by bit (not used except for decoing as of now) # Paths DATA_DIR = "data" ASSETS_DIR = "data/assets" From d191e654e1f2c8ff2071cd1d266b066be3f2da67 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 28 Sep 2022 17:18:07 -0400 Subject: [PATCH 227/302] format --- src/syn_net/MolEmbedder.py | 6 +- src/syn_net/config.py | 5 +- src/syn_net/data_generation/syntrees.py | 14 +- src/syn_net/encoding/distances.py | 9 +- src/syn_net/encoding/fingerprints.py | 17 +- src/syn_net/encoding/gins.py | 94 +++-- src/syn_net/encoding/utils.py | 3 +- src/syn_net/models/act.py | 8 +- src/syn_net/models/chkpt_loader.py | 8 +- src/syn_net/models/common.py | 12 +- src/syn_net/models/mlp.py | 4 +- src/syn_net/models/rxn.py | 2 +- src/syn_net/utils/data_utils.py | 440 +++++++++++++----------- src/syn_net/utils/predict_utils.py | 42 ++- src/syn_net/utils/prep_utils.py | 112 ++++-- src/syn_net/visualize/visualizer.py | 4 +- 16 files changed, 468 insertions(+), 312 deletions(-) diff --git a/src/syn_net/MolEmbedder.py b/src/syn_net/MolEmbedder.py index ee9c61fc..2298b9fc 100644 --- a/src/syn_net/MolEmbedder.py +++ b/src/syn_net/MolEmbedder.py @@ -4,7 +4,9 @@ import numpy as np from sklearn.neighbors import BallTree + from syn_net.config import MAX_PROCESSES + logger = logging.getLogger(__name__) @@ -79,9 +81,7 @@ def init_balltree(self, metric: Union[Callable, str]): if self.embeddings is None: raise ValueError("Need emebddings to compute kdtree.") X = self.embeddings - self.kdtree_metric = metric.__name__ if not isinstance(metric,str) else metric + self.kdtree_metric = metric.__name__ if not isinstance(metric, str) else metric self.kdtree = BallTree(X, metric=metric) return self - - diff --git a/src/syn_net/config.py b/src/syn_net/config.py index 6c62884c..e6b65dc2 100644 --- a/src/syn_net/config.py +++ b/src/syn_net/config.py @@ -1,7 +1,8 @@ """Central place for all configuration, paths, and parameter.""" import multiprocessing + # Multiprocessing -MAX_PROCESSES = min(32,multiprocessing.cpu_count())-1 +MAX_PROCESSES = min(32, multiprocessing.cpu_count()) - 1 # TODO: Remove these paths bit by bit (not used except for decoing as of now) # Paths @@ -26,4 +27,4 @@ DATA_RESULT_DIR = "results" # Checkpoints (& pre-trained weights) -CHECKPOINTS_DIR = "checkpoints" # \ No newline at end of file +CHECKPOINTS_DIR = "checkpoints" diff --git a/src/syn_net/data_generation/syntrees.py b/src/syn_net/data_generation/syntrees.py index c7099add..bf28a416 100644 --- a/src/syn_net/data_generation/syntrees.py +++ b/src/syn_net/data_generation/syntrees.py @@ -64,7 +64,7 @@ def __init__( *, building_blocks: list[str], rxn_templates: list[str], - rng=np.random.default_rng(), # TODO: Think about this... + rng=np.random.default_rng(), # TODO: Think about this... processes: int = MAX_PROCESSES, verbose: bool = False, ) -> None: @@ -336,17 +336,18 @@ def save_syntreegenerator(syntreegenerator: SynTreeGenerator, file: str) -> None # TODO: Move all these encoders to "from syn_net.encoding/" # TODO: Evaluate if One-Hot-Encoder can be replaced with encoder from sklearn -from abc import abstractmethod, ABC +from abc import ABC, abstractmethod -class Encoder(ABC): +class Encoder(ABC): @abstractmethod - def encode(self,*args,**kwargs): + def encode(self, *args, **kwargs): ... def __repr__(self) -> str: return f"'{self.__class__.__name__}': {self.__dict__}" + class OneHotEncoder(Encoder): def __init__(self, d: int) -> None: self.d = d @@ -375,7 +376,6 @@ def encode(self, smi: str) -> np.ndarray: return fp # (1,d) - class IdentityIntEncoder(Encoder): def __init__(self) -> None: pass @@ -385,7 +385,9 @@ def encode(self, number: int): class SynTreeFeaturizer: - def __init__(self, *, + def __init__( + self, + *, reactant_embedder: Encoder, mol_embedder: Encoder, rxn_embedder: Encoder, diff --git a/src/syn_net/encoding/distances.py b/src/syn_net/encoding/distances.py index 74045114..41d34429 100644 --- a/src/syn_net/encoding/distances.py +++ b/src/syn_net/encoding/distances.py @@ -1,6 +1,8 @@ import numpy as np + from syn_net.encoding.fingerprints import mol_fp + def cosine_distance(v1, v2): """Compute the cosine distance between two 1d-vectors. @@ -8,7 +10,8 @@ def cosine_distance(v1, v2): cosine_similarity = x'y / (||x|| ||y||) in [-1,1] cosine_distance = 1 - cosine_similarity in [0,2] """ - return max(0,min( 1-np.dot(v1,v2)/(np.linalg.norm(v1)*np.linalg.norm(v2)),2)) + return max(0, min(1 - np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2)), 2)) + def ce_distance(y, y_pred, eps=1e-15): """Computes the cross-entropy between two vectors. @@ -23,7 +26,7 @@ def ce_distance(y, y_pred, eps=1e-15): float: The cross-entropy. """ y_pred = np.clip(y_pred, eps, 1 - eps) - return - np.sum((y * np.log(y_pred) + (1 - y) * np.log(1 - y_pred))) + return -np.sum((y * np.log(y_pred) + (1 - y) * np.log(1 - y_pred))) def _tanimoto_similarity(fp1: np.ndarray, fp2: np.ndarray): @@ -39,6 +42,7 @@ def _tanimoto_similarity(fp1: np.ndarray, fp2: np.ndarray): """ return np.sum(fp1 * fp2) / (np.sum(fp1) + np.sum(fp2) - np.sum(fp1 * fp2)) + def tanimoto_similarity(target_fp: np.ndarray, smis: list[str]): """ Returns the Tanimoto similarities between a target fingerprint and molecules @@ -53,4 +57,3 @@ def tanimoto_similarity(target_fp: np.ndarray, smis: list[str]): """ fps = [mol_fp(smi, 2, 4096) for smi in smis] return [_tanimoto_similarity(target_fp, fp) for fp in fps] - diff --git a/src/syn_net/encoding/fingerprints.py b/src/syn_net/encoding/fingerprints.py index 74659fbc..ca60c4fb 100644 --- a/src/syn_net/encoding/fingerprints.py +++ b/src/syn_net/encoding/fingerprints.py @@ -2,6 +2,7 @@ from rdkit import Chem from rdkit.Chem import AllChem, DataStructs + ## Morgan fingerprints def mol_fp(smi, _radius=2, _nBits=4096): """ @@ -21,7 +22,10 @@ def mol_fp(smi, _radius=2, _nBits=4096): else: mol = Chem.MolFromSmiles(smi) features_vec = Chem.AllChem.GetMorganFingerprintAsBitVect(mol, _radius, _nBits) - return np.array(features_vec) # TODO: much slower compared to `DataStructs.ConvertToNumpyArray` (20x?) so deprecates + return np.array( + features_vec + ) # TODO: much slower compared to `DataStructs.ConvertToNumpyArray` (20x?) so deprecates + def fp_embedding(smi, _radius=2, _nBits=4096): """ @@ -36,25 +40,30 @@ def fp_embedding(smi, _radius=2, _nBits=4096): np.ndarray: A Morgan fingerprint generated using the specified parameters. """ if smi is None: - return np.zeros(_nBits).reshape((-1, )).tolist() + return np.zeros(_nBits).reshape((-1,)).tolist() else: mol = Chem.MolFromSmiles(smi) features_vec = AllChem.GetMorganFingerprintAsBitVect(mol, _radius, _nBits) features = np.zeros((1,)) DataStructs.ConvertToNumpyArray(features_vec, features) - return features.reshape((-1, )).tolist() + return features.reshape((-1,)).tolist() + def fp_4096(smi): return fp_embedding(smi, _radius=2, _nBits=4096) + def fp_2048(smi): return fp_embedding(smi, _radius=2, _nBits=2048) + def fp_1024(smi): return fp_embedding(smi, _radius=2, _nBits=1024) + def fp_512(smi): return fp_embedding(smi, _radius=2, _nBits=512) + def fp_256(smi): - return fp_embedding(smi, _radius=2, _nBits=256) \ No newline at end of file + return fp_embedding(smi, _radius=2, _nBits=256) diff --git a/src/syn_net/encoding/gins.py b/src/syn_net/encoding/gins.py index 82f19845..a95a1d1e 100644 --- a/src/syn_net/encoding/gins.py +++ b/src/syn_net/encoding/gins.py @@ -5,16 +5,15 @@ import tqdm from dgl.nn.pytorch.glob import AvgPooling from dgllife.model import load_pretrained -from dgllife.utils import (PretrainAtomFeaturizer, PretrainBondFeaturizer, - mol_to_bigraph) +from dgllife.utils import PretrainAtomFeaturizer, PretrainBondFeaturizer, mol_to_bigraph from rdkit import Chem @functools.lru_cache(1) def _fetch_gin_pretrained_model(model_name: str): """Get a GIN pretrained model to use for creating molecular embeddings""" - device = 'cpu' - model = load_pretrained(model_name).to(device) # used to learn embedding + device = "cpu" + model = load_pretrained(model_name).to(device) # used to learn embedding model.eval() return model @@ -39,10 +38,13 @@ def graph_construction_and_featurization(smiles): if mol is None: success.append(False) continue - g = mol_to_bigraph(mol, add_self_loop=True, - node_featurizer=PretrainAtomFeaturizer(), - edge_featurizer=PretrainBondFeaturizer(), - canonical_atom_order=False) + g = mol_to_bigraph( + mol, + add_self_loop=True, + node_featurizer=PretrainAtomFeaturizer(), + edge_featurizer=PretrainBondFeaturizer(), + canonical_atom_order=False, + ) graphs.append(g) success.append(True) except: @@ -50,7 +52,8 @@ def graph_construction_and_featurization(smiles): return graphs, success -def mol_embedding(smi, device='cpu', readout=AvgPooling()): + +def mol_embedding(smi, device="cpu", readout=AvgPooling()): """ Constructs a graph embedding using the GIN network for an input SMILES. @@ -61,7 +64,7 @@ def mol_embedding(smi, device='cpu', readout=AvgPooling()): Returns: np.ndarray: Either a zeros array or the graph embedding. """ - name = 'gin_supervised_contextpred' + name = "gin_supervised_contextpred" gin_pretrained_model = _fetch_gin_pretrained_model(name) # get the embedding @@ -70,21 +73,37 @@ def mol_embedding(smi, device='cpu', readout=AvgPooling()): else: mol = Chem.MolFromSmiles(smi) # convert RDKit.Mol into featurized bi-directed DGLGraph - g = mol_to_bigraph(mol, add_self_loop=True, - node_featurizer=PretrainAtomFeaturizer(), - edge_featurizer=PretrainBondFeaturizer(), - canonical_atom_order=False) + g = mol_to_bigraph( + mol, + add_self_loop=True, + node_featurizer=PretrainAtomFeaturizer(), + edge_featurizer=PretrainBondFeaturizer(), + canonical_atom_order=False, + ) bg = g.to(device) - nfeats = [bg.ndata.pop('atomic_number').to(device), - bg.ndata.pop('chirality_type').to(device)] - efeats = [bg.edata.pop('bond_type').to(device), - bg.edata.pop('bond_direction_type').to(device)] + nfeats = [ + bg.ndata.pop("atomic_number").to(device), + bg.ndata.pop("chirality_type").to(device), + ] + efeats = [ + bg.edata.pop("bond_type").to(device), + bg.edata.pop("bond_direction_type").to(device), + ] with torch.no_grad(): node_repr = gin_pretrained_model(bg, nfeats, efeats) - return readout(bg, node_repr).detach().cpu().numpy().reshape(-1, ).tolist() - - -def get_mol_embedding(smi, model, device='cpu', readout=AvgPooling()): + return ( + readout(bg, node_repr) + .detach() + .cpu() + .numpy() + .reshape( + -1, + ) + .tolist() + ) + + +def get_mol_embedding(smi, model, device="cpu", readout=AvgPooling()): """ Computes the molecular graph embedding for the input SMILES. @@ -100,21 +119,21 @@ def get_mol_embedding(smi, model, device='cpu', readout=AvgPooling()): torch.Tensor: Learned embedding for the input molecule. """ mol = Chem.MolFromSmiles(smi) - g = mol_to_bigraph(mol, add_self_loop=True, - node_featurizer=PretrainAtomFeaturizer(), - edge_featurizer=PretrainBondFeaturizer(), - canonical_atom_order=False) + g = mol_to_bigraph( + mol, + add_self_loop=True, + node_featurizer=PretrainAtomFeaturizer(), + edge_featurizer=PretrainBondFeaturizer(), + canonical_atom_order=False, + ) bg = g.to(device) - nfeats = [bg.ndata.pop('atomic_number').to(device), - bg.ndata.pop('chirality_type').to(device)] - efeats = [bg.edata.pop('bond_type').to(device), - bg.edata.pop('bond_direction_type').to(device)] + nfeats = [bg.ndata.pop("atomic_number").to(device), bg.ndata.pop("chirality_type").to(device)] + efeats = [bg.edata.pop("bond_type").to(device), bg.edata.pop("bond_direction_type").to(device)] with torch.no_grad(): node_repr = model(bg, nfeats, efeats) return readout(bg, node_repr).detach().cpu().numpy()[0] - def graph_construction_and_featurization(smiles): """ Constructs graphs from SMILES and featurizes them. @@ -127,7 +146,7 @@ def graph_construction_and_featurization(smiles): success (list of bool): Indicators for whether the SMILES string can be parsed by RDKit. """ - graphs = [] + graphs = [] success = [] for smi in tqdm(smiles): try: @@ -135,10 +154,13 @@ def graph_construction_and_featurization(smiles): if mol is None: success.append(False) continue - g = mol_to_bigraph(mol, add_self_loop=True, - node_featurizer=PretrainAtomFeaturizer(), - edge_featurizer=PretrainBondFeaturizer(), - canonical_atom_order=False) + g = mol_to_bigraph( + mol, + add_self_loop=True, + node_featurizer=PretrainAtomFeaturizer(), + edge_featurizer=PretrainBondFeaturizer(), + canonical_atom_order=False, + ) graphs.append(g) success.append(True) except: diff --git a/src/syn_net/encoding/utils.py b/src/syn_net/encoding/utils.py index 1b06828d..e5ffc995 100644 --- a/src/syn_net/encoding/utils.py +++ b/src/syn_net/encoding/utils.py @@ -1,5 +1,6 @@ import numpy as np + def one_hot_encoder(dim, space): """ Create a one-hot encoded vector of length=`space`, with a non-zero element @@ -14,4 +15,4 @@ def one_hot_encoder(dim, space): """ vec = np.zeros((1, space)) vec[0, dim] = 1 - return vec \ No newline at end of file + return vec diff --git a/src/syn_net/models/act.py b/src/syn_net/models/act.py index c43afc89..adc6ce35 100644 --- a/src/syn_net/models/act.py +++ b/src/syn_net/models/act.py @@ -32,7 +32,7 @@ X_file=Path(DATA_FEATURIZED_DIR) / f"{id}/X_{MODEL_ID}_{dataset}.npz", y_file=Path(DATA_FEATURIZED_DIR) / f"{id}/y_{MODEL_ID}_{dataset}.npz", n=None if not args.debug else 1000, - task = "classification", + task="classification", batch_size=args.batch_size, num_workers=args.ncpu, shuffle=True if dataset == "train" else False, @@ -43,7 +43,7 @@ X_file=Path(DATA_FEATURIZED_DIR) / f"{id}/X_{MODEL_ID}_{dataset}.npz", y_file=Path(DATA_FEATURIZED_DIR) / f"{id}/y_{MODEL_ID}_{dataset}.npz", n=None if not args.debug else 1000, - task = "classification", + task="classification", batch_size=args.batch_size, num_workers=args.ncpu, shuffle=True if dataset == "train" else False, @@ -83,7 +83,7 @@ save_dir.mkdir(exist_ok=True, parents=True) tb_logger = pl_loggers.TensorBoardLogger(save_dir, name="") - csv_logger = pl_loggers.CSVLogger(save_dir,name="") + csv_logger = pl_loggers.CSVLogger(save_dir, name="") logger.info(f"Log dir set to: {tb_logger.log_dir}") checkpoint_callback = ModelCheckpoint( @@ -101,7 +101,7 @@ max_epochs=max_epochs, progress_bar_refresh_rate=int(len(train_dataloader) * 0.05), callbacks=[checkpoint_callback], - logger=[tb_logger,csv_logger], + logger=[tb_logger, csv_logger], fast_dev_run=args.fast_dev_run, ) diff --git a/src/syn_net/models/chkpt_loader.py b/src/syn_net/models/chkpt_loader.py index c3f3231b..eea75b76 100644 --- a/src/syn_net/models/chkpt_loader.py +++ b/src/syn_net/models/chkpt_loader.py @@ -1,7 +1,9 @@ -from typing import Tuple -from syn_net.models.mlp import MLP +from typing import List, Tuple + import pytorch_lightning as pl -from typing import List + +from syn_net.models.mlp import MLP + def load_modules_from_checkpoint( path_to_act: str, diff --git a/src/syn_net/models/common.py b/src/syn_net/models/common.py index fb96164f..212ca647 100644 --- a/src/syn_net/models/common.py +++ b/src/syn_net/models/common.py @@ -4,8 +4,8 @@ # Helper to select validation func based on output dim from typing import Union -import torch import numpy as np +import torch from scipy import sparse VALIDATION_OPTS = { @@ -43,7 +43,9 @@ def get_args(): return parser.parse_args() -def xy_to_dataloader(X_file: str, y_file: str, task: str = "regression", n: Union[int, float] = 1.0, **kwargs): +def xy_to_dataloader( + X_file: str, y_file: str, task: str = "regression", n: Union[int, float] = 1.0, **kwargs +): """Loads featurized X,y `*.npz`-data into a `DataLoader`""" X = sparse.load_npz(X_file) y = sparse.load_npz(y_file) @@ -60,7 +62,11 @@ def xy_to_dataloader(X_file: str, y_file: str, task: str = "regression", n: Unio else: pass # X = np.atleast_2d(np.asarray(X.todense())) - y = np.atleast_2d(np.asarray(y.todense())) if task == "regression" else np.asarray(y.todense()).squeeze() + y = ( + np.atleast_2d(np.asarray(y.todense())) + if task == "regression" + else np.asarray(y.todense()).squeeze() + ) dataset = torch.utils.data.TensorDataset( torch.Tensor(X), torch.Tensor(y), diff --git a/src/syn_net/models/mlp.py b/src/syn_net/models/mlp.py index 45e05b8f..539c71f5 100644 --- a/src/syn_net/models/mlp.py +++ b/src/syn_net/models/mlp.py @@ -110,7 +110,9 @@ def validation_step(self, batch, batch_idx): elif self.valid_loss == "huber": loss = F.huber_loss(y_hat, y) else: - raise ValueError("Not specified validation loss function for '%s'" % self.valid_loss) + raise ValueError( + "Not specified validation loss function for '%s'" % self.valid_loss + ) self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) else: pass diff --git a/src/syn_net/models/rxn.py b/src/syn_net/models/rxn.py index c319370b..7bde26ee 100644 --- a/src/syn_net/models/rxn.py +++ b/src/syn_net/models/rxn.py @@ -104,7 +104,7 @@ ncpu=args.ncpu, ) else: # load from checkpt -> only for fp, not gin - # TODO: Use `ckpt_path`, c.f. https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html#pytorch_lightning.trainer.trainer.Trainer.fit + # TODO: Use `ckpt_path`, c.f. https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html#pytorch_lightning.trainer.trainer.Trainer.fit mlp = MLP.load_from_checkpoint( path_to_rxn, input_dim=input_dim, diff --git a/src/syn_net/utils/data_utils.py b/src/syn_net/utils/data_utils.py index 33b7b0fa..bc043653 100644 --- a/src/syn_net/utils/data_utils.py +++ b/src/syn_net/utils/data_utils.py @@ -29,15 +29,16 @@ class Reaction: smiles: (str): A reaction SMILES string that macthes the SMARTS pattern. reference (str): Reference information for the reaction. """ - smirks: str # SMARTS pattern + + smirks: str # SMARTS pattern rxn: Chem.rdChemReactions.ChemicalReaction num_reactant: int num_agent: int num_product: int - reactant_template: Tuple[str,str] + reactant_template: Tuple[str, str] product_template: str agent_template: str - available_reactants: Tuple[list[str],Optional[list[str]]] + available_reactants: Tuple[list[str], Optional[list[str]]] rxnname: str smiles: Any reference: Any @@ -46,9 +47,9 @@ def __init__(self, template=None, rxnname=None, smiles=None, reference=None): if template is not None: # define a few attributes based on the input - self.smirks = template.strip() - self.rxnname = rxnname - self.smiles = smiles + self.smirks = template.strip() + self.rxnname = rxnname + self.smiles = smiles self.reference = reference # compute a few additional attributes @@ -56,8 +57,8 @@ def __init__(self, template=None, rxnname=None, smiles=None, reference=None): # Extract number of ... self.num_reactant = self.rxn.GetNumReactantTemplates() - if self.num_reactant not in (1,2): - raise ValueError('Reaction is neither uni- nor bi-molecular.') + if self.num_reactant not in (1, 2): + raise ValueError("Reaction is neither uni- nor bi-molecular.") self.num_agent = self.rxn.GetNumAgentTemplates() self.num_product = self.rxn.GetNumProductTemplates() @@ -65,7 +66,7 @@ def __init__(self, template=None, rxnname=None, smiles=None, reference=None): reactants, agents, products = self.smirks.split(">") if self.num_reactant == 1: - self.reactant_template = list((reactants, )) + self.reactant_template = list((reactants,)) else: self.reactant_template = list(reactants.split(".")) self.product_template = products @@ -73,33 +74,45 @@ def __init__(self, template=None, rxnname=None, smiles=None, reference=None): else: self.smirks = None - def __init_reaction(self,smirks: str) -> Chem.rdChemReactions.ChemicalReaction: + def __init_reaction(self, smirks: str) -> Chem.rdChemReactions.ChemicalReaction: """Initializes a reaction by converting the SMARTS-pattern to an `rdkit` object.""" rxn = AllChem.ReactionFromSmarts(smirks) rdChemReactions.ChemicalReaction.Initialize(rxn) return rxn - def load(self, smirks, num_reactant, num_agent, num_product, reactant_template, - product_template, agent_template, available_reactants, rxnname, smiles, reference): + def load( + self, + smirks, + num_reactant, + num_agent, + num_product, + reactant_template, + product_template, + agent_template, + available_reactants, + rxnname, + smiles, + reference, + ): """ This function loads a set of elements and reconstructs a `Reaction` object. """ - self.smirks = smirks - self.num_reactant = num_reactant - self.num_agent = num_agent - self.num_product = num_product - self.reactant_template = list(reactant_template) - self.product_template = product_template - self.agent_template = agent_template - self.available_reactants = list(available_reactants) # TODO: use Tuple[list,list] here - self.rxnname = rxnname - self.smiles = smiles - self.reference = reference + self.smirks = smirks + self.num_reactant = num_reactant + self.num_agent = num_agent + self.num_product = num_product + self.reactant_template = list(reactant_template) + self.product_template = product_template + self.agent_template = agent_template + self.available_reactants = list(available_reactants) # TODO: use Tuple[list,list] here + self.rxnname = rxnname + self.smiles = smiles + self.reference = reference self.rxn = self.__init_reaction(self.smirks) return self @functools.lru_cache(maxsize=20) - def get_mol(self, smi: Union[str,Chem.Mol]) -> Chem.Mol: + def get_mol(self, smi: Union[str, Chem.Mol]) -> Chem.Mol: """ A internal function that returns an `RDKit.Chem.Mol` object. @@ -117,8 +130,7 @@ def get_mol(self, smi: Union[str,Chem.Mol]) -> Chem.Mol: else: raise TypeError(f"{type(smi)} not supported, only `str` or `rdkit.Chem.Mol`") - - def visualize(self, name='./reaction1_highlight.o.png'): + def visualize(self, name="./reaction1_highlight.o.png"): """ A function that plots the chemical translation into a PNG figure. One can use "from IPython.display import Image ; Image(name)" to see it @@ -131,41 +143,43 @@ def visualize(self, name='./reaction1_highlight.o.png'): name (str): The path to the figure. """ rxn = AllChem.ReactionFromSmarts(self.smirks) - d2d = Draw.MolDraw2DCairo(800,300) + d2d = Draw.MolDraw2DCairo(800, 300) d2d.DrawReaction(rxn, highlightByReactant=True) png = d2d.GetDrawingText() - open(name,'wb+').write(png) + open(name, "wb+").write(png) del rxn return name - def is_reactant(self, smi: Union[str,Chem.Mol]) -> bool: + def is_reactant(self, smi: Union[str, Chem.Mol]) -> bool: """Checks if `smi` is a reactant of this reaction.""" - smi = self.get_mol(smi) + smi = self.get_mol(smi) return self.rxn.IsMoleculeReactant(smi) - def is_agent(self, smi: Union[str,Chem.Mol]) -> bool: + def is_agent(self, smi: Union[str, Chem.Mol]) -> bool: """Checks if `smi` is an agent of this reaction.""" - smi = self.get_mol(smi) + smi = self.get_mol(smi) return self.rxn.IsMoleculeAgent(smi) def is_product(self, smi): """Checks if `smi` is a product of this reaction.""" - smi = self.get_mol(smi) + smi = self.get_mol(smi) return self.rxn.IsMoleculeProduct(smi) def is_reactant_first(self, smi: Union[str, Chem.Mol]) -> bool: - """Check if `smi` is the first reactant in this reaction """ + """Check if `smi` is the first reactant in this reaction""" mol = self.get_mol(smi) pattern = Chem.MolFromSmarts(self.reactant_template[0]) return mol.HasSubstructMatch(pattern) - def is_reactant_second(self, smi: Union[str,Chem.Mol]) -> bool: - """Check if `smi` the second reactant in this reaction """ + def is_reactant_second(self, smi: Union[str, Chem.Mol]) -> bool: + """Check if `smi` the second reactant in this reaction""" mol = self.get_mol(smi) pattern = Chem.MolFromSmarts(self.reactant_template[1]) return mol.HasSubstructMatch(pattern) - def run_reaction(self, reactants: Tuple[Union[str,Chem.Mol,None]], keep_main: bool=True) -> Union[str,None]: + def run_reaction( + self, reactants: Tuple[Union[str, Chem.Mol, None]], keep_main: bool = True + ) -> Union[str, None]: """Run this reactions with reactants and return corresponding product. Args: @@ -178,7 +192,7 @@ def run_reaction(self, reactants: Tuple[Union[str,Chem.Mol,None]], keep_main: bo # Input validation. if not isinstance(reactants, tuple): raise TypeError(f"Unsupported type '{type(reactants)}' for `reactants`.") - if not len(reactants) in (1,2): + if not len(reactants) in (1, 2): raise ValueError(f"Can only run reactions with 1 or 2 reactants, not {len(reactants)}.") rxn = self.rxn # TODO: investigate if this is necessary (if not, delete "delete rxn below") @@ -186,9 +200,8 @@ def run_reaction(self, reactants: Tuple[Union[str,Chem.Mol,None]], keep_main: bo # Convert all reactants to `Chem.Mol` r: Tuple = tuple(self.get_mol(smiles) for smiles in reactants if smiles is not None) - if self.num_reactant == 1: - if len(r)==2: # Provided two reactants for unimolecular reaction -> no rxn possible + if len(r) == 2: # Provided two reactants for unimolecular reaction -> no rxn possible return None if not self.is_reactant(r[0]): return None @@ -198,10 +211,10 @@ def run_reaction(self, reactants: Tuple[Union[str,Chem.Mol,None]], keep_main: bo pass elif self.is_reactant_first(r[1]) and self.is_reactant_second(r[0]): r = tuple(reversed(r)) - else: # No reaction possible + else: # No reaction possible return None else: - raise ValueError('This reaction is neither uni- nor bi-molecular.') + raise ValueError("This reaction is neither uni- nor bi-molecular.") # Run reaction with rdkit magic ps = rxn.RunReactants(r) @@ -224,7 +237,9 @@ def run_reaction(self, reactants: Tuple[Union[str,Chem.Mol,None]], keep_main: bo # <<< ^ delete this line if resolved. return uniqps - def _filter_reactants(self, smiles: list[str],verbose: bool=False) -> Tuple[list[str],list[str]]: + def _filter_reactants( + self, smiles: list[str], verbose: bool = False + ) -> Tuple[list[str], list[str]]: """ Filters reactants which do not match the reaction. @@ -242,7 +257,7 @@ def _filter_reactants(self, smiles: list[str],verbose: bool=False) -> Tuple[list if self.num_reactant == 1: # uni-molecular reaction reactants_1 = [smi for smi in smiles if self.is_reactant_first(smi)] - return (reactants_1, ) + return (reactants_1,) elif self.num_reactant == 2: # bi-molecular reaction reactants_1 = [smi for smi in smiles if self.is_reactant_first(smi)] @@ -250,9 +265,9 @@ def _filter_reactants(self, smiles: list[str],verbose: bool=False) -> Tuple[list return (reactants_1, reactants_2) else: - raise ValueError('This reaction is neither uni- nor bi-molecular.') + raise ValueError("This reaction is neither uni- nor bi-molecular.") - def set_available_reactants(self, building_blocks: list[str],verbose: bool=False): + def set_available_reactants(self, building_blocks: list[str], verbose: bool = False): """ Finds applicable reactants from a list of building blocks. Sets `self.available_reactants`. @@ -260,7 +275,7 @@ def set_available_reactants(self, building_blocks: list[str],verbose: bool=False Args: building_blocks: Building blocks as SMILES strings. """ - self.available_reactants = self._filter_reactants(building_blocks,verbose=verbose) + self.available_reactants = self._filter_reactants(building_blocks, verbose=verbose) return self @property @@ -271,23 +286,28 @@ def asdict(self) -> dict(): """Returns serializable fields as new dictionary mapping. *Excludes* Not-easily-serializable `self.rxn: rdkit.Chem.ChemicalReaction`.""" import copy - out = copy.deepcopy(self.__dict__) # TODO: + + out = copy.deepcopy(self.__dict__) # TODO: _ = out.pop("rxn") return out + class ReactionSet: """Represents a collection of reactions, for saving and loading purposes.""" - def __init__(self, rxns: Optional[list[Reaction]]=None): + + def __init__(self, rxns: Optional[list[Reaction]] = None): self.rxns = rxns if rxns is not None else [] def load(self, file: str): """Load a collection of reactions from a `*.json.gz` file.""" assert str(file).endswith(".json.gz"), f"Incompatible file extension for file {file}" - with gzip.open(file, 'r') as f: - data = json.loads(f.read().decode('utf-8')) + with gzip.open(file, "r") as f: + data = json.loads(f.read().decode("utf-8")) - for r in data['reactions']: - rxn = Reaction().load(**r) # TODO: `load()` relies on postional args, hence we cannot load a reaction that has no `available_reactants` for extample (or no template) + for r in data["reactions"]: + rxn = Reaction().load( + **r + ) # TODO: `load()` relies on postional args, hence we cannot load a reaction that has no `available_reactants` for extample (or no template) self.rxns.append(rxn) return self @@ -296,9 +316,9 @@ def save(self, file: str) -> None: assert str(file).endswith(".json.gz"), f"Incompatible file extension for file {file}" - r_list = {'reactions': [r.asdict() for r in self.rxns]} - with gzip.open(file, 'w') as f: - f.write(json.dumps(r_list).encode('utf-8')) + r_list = {"reactions": [r.asdict() for r in self.rxns]} + with gzip.open(file, "w") as f: + f.write(json.dumps(r_list).encode("utf-8")) def __len__(self): return len(self.rxns) @@ -308,7 +328,7 @@ def _print(self, x=3): for i, r in enumerate(self.rxns): if i >= x: break - print(json.dumps(r.asdict(),indent=2)) + print(json.dumps(r.asdict(), indent=2)) # the definition of classes for defining synthetic trees below @@ -324,6 +344,7 @@ class NodeChemical: depth: Depth this node is in tree (+1 for an action, +.5 for a reaction) index: Incremental index for all chemical nodes in the tree. """ + def __init__( self, smiles: Union[str, None] = None, @@ -388,12 +409,13 @@ class SyntheticTree: rxn_id2type (dict): A dictionary that maps reaction indices to reaction type (uni- or bi-molecular). """ + def __init__(self, tree=None): self.chemicals: list[NodeChemical] = [] self.reactions: list[NodeRxn] = [] - self.root = None - self.depth: float= 0 - self.actions = [] + self.root = None + self.depth: float = 0 + self.actions = [] self.rxn_id2type = None if tree is not None: @@ -406,16 +428,16 @@ def read(self, data): Args: data (dict): A dictionary representing a synthetic tree. """ - self.root = NodeChemical(**data['root']) - self.depth = data['depth'] - self.actions = data['actions'] - self.rxn_id2type = data['rxn_id2type'] + self.root = NodeChemical(**data["root"]) + self.depth = data["depth"] + self.actions = data["actions"] + self.rxn_id2type = data["rxn_id2type"] - for r_dict in data['reactions']: + for r_dict in data["reactions"]: r = NodeRxn(**r_dict) self.reactions.append(r) - for m_dict in data['chemicals']: + for m_dict in data["chemicals"]: r = NodeChemical(**m_dict) self.chemicals.append(r) @@ -426,24 +448,26 @@ def output_dict(self): Returns: data (dict): A dictionary representing a synthetic tree. """ - return {'reactions': [r.__dict__ for r in self.reactions], - 'chemicals': [m.__dict__ for m in self.chemicals], - 'root': self.root.__dict__, - 'depth': self.depth, - 'actions': self.actions, - 'rxn_id2type': self.rxn_id2type} + return { + "reactions": [r.__dict__ for r in self.reactions], + "chemicals": [m.__dict__ for m in self.chemicals], + "root": self.root.__dict__, + "depth": self.depth, + "actions": self.actions, + "rxn_id2type": self.rxn_id2type, + } def _print(self): """ A function that prints the contents of the synthetic tree. """ - print('===============Stored Molecules===============') + print("===============Stored Molecules===============") for node in self.chemicals: print(node.smiles, node.is_root) - print('===============Stored Reactions===============') + print("===============Stored Reactions===============") for node in self.reactions: print(node.rxn_id, node.rtype) - print('===============Followed Actions===============') + print("===============Followed Actions===============") print(self.actions) def get_node_index(self, smi): @@ -472,7 +496,7 @@ def get_state(self) -> list[str]: state = [node.smiles for node in self.chemicals if node.is_root] return state[::-1] - def update(self, action: int, rxn_id:int, mol1: str, mol2: str, mol_product:str): + def update(self, action: int, rxn_id: int, mol1: str, mol2: str, mol_product: str): """Update this synthetic tree by adding a reaction step. Args: @@ -486,151 +510,178 @@ def update(self, action: int, rxn_id:int, mol1: str, mol2: str, mol_product:str) """ self.actions.append(int(action)) - if action == 3: # End + if action == 3: # End self.root = self.chemicals[-1] self.depth = self.root.depth - elif action == 2: # Merge (with bi-mol rxn) + elif action == 2: # Merge (with bi-mol rxn) node_mol1 = self.chemicals[self.get_node_index(mol1)] node_mol2 = self.chemicals[self.get_node_index(mol2)] - node_rxn = NodeRxn(rxn_id=rxn_id, - rtype=2, - parent=None, - child=[node_mol1.smiles, node_mol2.smiles], - depth=max(node_mol1.depth, node_mol2.depth)+0.5, - index=len(self.reactions)) - node_product = NodeChemical(smiles=mol_product, - parent=None, - child=node_rxn.rxn_id, - is_leaf=False, - is_root=True, - depth=node_rxn.depth+0.5, - index=len(self.chemicals)) - - node_rxn.parent = node_product.smiles - node_mol1.parent = node_rxn.rxn_id - node_mol2.parent = node_rxn.rxn_id + node_rxn = NodeRxn( + rxn_id=rxn_id, + rtype=2, + parent=None, + child=[node_mol1.smiles, node_mol2.smiles], + depth=max(node_mol1.depth, node_mol2.depth) + 0.5, + index=len(self.reactions), + ) + node_product = NodeChemical( + smiles=mol_product, + parent=None, + child=node_rxn.rxn_id, + is_leaf=False, + is_root=True, + depth=node_rxn.depth + 0.5, + index=len(self.chemicals), + ) + + node_rxn.parent = node_product.smiles + node_mol1.parent = node_rxn.rxn_id + node_mol2.parent = node_rxn.rxn_id node_mol1.is_root = False node_mol2.is_root = False self.chemicals.append(node_product) self.reactions.append(node_rxn) - elif action == 1 and mol2 is None: # Expand with uni-mol rxn + elif action == 1 and mol2 is None: # Expand with uni-mol rxn node_mol1 = self.chemicals[self.get_node_index(mol1)] - node_rxn = NodeRxn(rxn_id=rxn_id, - rtype=1, - parent=None, - child=[node_mol1.smiles], - depth=node_mol1.depth+0.5, - index=len(self.reactions)) - node_product = NodeChemical(smiles=mol_product, - parent=None, - child=node_rxn.rxn_id, - is_leaf=False, - is_root=True, - depth=node_rxn.depth+0.5, - index=len(self.chemicals)) - - node_rxn.parent = node_product.smiles - node_mol1.parent = node_rxn.rxn_id + node_rxn = NodeRxn( + rxn_id=rxn_id, + rtype=1, + parent=None, + child=[node_mol1.smiles], + depth=node_mol1.depth + 0.5, + index=len(self.reactions), + ) + node_product = NodeChemical( + smiles=mol_product, + parent=None, + child=node_rxn.rxn_id, + is_leaf=False, + is_root=True, + depth=node_rxn.depth + 0.5, + index=len(self.chemicals), + ) + + node_rxn.parent = node_product.smiles + node_mol1.parent = node_rxn.rxn_id node_mol1.is_root = False self.chemicals.append(node_product) self.reactions.append(node_rxn) - elif action == 1 and mol2 is not None: # Expand with bi-mol rxn - node_mol1 = self.chemicals[self.get_node_index(mol1)] - node_mol2 = NodeChemical(smiles=mol2, - parent=None, - child=None, - is_leaf=True, - is_root=False, - depth=0, - index=len(self.chemicals)) - node_rxn = NodeRxn(rxn_id=rxn_id, - rtype=2, - parent=None, - child=[node_mol1.smiles, - node_mol2.smiles], - depth=max(node_mol1.depth, node_mol2.depth)+0.5, - index=len(self.reactions)) - node_product = NodeChemical(smiles=mol_product, - parent=None, - child=node_rxn.rxn_id, - is_leaf=False, - is_root=True, - depth=node_rxn.depth+0.5, - index=len(self.chemicals)+1) - - node_rxn.parent = node_product.smiles - node_mol1.parent = node_rxn.rxn_id - node_mol2.parent = node_rxn.rxn_id + elif action == 1 and mol2 is not None: # Expand with bi-mol rxn + node_mol1 = self.chemicals[self.get_node_index(mol1)] + node_mol2 = NodeChemical( + smiles=mol2, + parent=None, + child=None, + is_leaf=True, + is_root=False, + depth=0, + index=len(self.chemicals), + ) + node_rxn = NodeRxn( + rxn_id=rxn_id, + rtype=2, + parent=None, + child=[node_mol1.smiles, node_mol2.smiles], + depth=max(node_mol1.depth, node_mol2.depth) + 0.5, + index=len(self.reactions), + ) + node_product = NodeChemical( + smiles=mol_product, + parent=None, + child=node_rxn.rxn_id, + is_leaf=False, + is_root=True, + depth=node_rxn.depth + 0.5, + index=len(self.chemicals) + 1, + ) + + node_rxn.parent = node_product.smiles + node_mol1.parent = node_rxn.rxn_id + node_mol2.parent = node_rxn.rxn_id node_mol1.is_root = False self.chemicals.append(node_mol2) self.chemicals.append(node_product) self.reactions.append(node_rxn) - elif action == 0 and mol2 is None: # Add with uni-mol rxn - node_mol1 = NodeChemical(smiles=mol1, - parent=None, - child=None, - is_leaf=True, - is_root=False, - depth=0, - index=len(self.chemicals)) - node_rxn = NodeRxn(rxn_id=rxn_id, - rtype=1, - parent=None, - child=[node_mol1.smiles], - depth=0.5, - index=len(self.reactions)) - node_product = NodeChemical(smiles=mol_product, - parent=None, - child=node_rxn.rxn_id, - is_leaf=False, - is_root=True, - depth=1, - index=len(self.chemicals)+1) - - node_rxn.parent = node_product.smiles + elif action == 0 and mol2 is None: # Add with uni-mol rxn + node_mol1 = NodeChemical( + smiles=mol1, + parent=None, + child=None, + is_leaf=True, + is_root=False, + depth=0, + index=len(self.chemicals), + ) + node_rxn = NodeRxn( + rxn_id=rxn_id, + rtype=1, + parent=None, + child=[node_mol1.smiles], + depth=0.5, + index=len(self.reactions), + ) + node_product = NodeChemical( + smiles=mol_product, + parent=None, + child=node_rxn.rxn_id, + is_leaf=False, + is_root=True, + depth=1, + index=len(self.chemicals) + 1, + ) + + node_rxn.parent = node_product.smiles node_mol1.parent = node_rxn.rxn_id self.chemicals.append(node_mol1) self.chemicals.append(node_product) self.reactions.append(node_rxn) - elif action == 0 and mol2 is not None: # Add with bi-mol rxn - node_mol1 = NodeChemical(smiles=mol1, - parent=None, - child=None, - is_leaf=True, - is_root=False, - depth=0, - index=len(self.chemicals)) - node_mol2 = NodeChemical(smiles=mol2, - parent=None, - child=None, - is_leaf=True, - is_root=False, - depth=0, - index=len(self.chemicals)+1) - node_rxn = NodeRxn(rxn_id=rxn_id, - rtype=2, - parent=None, - child=[node_mol1.smiles, node_mol2.smiles], - depth=0.5, - index=len(self.reactions)) - node_product = NodeChemical(smiles=mol_product, - parent=None, - child=node_rxn.rxn_id, - is_leaf=False, - is_root=True, - depth=1, - index=len(self.chemicals)+2) - - node_rxn.parent = node_product.smiles + elif action == 0 and mol2 is not None: # Add with bi-mol rxn + node_mol1 = NodeChemical( + smiles=mol1, + parent=None, + child=None, + is_leaf=True, + is_root=False, + depth=0, + index=len(self.chemicals), + ) + node_mol2 = NodeChemical( + smiles=mol2, + parent=None, + child=None, + is_leaf=True, + is_root=False, + depth=0, + index=len(self.chemicals) + 1, + ) + node_rxn = NodeRxn( + rxn_id=rxn_id, + rtype=2, + parent=None, + child=[node_mol1.smiles, node_mol2.smiles], + depth=0.5, + index=len(self.reactions), + ) + node_product = NodeChemical( + smiles=mol_product, + parent=None, + child=node_rxn.rxn_id, + is_leaf=False, + is_root=True, + depth=1, + index=len(self.chemicals) + 2, + ) + + node_rxn.parent = node_product.smiles node_mol1.parent = node_rxn.rxn_id node_mol2.parent = node_rxn.rxn_id @@ -640,7 +691,7 @@ def update(self, action: int, rxn_id:int, mol1: str, mol2: str, mol_product:str) self.reactions.append(node_rxn) else: - raise ValueError('Check input') + raise ValueError("Check input") return None @@ -687,5 +738,6 @@ def _print(self, x=3): break print(r.output_dict()) -if __name__ == '__main__': + +if __name__ == "__main__": pass diff --git a/src/syn_net/utils/predict_utils.py b/src/syn_net/utils/predict_utils.py index a78745aa..3fa6d1ce 100644 --- a/src/syn_net/utils/predict_utils.py +++ b/src/syn_net/utils/predict_utils.py @@ -10,6 +10,7 @@ import torch from rdkit import Chem from sklearn.neighbors import BallTree + from syn_net.encoding.distances import cosine_distance, tanimoto_similarity from syn_net.encoding.fingerprints import mol_fp from syn_net.encoding.utils import one_hot_encoder @@ -120,7 +121,9 @@ def get_reaction_mask(smi: str, rxns: list[Reaction]): return reaction_mask, available_list -def nn_search(_e: np.ndarray, _tree: BallTree, _k: int = 1) -> Tuple[float, float]: # TODO: merge w `nn_search_rt1` +def nn_search( + _e: np.ndarray, _tree: BallTree, _k: int = 1 +) -> Tuple[float, float]: # TODO: merge w `nn_search_rt1` """ Conducts a nearest neighbor search to find the molecule from the tree most simimilar to the input embedding. @@ -144,7 +147,9 @@ def nn_search_rt1(_e: np.ndarray, _tree: BallTree, _k: int = 1) -> Tuple[np.ndar return dist[0], ind[0] -def set_embedding(z_target: np.ndarray, state: list[str], nbits: int, _mol_embedding: Callable) -> np.ndarray: +def set_embedding( + z_target: np.ndarray, state: list[str], nbits: int, _mol_embedding: Callable +) -> np.ndarray: """ Computes embeddings for all molecules in the input space. Embedding = [z_mol1, z_mol2, z_target] @@ -158,7 +163,7 @@ def set_embedding(z_target: np.ndarray, state: list[str], nbits: int, _mol_embed Returns: embedding (np.ndarray): shape (1,d+2*nbits) """ - z_target = np.atleast_2d(z_target) # (1,d) + z_target = np.atleast_2d(z_target) # (1,d) if len(state) == 0: z_mol1 = np.zeros((1, nbits)) z_mol2 = np.zeros((1, nbits)) @@ -171,7 +176,8 @@ def set_embedding(z_target: np.ndarray, state: list[str], nbits: int, _mol_embed else: raise ValueError embedding = np.concatenate([z_mol1, z_mol2, z_target], axis=1) - return embedding # (1,d+2*nbits) + return embedding # (1,d+2*nbits) + def synthetic_tree_decoder( z_target: np.ndarray, @@ -218,7 +224,7 @@ def synthetic_tree_decoder( # Initialization tree = SyntheticTree() mol_recent = None - kdtree = mol_embedder # TODO: dont mis-use this arg + kdtree = mol_embedder # TODO: dont mis-use this arg # Start iteration for i in range(max_step): @@ -238,7 +244,7 @@ def synthetic_tree_decoder( break z_mol1 = reactant1_net(torch.Tensor(z_state)) - z_mol1 = z_mol1.detach().numpy() # (1,dimension_output_embedding), default: (1,256) + z_mol1 = z_mol1.detach().numpy() # (1,dimension_output_embedding), default: (1,256) # Select first molecule if act == 0: @@ -247,8 +253,8 @@ def synthetic_tree_decoder( # Idea: Increase the chances of generating a better tree. k = k_reactant1 if mol_recent is None else 1 - _, idxs = kdtree.query(z_mol1,k=k) # idxs.shape = (1,k) - mol1 = building_blocks[idxs[0][k-1]] + _, idxs = kdtree.query(z_mol1, k=k) # idxs.shape = (1,k) + mol1 = building_blocks[idxs[0][k - 1]] elif act == 1 or act == 2: # Expand or Merge mol1 = mol_recent @@ -263,19 +269,21 @@ def synthetic_tree_decoder( reaction_proba = rxn_net(torch.Tensor(z)) reaction_proba = reaction_proba.squeeze().detach().numpy() + 1e-10 # (nReactionTemplate,) - if act==0 or act==1: # add or expand + if act == 0 or act == 1: # add or expand reaction_mask, available_list = get_reaction_mask(mol1, reaction_templates) else: # merge _, reaction_mask = can_react(tree.get_state(), reaction_templates) - available_list = [[] for rxn in reaction_templates] # TODO: if act=merge, this is not used at all + available_list = [ + [] for rxn in reaction_templates + ] # TODO: if act=merge, this is not used at all # If we ended up in a state where no reaction is possible, end this iteration. if reaction_mask is None: - if len(state) == 1: # only a single root mol, so this syntree is valid + if len(state) == 1: # only a single root mol, so this syntree is valid act = 3 break else: - break # action != 3, so in our analysis we will see this tree as "invalid" + break # action != 3, so in our analysis we will see this tree as "invalid" # Select reaction template rxn_id = np.argmax(reaction_proba * reaction_mask) @@ -311,11 +319,11 @@ def synthetic_tree_decoder( # Run reaction mol_product = rxn.run_reaction((mol1, mol2)) if mol_product is None or Chem.MolFromSmiles(mol_product) is None: - if len(state) == 1: # only a single root mol, so this syntree is valid + if len(state) == 1: # only a single root mol, so this syntree is valid act = 3 break else: - break # action != 3, so in our analysis we will see this tree as "invalid" + break # action != 3, so in our analysis we will see this tree as "invalid" # Update tree.update(act, int(rxn_id), mol1, mol2, mol_product) @@ -329,10 +337,8 @@ def synthetic_tree_decoder( return tree, act - def synthetic_tree_decoder_beam_search( - beam_width: int = 3, - **kwargs + beam_width: int = 3, **kwargs ) -> Tuple[str, float, SyntheticTree, int]: """ Wrapper around `synthetic_tree_decoder_rt1` with variable `k` for kNN search of 1st reactant. @@ -353,7 +359,7 @@ def synthetic_tree_decoder_beam_search( acts: list[int] = [] for i in range(beam_width): - tree, act = synthetic_tree_decoder(k_reactant1=i+1, **kwargs) + tree, act = synthetic_tree_decoder(k_reactant1=i + 1, **kwargs) # Find the chemical in this tree that is most similar to the target. # Note: This does not have to be the final root mol, but any, as we can truncate tree to our liking. diff --git a/src/syn_net/utils/prep_utils.py b/src/syn_net/utils/prep_utils.py index c675c0e4..7f1214c1 100644 --- a/src/syn_net/utils/prep_utils.py +++ b/src/syn_net/utils/prep_utils.py @@ -1,16 +1,18 @@ """ This file contains various utils for data preparation and preprocessing. """ +import logging +from pathlib import Path from typing import Iterator, Union + import numpy as np +from rdkit import Chem from scipy import sparse from sklearn.preprocessing import OneHotEncoder -from pathlib import Path -from rdkit import Chem -import logging logger = logging.getLogger(__name__) + def rdkit2d_embedding(smi): """ Computes an embedding using RDKit 2D descriptors. @@ -22,20 +24,27 @@ def rdkit2d_embedding(smi): np.ndarray: A molecular embedding corresponding to the input molecule. """ from tdc.chem_utils import MolConvert + if smi is None: - return np.zeros(200).reshape((-1, )) + return np.zeros(200).reshape((-1,)) else: # define the RDKit 2D descriptor - rdkit2d = MolConvert(src = 'SMILES', dst = 'RDKit2D') - return rdkit2d(smi).reshape(-1, ) + rdkit2d = MolConvert(src="SMILES", dst="RDKit2D") + return rdkit2d(smi).reshape( + -1, + ) + import functools + + @functools.lru_cache(maxsize=1) def _fetch_gin_pretrained_model(model_name: str): from dgllife.model import load_pretrained + """Get a GIN pretrained model to use for creating molecular embeddings""" - device = 'cpu' - model = load_pretrained(model_name).to(device) + device = "cpu" + model = load_pretrained(model_name).to(device) model.eval() return model @@ -47,7 +56,7 @@ def split_data_into_Xy( output_dir: Path, num_rxn: int, out_dim: int, - ) -> None: +) -> None: """Split the featurized data into X,y-chunks for the {act,rt1,rxn,rt2}-networks. Args: @@ -55,7 +64,7 @@ def split_data_into_Xy( out_dim (int): Size of the output feature vectors (used in kNN-search for rt1,rt2) """ output_dir = Path(output_dir) - output_dir.mkdir(exist_ok=True,parents=True) + output_dir.mkdir(exist_ok=True, parents=True) # Load data # TODO: separate functionality? states = sparse.load_npz(states_file) @@ -68,57 +77,96 @@ def split_data_into_Xy( # y: [action id] (int) X = states y = steps[:, 0] - sparse.save_npz(output_dir / f'X_act_{dataset_type}.npz', X) - sparse.save_npz(output_dir / f'y_act_{dataset_type}.npz', y) + sparse.save_npz(output_dir / f"X_act_{dataset_type}.npz", X) + sparse.save_npz(output_dir / f"y_act_{dataset_type}.npz", y) logger.info(f' saved data for "Action" to {output_dir}') # Delete all data where tree was ended (i.e. tree expansion did not trigger reaction) # TODO: Look into simpler slicing with boolean indices, perhabs consider CSR for row slicing - states = sparse.csc_matrix(states.A[(steps[:, 0].A != 3).reshape(-1, )]) - steps = sparse.csc_matrix(steps.A[(steps[:, 0].A != 3).reshape(-1, )]) + states = sparse.csc_matrix( + states.A[ + (steps[:, 0].A != 3).reshape( + -1, + ) + ] + ) + steps = sparse.csc_matrix( + steps.A[ + (steps[:, 0].A != 3).reshape( + -1, + ) + ] + ) # ... reaction data # X: [state, z_reactant_1] # y: [reaction_id] (int) - X = sparse.hstack([states, steps[:, (2 * out_dim + 2):]]) + X = sparse.hstack([states, steps[:, (2 * out_dim + 2) :]]) y = steps[:, out_dim + 1] - sparse.save_npz(output_dir / f'X_rxn_{dataset_type}.npz', X) - sparse.save_npz(output_dir / f'y_rxn_{dataset_type}.npz', y) + sparse.save_npz(output_dir / f"X_rxn_{dataset_type}.npz", X) + sparse.save_npz(output_dir / f"y_rxn_{dataset_type}.npz", y) logger.info(f' saved data for "Reaction" to {output_dir}') - states = sparse.csc_matrix(states.A[(steps[:, 0].A != 2).reshape(-1, )]) - steps = sparse.csc_matrix(steps.A[(steps[:, 0].A != 2).reshape(-1, )]) + states = sparse.csc_matrix( + states.A[ + (steps[:, 0].A != 2).reshape( + -1, + ) + ] + ) + steps = sparse.csc_matrix( + steps.A[ + (steps[:, 0].A != 2).reshape( + -1, + ) + ] + ) - enc = OneHotEncoder(handle_unknown='ignore') + enc = OneHotEncoder(handle_unknown="ignore") enc.fit([[i] for i in range(num_rxn)]) # ... reactant 2 data # X: [z_state, z_reactant_1, reaction_id] # y: [z'_reactant_2] X = sparse.hstack( - [states, - steps[:, (2 * out_dim + 2):], - sparse.csc_matrix(enc.transform(steps[:, out_dim+1].A.reshape((-1, 1))).toarray())] + [ + states, + steps[:, (2 * out_dim + 2) :], + sparse.csc_matrix(enc.transform(steps[:, out_dim + 1].A.reshape((-1, 1))).toarray()), + ] ) - y = steps[:, (out_dim+2): (2 * out_dim + 2)] - sparse.save_npz(output_dir / f'X_rt2_{dataset_type}.npz', X) - sparse.save_npz(output_dir / f'y_rt2_{dataset_type}.npz', y) + y = steps[:, (out_dim + 2) : (2 * out_dim + 2)] + sparse.save_npz(output_dir / f"X_rt2_{dataset_type}.npz", X) + sparse.save_npz(output_dir / f"y_rt2_{dataset_type}.npz", y) logger.info(f' saved data for "Reactant 2" to {output_dir}') - states = sparse.csc_matrix(states.A[(steps[:, 0].A != 1).reshape(-1, )]) - steps = sparse.csc_matrix(steps.A[(steps[:, 0].A != 1).reshape(-1, )]) + states = sparse.csc_matrix( + states.A[ + (steps[:, 0].A != 1).reshape( + -1, + ) + ] + ) + steps = sparse.csc_matrix( + steps.A[ + (steps[:, 0].A != 1).reshape( + -1, + ) + ] + ) # ... reactant 1 data # X: [z_state] # y: [z'_reactant_1] X = states - y = steps[:, 1: (out_dim+1)] - sparse.save_npz(output_dir / f'X_rt1_{dataset_type}.npz', X) - sparse.save_npz(output_dir / f'y_rt1_{dataset_type}.npz', y) + y = steps[:, 1 : (out_dim + 1)] + sparse.save_npz(output_dir / f"X_rt1_{dataset_type}.npz", X) + sparse.save_npz(output_dir / f"y_rt1_{dataset_type}.npz", y) logger.info(f' saved data for "Reactant 1" to {output_dir}') return None + class Sdf2SmilesExtractor: """Helper class for data generation.""" @@ -156,4 +204,4 @@ def to_file(self, file: Union[str, Path]) -> None: self._to_csv_gz(file) else: self._to_txt(file) - logger.info(f"Saved data to {file}") \ No newline at end of file + logger.info(f"Saved data to {file}") diff --git a/src/syn_net/visualize/visualizer.py b/src/syn_net/visualize/visualizer.py index d3135718..df9079aa 100644 --- a/src/syn_net/visualize/visualizer.py +++ b/src/syn_net/visualize/visualizer.py @@ -146,8 +146,9 @@ def demo(): """Demo syntree visualisation""" # 1. Load syntree import json + infile = "tests/assets/syntree-small.json" - with open(infile,"rt") as f: + with open(infile, "rt") as f: data = json.load(f) st = SyntheticTree() @@ -171,5 +172,6 @@ def demo(): print(f" Output file:", outfile) return None + if __name__ == "__main__": demo() From 540903146db23930e97a8bb2625b7440ec518665 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 28 Sep 2022 18:31:02 -0400 Subject: [PATCH 228/302] fix fstring --- scripts/07-split-data-for-networks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/07-split-data-for-networks.py b/scripts/07-split-data-for-networks.py index 0e494519..ecad73b0 100644 --- a/scripts/07-split-data-for-networks.py +++ b/scripts/07-split-data-for-networks.py @@ -32,7 +32,7 @@ def get_args(): input_dir = Path(args.input_dir) output_dir = input_dir / "Xy" for dataset_type in "train valid test".split(): - logger.info("Split {dataset_type}-data...") + logger.info(f"Split {dataset_type}-data...") split_data_into_Xy( dataset_type=dataset_type, steps_file=input_dir / f"{dataset_type}_steps.npz", From ee926f2ff01bd023723d18fd50dc8fbf9c441632 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Thu, 29 Sep 2022 12:20:08 -0400 Subject: [PATCH 229/302] fix: save `Reaction`s matched to bblocks --- INSTRUCTIONS.md | 19 ++++++++++--------- scripts/01-filter-building-blocks.py | 16 +++++++++++++--- 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/INSTRUCTIONS.md b/INSTRUCTIONS.md index 6cb75bbe..348b748e 100644 --- a/INSTRUCTIONS.md +++ b/INSTRUCTIONS.md @@ -29,14 +29,15 @@ Let's start. In other words, filter out all building blocks that do not match any reaction template. There is no need to keep them, as they cannot act as reactant. In a first step, we match all building blocks with each reaction template. - In a second step, we save all matched building blocks. + In a second step, we save all matched building blocks + and a collection of `Reaction`s with their available building blocks. ```bash - # Match python scripts/01-filter-building-blocks.py \ - --building-blocks-file "data/assets/building-blocks/enamine-us-smiles.csv.gz" \ - --rxn-templates-file "data/assets/reaction-templates/hb.txt" \ - --output-file "data/pre-process/building-blocks/enamine-us-smiles.csv.gz" --verbose + --building-blocks-file "data/assets/building-blocks/enamine-us-smiles.csv.gz" \ + --rxn-templates-file "data/assets/reaction-templates/hb.txt" \ + --output-bblock-file "data/pre-process/building-blocks-rxns/bblocks-enamine-us.csv.gz" \ + --output-rxns-file "data/pre-process/building-blocks-rxns/rxns-hb-enamine-us.json.gz" --verbose ``` > :bulb: All following steps use this matched building blocks <-> reaction template data. You have to specify the correct files for every script to that it can load the right data. It can save some time to store these as environment variables. @@ -79,7 +80,7 @@ Let's start. Each *synthetic tree* is serializable and so we save all trees in a compressed `.json` file. -5. Split *synthetic trees* into train,valid,test-data +4. Split *synthetic trees* into train,valid,test-data We load the `.json`-file with all *synthetic trees* and straightforward split it into three files: `{train,test,valid}.json`. @@ -91,7 +92,7 @@ Let's start. --output-dir "data/pre-process/syntrees/" ``` -6. Featurization +5. Featurization We featurize each *synthetic tree*. That is, we break down each tree to each iteration step ("Add", "Expand", "Extend", "End") and featurize it. @@ -111,7 +112,7 @@ Let's start. The encoders for the molecules must be provided in the script. A short text summary of the encoders will be saved as well. -7. Split features +6. Split features Up to this point, we worked with a (featurized) *synthetic tree* as a whole, now we split it up to into "consumable" input/output data for each of the four networks. @@ -125,7 +126,7 @@ Let's start. This will create 24 new files (3 splits, 4 networks, X + y). All new files will be saved in `/Xy`. -8. Train the networks +7. Train the networks Finally, we can train each of the four networks in `src/syn_net/models/` separately: diff --git a/scripts/01-filter-building-blocks.py b/scripts/01-filter-building-blocks.py index c36d86a8..8068b96c 100644 --- a/scripts/01-filter-building-blocks.py +++ b/scripts/01-filter-building-blocks.py @@ -10,6 +10,7 @@ BuildingBlockFilter, ReactionTemplateFileHandler, ) +from syn_net.utils.data_utils import ReactionSet RDLogger.DisableLog("rdApp.*") logger = logging.getLogger(__name__) @@ -32,9 +33,14 @@ def get_args(): help="Input file with reaction templates as SMARTS(No header, one per line).", ) parser.add_argument( - "--output-file", + "--output-bblock-file", type=str, - help="Output file for the filtered building-blocks file.", + help="Output file for the filtered building-blocks.", + ) + parser.add_argument( + "--output-rxns-file", + type=str, + help="Output file for the collection of reactions matched with building-blocks.", ) # Processing parser.add_argument("--ncpu", type=int, default=MAX_PROCESSES, help="Number of cpus") @@ -64,7 +70,11 @@ def get_args(): # ... and save to disk bblocks_filtered = bbf.building_blocks_filtered - BuildingBlockFileHandler().save(args.output_file, bblocks_filtered) + BuildingBlockFileHandler().save(args.output_bblock_file, bblocks_filtered) + + # Save collection of reactions which have "available reactants" set (for convenience) + rxn_collection = ReactionSet(bbf.rxns) + rxn_collection.save(args.output_rxns_file) logger.info(f"Total number of building blocks {len(bblocks):d}") logger.info(f"Matched number of building blocks {len(bblocks_filtered):d}") From d546f54f2559c281a2211bd0ca772119040e5cde Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Thu, 29 Sep 2022 12:20:17 -0400 Subject: [PATCH 230/302] bug fix? --- src/syn_net/data_generation/preprocessing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/syn_net/data_generation/preprocessing.py b/src/syn_net/data_generation/preprocessing.py index 840de40c..311508e8 100644 --- a/src/syn_net/data_generation/preprocessing.py +++ b/src/syn_net/data_generation/preprocessing.py @@ -52,7 +52,7 @@ def _init_rxns_with_reactants(self): Info: This can take a while for lots of possible reactants.""" self.rxns = tqdm(self.rxns) if self.verbose else self.rxns if self.processes == 1: - [rxn.set_available_reactants(self.building_blocks) for rxn in self.rxns] + self.rxns = [rxn.set_available_reactants(self.building_blocks) for rxn in self.rxns] else: self._match_mp() From c087090fc145a7815eab0e8acab02f13b41853c4 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Thu, 29 Sep 2022 17:58:38 -0400 Subject: [PATCH 231/302] delete print only script, use `Syntree()._print` --- scripts/read_st_data.py | 20 -------------------- 1 file changed, 20 deletions(-) delete mode 100644 scripts/read_st_data.py diff --git a/scripts/read_st_data.py b/scripts/read_st_data.py deleted file mode 100644 index 95807fbd..00000000 --- a/scripts/read_st_data.py +++ /dev/null @@ -1,20 +0,0 @@ -""" -Reads synthetic tree data and prints the first five trees. -""" -from syn_net.utils.data_utils import * - - -if __name__ == "__main__": - - st_set = SyntheticTreeSet() - path_to_data = '/pool001/whgao/data/synth_net/st_pis/st_data.json.gz' - - print('Reading data from ', path_to_data) - st_set.load(path_to_data) - data = st_set.sts - - for t in data[:5]: - t._print() - - print(len(data)) - print("Finish!") From 7975c70a8134084e1bb84750c8e871a2f810191b Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Thu, 29 Sep 2022 17:59:24 -0400 Subject: [PATCH 232/302] delete old code, use `SynTreeVisualizer()` instead --- scripts/sketch-synthetic-trees.py | 243 ------------------------------ 1 file changed, 243 deletions(-) delete mode 100644 scripts/sketch-synthetic-trees.py diff --git a/scripts/sketch-synthetic-trees.py b/scripts/sketch-synthetic-trees.py deleted file mode 100644 index 138aaaaa..00000000 --- a/scripts/sketch-synthetic-trees.py +++ /dev/null @@ -1,243 +0,0 @@ -""" -Sketches the synthetic trees in a specified file. -""" -from syn_net.utils.data_utils import * -import argparse -from typing import Tuple -from rdkit.Chem import MolFromSmiles -from rdkit.Chem.Draw import MolToImage -import networkx as nx -import matplotlib.pyplot as plt -from matplotlib.patches import Rectangle - - -# define some color maps for plotting -edges_cmap = { - 0 : "tab:brown", # Add - 1 : "tab:pink", # Expand - 2 : "tab:gray", # Merge - #3 : "tab:olive", # End # not currently plotting -} -nodes_cmap = { - 0 : "tab:blue", # most recent mol - 1 : "tab:orange", # other root mol - 2 : "tab:green", # product -} - - -def get_states_and_steps(synthetic_tree : "SyntheticTree") -> Tuple[list, list]: - """ - Gets the different nodes of the input synthetic tree, and the "action type" - that was used to get to those nodes. - - Args: - synthetic_tree (SyntheticTree): - - Returns: - Tuple[list, list]: Contains lists of the states and steps (actions) from - the Synthetic Tree. - """ - states = [] - steps = [] - - target = synthetic_tree.root.smiles - most_recent_mol = None - other_root_mol = None - - for i, action in enumerate(st.actions): - - # Action: (Add: 0, Expand: 1, Merge: 2, End: 3) - if action != 3: - r = synthetic_tree.reactions[i] - mol1 = r.child[0] - if len(r.child) == 2: - mol2 = r.child[1] - else: - mol2 = None - state = [mol1, mol2, r.parent] - else: - state = [most_recent_mol, other_root_mol, target] - - if action == 2: - most_recent_mol = r.parent - other_root_mol = None - - elif action == 1: - most_recent_mol = r.parent - - elif action == 0: - other_root_mol = most_recent_mol - most_recent_mol = r.parent - - states.append(state) - steps.append(action) - - return states, steps - -def draw_tree(states : list, steps : list, tree_name : str) -> None: - """ - Draws the synthetic tree based on the input list of states (reactant/product - nodes) and steps (actions). - - Args: - states (list): Molecular nodes (i.e. reactants and products). - steps (list): Action types (e.g. "Add" and "Merge"). - tree_name (str): Name of tree to use for file saving purposes. - """ - G = nx.Graph() - pos_dict = {} # sets the position of the nodes, for plotting below - edge_color_dict = {} # sets the color of the edges based on the action - node_color_dict = {} # sets the color of the box around the node during plotting - - node_idx =0 - prev_target_idx = None - merge_correction = 0.0 - for state_idx, state in enumerate(states): - - # Action: (Add: 0, Expand: 1, Merge: 2, End: 3) - step = steps[state_idx] - if step == 3: - break - - skip_mrm = False - skip_orm = False - for smiles_idx, smiles in enumerate(state): - - if smiles is None and smiles_idx == 0: - skip_mrm = True # mrm == 'most recent mol' - continue - elif smiles is None and smiles_idx == 1: - skip_orm = True # orm == 'other root molecule' - continue - elif smiles is None and smiles_idx == 2: - continue - elif step == 1 and smiles_idx == 0: - merge_correction -= 0.5 - skip_mrm = True # mrm == 'most recent mol' - continue - - # draw the molecules (creates a PIL image) - img = MolToImage(mol=MolFromSmiles(smiles), fitImage=False) - G.add_node(str(node_idx), image=img) - node_color_dict[str(node_idx)] = nodes_cmap[smiles_idx] - if smiles_idx != 2: - pos_dict[str(node_idx)] = [state_idx + merge_correction, smiles_idx/2 + 0.01] - else: - pos_dict[str(node_idx)] = [state_idx + 0.5 + merge_correction, 0.01] # 0.01 important to not plot edge under axis label, even if later axis label is turned off (weird behavior) - if smiles_idx == 2: - if not skip_mrm: - G.add_edge(str(node_idx - 2 + int(skip_orm)), str(node_idx)) # connect most recent mol to target - edge_color_dict[(str(node_idx - 2 + int(skip_orm)), str(node_idx))] = edges_cmap[step] - if not skip_orm: - G.add_edge(str(node_idx - 1), str(node_idx)) # connect other root mol to target - edge_color_dict[(str(node_idx - 1), str(node_idx))] = edges_cmap[step] - node_idx += 1 - - if prev_target_idx and not step == 1: - mrm_idx = node_idx - 3 + int(skip_orm) - G.add_edge(str(prev_target_idx), str(mrm_idx)) # connect the previous target to the current most recent mol - edge_color_dict[(str(prev_target_idx), str(mrm_idx))] = edges_cmap[step] - elif prev_target_idx and step == 1: - new_target_idx = node_idx - 1 - G.add_edge(str(prev_target_idx), str(new_target_idx)) # connect the previous target to the current most recent mol - edge_color_dict[(str(prev_target_idx), str(new_target_idx))] = edges_cmap[step] - prev_target_idx = node_idx - 1 - - # sketch the tree - fig, ax = plt.subplots() - - nx.draw_networkx_edges( - G, - pos=pos_dict, - ax=ax, - arrows=True, - edgelist=[edge for edge in G.edges], - edge_color=[edge_color_dict[edge] for edge in G.edges], - arrowstyle="-", # suppresses arrowheads - width=2.0, - alpha=0.9, - min_source_margin=15, - min_target_margin=15, - ) - - # Transform from data coordinates (scaled between xlim and ylim) to display coordinates - tr_figure = ax.transData.transform - # Transform from display to figure coordinates - tr_axes = fig.transFigure.inverted().transform - - # Select the size of the image (relative to the X axis) - x = 0 - for positions in pos_dict.values(): - if positions[0] > x: - x = positions[0] - - _, _ = ax.set_xlim(0, x) - _, _ = ax.set_ylim(0, 0.6) - icon_size = 0.2 - icon_center = icon_size / 2.0 - - # add a legend for the edge colors - markers_edges = [plt.Line2D([0,0],[0,0],color=color, linewidth=4, marker='_', linestyle='') for color in edges_cmap.values()] - markers_nodes = [plt.Line2D([0,0],[0,0],color=color, linewidth=2, marker='s', linestyle='') for color in nodes_cmap.values()] - markers_labels = ["Add", "Reactant 1", "Expand", "Reactant 2", "Merge", "Product"] - markers =[markers_edges[0], markers_nodes[0], markers_edges[1], markers_nodes[1], markers_edges[2], markers_nodes[2]] - plt.legend(markers, markers_labels, loc='upper center', - bbox_to_anchor=(0.5, 1.15), ncol=3, fancybox=True, shadow=True) - - # Add the respective image to each node - for n in G.nodes: - xf, yf = tr_figure(pos_dict[n]) - xa, ya = tr_axes((xf, yf)) - # get overlapped axes and plot icon - a = plt.axes([xa - icon_center, ya - icon_center, icon_size, icon_size]) - a.imshow(G.nodes[n]["image"]) - # add colored boxes around each node: - plt.gca().add_patch(Rectangle((0,0),295,295, linewidth=2, edgecolor=node_color_dict[n], facecolor="none")) - a.axis("off") - - ax.axis("off") - - # save the figure - plt.savefig(f"{tree_name}.png", dpi=100) - print(f"-- Tree saved in {tree_name}.png", flush=True) - - -if __name__ == '__main__': - - parser = argparse.ArgumentParser() - parser.add_argument("--file", type=str, default='/pool001/rociomer/test-data/synth_net/st_hb_test-plot-tests.json.gz', - help="Path/filename to the synthetic trees.") - parser.add_argument("--saveto", type=str, default='/pool001/rociomer/test-data/synth_net/images/', - help="Path to save the sketched synthetic trees.") - parser.add_argument("--nsketches", type=int, default=-1, - help="How many trees to sketch. Default -1 means to sketch all.") - parser.add_argument("--actions", type=int, default=-1, - help="How many actions the tree must have in order to sketch it (useful for testing).") - args = parser.parse_args() - - st_set = SyntheticTreeSet() - st_set.load(args.file) - data = st_set.sts - - trees_sketched = 0 - for st_idx, st in enumerate(data): - if len(st.actions) <= args.actions: - # don't sketch trees with fewer than n = `args.actions` actions - continue - try: - print("* Getting states and steps...") - states, steps = get_states_and_steps(synthetic_tree=st) - - print("* Sketching tree...") - draw_tree(states=states, steps=steps, tree_name=f"{args.saveto}tree{st_idx}") - - trees_sketched += 1 - - except Exception as e: - print(e) - continue - - if not (args.nsketches == -1) and trees_sketched > args.nsketches: - break - - print("Done!") From 9b306d5e1585f30924f54b3bf14d9bc3d9947875 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Thu, 29 Sep 2022 18:00:55 -0400 Subject: [PATCH 233/302] delete unused code --- scripts/predict.py | 172 --------------------------------------------- 1 file changed, 172 deletions(-) delete mode 100644 scripts/predict.py diff --git a/scripts/predict.py b/scripts/predict.py deleted file mode 100644 index e2885f97..00000000 --- a/scripts/predict.py +++ /dev/null @@ -1,172 +0,0 @@ -""" -This file contains the code to decode synthetic trees using a greedy search at -every sampling step. -""" -import os -import pandas as pd -import numpy as np -from tqdm import tqdm -from rdkit import Chem -from rdkit import DataStructs -from syn_net.utils.data_utils import ReactionSet, SyntheticTreeSet -from dgl.nn.pytorch.glob import AvgPooling -from dgllife.model import load_pretrained -from syn_net.utils.predict_utils import mol_fp, get_mol_embedding -from syn_net.utils.predict_utils import synthetic_tree_decoder, load_modules_from_checkpoint - -if __name__ == '__main__': - - import argparse - parser = argparse.ArgumentParser() - parser.add_argument("-f", "--featurize", type=str, default='fp', - help="Choose from ['fp', 'gin']") - parser.add_argument("-r", "--rxn_template", type=str, default='hb', - help="Choose from ['hb', 'pis']") - parser.add_argument("-v", "--version", type=int, default=0, - help="Version") - parser.add_argument("--param", type=str, default='hb_fp_2_4096', - help="Name of directory with parameters in it.") - parser.add_argument("--radius", type=int, default=2, - help="Radius for Morgan fingerprint.") - parser.add_argument("--nbits", type=int, default=4096, - help="Number of Bits for Morgan fingerprint.") - parser.add_argument("--out_dim", type=int, default=300, - help="Output dimension.") - parser.add_argument("--ncpu", type=int, default=16, - help="Number of cpus") - parser.add_argument("--batch_size", type=int, default=64, - help="Batch size") - parser.add_argument("-n", "--num", type=int, default=-1, - help="Number of molecules to decode.") - parser.add_argument("-d", "--data", type=str, default='test', - help="Choose from ['train', 'valid', 'test']") - args = parser.parse_args() - - # define model to use for molecular embedding - readout = AvgPooling() - model_type = 'gin_supervised_contextpred' - device = 'cuda:0' - mol_embedder = load_pretrained(model_type).to(device) - mol_embedder.eval() - - # load the purchasable building block embeddings - bb_emb = np.load('/pool001/whgao/data/synth_net/st_' + args.rxn_template + '/enamine_us_emb.npy') - - # define path to the reaction templates and purchasable building blocks - path_to_reaction_file = ('/pool001/whgao/data/synth_net/st_' + args.rxn_template - + '/reactions_' + args.rxn_template + '.json.gz') - path_to_building_blocks = ('/pool001/whgao/data/synth_net/st_' + args.rxn_template - + '/enamine_us_matched.csv.gz') - - # define paths to pretrained modules - param_path = '/home/whgao/scGen/synth_net/synth_net/params/' + args.param + '/' - path_to_act = param_path + 'act.ckpt' - path_to_rt1 = param_path + 'rt1.ckpt' - path_to_rxn = param_path + 'rxn.ckpt' - path_to_rt2 = param_path + 'rt2.ckpt' - - np.random.seed(6) - - # load the purchasable building block SMILES to a dictionary - building_blocks = pd.read_csv(path_to_building_blocks, compression='gzip')['SMILES'].tolist() - bb_dict = {building_blocks[i]: i for i in range(len(building_blocks))} - - # load the reaction templates as a ReactionSet object - rxn_set = ReactionSet().load(path_to_reaction_file) - rxns = rxn_set.rxns - - # load the pre-trained modules - act_net, rt1_net, rxn_net, rt2_net = load_modules_from_checkpoint( - path_to_act=path_to_act, - path_to_rt1=path_to_rt1, - path_to_rxn=path_to_rxn, - path_to_rt2=path_to_rt2, - featurize=args.featurize, - rxn_template=args.rxn_template, - out_dim=args.out_dim, - nbits=args.nbits, - ncpu=args.ncpu, - ) - - def decode_one_molecule(query_smi): - """ - Generate a synthetic tree from a given query SMILES. - - Args: - query_smi (str): SMILES for molecule to decode. - - Returns: - tree (SyntheticTree): The final synthetic tree - act (int): The final action (to know if the tree was "properly" terminated) - """ - if args.featurize == 'fp': - z_target = mol_fp(query_smi, args.radius, args.nbits) - elif args.featurize == 'gin': - z_target = get_mol_embedding(query_smi) - tree, action = synthetic_tree_decoder(z_target, - building_blocks, - bb_dict, - rxns, - mol_embedder, - act_net, - rt1_net, - rxn_net, - rt2_net, - bb_emb=bb_emb, - rxn_template=args.rxn_template, - n_bits=args.nbits, - max_step=15) - return tree, action - - - path_to_data = '/pool001/whgao/data/synth_net/st_' + args.rxn_template + '/st_' + args.data +'.json.gz' - print('Reading data from ', path_to_data) - sts = SyntheticTreeSet().load(path_to_data) - query_smis = [st.root.smiles for st in sts.sts] - if args.num == -1: - pass - else: - query_smis = query_smis[:args.num] - - output_smis = [] - similaritys = [] - trees = [] - num_finish = 0 - num_unfinish = 0 - - print('Start to decode!') - for smi in tqdm(query_smis): - - try: - tree, action = decode_one_molecule(smi) - except Exception as e: - print(e) - action = -1 - tree = None - - if action != 3: - num_unfinish += 1 - output_smis.append(None) - similaritys.append(None) - trees.append(None) - else: - num_finish += 1 - output_smis.append(tree.root.smiles) - ms = [Chem.MolFromSmiles(sm) for sm in [smi, tree.root.smiles]] - fps = [Chem.RDKFingerprint(x) for x in ms] - similaritys.append(DataStructs.FingerprintSimilarity(fps[0],fps[1])) - trees.append(tree) - - print('Saving ......') - save_path = '../results/' + args.rxn_template + '_' + args.featurize + '/' - if not os.path.exists(save_path): - os.makedirs(save_path) - df = pd.DataFrame({'query SMILES': query_smis, 'decode SMILES': output_smis, 'similarity': similaritys}) - print("mean similarities", df['similarity'].mean(), df['similarity'].std()) - print("NAs", df.isna().sum()) - df.to_csv(save_path + 'decode_result_' + args.data + '.csv.gz', compression='gzip', index=False) - - synthetic_tree_set = SyntheticTreeSet(sts=trees) - synthetic_tree_set.save(save_path + 'decoded_st_' + args.data + '.json.gz') - - print('Finish!') From 489e6ea850569e2b8fe86755608686691c7a884f Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Thu, 29 Sep 2022 18:12:49 -0400 Subject: [PATCH 234/302] delete all beam search related code :( --- scripts/_mp_predict_beam.py | 98 ----- scripts/predict-beam-fullTree.py | 176 --------- scripts/predict-beam-reactantOnly.py | 182 --------- src/syn_net/utils/predict_beam_utils.py | 468 ------------------------ 4 files changed, 924 deletions(-) delete mode 100644 scripts/_mp_predict_beam.py delete mode 100644 scripts/predict-beam-fullTree.py delete mode 100644 scripts/predict-beam-reactantOnly.py delete mode 100644 src/syn_net/utils/predict_beam_utils.py diff --git a/scripts/_mp_predict_beam.py b/scripts/_mp_predict_beam.py deleted file mode 100644 index 4beaedfe..00000000 --- a/scripts/_mp_predict_beam.py +++ /dev/null @@ -1,98 +0,0 @@ -""" -This file contains a function to decode a single synthetic tree. -""" -import pandas as pd -import numpy as np -from syn_net.utils.data_utils import ReactionSet -from dgllife.model import load_pretrained -from syn_net.utils.predict_utils import tanimoto_similarity, load_modules_from_checkpoint, mol_fp -from syn_net.utils.predict_beam_utils import synthetic_tree_decoder - - -# define some constants (here, for the Hartenfeller-Button test set) -nbits = 4096 -out_dim = 300 -rxn_template = 'hb' -featurize = 'fp' -param_dir = 'hb_fp_2_4096' -ncpu = 16 - -# define model to use for molecular embedding -model_type = 'gin_supervised_contextpred' -device = 'cpu' -mol_embedder = load_pretrained(model_type).to(device) -mol_embedder.eval() - -# load the purchasable building block embeddings -bb_emb = np.load('/pool001/whgao/data/synth_net/st_hb/enamine_us_emb_fp_256.npy') - -# define path to the reaction templates and purchasable building blocks -path_to_reaction_file = f'/pool001/whgao/data/synth_net/st_{rxn_template}/reactions_{rxn_template}.json.gz' -path_to_building_blocks = f'/pool001/whgao/data/synth_net/st_{rxn_template}/enamine_us_matched.csv.gz' - -# define paths to pretrained modules -param_path = f'/home/whgao/scGen/synth_net/synth_net/params/{param_dir}/' -path_to_act = f'{param_path}act.ckpt' -path_to_rt1 = f'{param_path}rt1.ckpt' -path_to_rxn = f'{param_path}rxn.ckpt' -path_to_rt2 = f'{param_path}rt2.ckpt' - -# load the purchasable building block SMILES to a dictionary -building_blocks = pd.read_csv(path_to_building_blocks, compression='gzip')['SMILES'].tolist() -bb_dict = {building_blocks[i]: i for i in range(len(building_blocks))} - -# load the reaction templates as a ReactionSet object -rxn_set = ReactionSet().load(path_to_reaction_file) -rxns = rxn_set.rxns - -# load the pre-trained modules -act_net, rt1_net, rxn_net, rt2_net = load_modules_from_checkpoint( - path_to_act=path_to_act, - path_to_rt1=path_to_rt1, - path_to_rxn=path_to_rxn, - path_to_rt2=path_to_rt2, - featurize=featurize, - rxn_template=rxn_template, - out_dim=out_dim, - nbits=nbits, - ncpu=ncpu, -) - -def func(smi): - """ - Generates the synthetic tree for the input moleular string. - - Args: - smi (str): Molecule (SMILES) to decode. - - Returns: - np.ndarray or None: State of the generated synthetic tree. - float: The best score. - SyntheticTree: The generated synthetic tree. - """ - emb = mol_fp(smi) - try: - tree, action = synthetic_tree_decoder(z_target=emb, - building_blocks=building_blocks, - bb_dict=bb_dict, - reaction_templates=rxns, - mol_embedder=mol_embedder, - action_net=act_net, - reactant1_net=rt1_net, - rxn_net=rxn_net, - reactant2_net=rt2_net, - bb_emb=bb_emb, - beam_width=10, - rxn_template=rxn_template, - n_bits=nbits, - max_step=15) - except Exception as e: - print(e) - action = -1 - - if action != 3: - return None, 0, None - else: - scores = tanimoto_similarity(emb, tree.get_state()) - max_score_idx = np.where(scores == np.max(scores))[0][0] - return tree.get_state()[max_score_idx], np.max(scores), tree diff --git a/scripts/predict-beam-fullTree.py b/scripts/predict-beam-fullTree.py deleted file mode 100644 index c2356d2c..00000000 --- a/scripts/predict-beam-fullTree.py +++ /dev/null @@ -1,176 +0,0 @@ -""" -This file contains the code to decode synthetic trees using beam search at every -sampling step after the action network (i.e. reactant 1, reaction, and reactant 2 -sampling). -""" -import os -import pandas as pd -import numpy as np -from tqdm import tqdm -from rdkit import Chem -from rdkit import DataStructs - -from syn_net.utils.data_utils import ReactionSet, SyntheticTreeSet - -from dgl.nn.pytorch.glob import AvgPooling -from dgllife.model import load_pretrained -from syn_net.utils.predict_utils import mol_fp, get_mol_embedding -from syn_net.utils.predict_beam_utils import synthetic_tree_decoder_fullbeam, load_modules_from_checkpoint - -if __name__ == '__main__': - - import argparse - parser = argparse.ArgumentParser() - parser.add_argument("-f", "--featurize", type=str, default='fp', - help="Choose from ['fp', 'gin']") - parser.add_argument("-r", "--rxn_template", type=str, default='hb', - help="Choose from ['hb', 'pis']") - parser.add_argument("-v", "--version", type=int, default=1, - help="Version") - parser.add_argument("--radius", type=int, default=2, - help="Radius for Morgan fingerprint.") - parser.add_argument("--nbits", type=int, default=1024, - help="Number of Bits for Morgan fingerprint.") - parser.add_argument("--out_dim", type=int, default=300, - help="Output dimension.") - parser.add_argument("--ncpu", type=int, default=16, - help="Number of cpus") - parser.add_argument("--batch_size", type=int, default=64, - help="Batch size") - parser.add_argument("--beam_width", type=int, default=5, - help="Beam width to use for Reactant1 search") - parser.add_argument("-n", "--num", type=int, default=-1, - help="Number of molecules to decode.") - parser.add_argument("-d", "--data", type=str, default='test', - help="Choose from ['train', 'valid', 'test']") - args = parser.parse_args() - - # define model to use for molecular embedding - readout = AvgPooling() - model_type = 'gin_supervised_contextpred' - device = 'cuda:0' - mol_embedder = load_pretrained(model_type).to(device) - mol_embedder.eval() - - # load the purchasable building block embeddings - bb_emb = np.load(f'/pool001/whgao/data/synth_net/st_{args.rxn_template}/enamine_us_emb.npy') - - # define path to the reaction templates and purchasable building blocks - path_to_reaction_file = f'/pool001/whgao/data/synth_net/st_{args.rxn_template}/reactions_{args.rxn_template}.json.gz' - path_to_building_blocks = f'/pool001/whgao/data/synth_net/st_{args.rxn_template}/enamine_us_matched.csv.gz' - - # define paths to pretrained modules - param_path = f'/home/rociomer/SynthNet/pre-trained-models/{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_v{args.version}/' - path_to_act = f'{param_path}act.ckpt' - path_to_rt1 = f'{param_path}rt1.ckpt' - path_to_rxn = f'{param_path}rxn.ckpt' - path_to_rt2 = f'{param_path}rt2.ckpt' - - np.random.seed(6) - - # load the purchasable building block SMILES to a dictionary - building_blocks = pd.read_csv(path_to_building_blocks, compression='gzip')['SMILES'].tolist() - bb_dict = {building_blocks[i]: i for i in range(len(building_blocks))} - - # load the reaction templates as a ReactionSet object - rxn_set = ReactionSet().load(path_to_reaction_file) - rxns = rxn_set.rxns - - # load the pre-trained modules - act_net, rt1_net, rxn_net, rt2_net = load_modules_from_checkpoint( - path_to_act=path_to_act, - path_to_rt1=path_to_rt1, - path_to_rxn=path_to_rxn, - path_to_rt2=path_to_rt2, - featurize=args.featurize, - rxn_template=args.rxn_template, - out_dim=args.out_dim, - nbits=args.nbits, - ncpu=args.ncpu, - ) - - def decode_one_molecule(query_smi): - """ - Generate a synthetic tree from a given query SMILES. - - Args: - query_smi (str): SMILES for molecule to decode. - - Returns: - tree (SyntheticTree): The final synthetic tree - act (int): The final action (to know if the tree was "properly" terminated) - """ - if args.featurize == 'fp': - z_target = mol_fp(query_smi, nBits=args.nbits) - elif args.featurize == 'gin': - z_target = get_mol_embedding(query_smi) - tree, action = synthetic_tree_decoder_fullbeam(z_target=z_target, - building_blocks=building_blocks, - bb_dict=bb_dict, - reaction_templates=rxns, - mol_embedder=mol_embedder, - action_net=act_net, - reactant1_net=rt1_net, - rxn_net=rxn_net, - reactant2_net=rt2_net, - bb_emb=bb_emb, - beam_width=args.beam_width, - rxn_template=args.rxn_template, - n_bits=args.nbits, - max_step=15) - return tree, action - - # load the purchasable building blocks to decode - path_to_data = f'/pool001/whgao/data/synth_net/st_{args.rxn_template}/st_{args.data}.json.gz' - print('Reading data from ', path_to_data) - sts = SyntheticTreeSet().load(path_to_data) - query_smis = [st.root.smiles for st in sts.sts] - if args.num == -1: - pass - else: - query_smis = query_smis[:args.num] - - output_smis = [] - similaritys = [] - trees = [] - num_finish = 0 - num_unfinish = 0 - - print('Start to decode!') - for smi in tqdm(query_smis): - - try: - tree, action = decode_one_molecule(smi) - except Exception as e: - print(e) - action = 1 - tree = None - - if action != 3: - num_unfinish += 1 - output_smis.append(None) - similaritys.append(None) - trees.append(None) - else: - num_finish += 1 - output_smis.append(tree.root.smiles) - ms = [Chem.MolFromSmiles(sm) for sm in [smi, tree.root.smiles]] - fps = [Chem.RDKFingerprint(x) for x in ms] - similaritys.append(DataStructs.FingerprintSimilarity(fps[0],fps[1])) - trees.append(tree) - - print('Saving ......') - save_path = '../results/' + args.rxn_template + '_' + args.featurize + '/' - if not os.path.exists(save_path): - os.makedirs(save_path) - df = pd.DataFrame({'query SMILES': query_smis, 'decode SMILES': output_smis, 'similarity': similaritys}) - print("mean similarities", df['similarity'].mean(), df['similarity'].std()) - print("NAs", df.isna().sum()) - df.to_csv(f'{save_path}decode_result_{args.data}_bw_{args.beam_width}.csv.gz', - compression='gzip', - index=False) - - synthetic_tree_set = SyntheticTreeSet(sts=trees) - synthetic_tree_set.save(f'{save_path}decoded_st_bw_{args.beam_width}_{args.data}.json.gz') - - print('Finish!') diff --git a/scripts/predict-beam-reactantOnly.py b/scripts/predict-beam-reactantOnly.py deleted file mode 100644 index a3dfa78c..00000000 --- a/scripts/predict-beam-reactantOnly.py +++ /dev/null @@ -1,182 +0,0 @@ -""" -This file contains the code to decode synthetic trees using beam search at the -first reactant sampling step (after the action network). -""" -import os -import pandas as pd -import numpy as np -from tqdm import tqdm -from rdkit import Chem -from rdkit import DataStructs -from syn_net.utils.data_utils import ReactionSet, SyntheticTreeSet -from dgl.nn.pytorch.glob import AvgPooling -from dgllife.model import load_pretrained -from syn_net.models.mlp import MLP -from syn_net.utils.predict_utils import mol_fp, get_mol_embedding -from syn_net.utils.predict_beam_utils import synthetic_tree_decoder, load_modules_from_checkpoint - - -if __name__ == '__main__': - - import argparse - parser = argparse.ArgumentParser() - parser.add_argument("-f", "--featurize", type=str, default='fp', - help="Choose from ['fp', 'gin']") - parser.add_argument("-r", "--rxn_template", type=str, default='hb', - help="Choose from ['hb', 'pis']") - parser.add_argument("-v", "--version", type=int, default=1, - help="Version") - parser.add_argument("--radius", type=int, default=2, - help="Radius for Morgan fingerprint.") - parser.add_argument("--nbits", type=int, default=1024, - help="Number of Bits for Morgan fingerprint.") - parser.add_argument("--out_dim", type=int, default=300, - help="Output dimension.") - parser.add_argument("--ncpu", type=int, default=16, - help="Number of cpus") - parser.add_argument("--batch_size", type=int, default=64, - help="Batch size") - parser.add_argument("--beam_width", type=int, default=5, - help="Beam width to use for Reactant1 search") - parser.add_argument("-n", "--num", type=int, default=-1, - help="Number of molecules to decode.") - parser.add_argument("-d", "--data", type=str, default='test', - help="Choose from ['train', 'valid', 'test']") - args = parser.parse_args() - - # define model to use for molecular embedding - readout = AvgPooling() - model_type = 'gin_supervised_contextpred' - device = 'cuda:0' - mol_embedder = load_pretrained(model_type).to(device) - mol_embedder.eval() - - # load the purchasable building block embeddings - bb_emb = np.load('/pool001/whgao/data/synth_net/st_' + args.rxn_template + '/enamine_us_emb.npy') - - # define path to the reaction templates and purchasable building blocks - path_to_reaction_file = ('/pool001/whgao/data/synth_net/st_' + args.rxn_template - + '/reactions_' + args.rxn_template + '.json.gz') - path_to_building_blocks = ('/pool001/whgao/data/synth_net/st_' + args.rxn_template - + '/enamine_us_matched.csv.gz') - - # define paths to pretrained modules - param_path = (f"/home/rociomer/SynthNet/pre-trained-models/{args.rxn_template}" - f"_{args.featurize}_{args.radius}_{args.nbits}_v{args.version}/") - path_to_act = param_path + 'act.ckpt' - path_to_rt1 = param_path + 'rt1.ckpt' - path_to_rxn = param_path + 'rxn.ckpt' - path_to_rt2 = param_path + 'rt2.ckpt' - - np.random.seed(6) - - # load the purchasable building block SMILES to a dictionary - building_blocks = pd.read_csv(path_to_building_blocks, compression='gzip')['SMILES'].tolist() - bb_dict = {building_blocks[i]: i for i in range(len(building_blocks))} - - # load the reaction templates as a ReactionSet object - rxn_set = ReactionSet().load(path_to_reaction_file) - rxns = rxn_set.rxns - - # load the pre-trained modules - act_net, rt1_net, rxn_net, rt2_net = load_modules_from_checkpoint( - path_to_act=path_to_act, - path_to_rt1=path_to_rt1, - path_to_rxn=path_to_rxn, - path_to_rt2=path_to_rt2, - featurize=args.featurize, - rxn_template=args.rxn_template, - out_dim=args.out_dim, - nbits=args.nbits, - ncpu=args.ncpu, - ) - - def decode_one_molecule(query_smi): - """ - Generate a synthetic tree from a given query SMILES. - - Args: - query_smi (str): SMILES for molecule to decode. - - Returns: - tree (SyntheticTree): The final synthetic tree - act (int): The final action (to know if the tree was "properly" terminated) - """ - if args.featurize == 'fp': - z_target = mol_fp(query_smi, args.radius, args.nbits) - elif args.featurize == 'gin': - z_target = get_mol_embedding(query_smi) - tree, action = synthetic_tree_decoder(z_target=z_target, - building_blocks=building_blocks, - bb_dict=bb_dict, - reaction_templates=rxns, - mol_embedder=mol_embedder, - action_net=act_net, - reactant1_net=rt1_net, - rxn_net=rxn_net, - reactant2_net=rt2_net, - bb_emb=bb_emb, - beam_width=args.beam_width, - rxn_template=args.rxn_template, - n_bits=args.nbits, - max_step=15) - return tree, action - - path_to_data = f'/pool001/whgao/data/synth_net/st_{args.rxn_template}/st_{args.data}.json.gz' - print('Reading data from ', path_to_data) - sts = SyntheticTreeSet() - sts.load(path_to_data) - query_smis = [st.root.smiles for st in sts.sts] - if args.num == -1: - pass - else: - query_smis = query_smis[:args.num] - - output_smis = [] - similaritys = [] - trees = [] - num_finish = 0 - num_unfinish = 0 - - print('Start to decode!') - for smi in tqdm(query_smis): - - try: - tree, action = decode_one_molecule(smi) - except Exception as e: - print(e) - action = 1 - tree = None - - if action != 3: - num_unfinish += 1 - output_smis.append(None) - similaritys.append(None) - trees.append(None) - else: - num_finish += 1 - output_smis.append(tree.root.smiles) - ms = [Chem.MolFromSmiles(sm) for sm in [smi, tree.root.smiles]] - fps = [Chem.RDKFingerprint(x) for x in ms] - similaritys.append(DataStructs.FingerprintSimilarity(fps[0],fps[1])) - trees.append(tree) - - print('Saving ......') - save_path = f'../results/{args.rxn_template}_{args.featurize}/' - if not os.path.exists(save_path): - os.makedirs(save_path) - df = pd.DataFrame( - {'query SMILES' : query_smis, - 'decode SMILES': output_smis, - 'similarity' : similaritys} - ) - print("mean similarities", df['similarity'].mean(), df['similarity'].std()) - print("NAs", df.isna().sum()) - df.to_csv(f'{save_path}decode_result_{args.data}_robw_{str(args.beam_width)}.csv.gz', - compression='gzip', - index=False) - - synthetic_tree_set = SyntheticTreeSet(sts=trees) - synthetic_tree_set.save(f'{save_path}decoded_st_robw_{str(args.beam_width)}_{args.data}.json.gz') - - print('Finish!') diff --git a/src/syn_net/utils/predict_beam_utils.py b/src/syn_net/utils/predict_beam_utils.py deleted file mode 100644 index 2f24126e..00000000 --- a/src/syn_net/utils/predict_beam_utils.py +++ /dev/null @@ -1,468 +0,0 @@ -""" -This file contains various utils for decoding synthetic trees using beam search. -""" -import numpy as np -from rdkit import Chem -from syn_net.utils.data_utils import SyntheticTree -from sklearn.neighbors import BallTree, KDTree -from syn_net.utils.predict_utils import * - - -np.random.seed(6) - - -def softmax(x): - """ - Computes softmax values for each sets of scores in x. - - Args: - x (np.ndarray or list): Values to normalize. - Returns: - (np.ndarray): Softmaxed values. - """ - e_x = np.exp(x - np.max(x)) - return e_x / e_x.sum(axis=0) - -def nn_search(_e, _tree, _k=1): - """ - Conducts a nearest neighbor search to find the molecule from the tree most - simimilar to the input embedding. - - Args: - _e (np.ndarray): A specific point in the dataset. - _tree (sklearn.neighbors._kd_tree.KDTree, optional): A k-d tree. - _k (int, optional): Indicates how many nearest neighbors to get. - Defaults to 1. - - Returns: - float: The distance to the nearest neighbor. - int: The indices of the nearest neighbor. - """ - dist, ind = _tree.query(_e, k=_k) - return dist[0], ind[0] - -def synthetic_tree_decoder(z_target, - building_blocks, - bb_dict, - reaction_templates, - mol_embedder, - action_net, - reactant1_net, - rxn_net, - reactant2_net, - bb_emb, - beam_width, - rxn_template, - n_bits, - max_step=15): - """ - Computes the synthetic tree given an input molecule embedding, using the - Action, Reaction, Reactant1, and Reactant2 networks and a greedy search. - - Args: - z_target (np.ndarray): Embedding for the target molecule - building_blocks (list of str): Contains available building blocks - bb_dict (dict): Building block dictionary - reaction_templates (list of Reactions): Contains reaction templates - mol_embedder (dgllife.model.gnn.gin.GIN): GNN to use for obtaining - molecular embeddings - action_net (synth_net.models.mlp.MLP): The action network - reactant1_net (synth_net.models.mlp.MLP): The reactant1 network - rxn_net (synth_net.models.mlp.MLP): The reaction network - reactant2_net (synth_net.models.mlp.MLP): The reactant2 network - bb_emb (list): Contains purchasable building block embeddings. - beam_width (int): The beam width to use for Reactant 1 search. - rxn_template (str): Specifies the set of reaction templates to use. - n_bits (int): Length of fingerprint. - max_step (int, optional): Maximum number of steps to include in the - synthetic tree - - Returns: - tree (SyntheticTree): The final synthetic tree. - act (int): The final action (to know if the tree was "properly" - terminated). - """ - # Initialization - tree = SyntheticTree() - kdtree = BallTree(bb_emb, metric=cosine_distance) - mol_recent = None - - # Start iteration - # try: - for i in range(max_step): - # Encode current state - state = tree.get_state() # a set - z_state = set_embedding(z_target, state, nbits=n_bits, mol_fp=mol_fp) - - # Predict action type, masked selection - # Action: (Add: 0, Expand: 1, Merge: 2, End: 3) - action_proba = action_net(torch.Tensor(z_state)) - action_proba = action_proba.squeeze().detach().numpy() + 1e-10 - action_mask = get_action_mask(tree.get_state(), reaction_templates) - act = np.argmax(action_proba * action_mask) - - reactant1_net_input = torch.Tensor( - np.concatenate([z_state, one_hot_encoder(act, 4)], axis=1) - ) - z_mol1 = reactant1_net(reactant1_net_input) - z_mol1 = z_mol1.detach().numpy() - - # Select first molecule - if act == 3: - # End - nlls = [0.0] - break - elif act == 0: - # Add - # **don't try to sample more points than there are in the tree - # beam search for mol1 candidates - dist, ind = nn_search(z_mol1, _tree=kdtree, _k=min(len(bb_emb), beam_width)) - try: - mol1_probas = softmax(- 0.1 * dist) - mol1_nlls = -np.log(mol1_probas) - except: # exception for beam search of length 1 - mol1_nlls = [-np.log(0.5)] - mol1_list = [building_blocks[idx] for idx in ind] - nlls = mol1_nlls - else: - # Expand or Merge - mol1_list = [mol_recent] - nlls = [-np.log(0.5)] - - rxn_list = [] - rxn_id_list = [] - mol2_list = [] - act_list = [act] * beam_width - for mol1_idx, mol1 in enumerate(mol1_list): - - z_mol1 = mol_fp(mol1) - act = act_list[mol1_idx] - - # Select reaction - z_mol1 = np.expand_dims(z_mol1, axis=0) - reaction_proba = rxn_net(torch.Tensor(np.concatenate([z_state, z_mol1], axis=1))) - reaction_proba = reaction_proba.squeeze().detach().numpy() - - if act != 2: - reaction_mask, available_list = get_reaction_mask(mol1, reaction_templates) - else: - _, reaction_mask = can_react(tree.get_state(), reaction_templates) - available_list = [[] for rxn in reaction_templates] - - if reaction_mask is None: - if len(state) == 1: - act = 3 - nlls[mol1_idx] += -np.log(action_proba * reaction_mask)[act] # correct the NLL - act_list[mol1_idx] = act - rxn_list.append(None) - rxn_id_list.append(None) - mol2_list.append(None) - continue - else: - act_list[mol1_idx] = act - rxn_list.append(None) - rxn_id_list.append(None) - mol2_list.append(None) - continue - - rxn_id = np.argmax(reaction_proba * reaction_mask) - rxn = reaction_templates[rxn_id] - rxn_nll = -np.log(reaction_proba * reaction_mask)[rxn_id] - - rxn_list.append(rxn) - rxn_id_list.append(rxn_id) - nlls[mol1_idx] += rxn_nll - - if np.isinf(rxn_nll): - mol2_list.append(None) - continue - elif rxn.num_reactant == 2: - # Select second molecule - if act == 2: - # Merge - temp = set(state) - set([mol1]) - mol2 = temp.pop() - else: - # Add or Expand - if rxn_template == 'hb': - z_mol2 = reactant2_net(torch.Tensor(np.concatenate([z_state, z_mol1, one_hot_encoder(rxn_id, 91)], axis=1))) - elif rxn_template == 'pis': - z_mol2 = reactant2_net(torch.Tensor(np.concatenate([z_state, z_mol1, one_hot_encoder(rxn_id, 4700)], axis=1))) - z_mol2 = z_mol2.detach().numpy() - available = available_list[rxn_id] - available = [bb_dict[available[i]] for i in range(len(available))] - temp_emb = bb_emb[available] - available_tree = BallTree(temp_emb, metric=cosine_distance) - dist, ind = nn_search(z_mol2, _tree=available_tree, _k=min(len(temp_emb), beam_width)) - try: - mol2_probas = softmax(-dist) - mol2_nll = -np.log(mol2_probas)[0] - except: - mol2_nll = 0.0 - mol2 = building_blocks[available[ind[0]]] - nlls[mol1_idx] += mol2_nll - else: - mol2 = None - - mol2_list.append(mol2) - - # Run reaction until get a valid (non-None) product - for i in range(0, len(nlls)): - best_idx = np.argsort(nlls)[i] - rxn = rxn_list[best_idx] - rxn_id = rxn_id_list[best_idx] - mol2 = mol2_list[best_idx] - act = act_list[best_idx] - try: - mol_product = rxn.run_reaction([mol1, mol2]) - except: - mol_product = None - else: - if mol_product is None: - continue - else: - break - - if mol_product is None or Chem.MolFromSmiles(mol_product) is None: - if len(tree.get_state()) == 1: - act = 3 - break - else: - break - - # Update - tree.update(act, int(rxn_id), mol1, mol2, mol_product) - mol_recent = mol_product - - if act != 3: - tree = tree - else: - tree.update(act, None, None, None, None) - - return tree, act - -def set_embedding_fullbeam(z_target, state, _mol_embedding, nbits): - """ - Computes embeddings for all molecules in input state. - - Args: - z_target (np.ndarray): Embedding for the target molecule. - state (list): Contains molecules in the current state, if not the - initial state. - _mol_embedding (Callable): Function to use for computing the embeddings - of the first and second molecules in the state (e.g. Morgan fingerprint). - nbits (int): Number of bits to use for the embedding. - - Returns: - np.ndarray: Embedding consisting of the concatenation of the target - molecule with the current molecules (if available) in the input - state. - """ - if len(state) == 0: - z_target = np.expand_dims(z_target, axis=0) - return np.concatenate([np.zeros((1, 2 * nbits)), z_target], axis=1) - else: - e1 = _mol_embedding(state[0]) - e1 = np.expand_dims(e1, axis=0) - if len(state) == 1: - e2 = np.zeros((1, nbits)) - else: - e2 = _mol_embedding(state[1]) - e2 = np.expand_dims(e2, axis=0) - z_target = np.expand_dims(z_target, axis=0) - return np.concatenate([e1, e2, z_target], axis=1) - -def synthetic_tree_decoder_fullbeam(z_target, - building_blocks, - bb_dict, - reaction_templates, - mol_embedder, - action_net, - reactant1_net, - rxn_net, - reactant2_net, - bb_emb, - beam_width, - rxn_template, - n_bits, - max_step=15): - """ - Computes the synthetic tree given an input molecule embedding, using the - Action, Reaction, Reactant1, and Reactant2 networks and a beam search. - - Args: - z_target (np.ndarray): Embedding for the target molecule - building_blocks (list of str): Contains available building blocks - bb_dict (dict): Building block dictionary - reaction_templates (list of Reactions): Contains reaction templates - mol_embedder (dgllife.model.gnn.gin.GIN): GNN to use for obtaining molecular embeddings - action_net (synth_net.models.mlp.MLP): The action network - reactant1_net (synth_net.models.mlp.MLP): The reactant1 network - rxn_net (synth_net.models.mlp.MLP): The reaction network - reactant2_net (synth_net.models.mlp.MLP): The reactant2 network - bb_emb (list): Contains purchasable building block embeddings. - beam_width (int): The beam width to use for Reactant 1 search. - rxn_template (str): Specifies the set of reaction templates to use. - n_bits (int): Length of fingerprint. - max_step (int, optional): Maximum number of steps to include in the synthetic tree - - Returns: - tree (SyntheticTree): The final synthetic tree - act (int): The final action (to know if the tree was "properly" terminated) - """ - # Initialization - tree = SyntheticTree() - mol_recent = None - kdtree = KDTree(bb_emb, metric='euclidean') - - # Start iteration - # try: - for i in range(max_step): - # Encode current state - state = tree.get_state() # a set - z_state = set_embedding_fullbeam(z_target, state, mol_fp, nbits=n_bits) - - # Predict action type, masked selection - # Action: (Add: 0, Expand: 1, Merge: 2, End: 3) - action_proba = action_net(torch.Tensor(z_state)) - action_proba = action_proba.squeeze().detach().numpy() - action_mask = get_action_mask(tree.get_state(), reaction_templates) - act = np.argmax(action_proba * action_mask) - - z_mol1 = reactant1_net(torch.Tensor(np.concatenate([z_state, one_hot_encoder(act, 4)], axis=1))) - z_mol1 = z_mol1.detach().numpy() - - # Select first molecule - if act == 3: - # End - mol1_nlls = [0.0] - break - elif act == 0: - # Add - # **don't try to sample more points than there are in the tree - # beam search for mol1 candidates - dist, ind = nn_search(z_mol1, _tree=kdtree, _k=min(len(bb_emb), beam_width)) - try: - mol1_probas = softmax(- 0.1 * dist) - mol1_nlls = -np.log(mol1_probas) - except: # exception for beam search of length 1 - mol1_nlls = [-np.log(0.5)] - mol1_list = [building_blocks[idx] for idx in ind] - else: - # Expand or Merge - mol1_list = [mol_recent] - mol1_nlls = [-np.log(0.5)] - - action_tuples = [] # list of action tuples created by beam search - act_list = [act] * beam_width - for mol1_idx, mol1 in enumerate(mol1_list): - - z_mol1 = mol_fp(mol1, nBits=n_bits) - act = act_list[mol1_idx] - - # Select reaction - z_mol1 = np.expand_dims(z_mol1, axis=0) - reaction_proba = rxn_net(torch.Tensor(np.concatenate([z_state, z_mol1], axis=1))) - reaction_proba = reaction_proba.squeeze().detach().numpy() - - if act != 2: - reaction_mask, available_list = get_reaction_mask(mol1, reaction_templates) - else: - _, reaction_mask = can_react(tree.get_state(), reaction_templates) - available_list = [[] for rxn in reaction_templates] - - if reaction_mask is None: - if len(state) == 1: - act = 3 - mol1_nlls[mol1_idx] += -np.log(action_proba * reaction_mask)[act] # correct the NLL - act_list[mol1_idx] = act - # nll, act, mol1, rxn, rxn_id, mol2 - action_tuples.append([mol1_nlls[mol1_idx], act, mol1, None, None, None]) - continue - else: - act_list[mol1_idx] = act - # nll, act, mol1, rxn, rxn_id, mol2 - action_tuples.append([mol1_nlls[mol1_idx], act, mol1, None, None, None]) - continue - - rxn_ids = np.argsort(-reaction_proba * reaction_mask)[:beam_width] - rxn_nlls = mol1_nlls[mol1_idx] - np.log(reaction_proba * reaction_mask) - - for rxn_id in rxn_ids: - rxn = reaction_templates[rxn_id] - rxn_nll = rxn_nlls[rxn_id] - - if np.isinf(rxn_nll): - # nll, act, mol1, rxn, rxn_id, mol2 - action_tuples.append([rxn_nll, act, mol1, rxn, rxn_id, None]) - continue - elif rxn.num_reactant == 2: - # Select second molecule - if act == 2: - # Merge - temp = set(state) - set([mol1]) - mol2 = temp.pop() - # nll, act, mol1, rxn, rxn_id, mol2 - action_tuples.append([rxn_nll, act, mol1, rxn, rxn_id, mol2]) - else: - # Add or Expand - if rxn_template == 'hb': - z_mol2 = reactant2_net(torch.Tensor(np.concatenate([z_state, z_mol1, one_hot_encoder(rxn_id, 91)], axis=1))) - elif rxn_template == 'pis': - z_mol2 = reactant2_net(torch.Tensor(np.concatenate([z_state, z_mol1, one_hot_encoder(rxn_id, 4700)], axis=1))) - - z_mol2 = z_mol2.detach().numpy() - available = available_list[rxn_id] - available = [bb_dict[available[i]] for i in range(len(available))] - temp_emb = bb_emb[available] - available_tree = KDTree(temp_emb, metric='euclidean') - dist, ind = nn_search(z_mol2, _tree=available_tree, _k=min(len(temp_emb), beam_width)) - try: - mol2_probas = softmax(-dist) - mol2_nlls = rxn_nll - np.log(mol2_probas) - except: - mol2_nlls = [rxn_nll + 0.0] - mol2_list = [building_blocks[available[idc]] for idc in ind] - for mol2_idx, mol2 in enumerate(mol2_list): - # nll, act, mol1, rxn, rxn_id, mol2 - action_tuples.append([mol2_nlls[mol2_idx], act, mol1, rxn, rxn_id, mol2]) - else: - # nll, act, mol1, rxn, rxn_id, mol2 - action_tuples.append([rxn_nll, act, mol1, rxn, rxn_id, None]) - - # Run reaction until get a valid (non-None) product - for i in range(0, len(action_tuples)): - nlls = list(zip(*action_tuples))[0] - best_idx = np.argsort(nlls)[i] - act = action_tuples[best_idx][1] - mol1 = action_tuples[best_idx][2] - rxn = action_tuples[best_idx][3] - rxn_id = action_tuples[best_idx][4] - mol2 = action_tuples[best_idx][5] - try: - mol_product = rxn.run_reaction([mol1, mol2]) - except: - mol_product = None - else: - if mol_product is None: - continue - else: - break - - if mol_product is None or Chem.MolFromSmiles(mol_product) is None: - if len(tree.get_state()) == 1: - act = 3 - break - else: - break - - # Update - tree.update(act, int(rxn_id), mol1, mol2, mol_product) - mol_recent = mol_product - - if act != 3: - tree = tree - else: - tree.update(act, None, None, None, None) - - return tree, act From 1efbae72f7d10a43cdeaf706427426053dcfeb29 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Thu, 29 Sep 2022 18:13:40 -0400 Subject: [PATCH 235/302] rename fct --- scripts/predict_multireactant_mp.py | 4 ++-- src/syn_net/utils/predict_utils.py | 2 +- tests/test_Predict.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/scripts/predict_multireactant_mp.py b/scripts/predict_multireactant_mp.py index 6afee6a0..b99c8d27 100644 --- a/scripts/predict_multireactant_mp.py +++ b/scripts/predict_multireactant_mp.py @@ -23,7 +23,7 @@ from syn_net.data_generation.preprocessing import BuildingBlockFileHandler from syn_net.models.chkpt_loader import load_modules_from_checkpoint from syn_net.utils.data_utils import ReactionSet, SyntheticTree, SyntheticTreeSet -from syn_net.utils.predict_utils import mol_fp, synthetic_tree_decoder_beam_search +from syn_net.utils.predict_utils import mol_fp, synthetic_tree_decoder_greedy_search Path(DATA_RESULT_DIR).mkdir(exist_ok=True) from syn_net.MolEmbedder import MolEmbedder @@ -99,7 +99,7 @@ def func(smiles: str) -> Tuple[str, float, SyntheticTree]: """Generate a synthetic tree for the input molecular embedding.""" emb = mol_fp(smiles) try: - smi, similarity, tree, action = synthetic_tree_decoder_beam_search( + smi, similarity, tree, action = synthetic_tree_decoder_greedy_search( z_target=emb, building_blocks=building_blocks, bb_dict=building_blocks_dict, diff --git a/src/syn_net/utils/predict_utils.py b/src/syn_net/utils/predict_utils.py index 3fa6d1ce..8437b86c 100644 --- a/src/syn_net/utils/predict_utils.py +++ b/src/syn_net/utils/predict_utils.py @@ -337,7 +337,7 @@ def synthetic_tree_decoder( return tree, act -def synthetic_tree_decoder_beam_search( +def synthetic_tree_decoder_greedy_search( beam_width: int = 3, **kwargs ) -> Tuple[str, float, SyntheticTree, int]: """ diff --git a/tests/test_Predict.py b/tests/test_Predict.py index 4a0c2f0f..83854e66 100644 --- a/tests/test_Predict.py +++ b/tests/test_Predict.py @@ -8,7 +8,7 @@ import pandas as pd from syn_net.utils.predict_utils import ( - synthetic_tree_decoder_beam_search, + synthetic_tree_decoder_greedy_search, mol_fp, ) from syn_net.utils.data_utils import SyntheticTreeSet, ReactionSet @@ -83,7 +83,7 @@ def test_predict(self): trees = [] for smi in smis_query: emb = mol_fp(smi) - smi, similarity, tree, action = synthetic_tree_decoder_beam_search( + smi, similarity, tree, action = synthetic_tree_decoder_greedy_search( z_target=emb, building_blocks=building_blocks, bb_dict=bb_dict, From bcf18306c95084020ef1cf0aa001999fbfc9baad Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Thu, 29 Sep 2022 18:13:47 -0400 Subject: [PATCH 236/302] add type hints --- src/syn_net/encoding/fingerprints.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/syn_net/encoding/fingerprints.py b/src/syn_net/encoding/fingerprints.py index ca60c4fb..0a437fb1 100644 --- a/src/syn_net/encoding/fingerprints.py +++ b/src/syn_net/encoding/fingerprints.py @@ -4,7 +4,7 @@ ## Morgan fingerprints -def mol_fp(smi, _radius=2, _nBits=4096): +def mol_fp(smi, _radius=2, _nBits=4096) -> np.ndarray: # dtype=int64 """ Computes the Morgan fingerprint for the input SMILES. @@ -27,7 +27,7 @@ def mol_fp(smi, _radius=2, _nBits=4096): ) # TODO: much slower compared to `DataStructs.ConvertToNumpyArray` (20x?) so deprecates -def fp_embedding(smi, _radius=2, _nBits=4096): +def fp_embedding(smi, _radius=2, _nBits=4096) -> list[float]: """ General function for building variable-size & -radius Morgan fingerprints. From 69b45f012f40e9c07ce0150631fb33758c97804e Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Thu, 29 Sep 2022 18:14:05 -0400 Subject: [PATCH 237/302] fix filepaths --- scripts/predict_multireactant_mp.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/scripts/predict_multireactant_mp.py b/scripts/predict_multireactant_mp.py index b99c8d27..43ad730a 100644 --- a/scripts/predict_multireactant_mp.py +++ b/scripts/predict_multireactant_mp.py @@ -1,6 +1,6 @@ """ Generate synthetic trees for a set of specified query molecules. Multiprocessing. -""" # TODO: Clean up +""" # TODO: Clean up + dont hardcode file paths import json import logging import multiprocessing as mp @@ -16,7 +16,6 @@ from syn_net.config import ( CHECKPOINTS_DIR, DATA_EMBEDDINGS_DIR, - DATA_PREPARED_DIR, DATA_PREPROCESS_DIR, DATA_RESULT_DIR, ) @@ -44,7 +43,7 @@ def _fetch_data_from_file(name: str) -> list[str]: def _fetch_data(name: str) -> list[str]: if args.data in ["train", "valid", "test"]: - file = Path(DATA_PREPARED_DIR) / f"synthetic-trees-{args.data}.json.gz" + file = Path(DATA_PREPROCESS_DIR) / "syntrees" / f"synthetic-trees-filtered-{args.data}.json.gz" logger.info(f"Reading data from {file}") sts = SyntheticTreeSet() sts.load(file) @@ -167,7 +166,6 @@ def get_args(): nbits = args.nbits out_dim = args.outputembedding.split("_")[-1] # <=> morgan fingerprint with 256 bits - building_blocks_id = "enamine_us-2021-smiles" param_dir = f"{args.rxn_template}_{args.featurize}_{args.radius}_{nbits}_{out_dim}" # Load data ... @@ -178,7 +176,7 @@ def get_args(): smiles_queries = smiles_queries[: args.num] # ... building blocks - file = Path(DATA_PREPROCESS_DIR) / f"{args.rxn_template}-{building_blocks_id}-matched.csv.gz" + file = Path(DATA_PREPROCESS_DIR) / "building-blocks-rxns" / f"enamine-us-smiles.csv.gz" # TODO: Do not hardcode building_blocks = BuildingBlockFileHandler().load(file) building_blocks_dict = { block: i for i, block in enumerate(building_blocks) @@ -186,15 +184,12 @@ def get_args(): logger.info("...loading building blocks completed.") # ... reaction templates - file = ( - Path(DATA_PREPROCESS_DIR) - / f"reaction-sets_{args.rxn_template}_{building_blocks_id}.json.gz" - ) + file = (Path(DATA_PREPROCESS_DIR) / "building-blocks-rxns" / "hb-enamine-us.json.gz") # TODO: Do not hardcode rxns = ReactionSet().load(file).rxns logger.info("...loading reaction collection completed.") # ... building block embedding - file = Path(DATA_EMBEDDINGS_DIR) / f"{args.rxn_template}-{building_blocks_id}-embeddings.npy" + file = Path(DATA_PREPROCESS_DIR) / "embeddings" / f"hb-enamine-embeddings.npy" # TODO: Do not hardcode bblocks_molembedder = MolEmbedder().load_precomputed(file).init_balltree(cosine_distance) bb_emb = bblocks_molembedder.get_embeddings() @@ -205,7 +200,7 @@ def get_args(): logger.info("Start loading models from checkpoints...") path = Path(CHECKPOINTS_DIR) / f"{param_dir}" paths = [ - find_best_model_ckpt("results/logs/hb_fp_2_4096/" + model) + find_best_model_ckpt("results/logs/hb_fp_2_4096/" + model) # TODO: Do not hardcode for model in "act rt1 rxn rt2".split() ] act_net, rt1_net, rxn_net, rt2_net = _load_pretrained_model(paths) From b31f069921b75fc580cc7bcb74c30c90914989f3 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Thu, 29 Sep 2022 18:20:47 -0400 Subject: [PATCH 238/302] move code from `_mp_search_similar.py` to `search_similar.py` --- scripts/_mp_search_similar.py | 34 ---------------------------------- scripts/search_similar.py | 27 ++++++++++++++++++++++++++- 2 files changed, 26 insertions(+), 35 deletions(-) delete mode 100644 scripts/_mp_search_similar.py diff --git a/scripts/_mp_search_similar.py b/scripts/_mp_search_similar.py deleted file mode 100644 index 41baaa74..00000000 --- a/scripts/_mp_search_similar.py +++ /dev/null @@ -1,34 +0,0 @@ -""" -This function is used to identify the most similar molecule in the training set -to a given molecular fingerprint. -""" -import numpy as np -from rdkit import Chem -from rdkit.Chem import AllChem -from rdkit import DataStructs -import pandas as pd -from syn_net.utils.data_utils import * - - -data_path = '/pool001/whgao/data/synth_net/st_hb/st_train.json.gz' -st_set = SyntheticTreeSet() -st_set.load(data_path) -data = st_set.sts -data_train = [t.root.smiles for t in data] -fps_train = [AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(smi), 2, nBits=1024) for smi in data_train] - - -def func(fp): - """ - Finds the most similar molecule in the training set to the input molecule - using the Tanimoto similarity. - - Args: - fp (np.ndarray): Morgan fingerprint to find similars to in the training set. - - Returns: - np.float: The maximum similarity found to the training set fingerprints. - np.ndarray: Fingerprint of the most similar training set molecule. - """ - dists = np.array([DataStructs.FingerprintSimilarity(fp, fp_, metric=DataStructs.TanimotoSimilarity) for fp_ in fps_train]) - return dists.max(), dists.argmax() diff --git a/scripts/search_similar.py b/scripts/search_similar.py index ec0f2c02..4e991e03 100644 --- a/scripts/search_similar.py +++ b/scripts/search_similar.py @@ -8,7 +8,32 @@ from rdkit import Chem from rdkit.Chem import AllChem import multiprocessing as mp -from scripts._mp_search_similar import func + + +from rdkit import DataStructs + +data_path = '/pool001/whgao/data/synth_net/st_hb/st_train.json.gz' +st_set = SyntheticTreeSet() +st_set.load(data_path) +data = st_set.sts +data_train = [t.root.smiles for t in data] +fps_train = [AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(smi), 2, nBits=1024) for smi in data_train] + + +def func(fp): + """ + Finds the most similar molecule in the training set to the input molecule + using the Tanimoto similarity. + + Args: + fp (np.ndarray): Morgan fingerprint to find similars to in the training set. + + Returns: + np.float: The maximum similarity found to the training set fingerprints. + np.ndarray: Fingerprint of the most similar training set molecule. + """ + dists = np.array([DataStructs.FingerprintSimilarity(fp, fp_, metric=DataStructs.TanimotoSimilarity) for fp_ in fps_train]) + return dists.max(), dists.argmax() if __name__ == '__main__': From 48cc53cbd3a9dc246f8f85367460f5bdf796f306 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Thu, 29 Sep 2022 18:23:43 -0400 Subject: [PATCH 239/302] consolidate duplicate code --- scripts/search_similar.py | 35 ++++++++++------------------------- 1 file changed, 10 insertions(+), 25 deletions(-) diff --git a/scripts/search_similar.py b/scripts/search_similar.py index 4e991e03..eeaf5c2c 100644 --- a/scripts/search_similar.py +++ b/scripts/search_similar.py @@ -8,18 +8,8 @@ from rdkit import Chem from rdkit.Chem import AllChem import multiprocessing as mp - - from rdkit import DataStructs -data_path = '/pool001/whgao/data/synth_net/st_hb/st_train.json.gz' -st_set = SyntheticTreeSet() -st_set.load(data_path) -data = st_set.sts -data_train = [t.root.smiles for t in data] -fps_train = [AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(smi), 2, nBits=1024) for smi in data_train] - - def func(fp): """ Finds the most similar molecule in the training set to the input molecule @@ -40,23 +30,18 @@ def func(fp): ncpu = 64 - data_path = '/pool001/whgao/data/synth_net/st_hb/st_train.json.gz' - st_set = SyntheticTreeSet() - st_set.load(data_path) - data = st_set.sts - data_train = [t.root.smiles for t in data] + file = '/pool001/whgao/data/synth_net/st_hb/st_train.json.gz' + syntree_collection = SyntheticTreeSet().load(file) + data_train = [st.root.smiles for st in syntree_collection] + fps_train = [AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(smi), 2, nBits=1024) for smi in data_train] - data_path = '/pool001/whgao/data/synth_net/st_hb/st_test.json.gz' - st_set = SyntheticTreeSet() - st_set.load(data_path) - data = st_set.sts - data_test = [t.root.smiles for t in data] + file = '/pool001/whgao/data/synth_net/st_hb/st_test.json.gz' + syntree_collection = SyntheticTreeSet().load(file) + data_test = [st.root.smiles for st in syntree_collection] - data_path = '/pool001/whgao/data/synth_net/st_hb/st_valid.json.gz' - st_set = SyntheticTreeSet() - st_set.load(data_path) - data = st_set.sts - data_valid = [t.root.smiles for t in data] + file = '/pool001/whgao/data/synth_net/st_hb/st_valid.json.gz' + syntree_collection = SyntheticTreeSet().load(file) + data_valid = [st.root.smiles for st in syntree_collection] fps_valid = [AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(smi), 2, nBits=1024) for smi in data_valid] fps_test = [AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(smi), 2, nBits=1024) for smi in data_test] From f753a11091b8911aca85d179f11531e3217c7ec5 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Thu, 29 Sep 2022 18:41:02 -0400 Subject: [PATCH 240/302] refactor --- scripts/search_similar.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/scripts/search_similar.py b/scripts/search_similar.py index eeaf5c2c..9bc9d2d9 100644 --- a/scripts/search_similar.py +++ b/scripts/search_similar.py @@ -25,6 +25,8 @@ def func(fp): dists = np.array([DataStructs.FingerprintSimilarity(fp, fp_, metric=DataStructs.TanimotoSimilarity) for fp_ in fps_train]) return dists.max(), dists.argmax() +def _compute_fp_bitvector(smiles: list[str], radius: int=2, nbits: int=1024): + return [AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(smi), radius, nBits=nbits) for smi in smiles] if __name__ == '__main__': @@ -33,27 +35,28 @@ def func(fp): file = '/pool001/whgao/data/synth_net/st_hb/st_train.json.gz' syntree_collection = SyntheticTreeSet().load(file) data_train = [st.root.smiles for st in syntree_collection] - fps_train = [AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(smi), 2, nBits=1024) for smi in data_train] + fps_train = _compute_fp_bitvector(data_train) file = '/pool001/whgao/data/synth_net/st_hb/st_test.json.gz' syntree_collection = SyntheticTreeSet().load(file) data_test = [st.root.smiles for st in syntree_collection] + fps_test = _compute_fp_bitvector(data_test) file = '/pool001/whgao/data/synth_net/st_hb/st_valid.json.gz' syntree_collection = SyntheticTreeSet().load(file) data_valid = [st.root.smiles for st in syntree_collection] - - fps_valid = [AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(smi), 2, nBits=1024) for smi in data_valid] - fps_test = [AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(smi), 2, nBits=1024) for smi in data_test] + fps_valid = _compute_fp_bitvector(data_valid) with mp.Pool(processes=ncpu) as pool: results = pool.map(func, fps_valid) + similaritys = [r[0] for r in results] indices = [data_train[r[1]] for r in results] df1 = pd.DataFrame({'smiles': data_valid, 'split': 'valid', 'most similar': indices, 'similarity': similaritys}) with mp.Pool(processes=ncpu) as pool: results = pool.map(func, fps_test) + similaritys = [r[0] for r in results] indices = [data_train[r[1]] for r in results] df2 = pd.DataFrame({'smiles': data_test, 'split': 'test', 'most similar': indices, 'similarity': similaritys}) From 9c1e725015913aae2ada767aba255cc04cc0fe2a Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Thu, 29 Sep 2022 18:47:49 -0400 Subject: [PATCH 241/302] shorten function --- scripts/search_similar.py | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/scripts/search_similar.py b/scripts/search_similar.py index 9bc9d2d9..6d4b908d 100644 --- a/scripts/search_similar.py +++ b/scripts/search_similar.py @@ -10,20 +10,14 @@ import multiprocessing as mp from rdkit import DataStructs -def func(fp): +def func(fp: np.ndarray, fps_reference: np.ndarray): + """Finds most similar fingerprint in a reference set for `fp`. + Uses Tanimoto Similarity. """ - Finds the most similar molecule in the training set to the input molecule - using the Tanimoto similarity. - - Args: - fp (np.ndarray): Morgan fingerprint to find similars to in the training set. - - Returns: - np.float: The maximum similarity found to the training set fingerprints. - np.ndarray: Fingerprint of the most similar training set molecule. - """ - dists = np.array([DataStructs.FingerprintSimilarity(fp, fp_, metric=DataStructs.TanimotoSimilarity) for fp_ in fps_train]) - return dists.max(), dists.argmax() + dists = np.array( + [DataStructs.FingerprintSimilarity(fp, fp_, metric=DataStructs.TanimotoSimilarity) for fp_ in fps_train]) + similarity_score, idx = dists.max(), dists.argmax() + return similarity_score, idx def _compute_fp_bitvector(smiles: list[str], radius: int=2, nbits: int=1024): return [AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(smi), radius, nBits=nbits) for smi in smiles] From 58133252eb586cf021f2b36b3e424b688d128d60 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Thu, 29 Sep 2022 18:52:53 -0400 Subject: [PATCH 242/302] refactor save fct --- scripts/search_similar.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/scripts/search_similar.py b/scripts/search_similar.py index 6d4b908d..2ff95ed9 100644 --- a/scripts/search_similar.py +++ b/scripts/search_similar.py @@ -22,6 +22,10 @@ def func(fp: np.ndarray, fps_reference: np.ndarray): def _compute_fp_bitvector(smiles: list[str], radius: int=2, nbits: int=1024): return [AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(smi), radius, nBits=nbits) for smi in smiles] +def _save_df(file: str, df): + if file is None: return + df.to_csv(file, index=False) + if __name__ == '__main__': ncpu = 64 @@ -55,6 +59,8 @@ def _compute_fp_bitvector(smiles: list[str], radius: int=2, nbits: int=1024): indices = [data_train[r[1]] for r in results] df2 = pd.DataFrame({'smiles': data_test, 'split': 'test', 'most similar': indices, 'similarity': similaritys}) - df = pd.concat([df1, df2], axis=0, ignore_index=True) - df.to_csv('data_similarity.csv', index=False) + outfile = 'data_similarity.csv' + _save_df(outfile, pd.concat([df1, df2], axis=0, ignore_index=True)) + + print('Finish!') From c503234acefea6c03141861d19c800ed830dc25d Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Thu, 29 Sep 2022 18:55:53 -0400 Subject: [PATCH 243/302] adds parser --- scripts/search_similar.py | 47 +++++++++++++++++++++++++++++++++++---- 1 file changed, 43 insertions(+), 4 deletions(-) diff --git a/scripts/search_similar.py b/scripts/search_similar.py index 2ff95ed9..ed6325e4 100644 --- a/scripts/search_similar.py +++ b/scripts/search_similar.py @@ -9,6 +9,42 @@ from rdkit.Chem import AllChem import multiprocessing as mp from rdkit import DataStructs +import logging +from pathlib import Path + +logger = logging.getLogger(__file__) + +from syn_net.config import MAX_PROCESSES + + +def get_args(): + import argparse + + parser = argparse.ArgumentParser() + # File I/O + parser.add_argument( + "--input-dir", + type=str, + help="Directory with `*{train,valid,test}*.json.gz`-data of synthetic trees", + ) + parser.add_argument( + "--output-file", + type=str, + default=None, + help="Optional: File to save similarity-values for test,valid-synthetic trees.", + ) + # Processing + parser.add_argument("--ncpu", type=int, default=MAX_PROCESSES, help="Number of cpus") + parser.add_argument("--verbose", default=False, action="store_true") + return parser.parse_args() + + +def _match_dataset_filename(path: str, dataset_type: str) -> Path: # TODO: consolidate with code in script/05-* + """Helper to find the exact filename for {train,valid,test} file.""" + files = list(Path(path).glob(f"*{dataset_type}*.json.gz")) + if len(files) != 1: + raise ValueError(f"Can not find unique '{dataset_type} 'file, got {files}") + return files[0] def func(fp: np.ndarray, fps_reference: np.ndarray): """Finds most similar fingerprint in a reference set for `fp`. @@ -26,9 +62,12 @@ def _save_df(file: str, df): if file is None: return df.to_csv(file, index=False) -if __name__ == '__main__': +if __name__ == "__main__": + logger.info("Start.") - ncpu = 64 + # Parse input args + args = get_args() + logger.info(f"Arguments: {json.dumps(vars(args),indent=2)}") file = '/pool001/whgao/data/synth_net/st_hb/st_train.json.gz' syntree_collection = SyntheticTreeSet().load(file) @@ -45,14 +84,14 @@ def _save_df(file: str, df): data_valid = [st.root.smiles for st in syntree_collection] fps_valid = _compute_fp_bitvector(data_valid) - with mp.Pool(processes=ncpu) as pool: + with mp.Pool(processes=args.ncpu) as pool: results = pool.map(func, fps_valid) similaritys = [r[0] for r in results] indices = [data_train[r[1]] for r in results] df1 = pd.DataFrame({'smiles': data_valid, 'split': 'valid', 'most similar': indices, 'similarity': similaritys}) - with mp.Pool(processes=ncpu) as pool: + with mp.Pool(processes=args.ncpu) as pool: results = pool.map(func, fps_test) similaritys = [r[0] for r in results] From 4e36e056ffdb89d644becbc3023417153caf9b9d Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Thu, 29 Sep 2022 19:26:12 -0400 Subject: [PATCH 244/302] move to fcts --- scripts/search_similar.py | 70 +++++++++++++++++++++------------------ 1 file changed, 38 insertions(+), 32 deletions(-) diff --git a/scripts/search_similar.py b/scripts/search_similar.py index ed6325e4..d291a941 100644 --- a/scripts/search_similar.py +++ b/scripts/search_similar.py @@ -1,10 +1,13 @@ """ Computes the fingerprint similarity of molecules in the validation and test set to molecules in the training set. -""" +""" # TODO: clean up, un-nest a couple of fcts +from functools import partial +import json +from typing import Tuple import numpy as np import pandas as pd -from syn_net.utils.data_utils import * +from syn_net.utils.data_utils import SyntheticTreeSet from rdkit import Chem from rdkit.Chem import AllChem import multiprocessing as mp @@ -46,7 +49,7 @@ def _match_dataset_filename(path: str, dataset_type: str) -> Path: # TODO: conso raise ValueError(f"Can not find unique '{dataset_type} 'file, got {files}") return files[0] -def func(fp: np.ndarray, fps_reference: np.ndarray): +def find_similar_fp(fp: np.ndarray, fps_reference: np.ndarray): """Finds most similar fingerprint in a reference set for `fp`. Uses Tanimoto Similarity. """ @@ -58,48 +61,51 @@ def func(fp: np.ndarray, fps_reference: np.ndarray): def _compute_fp_bitvector(smiles: list[str], radius: int=2, nbits: int=1024): return [AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(smi), radius, nBits=nbits) for smi in smiles] +def get_smiles_and_fps(dataset: str) -> Tuple[list[str],list[np.ndarray]]: + file = _match_dataset_filename(args.input_dir,dataset) + syntree_collection = SyntheticTreeSet().load(file) + smiles = [st.root.smiles for st in syntree_collection] + fps = _compute_fp_bitvector(smiles) + return smiles, fps + def _save_df(file: str, df): if file is None: return df.to_csv(file, index=False) +def compute_most_similar_smiles(split: str, fps: np.ndarray, smiles: list[str]) -> pd.DataFrame: + with mp.Pool(processes=args.ncpu) as pool: + results = pool.map(func, fps) + + similarities, idx = np.asfarray(results).T + most_similiar_ref_smiles = np.asarray(smiles_train)[idx.astype(int)] # use numpy for slicin' + + df = pd.DataFrame( + {'smiles': smiles, + 'split': split, + 'most similar': most_similiar_ref_smiles, 'similarity': similarities}) + return df + if __name__ == "__main__": logger.info("Start.") # Parse input args args = get_args() logger.info(f"Arguments: {json.dumps(vars(args),indent=2)}") + args.input_dir = "/home/ulmer/SynNet/data/pre-process/syntrees" - file = '/pool001/whgao/data/synth_net/st_hb/st_train.json.gz' - syntree_collection = SyntheticTreeSet().load(file) - data_train = [st.root.smiles for st in syntree_collection] - fps_train = _compute_fp_bitvector(data_train) - - file = '/pool001/whgao/data/synth_net/st_hb/st_test.json.gz' - syntree_collection = SyntheticTreeSet().load(file) - data_test = [st.root.smiles for st in syntree_collection] - fps_test = _compute_fp_bitvector(data_test) - - file = '/pool001/whgao/data/synth_net/st_hb/st_valid.json.gz' - syntree_collection = SyntheticTreeSet().load(file) - data_valid = [st.root.smiles for st in syntree_collection] - fps_valid = _compute_fp_bitvector(data_valid) - - with mp.Pool(processes=args.ncpu) as pool: - results = pool.map(func, fps_valid) - - similaritys = [r[0] for r in results] - indices = [data_train[r[1]] for r in results] - df1 = pd.DataFrame({'smiles': data_valid, 'split': 'valid', 'most similar': indices, 'similarity': similaritys}) - - with mp.Pool(processes=args.ncpu) as pool: - results = pool.map(func, fps_test) + # Load data + smiles_train, fps_train = get_smiles_and_fps("train") + smiles_valid, fps_valid = get_smiles_and_fps("valid") + smiles_test, fps_test = get_smiles_and_fps("test") - similaritys = [r[0] for r in results] - indices = [data_train[r[1]] for r in results] - df2 = pd.DataFrame({'smiles': data_test, 'split': 'test', 'most similar': indices, 'similarity': similaritys}) + # Compute (mp) + func = partial(find_similar_fp,fps_reference=fps_train) + df_valid = compute_most_similar_smiles("valid",fps_valid,smiles_valid) + df_test = compute_most_similar_smiles("test",fps_test,smiles_test) + # Save outfile = 'data_similarity.csv' - _save_df(outfile, pd.concat([df1, df2], axis=0, ignore_index=True)) + _save_df(outfile, pd.concat([df_valid, df_test], axis=0, ignore_index=True)) - print('Finish!') + logger.info("Completed.") From 2fa0f9f2688f06638dc52c1a15755f379516bd62 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Thu, 29 Sep 2022 19:26:31 -0400 Subject: [PATCH 245/302] format --- scripts/search_similar.py | 70 +++++++++++++++++++++++++-------------- 1 file changed, 45 insertions(+), 25 deletions(-) diff --git a/scripts/search_similar.py b/scripts/search_similar.py index d291a941..f661397a 100644 --- a/scripts/search_similar.py +++ b/scripts/search_similar.py @@ -1,19 +1,20 @@ """ Computes the fingerprint similarity of molecules in the validation and test set to molecules in the training set. -""" # TODO: clean up, un-nest a couple of fcts -from functools import partial +""" # TODO: clean up, un-nest a couple of fcts import json +import logging +import multiprocessing as mp +from functools import partial +from pathlib import Path from typing import Tuple + import numpy as np import pandas as pd -from syn_net.utils.data_utils import SyntheticTreeSet -from rdkit import Chem +from rdkit import Chem, DataStructs from rdkit.Chem import AllChem -import multiprocessing as mp -from rdkit import DataStructs -import logging -from pathlib import Path + +from syn_net.utils.data_utils import SyntheticTreeSet logger = logging.getLogger(__file__) @@ -42,49 +43,69 @@ def get_args(): return parser.parse_args() -def _match_dataset_filename(path: str, dataset_type: str) -> Path: # TODO: consolidate with code in script/05-* +def _match_dataset_filename( + path: str, dataset_type: str +) -> Path: # TODO: consolidate with code in script/05-* """Helper to find the exact filename for {train,valid,test} file.""" files = list(Path(path).glob(f"*{dataset_type}*.json.gz")) if len(files) != 1: raise ValueError(f"Can not find unique '{dataset_type} 'file, got {files}") return files[0] + def find_similar_fp(fp: np.ndarray, fps_reference: np.ndarray): """Finds most similar fingerprint in a reference set for `fp`. Uses Tanimoto Similarity. """ dists = np.array( - [DataStructs.FingerprintSimilarity(fp, fp_, metric=DataStructs.TanimotoSimilarity) for fp_ in fps_train]) + [ + DataStructs.FingerprintSimilarity(fp, fp_, metric=DataStructs.TanimotoSimilarity) + for fp_ in fps_train + ] + ) similarity_score, idx = dists.max(), dists.argmax() return similarity_score, idx -def _compute_fp_bitvector(smiles: list[str], radius: int=2, nbits: int=1024): - return [AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(smi), radius, nBits=nbits) for smi in smiles] -def get_smiles_and_fps(dataset: str) -> Tuple[list[str],list[np.ndarray]]: - file = _match_dataset_filename(args.input_dir,dataset) +def _compute_fp_bitvector(smiles: list[str], radius: int = 2, nbits: int = 1024): + return [ + AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(smi), radius, nBits=nbits) + for smi in smiles + ] + + +def get_smiles_and_fps(dataset: str) -> Tuple[list[str], list[np.ndarray]]: + file = _match_dataset_filename(args.input_dir, dataset) syntree_collection = SyntheticTreeSet().load(file) smiles = [st.root.smiles for st in syntree_collection] fps = _compute_fp_bitvector(smiles) return smiles, fps + def _save_df(file: str, df): - if file is None: return + if file is None: + return df.to_csv(file, index=False) + def compute_most_similar_smiles(split: str, fps: np.ndarray, smiles: list[str]) -> pd.DataFrame: with mp.Pool(processes=args.ncpu) as pool: results = pool.map(func, fps) similarities, idx = np.asfarray(results).T - most_similiar_ref_smiles = np.asarray(smiles_train)[idx.astype(int)] # use numpy for slicin' + most_similiar_ref_smiles = np.asarray(smiles_train)[idx.astype(int)] # use numpy for slicin' df = pd.DataFrame( - {'smiles': smiles, - 'split': split, - 'most similar': most_similiar_ref_smiles, 'similarity': similarities}) + { + "smiles": smiles, + "split": split, + "most similar": most_similiar_ref_smiles, + "similarity": similarities, + } + ) return df + if __name__ == "__main__": logger.info("Start.") @@ -99,13 +120,12 @@ def compute_most_similar_smiles(split: str, fps: np.ndarray, smiles: list[str]) smiles_test, fps_test = get_smiles_and_fps("test") # Compute (mp) - func = partial(find_similar_fp,fps_reference=fps_train) - df_valid = compute_most_similar_smiles("valid",fps_valid,smiles_valid) - df_test = compute_most_similar_smiles("test",fps_test,smiles_test) + func = partial(find_similar_fp, fps_reference=fps_train) + df_valid = compute_most_similar_smiles("valid", fps_valid, smiles_valid) + df_test = compute_most_similar_smiles("test", fps_test, smiles_test) # Save - outfile = 'data_similarity.csv' - _save_df(outfile, pd.concat([df_valid, df_test], axis=0, ignore_index=True)) - + outfile = "data_similarity.csv" + _save_df(outfile, pd.concat([df_valid, df_test], axis=0, ignore_index=True)) logger.info("Completed.") From da25071ca75d7453a4c72e5d2c7dbe197a953dca Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Thu, 29 Sep 2022 19:31:05 -0400 Subject: [PATCH 246/302] delete (hopefully) old code --- scripts/_mp_predict.py | 91 ------------------------------------------ scripts/predict_mp.py | 64 ----------------------------- 2 files changed, 155 deletions(-) delete mode 100644 scripts/_mp_predict.py delete mode 100644 scripts/predict_mp.py diff --git a/scripts/_mp_predict.py b/scripts/_mp_predict.py deleted file mode 100644 index 10884b6c..00000000 --- a/scripts/_mp_predict.py +++ /dev/null @@ -1,91 +0,0 @@ -""" -This file contains a function to predict a single synthetic tree given a molecular SMILES. -""" -import pandas as pd -import numpy as np -from syn_net.utils.data_utils import ReactionSet -from dgllife.model import load_pretrained -from syn_net.utils.predict_utils import synthetic_tree_decoder, tanimoto_similarity, load_modules_from_checkpoint, mol_fp - -# define some constants (here, for the Hartenfeller-Button test set) -nbits = 4096 -out_dim = 256 -rxn_template = 'hb' -featurize = 'fp' -param_dir = 'hb_fp_2_4096_256' -ncpu = 32 - -# define model to use for molecular embedding -model_type = 'gin_supervised_contextpred' -device = 'cpu' -mol_embedder = load_pretrained(model_type).to(device) -mol_embedder.eval() - -# load the purchasable building block embeddings -bb_emb = np.load('/pool001/whgao/data/synth_net/st_hb/enamine_us_emb_fp_256.npy') - -# define path to the reaction templates and purchasable building blocks -path_to_reaction_file = '/pool001/whgao/data/synth_net/st_' + rxn_template + '/reactions_' + rxn_template + '.json.gz' -path_to_building_blocks = '/pool001/whgao/data/synth_net/st_' + rxn_template + '/enamine_us_matched.csv.gz' - -# define paths to pretrained modules -param_path = '/home/whgao/scGen/synth_net/synth_net/params/' + param_dir + '/' -path_to_act = param_path + 'act.ckpt' -path_to_rt1 = param_path + 'rt1.ckpt' -path_to_rxn = param_path + 'rxn.ckpt' -path_to_rt2 = param_path + 'rt2.ckpt' - -# load the purchasable building block SMILES to a dictionary -building_blocks = pd.read_csv(path_to_building_blocks, compression='gzip')['SMILES'].tolist() -bb_dict = {building_blocks[i]: i for i in range(len(building_blocks))} - -# load the reaction templates as a ReactionSet object -rxn_set = ReactionSet().load(path_to_reaction_file) -rxns = rxn_set.rxns - -# load the pre-trained modules -act_net, rt1_net, rxn_net, rt2_net = load_modules_from_checkpoint( - path_to_act=path_to_act, - path_to_rt1=path_to_rt1, - path_to_rxn=path_to_rxn, - path_to_rt2=path_to_rt2, - featurize=featurize, - rxn_template=rxn_template, - out_dim=out_dim, - nbits=nbits, - ncpu=ncpu, -) - -def func(smi): - """ - Generates the synthetic tree for the input SMILES. - - Args: - smi (str): Molecular to reconstruct. - - Returns: - str: Final product SMILES. - float: Score of the best final product. - SyntheticTree: The generated synthetic tree. - """ - emb = mol_fp(smi) - try: - tree, action = synthetic_tree_decoder(emb, building_blocks, bb_dict, rxns, mol_embedder, act_net, rt1_net, rxn_net, rt2_net, bb_emb, rxn_template=rxn_template, n_bits=nbits, max_step=15) - except Exception as e: - print(e) - action = -1 - - # tree, action = synthetic_tree_decoder(emb, building_blocks, bb_dict, rxns, mol_embedder, act_net, rt1_net, rxn_net, rt2_net, max_step=15) - - # import ipdb; ipdb.set_trace(context=9) - # tree._print() - # print(action) - # print(np.max(oracle(tree.get_state()))) - # print() - - if action != 3: - return None, 0, None - else: - scores = tanimoto_similarity(emb, tree.get_state()) - max_score_idx = np.where(scores == np.max(scores))[0][0] - return tree.get_state()[max_score_idx], np.max(scores), tree diff --git a/scripts/predict_mp.py b/scripts/predict_mp.py deleted file mode 100644 index 649c78a5..00000000 --- a/scripts/predict_mp.py +++ /dev/null @@ -1,64 +0,0 @@ -""" -Generate synthetic trees for a set of specified query molecules. Multiprocessing. -""" -import multiprocessing as mp -import numpy as np -import pandas as pd -import scripts._mp_predict as predict -from syn_net.utils.data_utils import SyntheticTreeSet -from pathlib import Path - -from syn_net.config import DATA_RESULT_DIR - -if __name__ == '__main__': - - import argparse - parser = argparse.ArgumentParser() - parser.add_argument("-f", "--featurize", type=str, default='fp', - help="Choose from ['fp', 'gin']") - parser.add_argument("-r", "--rxn_template", type=str, default='hb', - help="Choose from ['hb', 'pis']") - parser.add_argument("--ncpu", type=int, default=16, - help="Number of cpus") - parser.add_argument("-n", "--num", type=int, default=-1, - help="Number of molecules to predict.") - parser.add_argument("-d", "--data", type=str, default='test', - help="Choose from ['train', 'valid', 'test']") - args = parser.parse_args() - - # load the query molecules (i.e. molecules to decode) - path_to_data = '/pool001/whgao/data/synth_net/st_' + args.rxn_template + '/st_' + args.data +'.json.gz' - print('Reading data from ', path_to_data) - sts = SyntheticTreeSet().load(path_to_data) - smis_query = [st.root.smiles for st in sts.sts] - if args.num == -1: - pass - else: - smis_query = smis_query[:args.num] - - print('Start to decode!') - with mp.Pool(processes=args.ncpu) as pool: - results = pool.map(predict.func, smis_query) - - smis_decoded = [r[0] for r in results] - similaritys = [r[1] for r in results] - trees = [r[2] for r in results] - - print("Finish decoding") - print(f"Recovery rate {args.data}: {np.sum(np.array(similaritys) == 1.0) / len(similaritys)}") - print(f"Average similarity {args.data}: {np.mean(np.array(similaritys))}") - - print('Saving ......') - out_dir = Path(DATA_RESULT_DIR) / f"{args.rxn_template}_{args.featurize}" - out_dir.mkdir(exist_ok=1,parent=1) - df = pd.DataFrame({'query SMILES': smis_query, 'decode SMILES': smis_decoded, 'similarity': similaritys}) - file = out_dir / f'decode_result_{args.data}.csv.gz' - df.to_csv(file, compression='gzip', index=False) - - synthetic_tree_set = SyntheticTreeSet(sts=trees) - file = out_dir / f'decoded_st_{args.data}.json.gz' - synthetic_tree_set.save(file) - - print('Finish!') - - From 6c572f1a37253ba4870fb5558837d70decc7c91d Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Fri, 30 Sep 2022 11:54:26 -0400 Subject: [PATCH 247/302] rename fcts --- scripts/{predict_multireactant_mp.py => 20-predict-targets.py} | 0 scripts/{search_similar.py => 21-identify-similar-fps.py} | 0 scripts/{mrr.py => 22-compute-mrr.py} | 0 .../{evaluate_batch_recovery.py => 23-evaluate-predictions.py} | 0 4 files changed, 0 insertions(+), 0 deletions(-) rename scripts/{predict_multireactant_mp.py => 20-predict-targets.py} (100%) rename scripts/{search_similar.py => 21-identify-similar-fps.py} (100%) rename scripts/{mrr.py => 22-compute-mrr.py} (100%) rename scripts/{evaluate_batch_recovery.py => 23-evaluate-predictions.py} (100%) diff --git a/scripts/predict_multireactant_mp.py b/scripts/20-predict-targets.py similarity index 100% rename from scripts/predict_multireactant_mp.py rename to scripts/20-predict-targets.py diff --git a/scripts/search_similar.py b/scripts/21-identify-similar-fps.py similarity index 100% rename from scripts/search_similar.py rename to scripts/21-identify-similar-fps.py diff --git a/scripts/mrr.py b/scripts/22-compute-mrr.py similarity index 100% rename from scripts/mrr.py rename to scripts/22-compute-mrr.py diff --git a/scripts/evaluate_batch_recovery.py b/scripts/23-evaluate-predictions.py similarity index 100% rename from scripts/evaluate_batch_recovery.py rename to scripts/23-evaluate-predictions.py From 98c04efa9f58212abe4fe07ae45388f030d06baf Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Fri, 30 Sep 2022 12:51:56 -0400 Subject: [PATCH 248/302] use argparse for files, clean up --- scripts/22-compute-mrr.py | 63 +++++++++++---------------------------- 1 file changed, 17 insertions(+), 46 deletions(-) diff --git a/scripts/22-compute-mrr.py b/scripts/22-compute-mrr.py index c77fd4a1..79333608 100644 --- a/scripts/22-compute-mrr.py +++ b/scripts/22-compute-mrr.py @@ -1,5 +1,5 @@ """ -This function is used to compute the mean reciprocal ranking for reactant 1 +This function is used to compute the mean reciprocal ranking for reactant 1 selection using the different distance metrics in the k-NN search. """ from syn_net.models.mlp import MLP, load_array @@ -7,25 +7,18 @@ import numpy as np from sklearn.neighbors import BallTree import torch -from syn_net.utils.predict_utils import cosine_distance, ce_distance - +from syn_net.encoding.distances import cosine_distance, ce_distance if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() - parser.add_argument("-f", "--featurize", type=str, default='fp', - help="Choose from ['fp', 'gin']") - parser.add_argument("-r", "--rxn_template", type=str, default='hb', - help="Choose from ['hb', 'pis']") - parser.add_argument("--param_dir", type=str, default='hb_fp_2_4096_256', - help="") - parser.add_argument("--radius", type=int, default=2, - help="Radius for Morgan fingerprint.") + parser.add_argument("--ckpt-file", type=str,help="Checkpoint to load trained reactant 1 network.") + parser.add_argument("--embeddings-file", type=str,help="Pre-computed molecular embeddings for kNN search.") + parser.add_argument("--X-data-file", type=str, help="Featurized X data for network.") + parser.add_argument("--y-data-file", type=str, help="Featurized y data for network.") parser.add_argument("--nbits", type=int, default=4096, help="Number of Bits for Morgan fingerprint.") - parser.add_argument("--out_dim", type=int, default=256, - help="Output dimension.") parser.add_argument("--ncpu", type=int, default=8, help="Number of cpus") parser.add_argument("--batch_size", type=int, default=64, @@ -33,41 +26,16 @@ parser.add_argument("--device", type=str, default="cuda:0", help="") parser.add_argument("--distance", type=str, default="euclidean", - help="Choose from ['euclidean', 'manhattan', 'chebyshev', 'cross_entropy', 'cosine']") + choices=['euclidean', 'manhattan', 'chebyshev', 'cross_entropy', 'cosine'],help="Distance function for `BallTree`.") args = parser.parse_args() - if args.out_dim == 300: - validation_option = 'nn_accuracy_gin' - elif args.out_dim == 4096: - validation_option = 'nn_accuracy_fp_4096' - elif args.out_dim == 256: - validation_option = 'nn_accuracy_fp_256' - elif args.out_dim == 200: - validation_option = 'nn_accuracy_rdkit2d' - else: - raise ValueError - main_dir = '/pool001/whgao/data/synth_net/' + args.rxn_template + '_' + args.featurize + '_' + str(args.radius) + '_' + str(args.nbits) + '_' + validation_option[12:] + '/' - path_to_rt1 = '/home/whgao/scGen/synth_net/synth_net/params/' + args.param_dir + '/' + 'rt1.ckpt' + path_to_rt1 = args.ckpt_file batch_size = args.batch_size ncpu = args.ncpu - # X = sparse.load_npz(main_dir + 'X_rt1_train.npz') - # y = sparse.load_npz(main_dir + 'y_rt1_train.npz') - # X = torch.Tensor(X.A) - # y = torch.Tensor(y.A) - # _idx = np.random.choice(list(range(X.shape[0])), size=int(X.shape[0]/100), replace=False) - # train_data_iter = load_array((X[_idx], y[_idx]), batch_size, ncpu=ncpu, is_train=False) - - # X = sparse.load_npz(main_dir + 'X_rt1_valid.npz') - # y = sparse.load_npz(main_dir + 'y_rt1_valid.npz') - # X = torch.Tensor(X.A) - # y = torch.Tensor(y.A) - # _idx = np.random.choice(list(range(X.shape[0])), size=int(X.shape[0]/10), replace=False) - # valid_data_iter = load_array((X[_idx], y[_idx]), batch_size, ncpu=ncpu, is_train=False) - - X = sparse.load_npz(main_dir + 'X_rt1_test.npz') - y = sparse.load_npz(main_dir + 'y_rt1_test.npz') + X = sparse.load_npz(args.X_data_file) + y = sparse.load_npz(args.y_data_file) X = torch.Tensor(X.A) y = torch.Tensor(y.A) _idx = np.random.choice(list(range(X.shape[0])), size=int(X.shape[0]/10), replace=False) @@ -76,7 +44,7 @@ rt1_net = MLP.load_from_checkpoint(path_to_rt1, input_dim=int(3 * args.nbits), - output_dim=args.out_dim, + output_dim=d, hidden_dim=1200, num_layers=5, dropout=0.5, @@ -90,7 +58,8 @@ rt1_net.eval() rt1_net.to(args.device) - bb_emb_fp_256 = np.load('/pool001/whgao/data/synth_net/st_hb/enamine_us_emb_fp_256.npy') + bb_emb_fp_256 = np.load(args.embeddings_file) + n, d = bb_emb_fp_256.shape # for kw_metric_ in ['euclidean', 'manhattan', 'chebyshev', 'cross_entropy', 'cosine']: kw_metric_ = args.distance @@ -109,12 +78,14 @@ X, y = X.to(args.device), y.to(args.device) y_hat = rt1_net(X) dist_true, ind_true = kdtree_fp_256.query(y.detach().cpu().numpy(), k=1) - dist, ind = kdtree_fp_256.query(y_hat.detach().cpu().numpy(), k=bb_emb_fp_256.shape[0]) + dist, ind = kdtree_fp_256.query(y_hat.detach().cpu().numpy(), k=n) ranks = ranks + [np.where(ind[i] == ind_true[i])[0][0] for i in range(len(ind_true))] ranks = np.array(ranks) rrs = 1 / (ranks + 1) - np.save('ranks_' + kw_metric_ + '.npy', ranks) + + np.save('ranks_' + kw_metric_ + '.npy', ranks) # TODO: do not hard code + print(f"Result using metric: {kw_metric_}") print(f"The mean reciprocal ranking is: {rrs.mean():.3f}") print(f"The Top-1 recovery rate is: {sum(ranks < 1) / len(ranks) :.3f}, {sum(ranks < 1)} / {len(ranks)}") From e110c3046d5a686dfbf346b2d6559640bee8990a Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Fri, 30 Sep 2022 12:54:48 -0400 Subject: [PATCH 249/302] clean up --- scripts/22-compute-mrr.py | 37 +++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/scripts/22-compute-mrr.py b/scripts/22-compute-mrr.py index 79333608..2cf5b137 100644 --- a/scripts/22-compute-mrr.py +++ b/scripts/22-compute-mrr.py @@ -9,8 +9,7 @@ import torch from syn_net.encoding.distances import cosine_distance, ce_distance -if __name__ == '__main__': - +def get_args(): import argparse parser = argparse.ArgumentParser() parser.add_argument("--ckpt-file", type=str,help="Checkpoint to load trained reactant 1 network.") @@ -27,8 +26,23 @@ help="") parser.add_argument("--distance", type=str, default="euclidean", choices=['euclidean', 'manhattan', 'chebyshev', 'cross_entropy', 'cosine'],help="Distance function for `BallTree`.") - args = parser.parse_args() + return parser.parse_args() + +if __name__ == '__main__': + + args = get_args() + + + bb_emb_fp_256 = np.load(args.embeddings_file) + n, d = bb_emb_fp_256.shape + metric = args.distance + if metric == 'cross_entropy': + metric = ce_distance + elif metric == 'cosine': + metric = cosine_distance + + kdtree_fp_256 = BallTree(bb_emb_fp_256, metric=metric) path_to_rt1 = args.ckpt_file batch_size = args.batch_size @@ -58,20 +72,7 @@ rt1_net.eval() rt1_net.to(args.device) - bb_emb_fp_256 = np.load(args.embeddings_file) - n, d = bb_emb_fp_256.shape - - # for kw_metric_ in ['euclidean', 'manhattan', 'chebyshev', 'cross_entropy', 'cosine']: - kw_metric_ = args.distance - - if kw_metric_ == 'cross_entropy': - kw_metric = ce_distance - elif kw_metric_ == 'cosine': - kw_metric = cosine_distance - else: - kw_metric = kw_metric_ - kdtree_fp_256 = BallTree(bb_emb_fp_256, metric=kw_metric) ranks = [] for X, y in data_iter: @@ -84,9 +85,9 @@ ranks = np.array(ranks) rrs = 1 / (ranks + 1) - np.save('ranks_' + kw_metric_ + '.npy', ranks) # TODO: do not hard code + np.save('ranks_' + metric + '.npy', ranks) # TODO: do not hard code - print(f"Result using metric: {kw_metric_}") + print(f"Result using metric: {metric}") print(f"The mean reciprocal ranking is: {rrs.mean():.3f}") print(f"The Top-1 recovery rate is: {sum(ranks < 1) / len(ranks) :.3f}, {sum(ranks < 1)} / {len(ranks)}") print(f"The Top-3 recovery rate is: {sum(ranks < 3) / len(ranks) :.3f}, {sum(ranks < 3)} / {len(ranks)}") From f291c200b7e7cf970b2650b2124e9090052dcb07 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Fri, 30 Sep 2022 12:55:08 -0400 Subject: [PATCH 250/302] format --- scripts/22-compute-mrr.py | 108 +++++++++++++++++++++++--------------- 1 file changed, 65 insertions(+), 43 deletions(-) diff --git a/scripts/22-compute-mrr.py b/scripts/22-compute-mrr.py index 2cf5b137..be2ab55a 100644 --- a/scripts/22-compute-mrr.py +++ b/scripts/22-compute-mrr.py @@ -2,44 +2,54 @@ This function is used to compute the mean reciprocal ranking for reactant 1 selection using the different distance metrics in the k-NN search. """ -from syn_net.models.mlp import MLP, load_array -from scipy import sparse import numpy as np -from sklearn.neighbors import BallTree import torch -from syn_net.encoding.distances import cosine_distance, ce_distance +from scipy import sparse +from sklearn.neighbors import BallTree + +from syn_net.encoding.distances import ce_distance, cosine_distance +from syn_net.models.mlp import MLP, load_array + def get_args(): import argparse + parser = argparse.ArgumentParser() - parser.add_argument("--ckpt-file", type=str,help="Checkpoint to load trained reactant 1 network.") - parser.add_argument("--embeddings-file", type=str,help="Pre-computed molecular embeddings for kNN search.") + parser.add_argument( + "--ckpt-file", type=str, help="Checkpoint to load trained reactant 1 network." + ) + parser.add_argument( + "--embeddings-file", type=str, help="Pre-computed molecular embeddings for kNN search." + ) parser.add_argument("--X-data-file", type=str, help="Featurized X data for network.") parser.add_argument("--y-data-file", type=str, help="Featurized y data for network.") - parser.add_argument("--nbits", type=int, default=4096, - help="Number of Bits for Morgan fingerprint.") - parser.add_argument("--ncpu", type=int, default=8, - help="Number of cpus") - parser.add_argument("--batch_size", type=int, default=64, - help="Batch size") - parser.add_argument("--device", type=str, default="cuda:0", - help="") - parser.add_argument("--distance", type=str, default="euclidean", - choices=['euclidean', 'manhattan', 'chebyshev', 'cross_entropy', 'cosine'],help="Distance function for `BallTree`.") + parser.add_argument( + "--nbits", type=int, default=4096, help="Number of Bits for Morgan fingerprint." + ) + parser.add_argument("--ncpu", type=int, default=8, help="Number of cpus") + parser.add_argument("--batch_size", type=int, default=64, help="Batch size") + parser.add_argument("--device", type=str, default="cuda:0", help="") + parser.add_argument( + "--distance", + type=str, + default="euclidean", + choices=["euclidean", "manhattan", "chebyshev", "cross_entropy", "cosine"], + help="Distance function for `BallTree`.", + ) return parser.parse_args() -if __name__ == '__main__': - args = get_args() +if __name__ == "__main__": + args = get_args() bb_emb_fp_256 = np.load(args.embeddings_file) n, d = bb_emb_fp_256.shape metric = args.distance - if metric == 'cross_entropy': + if metric == "cross_entropy": metric = ce_distance - elif metric == 'cosine': + elif metric == "cosine": metric = cosine_distance kdtree_fp_256 = BallTree(bb_emb_fp_256, metric=metric) @@ -52,28 +62,28 @@ def get_args(): y = sparse.load_npz(args.y_data_file) X = torch.Tensor(X.A) y = torch.Tensor(y.A) - _idx = np.random.choice(list(range(X.shape[0])), size=int(X.shape[0]/10), replace=False) + _idx = np.random.choice(list(range(X.shape[0])), size=int(X.shape[0] / 10), replace=False) test_data_iter = load_array((X[_idx], y[_idx]), batch_size, ncpu=ncpu, is_train=False) data_iter = test_data_iter - rt1_net = MLP.load_from_checkpoint(path_to_rt1, - input_dim=int(3 * args.nbits), - output_dim=d, - hidden_dim=1200, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='regression', - loss='mse', - valid_loss='mse', - optimizer='adam', - learning_rate=1e-4, - ncpu=ncpu) + rt1_net = MLP.load_from_checkpoint( + path_to_rt1, + input_dim=int(3 * args.nbits), + output_dim=d, + hidden_dim=1200, + num_layers=5, + dropout=0.5, + num_dropout_layers=1, + task="regression", + loss="mse", + valid_loss="mse", + optimizer="adam", + learning_rate=1e-4, + ncpu=ncpu, + ) rt1_net.eval() rt1_net.to(args.device) - - ranks = [] for X, y in data_iter: X, y = X.to(args.device), y.to(args.device) @@ -85,14 +95,26 @@ def get_args(): ranks = np.array(ranks) rrs = 1 / (ranks + 1) - np.save('ranks_' + metric + '.npy', ranks) # TODO: do not hard code + np.save("ranks_" + metric + ".npy", ranks) # TODO: do not hard code print(f"Result using metric: {metric}") print(f"The mean reciprocal ranking is: {rrs.mean():.3f}") - print(f"The Top-1 recovery rate is: {sum(ranks < 1) / len(ranks) :.3f}, {sum(ranks < 1)} / {len(ranks)}") - print(f"The Top-3 recovery rate is: {sum(ranks < 3) / len(ranks) :.3f}, {sum(ranks < 3)} / {len(ranks)}") - print(f"The Top-5 recovery rate is: {sum(ranks < 5) / len(ranks) :.3f}, {sum(ranks < 5)} / {len(ranks)}") - print(f"The Top-10 recovery rate is: {sum(ranks < 10) / len(ranks) :.3f}, {sum(ranks < 10)} / {len(ranks)}") - print(f"The Top-15 recovery rate is: {sum(ranks < 15) / len(ranks) :.3f}, {sum(ranks < 15)} / {len(ranks)}") - print(f"The Top-30 recovery rate is: {sum(ranks < 30) / len(ranks) :.3f}, {sum(ranks < 30)} / {len(ranks)}") + print( + f"The Top-1 recovery rate is: {sum(ranks < 1) / len(ranks) :.3f}, {sum(ranks < 1)} / {len(ranks)}" + ) + print( + f"The Top-3 recovery rate is: {sum(ranks < 3) / len(ranks) :.3f}, {sum(ranks < 3)} / {len(ranks)}" + ) + print( + f"The Top-5 recovery rate is: {sum(ranks < 5) / len(ranks) :.3f}, {sum(ranks < 5)} / {len(ranks)}" + ) + print( + f"The Top-10 recovery rate is: {sum(ranks < 10) / len(ranks) :.3f}, {sum(ranks < 10)} / {len(ranks)}" + ) + print( + f"The Top-15 recovery rate is: {sum(ranks < 15) / len(ranks) :.3f}, {sum(ranks < 15)} / {len(ranks)}" + ) + print( + f"The Top-30 recovery rate is: {sum(ranks < 30) / len(ranks) :.3f}, {sum(ranks < 30)} / {len(ranks)}" + ) print() From ea47c2436f51318977753757e6d0920d6098785b Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Fri, 30 Sep 2022 14:50:57 -0400 Subject: [PATCH 251/302] clean clean clean --- scripts/23-evaluate-predictions.py | 89 +++++++++++++++--------------- 1 file changed, 45 insertions(+), 44 deletions(-) diff --git a/scripts/23-evaluate-predictions.py b/scripts/23-evaluate-predictions.py index 1eff5127..c4573b4b 100644 --- a/scripts/23-evaluate-predictions.py +++ b/scripts/23-evaluate-predictions.py @@ -1,7 +1,5 @@ -""" -This function evaluates a batch of predictions by computing the (1) novelty, (2) validity, -(3) uniqueness, (4) Fréchet ChemNet distance, and (5) KL divergence for the final -root molecules which correspond to *unrecovered* molecules in all the generated trees. +"""Evaluate a batch of predictions on different metrics. +The predictions are generated in `20-predict-targets.py`. """ from tdc import Evaluator import pandas as pd @@ -14,60 +12,63 @@ validity = Evaluator(name = 'Validity') uniqueness = Evaluator(name = 'Uniqueness') +def get_args(): + import argparse + parser = argparse.ArgumentParser() + + parser.add_argument( + "--input-file", type=str, help="Dataframe with target- and prediction smiles and similarities (*.csv.gz)." + ) + return parser.parse_args() + if __name__ == '__main__': - # load the final root molecules generated by a prediction run using a - # pre-trained model, which were all saved to different files - generated_st_files = glob.glob('../../results_mp/pis_fp/decode_result_train*.csv.gz') - - # lists in which to collect all the results - recovered_molecules = pd.DataFrame({'query SMILES': [], 'decode SMILES': [], 'similarity':[]}) # molecules successfully recovered from query - unrecovered_molecules = pd.DataFrame({'query SMILES': [], 'decode SMILES': [], 'similarity':[]}) # unsuccessfully recovered - recovered_novelty_all = [] - recovered_validity_decode_all = [] - recovered_uniqueness_decode_all = [] - recovered_fcd_distance_all = [] - recovered_kl_divergence_all = [] - unrecovered_novelty_all = [] - unrecovered_validity_decode_all = [] - unrecovered_uniqueness_decode_all = [] - unrecovered_fcd_distance_all = [] - unrecovered_kl_divergence_all = [] + args = get_args() - similarity = [] + files = [args.file] # TODO: not sure why the loop but let's keep it for now + + # Keep track of successfully and unsuccessfully recovered molecules in 2 df's + recovered = pd.DataFrame({'query SMILES': [], 'decode SMILES': [], 'similarity':[]}) + unrecovered = pd.DataFrame({'query SMILES': [], 'decode SMILES': [], 'similarity':[]}) + # load each file containing the predictions + similarity = [] n_recovered = 0 n_unrecovered = 0 n_total = 0 + for file in files: + print(f'File currently being evaluated: {file}') - # load each file containing the predictions - for generate_st_file in generated_st_files: - - print(f'File currently being evaluated: {generate_st_file}') - - result_df = pd.read_csv(generate_st_file, compression='gzip') + result_df = pd.read_csv(file) n_total += len(result_df['decode SMILES']) - # get the recovered and unrecovered molecules only (no NaNs) - unrecovered_molecules = pd.concat([unrecovered_molecules, result_df[result_df['similarity'] != 1.0].dropna()]) - recovered_molecules = pd.concat([recovered_molecules, result_df[result_df['similarity'] == 1.0].dropna()]) + # Split smiles, discard NaNs + is_recovered = result_df['similarity'] == 1.0 + unrecovered = pd.concat([unrecovered, result_df[~is_recovered].dropna()]) + recovered = pd.concat([recovered, result_df[is_recovered].dropna()]) - n_recovered += len(recovered_molecules['decode SMILES']) - n_unrecovered += len(unrecovered_molecules['decode SMILES']) - similarity += unrecovered_molecules['similarity'].tolist() + n_recovered += len(recovered) + n_unrecovered += len(unrecovered) + similarity += unrecovered['similarity'].tolist() # compute the following properties, using the TDC, for the succesfully recovered molecules - recovered_novelty_all = novelty(recovered_molecules['query SMILES'].tolist(), recovered_molecules['decode SMILES'].tolist()) - recovered_validity_decode_all = validity(recovered_molecules['decode SMILES'].tolist()) - recovered_uniqueness_decode_all = uniqueness(recovered_molecules['decode SMILES'].tolist()) - recovered_fcd_distance_all = fcd_distance(recovered_molecules['query SMILES'].tolist(), recovered_molecules['decode SMILES'].tolist()) - recovered_kl_divergence_all = kl_divergence(recovered_molecules['query SMILES'].tolist(), recovered_molecules['decode SMILES'].tolist()) + recovered_novelty_all = novelty( + recovered['query SMILES'].tolist(), + recovered['decode SMILES'].tolist(), + ) + recovered_validity_decode_all = validity(recovered['decode SMILES'].tolist()) + recovered_uniqueness_decode_all = uniqueness(recovered['decode SMILES'].tolist()) + recovered_fcd_distance_all = fcd_distance( + recovered['query SMILES'].tolist(), + recovered['decode SMILES'].tolist() + ) + recovered_kl_divergence_all = kl_divergence(recovered['query SMILES'].tolist(), recovered['decode SMILES'].tolist()) # compute the following properties, using the TDC, for the unrecovered molecules - unrecovered_novelty_all = novelty(unrecovered_molecules['query SMILES'].tolist(), unrecovered_molecules['decode SMILES'].tolist()) - unrecovered_validity_decode_all = validity(unrecovered_molecules['decode SMILES'].tolist()) - unrecovered_uniqueness_decode_all = uniqueness(unrecovered_molecules['decode SMILES'].tolist()) - unrecovered_fcd_distance_all = fcd_distance(unrecovered_molecules['query SMILES'].tolist(), unrecovered_molecules['decode SMILES'].tolist()) - unrecovered_kl_divergence_all = kl_divergence(unrecovered_molecules['query SMILES'].tolist(), unrecovered_molecules['decode SMILES'].tolist()) + unrecovered_novelty_all = novelty(unrecovered['query SMILES'].tolist(), unrecovered['decode SMILES'].tolist()) + unrecovered_validity_decode_all = validity(unrecovered['decode SMILES'].tolist()) + unrecovered_uniqueness_decode_all = uniqueness(unrecovered['decode SMILES'].tolist()) + unrecovered_fcd_distance_all = fcd_distance(unrecovered['query SMILES'].tolist(), unrecovered['decode SMILES'].tolist()) + unrecovered_kl_divergence_all = kl_divergence(unrecovered['query SMILES'].tolist(), unrecovered['decode SMILES'].tolist()) print('N recovered, N unrecovered, N total (% recovered):', n_recovered, ',', n_unrecovered, ',', n_total, ', (', 100*n_recovered/n_total, '%)') n_finished = n_recovered + n_unrecovered From 073e40256dd9d275d7d8ff904fa2ba05a213f2e2 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Fri, 30 Sep 2022 14:51:29 -0400 Subject: [PATCH 252/302] fix bug in old code --- scripts/23-evaluate-predictions.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/23-evaluate-predictions.py b/scripts/23-evaluate-predictions.py index c4573b4b..fa8c8f6a 100644 --- a/scripts/23-evaluate-predictions.py +++ b/scripts/23-evaluate-predictions.py @@ -46,9 +46,9 @@ def get_args(): unrecovered = pd.concat([unrecovered, result_df[~is_recovered].dropna()]) recovered = pd.concat([recovered, result_df[is_recovered].dropna()]) - n_recovered += len(recovered) - n_unrecovered += len(unrecovered) - similarity += unrecovered['similarity'].tolist() + n_recovered += len(recovered) + n_unrecovered += len(unrecovered) + similarity += unrecovered['similarity'].tolist() # compute the following properties, using the TDC, for the succesfully recovered molecules recovered_novelty_all = novelty( From 4b2cc95d0755f282c0dac55ff55bb470fae68567 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Fri, 30 Sep 2022 15:05:16 -0400 Subject: [PATCH 253/302] use fstrings --- scripts/23-evaluate-predictions.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/scripts/23-evaluate-predictions.py b/scripts/23-evaluate-predictions.py index fa8c8f6a..0214d666 100644 --- a/scripts/23-evaluate-predictions.py +++ b/scripts/23-evaluate-predictions.py @@ -3,7 +3,6 @@ """ from tdc import Evaluator import pandas as pd -import glob import numpy as np kl_divergence = Evaluator(name = 'KL_Divergence') @@ -70,12 +69,16 @@ def get_args(): unrecovered_fcd_distance_all = fcd_distance(unrecovered['query SMILES'].tolist(), unrecovered['decode SMILES'].tolist()) unrecovered_kl_divergence_all = kl_divergence(unrecovered['query SMILES'].tolist(), unrecovered['decode SMILES'].tolist()) - print('N recovered, N unrecovered, N total (% recovered):', n_recovered, ',', n_unrecovered, ',', n_total, ', (', 100*n_recovered/n_total, '%)') + # Print info + print(f'N total {n_total}') + print(f'N recovered {n_recovered} ({n_recovered/n_total:.2f})') + print(f'N unrecovered {n_unrecovered} ({n_recovered/n_total:.2f})') + n_finished = n_recovered + n_unrecovered n_unfinished = n_total - n_finished - print('N finished trees (%):', n_finished, '(', 100*n_finished/n_total,'%)') - print('N unfinished trees (NaN) (%):', n_unfinished, '(', 100*n_unfinished/n_total,'%)') - print('Average similarity (unrecovered only)', np.mean(similarity)) + print(f'N finished tree {n_finished} ({n_finished/n_total:.2f})') + print(f'N unfinished trees (NaN) {n_unfinished} ({n_unfinished/n_total:.2f})') + print(f'Average similarity (unrecovered only) {np.mean(similarity)}') print('Novelty, recovered:', recovered_novelty_all) print('Novelty, unrecovered:', unrecovered_novelty_all) From cbace8f71c7154a11a7b08bf69548629ba1570c7 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Fri, 30 Sep 2022 15:05:37 -0400 Subject: [PATCH 254/302] format --- scripts/23-evaluate-predictions.py | 106 ++++++++++++++++------------- 1 file changed, 59 insertions(+), 47 deletions(-) diff --git a/scripts/23-evaluate-predictions.py b/scripts/23-evaluate-predictions.py index 0214d666..9946eed1 100644 --- a/scripts/23-evaluate-predictions.py +++ b/scripts/23-evaluate-predictions.py @@ -1,33 +1,38 @@ """Evaluate a batch of predictions on different metrics. The predictions are generated in `20-predict-targets.py`. """ -from tdc import Evaluator -import pandas as pd import numpy as np +import pandas as pd +from tdc import Evaluator + +kl_divergence = Evaluator(name="KL_Divergence") +fcd_distance = Evaluator(name="FCD_Distance") +novelty = Evaluator(name="Novelty") +validity = Evaluator(name="Validity") +uniqueness = Evaluator(name="Uniqueness") -kl_divergence = Evaluator(name = 'KL_Divergence') -fcd_distance = Evaluator(name = 'FCD_Distance') -novelty = Evaluator(name = 'Novelty') -validity = Evaluator(name = 'Validity') -uniqueness = Evaluator(name = 'Uniqueness') def get_args(): import argparse + parser = argparse.ArgumentParser() parser.add_argument( - "--input-file", type=str, help="Dataframe with target- and prediction smiles and similarities (*.csv.gz)." + "--input-file", + type=str, + help="Dataframe with target- and prediction smiles and similarities (*.csv.gz).", ) return parser.parse_args() -if __name__ == '__main__': + +if __name__ == "__main__": args = get_args() - files = [args.file] # TODO: not sure why the loop but let's keep it for now + files = [args.file] # TODO: not sure why the loop but let's keep it for now # Keep track of successfully and unsuccessfully recovered molecules in 2 df's - recovered = pd.DataFrame({'query SMILES': [], 'decode SMILES': [], 'similarity':[]}) - unrecovered = pd.DataFrame({'query SMILES': [], 'decode SMILES': [], 'similarity':[]}) + recovered = pd.DataFrame({"query SMILES": [], "decode SMILES": [], "similarity": []}) + unrecovered = pd.DataFrame({"query SMILES": [], "decode SMILES": [], "similarity": []}) # load each file containing the predictions similarity = [] @@ -35,62 +40,69 @@ def get_args(): n_unrecovered = 0 n_total = 0 for file in files: - print(f'File currently being evaluated: {file}') + print(f"File currently being evaluated: {file}") result_df = pd.read_csv(file) - n_total += len(result_df['decode SMILES']) + n_total += len(result_df["decode SMILES"]) # Split smiles, discard NaNs - is_recovered = result_df['similarity'] == 1.0 + is_recovered = result_df["similarity"] == 1.0 unrecovered = pd.concat([unrecovered, result_df[~is_recovered].dropna()]) - recovered = pd.concat([recovered, result_df[is_recovered].dropna()]) + recovered = pd.concat([recovered, result_df[is_recovered].dropna()]) n_recovered += len(recovered) n_unrecovered += len(unrecovered) - similarity += unrecovered['similarity'].tolist() + similarity += unrecovered["similarity"].tolist() # compute the following properties, using the TDC, for the succesfully recovered molecules recovered_novelty_all = novelty( - recovered['query SMILES'].tolist(), - recovered['decode SMILES'].tolist(), - ) - recovered_validity_decode_all = validity(recovered['decode SMILES'].tolist()) - recovered_uniqueness_decode_all = uniqueness(recovered['decode SMILES'].tolist()) + recovered["query SMILES"].tolist(), + recovered["decode SMILES"].tolist(), + ) + recovered_validity_decode_all = validity(recovered["decode SMILES"].tolist()) + recovered_uniqueness_decode_all = uniqueness(recovered["decode SMILES"].tolist()) recovered_fcd_distance_all = fcd_distance( - recovered['query SMILES'].tolist(), - recovered['decode SMILES'].tolist() - ) - recovered_kl_divergence_all = kl_divergence(recovered['query SMILES'].tolist(), recovered['decode SMILES'].tolist()) + recovered["query SMILES"].tolist(), recovered["decode SMILES"].tolist() + ) + recovered_kl_divergence_all = kl_divergence( + recovered["query SMILES"].tolist(), recovered["decode SMILES"].tolist() + ) # compute the following properties, using the TDC, for the unrecovered molecules - unrecovered_novelty_all = novelty(unrecovered['query SMILES'].tolist(), unrecovered['decode SMILES'].tolist()) - unrecovered_validity_decode_all = validity(unrecovered['decode SMILES'].tolist()) - unrecovered_uniqueness_decode_all = uniqueness(unrecovered['decode SMILES'].tolist()) - unrecovered_fcd_distance_all = fcd_distance(unrecovered['query SMILES'].tolist(), unrecovered['decode SMILES'].tolist()) - unrecovered_kl_divergence_all = kl_divergence(unrecovered['query SMILES'].tolist(), unrecovered['decode SMILES'].tolist()) + unrecovered_novelty_all = novelty( + unrecovered["query SMILES"].tolist(), unrecovered["decode SMILES"].tolist() + ) + unrecovered_validity_decode_all = validity(unrecovered["decode SMILES"].tolist()) + unrecovered_uniqueness_decode_all = uniqueness(unrecovered["decode SMILES"].tolist()) + unrecovered_fcd_distance_all = fcd_distance( + unrecovered["query SMILES"].tolist(), unrecovered["decode SMILES"].tolist() + ) + unrecovered_kl_divergence_all = kl_divergence( + unrecovered["query SMILES"].tolist(), unrecovered["decode SMILES"].tolist() + ) # Print info - print(f'N total {n_total}') - print(f'N recovered {n_recovered} ({n_recovered/n_total:.2f})') - print(f'N unrecovered {n_unrecovered} ({n_recovered/n_total:.2f})') + print(f"N total {n_total}") + print(f"N recovered {n_recovered} ({n_recovered/n_total:.2f})") + print(f"N unrecovered {n_unrecovered} ({n_recovered/n_total:.2f})") n_finished = n_recovered + n_unrecovered n_unfinished = n_total - n_finished - print(f'N finished tree {n_finished} ({n_finished/n_total:.2f})') - print(f'N unfinished trees (NaN) {n_unfinished} ({n_unfinished/n_total:.2f})') - print(f'Average similarity (unrecovered only) {np.mean(similarity)}') + print(f"N finished tree {n_finished} ({n_finished/n_total:.2f})") + print(f"N unfinished trees (NaN) {n_unfinished} ({n_unfinished/n_total:.2f})") + print(f"Average similarity (unrecovered only) {np.mean(similarity)}") - print('Novelty, recovered:', recovered_novelty_all) - print('Novelty, unrecovered:', unrecovered_novelty_all) + print("Novelty, recovered:", recovered_novelty_all) + print("Novelty, unrecovered:", unrecovered_novelty_all) - print('Validity, decode molecules, recovered:', recovered_validity_decode_all) - print('Validity, decode molecules, unrecovered:', unrecovered_validity_decode_all) + print("Validity, decode molecules, recovered:", recovered_validity_decode_all) + print("Validity, decode molecules, unrecovered:", unrecovered_validity_decode_all) - print('Uniqueness, decode molecules, recovered:', recovered_uniqueness_decode_all) - print('Uniqueness, decode molecules, unrecovered:', unrecovered_uniqueness_decode_all) + print("Uniqueness, decode molecules, recovered:", recovered_uniqueness_decode_all) + print("Uniqueness, decode molecules, unrecovered:", unrecovered_uniqueness_decode_all) - print('FCD distance, recovered:', recovered_fcd_distance_all) - print('FCD distance, unrecovered:', unrecovered_fcd_distance_all) + print("FCD distance, recovered:", recovered_fcd_distance_all) + print("FCD distance, unrecovered:", unrecovered_fcd_distance_all) - print('KL divergence, recovered:', recovered_kl_divergence_all) - print('KL divergence, unrecovered:', unrecovered_kl_divergence_all) + print("KL divergence, recovered:", recovered_kl_divergence_all) + print("KL divergence, unrecovered:", unrecovered_kl_divergence_all) From 12cc6cf287c9359190633e101fd796a0a2c6359f Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Fri, 30 Sep 2022 15:06:48 -0400 Subject: [PATCH 255/302] delete (hopefully) old code --- scripts/evaluate_batch.py | 28 ---------------------------- 1 file changed, 28 deletions(-) delete mode 100644 scripts/evaluate_batch.py diff --git a/scripts/evaluate_batch.py b/scripts/evaluate_batch.py deleted file mode 100644 index 85283da2..00000000 --- a/scripts/evaluate_batch.py +++ /dev/null @@ -1,28 +0,0 @@ -""" -This function evaluates a batch of predictions by computing the (1) novelty, (2) validity, -(3) uniqueness, (4) Fréchet ChemNet distance, and (5) KL divergence for the final -root molecules which correspond to *unrecovered* molecules in all the generated trees. -""" -from tdc import Evaluator -import pandas as pd - -kl_divergence = Evaluator(name = 'KL_Divergence') -fcd_distance = Evaluator(name = 'FCD_Distance') -novelty = Evaluator(name = 'Novelty') -validity = Evaluator(name = 'Validity') -uniqueness = Evaluator(name = 'Uniqueness') - -if __name__ == '__main__': - # load the final root molecules generated by a prediction run using a pre-trained model - result_train = pd.read_csv('../results/decode_result_test_processed_property.csv.gz', compression='gzip') - - # get the unrecovered molecules only - # result_test_unrecover = result_train[result_train['recovered sa'] != -1][result_train['similarity'] != 1.0] - result_test_unrecover = result_train[result_train['recovered sa'] != -1] - - # compute the following properties, using the TDC - print(f"Novelty: {novelty(result_test_unrecover['query SMILES'].tolist(), result_test_unrecover['decode SMILES'].tolist())}") - print(f"Validity: {validity(result_test_unrecover['decode SMILES'].tolist())}") - print(f"Uniqueness: {uniqueness(result_test_unrecover['decode SMILES'].tolist())}") - print(f"FCD: {fcd_distance(result_test_unrecover['query SMILES'].tolist(), result_test_unrecover['decode SMILES'].tolist())}") - print(f"KL: {kl_divergence(result_test_unrecover['query SMILES'].tolist()[:10000], result_test_unrecover['decode SMILES'].tolist()[:10000])}") From f3e64aa0ca39a3307920d1d533db7a63bbb16bd6 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Fri, 30 Sep 2022 15:18:51 -0400 Subject: [PATCH 256/302] rename: synthetic-trees -> syntrees --- INSTRUCTIONS.md | 2 +- scripts/{04-filter-synthetic-trees.py => 04-filter-syntrees.py} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename scripts/{04-filter-synthetic-trees.py => 04-filter-syntrees.py} (100%) diff --git a/INSTRUCTIONS.md b/INSTRUCTIONS.md index 348b748e..40eb606e 100644 --- a/INSTRUCTIONS.md +++ b/INSTRUCTIONS.md @@ -73,7 +73,7 @@ Let's start. ```bash # Filter - python scripts/04-filter-synthetic-trees.py \ + python scripts/04-filter-syntrees.py \ --input-file "data/pre-process/synthetic-trees.json.gz" \ --output-file "data/pre-process/synthetic-trees-filtered.json.gz" ``` diff --git a/scripts/04-filter-synthetic-trees.py b/scripts/04-filter-syntrees.py similarity index 100% rename from scripts/04-filter-synthetic-trees.py rename to scripts/04-filter-syntrees.py From 203310d0080585bf4fb1a0bd5ef80f90885d3854 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Fri, 30 Sep 2022 15:20:44 -0400 Subject: [PATCH 257/302] format --- scripts/07-split-data-for-networks.py | 11 +- scripts/20-predict-targets.py | 18 ++- scripts/_mp_decode.py | 86 ++++++----- scripts/optimize_ga.py | 197 +++++++++++++++----------- 4 files changed, 179 insertions(+), 133 deletions(-) diff --git a/scripts/07-split-data-for-networks.py b/scripts/07-split-data-for-networks.py index ecad73b0..f429644c 100644 --- a/scripts/07-split-data-for-networks.py +++ b/scripts/07-split-data-for-networks.py @@ -1,15 +1,17 @@ """Split the featurized data into X,y-chunks for the {act,rt1,rxn,rt2}-networks """ +import json import logging from pathlib import Path -import json from syn_net.utils.prep_utils import split_data_into_Xy logger = logging.getLogger(__file__) + def get_args(): import argparse + parser = argparse.ArgumentParser() parser.add_argument( "--input-dir", @@ -18,6 +20,7 @@ def get_args(): ) return parser.parse_args() + if __name__ == "__main__": logger.info("Start.") @@ -27,8 +30,8 @@ def get_args(): # Split datasets for each MLP logger.info("Start splitting data.") - num_rxn = 91 # Auxiliary var for indexing TODO: Dont hardcode - out_dim = 256 # Auxiliary var for indexing TODO: Dont hardcode + num_rxn = 91 # Auxiliary var for indexing TODO: Dont hardcode + out_dim = 256 # Auxiliary var for indexing TODO: Dont hardcode input_dir = Path(args.input_dir) output_dir = input_dir / "Xy" for dataset_type in "train valid test".split(): @@ -40,6 +43,6 @@ def get_args(): output_dir=input_dir / "Xy", num_rxn=num_rxn, out_dim=out_dim, - ) + ) logger.info(f"Completed.") diff --git a/scripts/20-predict-targets.py b/scripts/20-predict-targets.py index 43ad730a..eed677df 100644 --- a/scripts/20-predict-targets.py +++ b/scripts/20-predict-targets.py @@ -43,7 +43,9 @@ def _fetch_data_from_file(name: str) -> list[str]: def _fetch_data(name: str) -> list[str]: if args.data in ["train", "valid", "test"]: - file = Path(DATA_PREPROCESS_DIR) / "syntrees" / f"synthetic-trees-filtered-{args.data}.json.gz" + file = ( + Path(DATA_PREPROCESS_DIR) / "syntrees" / f"synthetic-trees-filtered-{args.data}.json.gz" + ) logger.info(f"Reading data from {file}") sts = SyntheticTreeSet() sts.load(file) @@ -176,7 +178,9 @@ def get_args(): smiles_queries = smiles_queries[: args.num] # ... building blocks - file = Path(DATA_PREPROCESS_DIR) / "building-blocks-rxns" / f"enamine-us-smiles.csv.gz" # TODO: Do not hardcode + file = ( + Path(DATA_PREPROCESS_DIR) / "building-blocks-rxns" / f"enamine-us-smiles.csv.gz" + ) # TODO: Do not hardcode building_blocks = BuildingBlockFileHandler().load(file) building_blocks_dict = { block: i for i, block in enumerate(building_blocks) @@ -184,12 +188,16 @@ def get_args(): logger.info("...loading building blocks completed.") # ... reaction templates - file = (Path(DATA_PREPROCESS_DIR) / "building-blocks-rxns" / "hb-enamine-us.json.gz") # TODO: Do not hardcode + file = ( + Path(DATA_PREPROCESS_DIR) / "building-blocks-rxns" / "hb-enamine-us.json.gz" + ) # TODO: Do not hardcode rxns = ReactionSet().load(file).rxns logger.info("...loading reaction collection completed.") # ... building block embedding - file = Path(DATA_PREPROCESS_DIR) / "embeddings" / f"hb-enamine-embeddings.npy" # TODO: Do not hardcode + file = ( + Path(DATA_PREPROCESS_DIR) / "embeddings" / f"hb-enamine-embeddings.npy" + ) # TODO: Do not hardcode bblocks_molembedder = MolEmbedder().load_precomputed(file).init_balltree(cosine_distance) bb_emb = bblocks_molembedder.get_embeddings() @@ -200,7 +208,7 @@ def get_args(): logger.info("Start loading models from checkpoints...") path = Path(CHECKPOINTS_DIR) / f"{param_dir}" paths = [ - find_best_model_ckpt("results/logs/hb_fp_2_4096/" + model) # TODO: Do not hardcode + find_best_model_ckpt("results/logs/hb_fp_2_4096/" + model) # TODO: Do not hardcode for model in "act rt1 rxn rt2".split() ] act_net, rt1_net, rxn_net, rt2_net = _load_pretrained_model(paths) diff --git a/scripts/_mp_decode.py b/scripts/_mp_decode.py index e5e833c2..d8bdf150 100644 --- a/scripts/_mp_decode.py +++ b/scripts/_mp_decode.py @@ -1,48 +1,57 @@ """ This file contains a function to decode a single synthetic tree. +TODO: Ussed in `scripts/optimize_ga.py`, refactor. """ -import pandas as pd import numpy as np +import pandas as pd from dgllife.model import load_pretrained -from syn_net.utils.data_utils import ReactionSet -from syn_net.utils.predict_utils import synthetic_tree_decoder, tanimoto_similarity, load_modules_from_checkpoint +from syn_net.utils.data_utils import ReactionSet +from syn_net.utils.predict_utils import ( + load_modules_from_checkpoint, + synthetic_tree_decoder, + tanimoto_similarity, +) # define some constants (here, for the Hartenfeller-Button test set) -nbits = 4096 -out_dim = 256 -rxn_template = 'hb' -featurize = 'fp' -param_dir = 'hb_fp_2_4096_256' -ncpu = 16 +nbits = 4096 +out_dim = 256 +rxn_template = "hb" +featurize = "fp" +param_dir = "hb_fp_2_4096_256" +ncpu = 16 # define model to use for molecular embedding -model_type = 'gin_supervised_contextpred' -device = 'cpu' +model_type = "gin_supervised_contextpred" +device = "cpu" mol_embedder = load_pretrained(model_type).to(device) mol_embedder.eval() # load the purchasable building block embeddings -bb_emb = np.load('/pool001/whgao/data/synth_net/st_hb/enamine_us_emb_fp_256.npy') +bb_emb = np.load("/pool001/whgao/data/synth_net/st_hb/enamine_us_emb_fp_256.npy") # define path to the reaction templates and purchasable building blocks -path_to_reaction_file = f'/pool001/whgao/data/synth_net/st_{rxn_template}/reactions_{rxn_template}.json.gz' -path_to_building_blocks = f'/pool001/whgao/data/synth_net/st_{rxn_template}/enamine_us_matched.csv.gz' +path_to_reaction_file = ( + f"/pool001/whgao/data/synth_net/st_{rxn_template}/reactions_{rxn_template}.json.gz" +) +path_to_building_blocks = ( + f"/pool001/whgao/data/synth_net/st_{rxn_template}/enamine_us_matched.csv.gz" +) # define paths to pretrained modules -param_path = f'/home/whgao/synth_net/synth_net/params/{param_dir}/' -path_to_act = f'{param_path}act.ckpt' -path_to_rt1 = f'{param_path}rt1.ckpt' -path_to_rxn = f'{param_path}rxn.ckpt' -path_to_rt2 = f'{param_path}rt2.ckpt' +param_path = f"/home/whgao/synth_net/synth_net/params/{param_dir}/" +path_to_act = f"{param_path}act.ckpt" +path_to_rt1 = f"{param_path}rt1.ckpt" +path_to_rxn = f"{param_path}rxn.ckpt" +path_to_rt2 = f"{param_path}rt2.ckpt" # load the purchasable building block SMILES to a dictionary -building_blocks = pd.read_csv(path_to_building_blocks, compression='gzip')['SMILES'].tolist() -bb_dict = {building_blocks[i]: i for i in range(len(building_blocks))} +building_blocks = pd.read_csv(path_to_building_blocks, compression="gzip")["SMILES"].tolist() +bb_dict = {building_blocks[i]: i for i in range(len(building_blocks))} # load the reaction templates as a ReactionSet object rxn_set = ReactionSet().load(path_to_reaction_file) -rxns = rxn_set.rxns +rxns = rxn_set.rxns # load the pre-trained modules act_net, rt1_net, rxn_net, rt2_net = load_modules_from_checkpoint( @@ -57,6 +66,7 @@ ncpu=ncpu, ) + def func(emb): """ Generates the synthetic tree for the input molecular embedding. @@ -70,27 +80,27 @@ def func(emb): """ emb = emb.reshape((1, -1)) try: - tree, action = synthetic_tree_decoder(z_target=emb, - building_blocks=building_blocks, - bb_dict=bb_dict, - reaction_templates=rxns, - mol_embedder=mol_embedder, - action_net=act_net, - reactant1_net=rt1_net, - rxn_net=rxn_net, - reactant2_net=rt2_net, - bb_emb=bb_emb, - rxn_template=rxn_template, - n_bits=nbits, - max_step=15) + tree, action = synthetic_tree_decoder( + z_target=emb, + building_blocks=building_blocks, + bb_dict=bb_dict, + reaction_templates=rxns, + mol_embedder=mol_embedder, + action_net=act_net, + reactant1_net=rt1_net, + rxn_net=rxn_net, + reactant2_net=rt2_net, + bb_emb=bb_emb, + rxn_template=rxn_template, + n_bits=nbits, + max_step=15, + ) except Exception as e: print(e) action = -1 if action != 3: return None, None else: - scores = np.array( - tanimoto_similarity(emb, [node.smiles for node in tree.chemicals]) - ) + scores = np.array(tanimoto_similarity(emb, [node.smiles for node in tree.chemicals])) max_score_idx = np.where(scores == np.max(scores))[0][0] return tree.chemicals[max_score_idx].smiles, tree diff --git a/scripts/optimize_ga.py b/scripts/optimize_ga.py index 0aad875a..4d3c6342 100644 --- a/scripts/optimize_ga.py +++ b/scripts/optimize_ga.py @@ -3,15 +3,17 @@ based on Therapeutic Data Commons (TDC) oracle functions. Uses a genetic algorithm to optimize embeddings before decoding. """ -from syn_net.utils.ga_utils import crossover, mutation +import json import multiprocessing as mp +import time + import numpy as np import pandas as pd -import time -import json +from tdc import Oracle + import scripts._mp_decode as decode +from syn_net.utils.ga_utils import crossover, mutation from syn_net.utils.predict_utils import mol_fp -from tdc import Oracle def dock_drd3(smi): @@ -25,16 +27,17 @@ def dock_drd3(smi): float: Predicted docking score against the DRD3 target. """ # define the oracle function from the TDC - _drd3 = Oracle(name = 'drd3_docking') + _drd3 = Oracle(name="drd3_docking") if smi is None: return 0.0 else: try: - return - _drd3(smi) + return -_drd3(smi) except: return 0.0 + def dock_7l11(smi): """ Returns the docking score for the 7L11 target. @@ -46,12 +49,12 @@ def dock_7l11(smi): float: Predicted docking score against the 7L11 target. """ # define the oracle function from the TDC - _7l11 = Oracle(name = '7l11_docking') + _7l11 = Oracle(name="7l11_docking") if smi is None: return 0.0 else: try: - return - _7l11(smi) + return -_7l11(smi) except: return 0.0 @@ -78,35 +81,35 @@ def fitness(embs, _pool, obj): embeddings. """ results = _pool.map(decode.func, embs) - smiles = [r[0] for r in results] - trees = [r[1] for r in results] + smiles = [r[0] for r in results] + trees = [r[1] for r in results] - if obj == 'qed': + if obj == "qed": # define the oracle function from the TDC - qed = Oracle(name = 'QED') + qed = Oracle(name="QED") scores = [qed(smi) for smi in smiles] - elif obj == 'logp': + elif obj == "logp": # define the oracle function from the TDC - logp = Oracle(name = 'LogP') + logp = Oracle(name="LogP") scores = [logp(smi) for smi in smiles] - elif obj == 'jnk': + elif obj == "jnk": # define the oracle function from the TDC - jnk = Oracle(name = 'JNK3') + jnk = Oracle(name="JNK3") scores = [jnk(smi) if smi is not None else 0.0 for smi in smiles] - elif obj == 'gsk': + elif obj == "gsk": # define the oracle function from the TDC - gsk = Oracle(name = 'GSK3B') + gsk = Oracle(name="GSK3B") scores = [gsk(smi) if smi is not None else 0.0 for smi in smiles] - elif obj == 'drd2': + elif obj == "drd2": # define the oracle function from the TDC - drd2 = Oracle(name = 'DRD2') + drd2 = Oracle(name="DRD2") scores = [drd2(smi) if smi is not None else 0.0 for smi in smiles] - elif obj == '7l11': + elif obj == "7l11": scores = [dock_7l11(smi) for smi in smiles] - elif obj == 'drd3': + elif obj == "drd3": scores = [dock_drd3(smi) for smi in smiles] else: - raise ValueError('Objective function not implemneted') + raise ValueError("Objective function not implemneted") return scores, smiles, trees @@ -122,10 +125,11 @@ def distribution_schedule(n, total): Returns: str: Describes a type of probability distribution. """ - if n < 4 * total/5: - return 'linear' + if n < 4 * total / 5: + return "linear" else: - return 'softmax_linear' + return "softmax_linear" + def num_mut_per_ele_scheduler(n, total): """ @@ -145,6 +149,7 @@ def num_mut_per_ele_scheduler(n, total): # return 512 return 24 + def mut_probability_scheduler(n, total): """ Determines the probability of mutating a vector, based on the number of elapsed @@ -157,38 +162,56 @@ def mut_probability_scheduler(n, total): Returns: float: The probability of mutation. """ - if n < total/2: + if n < total / 2: return 0.5 else: return 0.5 -if __name__ == '__main__': + +if __name__ == "__main__": import argparse + parser = argparse.ArgumentParser() - parser.add_argument("-i", "--input_file", type=str, default=None, - help="A file contains the starting mating pool.") - parser.add_argument("--objective", type=str, default="qed", - help="Objective function to optimize") - parser.add_argument("--radius", type=int, default=2, - help="Radius for Morgan fingerprint.") - parser.add_argument("--nbits", type=int, default=4096, - help="Number of Bits for Morgan fingerprint.") - parser.add_argument("--num_population", type=int, default=100, - help="Number of parents sets to keep.") - parser.add_argument("--num_offspring", type=int, default=300, - help="Number of offsprings to generate each iteration.") - parser.add_argument("--num_gen", type=int, default=30, - help="Number of generations to proceed.") - parser.add_argument("--ncpu", type=int, default=16, - help="Number of cpus") - parser.add_argument("--mut_probability", type=float, default=0.5, - help="Probability to mutate for one offspring.") - parser.add_argument("--num_mut_per_ele", type=int, default=1, - help="Number of bits to mutate in one fingerprint.") - parser.add_argument('--restart', action='store_true') - parser.add_argument("--seed", type=int, default=1, - help="Random seed.") + parser.add_argument( + "-i", + "--input_file", + type=str, + default=None, + help="A file contains the starting mating pool.", + ) + parser.add_argument( + "--objective", type=str, default="qed", help="Objective function to optimize" + ) + parser.add_argument("--radius", type=int, default=2, help="Radius for Morgan fingerprint.") + parser.add_argument( + "--nbits", type=int, default=4096, help="Number of Bits for Morgan fingerprint." + ) + parser.add_argument( + "--num_population", type=int, default=100, help="Number of parents sets to keep." + ) + parser.add_argument( + "--num_offspring", + type=int, + default=300, + help="Number of offsprings to generate each iteration.", + ) + parser.add_argument("--num_gen", type=int, default=30, help="Number of generations to proceed.") + parser.add_argument("--ncpu", type=int, default=16, help="Number of cpus") + parser.add_argument( + "--mut_probability", + type=float, + default=0.5, + help="Probability to mutate for one offspring.", + ) + parser.add_argument( + "--num_mut_per_ele", + type=int, + default=1, + help="Number of bits to mutate in one fingerprint.", + ) + parser.add_argument("--restart", action="store_true") + parser.add_argument("--seed", type=int, default=1, help="Random seed.") args = parser.parse_args() np.random.seed(args.seed) @@ -202,21 +225,17 @@ def mut_probability_scheduler(n, total): print(f"Starting with {args.num_population} fps with {args.nbits} bits") else: starting_smiles = pd.read_csv(args.input_file).sample(args.num_population) - starting_smiles = starting_smiles['smiles'].tolist() - population = np.array( - [mol_fp(smi, args.radius, args.nbits) for smi in starting_smiles] - ) + starting_smiles = starting_smiles["smiles"].tolist() + population = np.array([mol_fp(smi, args.radius, args.nbits) for smi in starting_smiles]) print(f"Starting with {len(starting_smiles)} fps from {args.input_file}") with mp.Pool(processes=args.ncpu) as pool: - scores, mols, trees = fitness(embs=population, - _pool=pool, - obj=args.objective) - scores = np.array(scores) - score_x = np.argsort(scores) + scores, mols, trees = fitness(embs=population, _pool=pool, obj=args.objective) + scores = np.array(scores) + score_x = np.argsort(scores) population = population[score_x[::-1]] - mols = [mols[i] for i in score_x[::-1]] - scores = scores[score_x[::-1]] + mols = [mols[i] for i in score_x[::-1]] + scores = scores[score_x[::-1]] print(f"Initial: {scores.mean():.3f} +/- {scores.std():.3f}") print(f"Scores: {scores}") print(f"Top-3 Smiles: {mols[:3]}") @@ -227,16 +246,18 @@ def mut_probability_scheduler(n, total): t = time.time() - dist_ = distribution_schedule(n, args.num_gen) + dist_ = distribution_schedule(n, args.num_gen) num_mut_per_ele_ = num_mut_per_ele_scheduler(n, args.num_gen) mut_probability_ = mut_probability_scheduler(n, args.num_gen) - offspring = crossover(parents=population, - offspring_size=args.num_offspring, - distribution=dist_) - offspring = mutation(offspring_crossover=offspring, - num_mut_per_ele=num_mut_per_ele_, - mut_probability=mut_probability_) + offspring = crossover( + parents=population, offspring_size=args.num_offspring, distribution=dist_ + ) + offspring = mutation( + offspring_crossover=offspring, + num_mut_per_ele=num_mut_per_ele_, + mut_probability=mut_probability_, + ) new_population = np.unique(np.concatenate([population, offspring], axis=0), axis=0) with mp.Pool(processes=args.ncpu) as pool: new_scores, new_mols, trees = fitness(new_population, pool, args.objective) @@ -272,28 +293,32 @@ def mut_probability_scheduler(n, total): if len(recent_scores) > 10: del recent_scores[0] - np.save('population_' + args.objective + '_' + str(n+1) + '.npy', population) - - data = {'objective': args.objective, - 'top1' : np.mean(scores[:1]), - 'top10' : np.mean(scores[:10]), - 'top100' : np.mean(scores[:100]), - 'smiles' : mols, - 'scores' : scores.tolist()} - with open('opt_' + args.objective + '.json', 'w') as f: + np.save("population_" + args.objective + "_" + str(n + 1) + ".npy", population) + + data = { + "objective": args.objective, + "top1": np.mean(scores[:1]), + "top10": np.mean(scores[:10]), + "top100": np.mean(scores[:100]), + "smiles": mols, + "scores": scores.tolist(), + } + with open("opt_" + args.objective + ".json", "w") as f: json.dump(data, f) if n > 30 and recent_scores[-1] - recent_scores[0] < 0.01: print("Early Stop!") break - data = {'objective': args.objective, - 'top1' : np.mean(scores[:1]), - 'top10' : np.mean(scores[:10]), - 'top100' : np.mean(scores[:100]), - 'smiles' : mols, - 'scores' : scores.tolist()} - with open('opt_' + args.objective + '.json', 'w') as f: + data = { + "objective": args.objective, + "top1": np.mean(scores[:1]), + "top10": np.mean(scores[:10]), + "top100": np.mean(scores[:100]), + "smiles": mols, + "scores": scores.tolist(), + } + with open("opt_" + args.objective + ".json", "w") as f: json.dump(data, f) - np.save('population_' + args.objective + '.npy', population) + np.save("population_" + args.objective + ".npy", population) From c0a526e64cd33656efd6e5fddacc206608d76e00 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Fri, 30 Sep 2022 15:21:25 -0400 Subject: [PATCH 258/302] format --- src/syn_net/config.py | 3 - .../data_generation/check_all_template.py | 46 +++++----- src/syn_net/encoding/fingerprints.py | 2 +- src/syn_net/utils/ga_utils.py | 92 ++++++++++--------- 4 files changed, 71 insertions(+), 72 deletions(-) diff --git a/src/syn_net/config.py b/src/syn_net/config.py index e6b65dc2..88f26d35 100644 --- a/src/syn_net/config.py +++ b/src/syn_net/config.py @@ -17,9 +17,6 @@ DATA_PREPROCESS_DIR = "data/pre-process" DATA_EMBEDDINGS_DIR = "data/pre-process/embeddings" -# Prepared data -DATA_PREPARED_DIR = "data/prepared" - # Prepared data DATA_FEATURIZED_DIR = "data/featurized" diff --git a/src/syn_net/data_generation/check_all_template.py b/src/syn_net/data_generation/check_all_template.py index 03d21e37..5542701a 100644 --- a/src/syn_net/data_generation/check_all_template.py +++ b/src/syn_net/data_generation/check_all_template.py @@ -3,13 +3,12 @@ templates. Originally written by Jake. Wenhao edited. """ import rdkit.Chem as Chem -from rdkit.Chem import AllChem -from rdkit.Chem import rdChemReactions from rdkit import RDLogger +from rdkit.Chem import AllChem, rdChemReactions def split_rxn_parts(rxn): - ''' + """ Given SMILES reaction, splits into reactants, agents, and products Args: @@ -17,11 +16,11 @@ def split_rxn_parts(rxn): Returns: list: Contains sets of reactants, agents, and products as RDKit molecules. - ''' - rxn_parts = rxn.strip().split('>') - rxn_reactants = set(rxn_parts[0].split('.')) - rxn_agents = None if not rxn_parts[1] else set(rxn_parts[1].split('.')) - rxn_products = set(rxn_parts[2].split('.')) + """ + rxn_parts = rxn.strip().split(">") + rxn_reactants = set(rxn_parts[0].split(".")) + rxn_agents = None if not rxn_parts[1] else set(rxn_parts[1].split(".")) + rxn_products = set(rxn_parts[2].split(".")) reactants, agents, products = set(), set(), set() @@ -42,7 +41,7 @@ def split_rxn_parts(rxn): def rxn_template(rxn_smiles, templates): - ''' + """ Given a reaction, checks whether it matches any templates. Args: @@ -51,7 +50,7 @@ def rxn_template(rxn_smiles, templates): Returns: str: Matching template name. If no templates matched, returns None. - ''' + """ rxn_parts = split_rxn_parts(rxn_smiles) reactants, agents, products = rxn_parts[0], rxn_parts[1], rxn_parts[2] temp_match = None @@ -92,7 +91,7 @@ def rxn_template(rxn_smiles, templates): def route_templates(route, templates): - ''' + """ Given synthesis route, checks whether all reaction steps are in template list Args: @@ -102,7 +101,7 @@ def route_templates(route, templates): Returns: List of matching template names (as strings). If no templates matched, returns empty list. - ''' + """ synth_route = [] tree_match = True for rxn_step in route: @@ -116,32 +115,33 @@ def route_templates(route, templates): return synth_route -if __name__ == '__main__': + +if __name__ == "__main__": disable_RDLogger = True # disables RDKit warnings if disable_RDLogger: - RDLogger.DisableLog('rdApp.*') + RDLogger.DisableLog("rdApp.*") - rxn_set_path = '/path/to/rxn_set.txt' + rxn_set_path = "/path/to/rxn_set.txt" - rxn_set = open(rxn_set_path, 'r') + rxn_set = open(rxn_set_path, "r") templates = {} for rxn in rxn_set: - rxn_name = rxn.split('|')[0] - template = rxn.split('|')[1].strip() + rxn_name = rxn.split("|")[0] + template = rxn.split("|")[1].strip() rdkit_rxn = AllChem.ReactionFromSmarts(template) rdChemReactions.ChemicalReaction.Initialize(rdkit_rxn) templates[rdkit_rxn] = rxn_name - rxn_smiles = 'ClCC1CO1.NC(=O)Cc1ccc(O)cc1>>NC(=O)Cc1ccc(OCC2CO2)cc1' + rxn_smiles = "ClCC1CO1.NC(=O)Cc1ccc(O)cc1>>NC(=O)Cc1ccc(OCC2CO2)cc1" print(rxn_smiles) print(rxn_template(rxn_smiles, templates)) - print('------------------------------------------------------') + print("------------------------------------------------------") synthesis_route = [ - 'C(CCc1ccccc1)N(Cc1ccccc1)CC(O)c1ccc(O)c(C(N)=O)c1>>CC(CCc1ccccc1)NCC(O)c1ccc(O)c(C(N)=O)c1', - 'CC(CCc1ccccc1)N(CC(=O)c1ccc(O)c(C(N)=O)c1)Cc1ccccc1>>CC(CCc1ccccc1)N(Cc1ccccc1)CC(O)c1ccc(O)c(C(N)=O)c1', - 'CC(CCc1ccccc1)NCc1ccccc1.NC(=O)c1cc(C(=O)CBr)ccc1O>>CC(CCc1ccccc1)N(CC(=O)c1ccc(O)c(C(N)=O)c1)Cc1ccccc1' + "C(CCc1ccccc1)N(Cc1ccccc1)CC(O)c1ccc(O)c(C(N)=O)c1>>CC(CCc1ccccc1)NCC(O)c1ccc(O)c(C(N)=O)c1", + "CC(CCc1ccccc1)N(CC(=O)c1ccc(O)c(C(N)=O)c1)Cc1ccccc1>>CC(CCc1ccccc1)N(Cc1ccccc1)CC(O)c1ccc(O)c(C(N)=O)c1", + "CC(CCc1ccccc1)NCc1ccccc1.NC(=O)c1cc(C(=O)CBr)ccc1O>>CC(CCc1ccccc1)N(CC(=O)c1ccc(O)c(C(N)=O)c1)Cc1ccccc1", ] print(synthesis_route) print(route_templates(synthesis_route, templates)) diff --git a/src/syn_net/encoding/fingerprints.py b/src/syn_net/encoding/fingerprints.py index 0a437fb1..66501af4 100644 --- a/src/syn_net/encoding/fingerprints.py +++ b/src/syn_net/encoding/fingerprints.py @@ -4,7 +4,7 @@ ## Morgan fingerprints -def mol_fp(smi, _radius=2, _nBits=4096) -> np.ndarray: # dtype=int64 +def mol_fp(smi, _radius=2, _nBits=4096) -> np.ndarray: # dtype=int64 """ Computes the Morgan fingerprint for the input SMILES. diff --git a/src/syn_net/utils/ga_utils.py b/src/syn_net/utils/ga_utils.py index 7e6cc11d..9ee2f4c4 100644 --- a/src/syn_net/utils/ga_utils.py +++ b/src/syn_net/utils/ga_utils.py @@ -5,7 +5,7 @@ import scipy -def crossover(parents, offspring_size, distribution='even'): +def crossover(parents, offspring_size, distribution="even"): """ A function that samples an offspring set through a crossover from a mating pool. @@ -24,54 +24,57 @@ def crossover(parents, offspring_size, distribution='even'): Returns: offspring (numpy.ndarray): An array which represents the offspring pool. """ - fp_length = parents.shape[1] - offspring = np.zeros((offspring_size, fp_length)) + fp_length = parents.shape[1] + offspring = np.zeros((offspring_size, fp_length)) inherit_num = np.ceil( - np.random.normal(loc=fp_length/2, scale=fp_length/10, size=(offspring_size, )) + np.random.normal(loc=fp_length / 2, scale=fp_length / 10, size=(offspring_size,)) ) inherit_num = np.where( - inherit_num >= int(fp_length/5) * np.ones((offspring_size, )), - inherit_num, int(fp_length/5) * np.ones((offspring_size, )) + inherit_num >= int(fp_length / 5) * np.ones((offspring_size,)), + inherit_num, + int(fp_length / 5) * np.ones((offspring_size,)), ) inherit_num = np.where( - int(fp_length*4/5) * np.ones((offspring_size, )) <= inherit_num, - int(fp_length*4/5) * np.ones((offspring_size, )), - inherit_num + int(fp_length * 4 / 5) * np.ones((offspring_size,)) <= inherit_num, + int(fp_length * 4 / 5) * np.ones((offspring_size,)), + inherit_num, ) for k in range(offspring_size): - parent1_idx = list(set(np.random.choice(fp_length, size=int(inherit_num[k]), replace=False))) + parent1_idx = list( + set(np.random.choice(fp_length, size=int(inherit_num[k]), replace=False)) + ) parent2_idx = list(set(range(fp_length)).difference(set(parent1_idx))) - if distribution == 'even': - parent_set = parents[np.random.choice(parents.shape[0], - size=2, - replace=False)] - elif distribution == 'linear': - p_ = np.arange(parents.shape[0])[::-1] + 10 - parent_set = parents[np.random.choice(parents.shape[0], - size=2, - replace=False, - p=p_/np.sum(p_))] - elif distribution == 'softmax_linear': - p_ = np.arange(parents.shape[0])[::-1] + 10 - parent_set = parents[np.random.choice(parents.shape[0], - size=2, - replace=False, - p=scipy.special.softmax(p_))] + if distribution == "even": + parent_set = parents[np.random.choice(parents.shape[0], size=2, replace=False)] + elif distribution == "linear": + p_ = np.arange(parents.shape[0])[::-1] + 10 + parent_set = parents[ + np.random.choice(parents.shape[0], size=2, replace=False, p=p_ / np.sum(p_)) + ] + elif distribution == "softmax_linear": + p_ = np.arange(parents.shape[0])[::-1] + 10 + parent_set = parents[ + np.random.choice( + parents.shape[0], size=2, replace=False, p=scipy.special.softmax(p_) + ) + ] offspring[k, parent1_idx] = parent_set[0][parent1_idx] offspring[k, parent2_idx] = parent_set[1][parent2_idx] return offspring + def fitness_sum(element): """ Test fitness function. """ return np.sum(element) + def mutation(offspring_crossover, num_mut_per_ele=1, mut_probability=0.5): """ A function that samples an offspring set through a crossover from a mating @@ -87,44 +90,43 @@ def mutation(offspring_crossover, num_mut_per_ele=1, mut_probability=0.5): offspring_crossover (numpy.ndarray): An array represents the offspring pool after mutation. """ - b_dict = {1:0, 0:1} + b_dict = {1: 0, 0: 1} fp_length = offspring_crossover.shape[1] mut_proba = np.random.random(offspring_crossover.shape[0]) for idx in range(offspring_crossover.shape[0]): # The random value to be added to the gene. if mut_proba[idx] <= mut_probability: - position = np.random.choice(fp_length, - size=int(num_mut_per_ele), - replace=False) - tmp = np.array([b_dict[int(_)] for _ in offspring_crossover[idx, position]]) + position = np.random.choice(fp_length, size=int(num_mut_per_ele), replace=False) + tmp = np.array([b_dict[int(_)] for _ in offspring_crossover[idx, position]]) offspring_crossover[idx, position] = tmp else: pass return offspring_crossover -if __name__ == '__main__': - num_parents = 10 - fp_size = 128 +if __name__ == "__main__": + + num_parents = 10 + fp_size = 128 offspring_size = 30 - ngen = 100 - population = np.ceil(np.random.random(size=(num_parents, fp_size)) * 2 - 1) + ngen = 100 + population = np.ceil(np.random.random(size=(num_parents, fp_size)) * 2 - 1) - print(f'Starting with {num_parents} fps with {fp_size} bits') + print(f"Starting with {num_parents} fps with {fp_size} bits") scores = np.array([fitness_sum(_) for _ in population]) - print(f'Initial: {scores.mean():.3f} +/- {scores.std():.3f}') - print(f'Scores: {scores}') + print(f"Initial: {scores.mean():.3f} +/- {scores.std():.3f}") + print(f"Scores: {scores}") for n in range(ngen): - offspring = crossover(population, offspring_size) - offspring = mutation(offspring, num_mut_per_ele=4, mut_probability=0.5) + offspring = crossover(population, offspring_size) + offspring = mutation(offspring, num_mut_per_ele=4, mut_probability=0.5) new_population = np.concatenate([population, offspring], axis=0) - new_scores = np.array(scores.tolist() + [fitness_sum(_) for _ in offspring]) - scores = [] + new_scores = np.array(scores.tolist() + [fitness_sum(_) for _ in offspring]) + scores = [] for parent_idx in range(num_parents): max_score_idx = np.where(new_scores == np.max(new_scores))[0][0] @@ -133,5 +135,5 @@ def mutation(offspring_crossover, num_mut_per_ele=1, mut_probability=0.5): new_scores[max_score_idx] = -999999 scores = np.array(scores) - print(f'Generation {ngen}: {scores.mean()} +/- {scores.std()}') - print(f'Scores: {scores}') + print(f"Generation {ngen}: {scores.mean()} +/- {scores.std()}") + print(f"Scores: {scores}") From 1328ec0287913c8ab49bb69a40666e8b7331a934 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Fri, 30 Sep 2022 15:30:12 -0400 Subject: [PATCH 259/302] clean up `config.py` --- scripts/20-predict-targets.py | 7 +------ src/syn_net/config.py | 10 +--------- 2 files changed, 2 insertions(+), 15 deletions(-) diff --git a/scripts/20-predict-targets.py b/scripts/20-predict-targets.py index eed677df..2ddaeb98 100644 --- a/scripts/20-predict-targets.py +++ b/scripts/20-predict-targets.py @@ -13,12 +13,7 @@ import numpy as np import pandas as pd -from syn_net.config import ( - CHECKPOINTS_DIR, - DATA_EMBEDDINGS_DIR, - DATA_PREPROCESS_DIR, - DATA_RESULT_DIR, -) +from syn_net.config import CHECKPOINTS_DIR, DATA_PREPROCESS_DIR, DATA_RESULT_DIR from syn_net.data_generation.preprocessing import BuildingBlockFileHandler from syn_net.models.chkpt_loader import load_modules_from_checkpoint from syn_net.utils.data_utils import ReactionSet, SyntheticTree, SyntheticTreeSet diff --git a/src/syn_net/config.py b/src/syn_net/config.py index 88f26d35..977f72b0 100644 --- a/src/syn_net/config.py +++ b/src/syn_net/config.py @@ -4,18 +4,10 @@ # Multiprocessing MAX_PROCESSES = min(32, multiprocessing.cpu_count()) - 1 -# TODO: Remove these paths bit by bit (not used except for decoing as of now) -# Paths -DATA_DIR = "data" -ASSETS_DIR = "data/assets" - -# -BUILDING_BLOCKS_RAW_DIR = f"{ASSETS_DIR}/building-blocks" -REACTION_TEMPLATE_DIR = f"{ASSETS_DIR}/reaction-templates" +# TODO: Remove these paths bit by bit # Pre-processed data DATA_PREPROCESS_DIR = "data/pre-process" -DATA_EMBEDDINGS_DIR = "data/pre-process/embeddings" # Prepared data DATA_FEATURIZED_DIR = "data/featurized" From a4a15d3e654ac9da7591fd1e7848ade42e31b708 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Fri, 30 Sep 2022 16:53:00 -0400 Subject: [PATCH 260/302] align code --- src/syn_net/models/act.py | 17 ++++++++--------- src/syn_net/models/common.py | 1 + src/syn_net/models/rt1.py | 15 +++++++-------- src/syn_net/models/rt2.py | 20 +++++++------------- src/syn_net/models/rxn.py | 17 +++++++---------- 5 files changed, 30 insertions(+), 40 deletions(-) diff --git a/src/syn_net/models/act.py b/src/syn_net/models/act.py index adc6ce35..7c1e4435 100644 --- a/src/syn_net/models/act.py +++ b/src/syn_net/models/act.py @@ -3,6 +3,7 @@ """ import logging from pathlib import Path +import json import pytorch_lightning as pl from pytorch_lightning import loggers as pl_loggers @@ -17,20 +18,19 @@ MODEL_ID = Path(__file__).stem if __name__ == "__main__": + logger.info("Start.") + # Parse input args args = get_args() + logger.info(f"Arguments: {json.dumps(vars(args),indent=2)}") - validation_option = VALIDATION_OPTS[args.out_dim] - - # Get ID for the data to know what we're working with and find right files. - id = ( - f"{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_{validation_option[12:]}/" - ) + pl.seed_everything(0) + # Set up dataloaders dataset = "train" train_dataloader = xy_to_dataloader( - X_file=Path(DATA_FEATURIZED_DIR) / f"{id}/X_{MODEL_ID}_{dataset}.npz", - y_file=Path(DATA_FEATURIZED_DIR) / f"{id}/y_{MODEL_ID}_{dataset}.npz", + X_file=Path(args.data_dir) / "X_{MODEL_ID}_{dataset}.npz", + y_file=Path(args.data_dir) / "y_{MODEL_ID}_{dataset}.npz", n=None if not args.debug else 1000, task="classification", batch_size=args.batch_size, @@ -50,7 +50,6 @@ ) logger.info(f"Set up dataloaders.") - pl.seed_everything(0) INPUT_DIMS = { "fp": int(3 * args.nbits), "gin": int(2 * args.nbits + args.out_dim), diff --git a/src/syn_net/models/common.py b/src/syn_net/models/common.py index 212ca647..2d323f0a 100644 --- a/src/syn_net/models/common.py +++ b/src/syn_net/models/common.py @@ -20,6 +20,7 @@ def get_args(): import argparse parser = argparse.ArgumentParser() + parser.add_argument("--data-dir",type=str,help="Directory with X,y data.") parser.add_argument( "-f", "--featurize", type=str, default="fp", help="Choose from ['fp', 'gin']" ) diff --git a/src/syn_net/models/rt1.py b/src/syn_net/models/rt1.py index f512abd2..5b45016a 100644 --- a/src/syn_net/models/rt1.py +++ b/src/syn_net/models/rt1.py @@ -28,20 +28,19 @@ def _fetch_molembedder(): if __name__ == "__main__": + logger.info("Start.") + # Parse input args args = get_args() + logger.info(f"Arguments: {json.dumps(vars(args),indent=2)}") - validation_option = VALIDATION_OPTS[args.out_dim] - - # Get ID for the data to know what we're working with and find right files. - id = ( - f"{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_{validation_option[12:]}/" - ) + pl.seed_everything(0) + # Set up dataloaders dataset = "train" train_dataloader = xy_to_dataloader( - X_file=Path(DATA_FEATURIZED_DIR) / f"{id}/X_{MODEL_ID}_{dataset}.npz", - y_file=Path(DATA_FEATURIZED_DIR) / f"{id}/y_{MODEL_ID}_{dataset}.npz", + X_file=Path(args.data_dir) / "X_{MODEL_ID}_{dataset}.npz", + y_file=Path(args.data_dir) / "y_{MODEL_ID}_{dataset}.npz", n=None if not args.debug else 1000, batch_size=args.batch_size, num_workers=args.ncpu, diff --git a/src/syn_net/models/rt2.py b/src/syn_net/models/rt2.py index b9df4a7a..15d8e93d 100644 --- a/src/syn_net/models/rt2.py +++ b/src/syn_net/models/rt2.py @@ -28,20 +28,19 @@ def _fetch_molembedder(): if __name__ == "__main__": + logger.info("Start.") + # Parse input args args = get_args() + logger.info(f"Arguments: {json.dumps(vars(args),indent=2)}") - validation_option = VALIDATION_OPTS[args.out_dim] - - # Get ID for the data to know what we're working with and find right files. - id = ( - f"{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_{validation_option[12:]}/" - ) + pl.seed_everything(0) + # Set up dataloaders dataset = "train" train_dataloader = xy_to_dataloader( - X_file=Path(DATA_FEATURIZED_DIR) / f"{id}/X_{MODEL_ID}_{dataset}.npz", - y_file=Path(DATA_FEATURIZED_DIR) / f"{id}/y_{MODEL_ID}_{dataset}.npz", + X_file=Path(args.data_dir) / "X_{MODEL_ID}_{dataset}.npz", + y_file=Path(args.data_dir) / "y_{MODEL_ID}_{dataset}.npz", n=None if not args.debug else 1000, batch_size=args.batch_size, num_workers=args.ncpu, @@ -58,11 +57,6 @@ def _fetch_molembedder(): shuffle=True if dataset == "train" else False, ) logger.info(f"Set up dataloaders.") - - # Fetch Molembedder and init BallTree - molembedder = _fetch_molembedder() - - pl.seed_everything(0) INPUT_DIMS = { "fp": { "hb": int(4 * args.nbits + 91), diff --git a/src/syn_net/models/rxn.py b/src/syn_net/models/rxn.py index 7bde26ee..6207a7a8 100644 --- a/src/syn_net/models/rxn.py +++ b/src/syn_net/models/rxn.py @@ -17,20 +17,19 @@ MODEL_ID = Path(__file__).stem if __name__ == "__main__": + logger.info("Start.") + # Parse input args args = get_args() + logger.info(f"Arguments: {json.dumps(vars(args),indent=2)}") - validation_option = VALIDATION_OPTS[args.out_dim] - - # Get ID for the data to know what we're working with and find right files. - id = ( - f"{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_{validation_option[12:]}/" - ) + pl.seed_everything(0) + # Set up dataloaders dataset = "train" train_dataloader = xy_to_dataloader( - X_file=Path(DATA_FEATURIZED_DIR) / f"{id}/X_{MODEL_ID}_{dataset}.npz", - y_file=Path(DATA_FEATURIZED_DIR) / f"{id}/y_{MODEL_ID}_{dataset}.npz", + X_file=Path(args.data_dir) / "X_{MODEL_ID}_{dataset}.npz", + y_file=Path(args.data_dir) / "y_{MODEL_ID}_{dataset}.npz", n=None if not args.debug else 1000, task="classification", batch_size=args.batch_size, @@ -49,8 +48,6 @@ shuffle=True if dataset == "train" else False, ) logger.info(f"Set up dataloaders.") - - pl.seed_everything(0) param_path = ( Path(CHECKPOINTS_DIR) / f"{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_v{args.version}/" From 9e506907470f5e6c6039c17ed43a8e29fbc93425 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Tue, 4 Oct 2022 07:01:49 -0400 Subject: [PATCH 261/302] adds `CSVLogger` --- src/syn_net/models/act.py | 8 ++------ src/syn_net/models/rt1.py | 8 +++----- src/syn_net/models/rt2.py | 10 +++++----- src/syn_net/models/rxn.py | 11 +++++------ 4 files changed, 15 insertions(+), 22 deletions(-) diff --git a/src/syn_net/models/act.py b/src/syn_net/models/act.py index 7c1e4435..74f7b631 100644 --- a/src/syn_net/models/act.py +++ b/src/syn_net/models/act.py @@ -74,15 +74,11 @@ ) # Set up Trainer - save_dir = Path( - "results/logs/" - + f"{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}" - + f"/{MODEL_ID}" - ) + save_dir = Path("results/logs/") / MODEL_ID save_dir.mkdir(exist_ok=True, parents=True) tb_logger = pl_loggers.TensorBoardLogger(save_dir, name="") - csv_logger = pl_loggers.CSVLogger(save_dir, name="") + csv_logger = pl_loggers.CSVLogger(tb_logger.log_dir, name="",version="") logger.info(f"Log dir set to: {tb_logger.log_dir}") checkpoint_callback = ModelCheckpoint( diff --git a/src/syn_net/models/rt1.py b/src/syn_net/models/rt1.py index 5b45016a..527f08f3 100644 --- a/src/syn_net/models/rt1.py +++ b/src/syn_net/models/rt1.py @@ -87,14 +87,12 @@ def _fetch_molembedder(): ) # Set up Trainer - save_dir = Path( - "results/logs/" - + f"{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}" - + f"/{MODEL_ID}" - ) + save_dir = Path("results/logs/") / MODEL_ID save_dir.mkdir(exist_ok=True, parents=True) tb_logger = pl_loggers.TensorBoardLogger(save_dir, name="") + csv_logger = pl_loggers.CSVLogger(tb_logger.log_dir, name="",version="") + logger.info(f"Log dir set to: {tb_logger.log_dir}") checkpoint_callback = ModelCheckpoint( monitor="val_loss", diff --git a/src/syn_net/models/rt2.py b/src/syn_net/models/rt2.py index 15d8e93d..6800f64a 100644 --- a/src/syn_net/models/rt2.py +++ b/src/syn_net/models/rt2.py @@ -87,13 +87,13 @@ def _fetch_molembedder(): ) # Set up Trainer - save_dir = Path( - "results/logs/" - + f"{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}" - + f"/{MODEL_ID}" - ) + save_dir = Path("results/logs/") / MODEL_ID save_dir.mkdir(exist_ok=True, parents=True) + tb_logger = pl_loggers.TensorBoardLogger(save_dir, name="") + csv_logger = pl_loggers.CSVLogger(tb_logger.log_dir, name="",version="") + logger.info(f"Log dir set to: {tb_logger.log_dir}") + tb_logger = pl_loggers.TensorBoardLogger(save_dir, name="") checkpoint_callback = ModelCheckpoint( diff --git a/src/syn_net/models/rxn.py b/src/syn_net/models/rxn.py index 6207a7a8..29cbc175 100644 --- a/src/syn_net/models/rxn.py +++ b/src/syn_net/models/rxn.py @@ -119,14 +119,13 @@ ) # Set up Trainer - # Set up Trainer - save_dir = Path( - "results/logs/" - + f"{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}" - + f"/{MODEL_ID}" - ) + save_dir = Path("results/logs/") / MODEL_ID save_dir.mkdir(exist_ok=True, parents=True) + tb_logger = pl_loggers.TensorBoardLogger(save_dir, name="") + csv_logger = pl_loggers.CSVLogger(tb_logger.log_dir, name="",version="") + logger.info(f"Log dir set to: {tb_logger.log_dir}") + tb_logger = pl_loggers.TensorBoardLogger(save_dir, name="") checkpoint_callback = ModelCheckpoint( From bbd67fe0c6a22723265b61a15b39deee6c18de8b Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Tue, 4 Oct 2022 07:05:01 -0400 Subject: [PATCH 262/302] fix imports + fstring --- src/syn_net/models/act.py | 20 ++++++++++---------- src/syn_net/models/rt1.py | 15 ++++++++------- src/syn_net/models/rt2.py | 15 ++++++++------- src/syn_net/models/rxn.py | 15 ++++++++------- 4 files changed, 34 insertions(+), 31 deletions(-) diff --git a/src/syn_net/models/act.py b/src/syn_net/models/act.py index 74f7b631..e15eac25 100644 --- a/src/syn_net/models/act.py +++ b/src/syn_net/models/act.py @@ -1,18 +1,18 @@ """ Action network. """ +import json import logging from pathlib import Path -import json import pytorch_lightning as pl from pytorch_lightning import loggers as pl_loggers from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint +from pytorch_lightning.callbacks.progress import TQDMProgressBar -from syn_net.config import DATA_FEATURIZED_DIR -from syn_net.models.common import VALIDATION_OPTS, get_args, xy_to_dataloader -from syn_net.models.mlp import MLP, load_array +from syn_net.models.common import get_args, xy_to_dataloader +from syn_net.models.mlp import MLP logger = logging.getLogger(__name__) MODEL_ID = Path(__file__).stem @@ -29,9 +29,9 @@ # Set up dataloaders dataset = "train" train_dataloader = xy_to_dataloader( - X_file=Path(args.data_dir) / "X_{MODEL_ID}_{dataset}.npz", - y_file=Path(args.data_dir) / "y_{MODEL_ID}_{dataset}.npz", - n=None if not args.debug else 1000, + X_file=Path(args.data_dir) / f"X_{MODEL_ID}_{dataset}.npz", + y_file=Path(args.data_dir) / f"y_{MODEL_ID}_{dataset}.npz", + n=None if not args.debug else 128, task="classification", batch_size=args.batch_size, num_workers=args.ncpu, @@ -40,9 +40,9 @@ dataset = "valid" valid_dataloader = xy_to_dataloader( - X_file=Path(DATA_FEATURIZED_DIR) / f"{id}/X_{MODEL_ID}_{dataset}.npz", - y_file=Path(DATA_FEATURIZED_DIR) / f"{id}/y_{MODEL_ID}_{dataset}.npz", - n=None if not args.debug else 1000, + X_file=Path(args.data_dir) / f"X_{MODEL_ID}_{dataset}.npz", + y_file=Path(args.data_dir) / f"y_{MODEL_ID}_{dataset}.npz", + n=None if not args.debug else 128, task="classification", batch_size=args.batch_size, num_workers=args.ncpu, diff --git a/src/syn_net/models/rt1.py b/src/syn_net/models/rt1.py index 527f08f3..c806e9a0 100644 --- a/src/syn_net/models/rt1.py +++ b/src/syn_net/models/rt1.py @@ -1,6 +1,7 @@ """ Reactant1 network (for predicting 1st reactant). """ +import json import logging from pathlib import Path @@ -9,8 +10,8 @@ from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint -from syn_net.config import DATA_EMBEDDINGS_DIR, DATA_FEATURIZED_DIR -from syn_net.models.common import VALIDATION_OPTS, get_args, xy_to_dataloader +from syn_net.config import DATA_EMBEDDINGS_DIR +from syn_net.models.common import get_args, xy_to_dataloader from syn_net.models.mlp import MLP, cosine_distance from syn_net.MolEmbedder import MolEmbedder @@ -39,8 +40,8 @@ def _fetch_molembedder(): # Set up dataloaders dataset = "train" train_dataloader = xy_to_dataloader( - X_file=Path(args.data_dir) / "X_{MODEL_ID}_{dataset}.npz", - y_file=Path(args.data_dir) / "y_{MODEL_ID}_{dataset}.npz", + X_file=Path(args.data_dir) / f"X_{MODEL_ID}_{dataset}.npz", + y_file=Path(args.data_dir) / f"y_{MODEL_ID}_{dataset}.npz", n=None if not args.debug else 1000, batch_size=args.batch_size, num_workers=args.ncpu, @@ -49,8 +50,8 @@ def _fetch_molembedder(): dataset = "valid" valid_dataloader = xy_to_dataloader( - X_file=Path(DATA_FEATURIZED_DIR) / f"{id}/X_{MODEL_ID}_{dataset}.npz", - y_file=Path(DATA_FEATURIZED_DIR) / f"{id}/y_{MODEL_ID}_{dataset}.npz", + X_file=Path(args.data_dir) / f"X_{MODEL_ID}_{dataset}.npz", + y_file=Path(args.data_dir) / f"y_{MODEL_ID}_{dataset}.npz", n=None if not args.debug else 1000, batch_size=args.batch_size, num_workers=args.ncpu, @@ -91,7 +92,7 @@ def _fetch_molembedder(): save_dir.mkdir(exist_ok=True, parents=True) tb_logger = pl_loggers.TensorBoardLogger(save_dir, name="") - csv_logger = pl_loggers.CSVLogger(tb_logger.log_dir, name="",version="") + csv_logger = pl_loggers.CSVLogger(tb_logger.log_dir, name="", version="") logger.info(f"Log dir set to: {tb_logger.log_dir}") checkpoint_callback = ModelCheckpoint( diff --git a/src/syn_net/models/rt2.py b/src/syn_net/models/rt2.py index 6800f64a..072f6a01 100644 --- a/src/syn_net/models/rt2.py +++ b/src/syn_net/models/rt2.py @@ -1,6 +1,7 @@ """ Reactant2 network (for predicting 2nd reactant). """ +import json import logging from pathlib import Path @@ -9,8 +10,8 @@ from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint -from syn_net.config import DATA_EMBEDDINGS_DIR, DATA_FEATURIZED_DIR -from syn_net.models.common import VALIDATION_OPTS, get_args, xy_to_dataloader +from syn_net.config import DATA_EMBEDDINGS_DIR +from syn_net.models.common import get_args, xy_to_dataloader from syn_net.models.mlp import MLP, cosine_distance from syn_net.MolEmbedder import MolEmbedder @@ -39,8 +40,8 @@ def _fetch_molembedder(): # Set up dataloaders dataset = "train" train_dataloader = xy_to_dataloader( - X_file=Path(args.data_dir) / "X_{MODEL_ID}_{dataset}.npz", - y_file=Path(args.data_dir) / "y_{MODEL_ID}_{dataset}.npz", + X_file=Path(args.data_dir) / f"X_{MODEL_ID}_{dataset}.npz", + y_file=Path(args.data_dir) / f"y_{MODEL_ID}_{dataset}.npz", n=None if not args.debug else 1000, batch_size=args.batch_size, num_workers=args.ncpu, @@ -49,8 +50,8 @@ def _fetch_molembedder(): dataset = "valid" valid_dataloader = xy_to_dataloader( - X_file=Path(DATA_FEATURIZED_DIR) / f"{id}/X_{MODEL_ID}_{dataset}.npz", - y_file=Path(DATA_FEATURIZED_DIR) / f"{id}/y_{MODEL_ID}_{dataset}.npz", + X_file=Path(args.data_dir) / f"X_{MODEL_ID}_{dataset}.npz", + y_file=Path(args.data_dir) / f"y_{MODEL_ID}_{dataset}.npz", n=None if not args.debug else 1000, batch_size=args.batch_size, num_workers=args.ncpu, @@ -91,7 +92,7 @@ def _fetch_molembedder(): save_dir.mkdir(exist_ok=True, parents=True) tb_logger = pl_loggers.TensorBoardLogger(save_dir, name="") - csv_logger = pl_loggers.CSVLogger(tb_logger.log_dir, name="",version="") + csv_logger = pl_loggers.CSVLogger(tb_logger.log_dir, name="", version="") logger.info(f"Log dir set to: {tb_logger.log_dir}") tb_logger = pl_loggers.TensorBoardLogger(save_dir, name="") diff --git a/src/syn_net/models/rxn.py b/src/syn_net/models/rxn.py index 29cbc175..57727b00 100644 --- a/src/syn_net/models/rxn.py +++ b/src/syn_net/models/rxn.py @@ -1,6 +1,7 @@ """ Reaction network. """ +import json import logging from pathlib import Path @@ -9,8 +10,8 @@ from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint -from syn_net.config import CHECKPOINTS_DIR, DATA_FEATURIZED_DIR -from syn_net.models.common import VALIDATION_OPTS, get_args, xy_to_dataloader +from syn_net.config import CHECKPOINTS_DIR +from syn_net.models.common import get_args, xy_to_dataloader from syn_net.models.mlp import MLP logger = logging.getLogger(__name__) @@ -28,8 +29,8 @@ # Set up dataloaders dataset = "train" train_dataloader = xy_to_dataloader( - X_file=Path(args.data_dir) / "X_{MODEL_ID}_{dataset}.npz", - y_file=Path(args.data_dir) / "y_{MODEL_ID}_{dataset}.npz", + X_file=Path(args.data_dir) / f"X_{MODEL_ID}_{dataset}.npz", + y_file=Path(args.data_dir) / f"y_{MODEL_ID}_{dataset}.npz", n=None if not args.debug else 1000, task="classification", batch_size=args.batch_size, @@ -39,8 +40,8 @@ dataset = "valid" valid_dataloader = xy_to_dataloader( - X_file=Path(DATA_FEATURIZED_DIR) / f"{id}/X_{MODEL_ID}_{dataset}.npz", - y_file=Path(DATA_FEATURIZED_DIR) / f"{id}/y_{MODEL_ID}_{dataset}.npz", + X_file=Path(args.data_dir) / f"X_{MODEL_ID}_{dataset}.npz", + y_file=Path(args.data_dir) / f"y_{MODEL_ID}_{dataset}.npz", n=None if not args.debug else 1000, task="classification", batch_size=args.batch_size, @@ -123,7 +124,7 @@ save_dir.mkdir(exist_ok=True, parents=True) tb_logger = pl_loggers.TensorBoardLogger(save_dir, name="") - csv_logger = pl_loggers.CSVLogger(tb_logger.log_dir, name="",version="") + csv_logger = pl_loggers.CSVLogger(tb_logger.log_dir, name="", version="") logger.info(f"Log dir set to: {tb_logger.log_dir}") tb_logger = pl_loggers.TensorBoardLogger(save_dir, name="") From 54574138a729865b4e4f4f9d43ba82b5c4bf0b44 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Tue, 4 Oct 2022 08:05:58 -0400 Subject: [PATCH 263/302] wip fix: do not apply softmax for classification tasks `nn.functional.cross_entropy` takes raw logits, so we do not need to apply softmax in the last layer as well --- src/syn_net/models/act.py | 2 +- src/syn_net/models/rxn.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/syn_net/models/act.py b/src/syn_net/models/act.py index e15eac25..267504f0 100644 --- a/src/syn_net/models/act.py +++ b/src/syn_net/models/act.py @@ -64,7 +64,7 @@ num_layers=5, dropout=0.5, num_dropout_layers=1, - task="classification", + task="classification-w/o-softmax", loss="cross_entropy", valid_loss="accuracy", optimizer="adam", diff --git a/src/syn_net/models/rxn.py b/src/syn_net/models/rxn.py index 57727b00..54600caa 100644 --- a/src/syn_net/models/rxn.py +++ b/src/syn_net/models/rxn.py @@ -93,7 +93,7 @@ num_layers=5, dropout=0.5, num_dropout_layers=1, - task="classification", + task="classification-w/o-softmax", loss="cross_entropy", valid_loss="accuracy", optimizer="adam", From 59f325c8c44ede2cd2252270396636e8ac8efbb5 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Tue, 4 Oct 2022 08:14:22 -0400 Subject: [PATCH 264/302] align params and train callbacks for networks --- src/syn_net/models/act.py | 13 ++++++------- src/syn_net/models/common.py | 4 +++- src/syn_net/models/rt1.py | 21 +++++++++------------ src/syn_net/models/rt2.py | 24 ++++++++++++------------ src/syn_net/models/rxn.py | 10 +++++----- 5 files changed, 35 insertions(+), 37 deletions(-) diff --git a/src/syn_net/models/act.py b/src/syn_net/models/act.py index 267504f0..9de0e27d 100644 --- a/src/syn_net/models/act.py +++ b/src/syn_net/models/act.py @@ -1,5 +1,4 @@ -""" -Action network. +"""Action network. """ import json import logging @@ -68,7 +67,7 @@ loss="cross_entropy", valid_loss="accuracy", optimizer="adam", - learning_rate=1e-4, + learning_rate=3e-4, val_freq=10, ncpu=args.ncpu, ) @@ -78,7 +77,7 @@ save_dir.mkdir(exist_ok=True, parents=True) tb_logger = pl_loggers.TensorBoardLogger(save_dir, name="") - csv_logger = pl_loggers.CSVLogger(tb_logger.log_dir, name="",version="") + csv_logger = pl_loggers.CSVLogger(tb_logger.log_dir, name="", version="") logger.info(f"Log dir set to: {tb_logger.log_dir}") checkpoint_callback = ModelCheckpoint( @@ -88,14 +87,14 @@ save_weights_only=False, ) earlystop_callback = EarlyStopping(monitor="val_loss", patience=10) + tqdm_callback = TQDMProgressBar(refresh_rate=int(len(train_dataloader) * 0.05)) - max_epochs = args.epoch if not args.debug else 20 + max_epochs = args.epoch if not args.debug else 1000 # Create trainer trainer = pl.Trainer( gpus=[0], max_epochs=max_epochs, - progress_bar_refresh_rate=int(len(train_dataloader) * 0.05), - callbacks=[checkpoint_callback], + callbacks=[checkpoint_callback, tqdm_callback], logger=[tb_logger, csv_logger], fast_dev_run=args.fast_dev_run, ) diff --git a/src/syn_net/models/common.py b/src/syn_net/models/common.py index 2d323f0a..7fb1265a 100644 --- a/src/syn_net/models/common.py +++ b/src/syn_net/models/common.py @@ -20,7 +20,9 @@ def get_args(): import argparse parser = argparse.ArgumentParser() - parser.add_argument("--data-dir",type=str,help="Directory with X,y data.") + parser.add_argument( + "--data-dir", type=str, default="data/featurized/Xy", help="Directory with X,y data." + ) parser.add_argument( "-f", "--featurize", type=str, default="fp", help="Choose from ['fp', 'gin']" ) diff --git a/src/syn_net/models/rt1.py b/src/syn_net/models/rt1.py index c806e9a0..880c6c08 100644 --- a/src/syn_net/models/rt1.py +++ b/src/syn_net/models/rt1.py @@ -1,5 +1,4 @@ -""" -Reactant1 network (for predicting 1st reactant). +"""Reactant1 network (for predicting 1st reactant). """ import json import logging @@ -10,7 +9,6 @@ from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint -from syn_net.config import DATA_EMBEDDINGS_DIR from syn_net.models.common import get_args, xy_to_dataloader from syn_net.models.mlp import MLP, cosine_distance from syn_net.MolEmbedder import MolEmbedder @@ -20,8 +18,7 @@ def _fetch_molembedder(): - knn_embedding_id = validation_option[12:] - file = Path(DATA_EMBEDDINGS_DIR) / f"hb-enamine_us-2021-smiles-{knn_embedding_id}.npy" + file = args.mol_embedder_file logger.info(f"Try to load precomputed MolEmbedder from {file}.") molembedder = MolEmbedder().load_precomputed(file).init_balltree(metric=cosine_distance) logger.info(f"Loaded MolEmbedder from {file}.") @@ -42,7 +39,7 @@ def _fetch_molembedder(): train_dataloader = xy_to_dataloader( X_file=Path(args.data_dir) / f"X_{MODEL_ID}_{dataset}.npz", y_file=Path(args.data_dir) / f"y_{MODEL_ID}_{dataset}.npz", - n=None if not args.debug else 1000, + n=None if not args.debug else 128, batch_size=args.batch_size, num_workers=args.ncpu, shuffle=True if dataset == "train" else False, @@ -52,17 +49,17 @@ def _fetch_molembedder(): valid_dataloader = xy_to_dataloader( X_file=Path(args.data_dir) / f"X_{MODEL_ID}_{dataset}.npz", y_file=Path(args.data_dir) / f"y_{MODEL_ID}_{dataset}.npz", - n=None if not args.debug else 1000, + n=None if not args.debug else 128, batch_size=args.batch_size, num_workers=args.ncpu, shuffle=True if dataset == "train" else False, ) + logger.info(f"Set up dataloaders.") # Fetch Molembedder and init BallTree - molembedder = _fetch_molembedder() + molembedder = None # _fetch_molembedder() - pl.seed_everything(0) INPUT_DIMS = { "fp": int(3 * args.nbits), "gin": int(2 * args.nbits + args.out_dim), @@ -81,7 +78,7 @@ def _fetch_molembedder(): loss="mse", valid_loss="mse", optimizer="adam", - learning_rate=1e-4, + learning_rate=3e-4, val_freq=10, molembedder=molembedder, ncpu=args.ncpu, @@ -103,14 +100,14 @@ def _fetch_molembedder(): ) earlystop_callback = EarlyStopping(monitor="val_loss", patience=10) - max_epochs = args.epoch if not args.debug else 2 + max_epochs = args.epoch if not args.debug else 300 # Create trainer trainer = pl.Trainer( gpus=[0], max_epochs=max_epochs, progress_bar_refresh_rate=int(len(train_dataloader) * 0.05), callbacks=[checkpoint_callback], - logger=[tb_logger], + logger=[tb_logger, csv_logger], ) logger.info(f"Start training") diff --git a/src/syn_net/models/rt2.py b/src/syn_net/models/rt2.py index 072f6a01..30f5140f 100644 --- a/src/syn_net/models/rt2.py +++ b/src/syn_net/models/rt2.py @@ -1,5 +1,4 @@ -""" -Reactant2 network (for predicting 2nd reactant). +"""Reactant2 network (for predicting 2nd reactant). """ import json import logging @@ -10,7 +9,6 @@ from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint -from syn_net.config import DATA_EMBEDDINGS_DIR from syn_net.models.common import get_args, xy_to_dataloader from syn_net.models.mlp import MLP, cosine_distance from syn_net.MolEmbedder import MolEmbedder @@ -20,8 +18,7 @@ def _fetch_molembedder(): - knn_embedding_id = validation_option[12:] - file = Path(DATA_EMBEDDINGS_DIR) / f"hb-enamine_us-2021-smiles-{knn_embedding_id}.npy" + file = args.mol_embedder_file logger.info(f"Try to load precomputed MolEmbedder from {file}.") molembedder = MolEmbedder().load_precomputed(file).init_balltree(metric=cosine_distance) logger.info(f"Loaded MolEmbedder from {file}.") @@ -42,7 +39,7 @@ def _fetch_molembedder(): train_dataloader = xy_to_dataloader( X_file=Path(args.data_dir) / f"X_{MODEL_ID}_{dataset}.npz", y_file=Path(args.data_dir) / f"y_{MODEL_ID}_{dataset}.npz", - n=None if not args.debug else 1000, + n=None if not args.debug else 128, batch_size=args.batch_size, num_workers=args.ncpu, shuffle=True if dataset == "train" else False, @@ -52,12 +49,17 @@ def _fetch_molembedder(): valid_dataloader = xy_to_dataloader( X_file=Path(args.data_dir) / f"X_{MODEL_ID}_{dataset}.npz", y_file=Path(args.data_dir) / f"y_{MODEL_ID}_{dataset}.npz", - n=None if not args.debug else 1000, + n=None if not args.debug else 128, batch_size=args.batch_size, num_workers=args.ncpu, shuffle=True if dataset == "train" else False, ) + logger.info(f"Set up dataloaders.") + + # Fetch Molembedder and init BallTree + molembedder = None # _fetch_molembedder() + INPUT_DIMS = { "fp": { "hb": int(4 * args.nbits + 91), @@ -81,7 +83,7 @@ def _fetch_molembedder(): loss="mse", valid_loss="mse", optimizer="adam", - learning_rate=1e-4, + learning_rate=3e-4, val_freq=10, molembedder=molembedder, ncpu=args.ncpu, @@ -95,8 +97,6 @@ def _fetch_molembedder(): csv_logger = pl_loggers.CSVLogger(tb_logger.log_dir, name="", version="") logger.info(f"Log dir set to: {tb_logger.log_dir}") - tb_logger = pl_loggers.TensorBoardLogger(save_dir, name="") - checkpoint_callback = ModelCheckpoint( monitor="val_loss", dirpath=tb_logger.log_dir, @@ -105,14 +105,14 @@ def _fetch_molembedder(): ) earlystop_callback = EarlyStopping(monitor="val_loss", patience=10) - max_epochs = args.epoch if not args.debug else 2 + max_epochs = args.epoch if not args.debug else 300 # Create trainer trainer = pl.Trainer( gpus=[0], max_epochs=max_epochs, progress_bar_refresh_rate=int(len(train_dataloader) * 0.05), callbacks=[checkpoint_callback], - logger=[tb_logger], + logger=[tb_logger, csv_logger], ) logger.info(f"Start training") diff --git a/src/syn_net/models/rxn.py b/src/syn_net/models/rxn.py index 54600caa..25528e72 100644 --- a/src/syn_net/models/rxn.py +++ b/src/syn_net/models/rxn.py @@ -31,7 +31,7 @@ train_dataloader = xy_to_dataloader( X_file=Path(args.data_dir) / f"X_{MODEL_ID}_{dataset}.npz", y_file=Path(args.data_dir) / f"y_{MODEL_ID}_{dataset}.npz", - n=None if not args.debug else 1000, + n=None if not args.debug else 128, task="classification", batch_size=args.batch_size, num_workers=args.ncpu, @@ -42,7 +42,7 @@ valid_dataloader = xy_to_dataloader( X_file=Path(args.data_dir) / f"X_{MODEL_ID}_{dataset}.npz", y_file=Path(args.data_dir) / f"y_{MODEL_ID}_{dataset}.npz", - n=None if not args.debug else 1000, + n=None if not args.debug else 128, task="classification", batch_size=args.batch_size, num_workers=args.ncpu, @@ -97,7 +97,7 @@ loss="cross_entropy", valid_loss="accuracy", optimizer="adam", - learning_rate=1e-4, + learning_rate=3e-4, val_freq=10, ncpu=args.ncpu, ) @@ -137,14 +137,14 @@ ) earlystop_callback = EarlyStopping(monitor="val_loss", patience=10) - max_epochs = args.epoch if not args.debug else 2 + max_epochs = args.epoch if not args.debug else 300 # Create trainer trainer = pl.Trainer( gpus=[0], max_epochs=max_epochs, progress_bar_refresh_rate=int(len(train_dataloader) * 0.05), callbacks=[checkpoint_callback], - logger=[tb_logger], + logger=[tb_logger, csv_logger], ) logger.info(f"Start training") From bc583d009c53ba0bd5406b21ec5a1007d1fe744a Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 5 Oct 2022 02:39:58 -0400 Subject: [PATCH 265/302] update --- .gitignore | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index f1e406b5..9b6cec10 100644 --- a/.gitignore +++ b/.gitignore @@ -187,4 +187,5 @@ temp.py .notes/ .aliases figures/ -*.html \ No newline at end of file +*.html +*.data*/ \ No newline at end of file From 2a491012a7e0551e66fdb5ee57bc1a3c6d82eb81 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 5 Oct 2022 05:11:29 -0400 Subject: [PATCH 266/302] use TQDMProgressBar --- src/syn_net/models/act.py | 4 ++-- src/syn_net/models/rt1.py | 10 ++++++---- src/syn_net/models/rt2.py | 10 ++++++---- src/syn_net/models/rxn.py | 10 ++++++---- 4 files changed, 20 insertions(+), 14 deletions(-) diff --git a/src/syn_net/models/act.py b/src/syn_net/models/act.py index 9de0e27d..02a54414 100644 --- a/src/syn_net/models/act.py +++ b/src/syn_net/models/act.py @@ -86,10 +86,10 @@ filename="ckpts.{epoch}-{val_loss:.2f}", save_weights_only=False, ) - earlystop_callback = EarlyStopping(monitor="val_loss", patience=10) + earlystop_callback = EarlyStopping(monitor="val_loss", patience=3) tqdm_callback = TQDMProgressBar(refresh_rate=int(len(train_dataloader) * 0.05)) - max_epochs = args.epoch if not args.debug else 1000 + max_epochs = args.epoch if not args.debug else 100 # Create trainer trainer = pl.Trainer( gpus=[0], diff --git a/src/syn_net/models/rt1.py b/src/syn_net/models/rt1.py index 880c6c08..6eeed0b0 100644 --- a/src/syn_net/models/rt1.py +++ b/src/syn_net/models/rt1.py @@ -8,6 +8,7 @@ from pytorch_lightning import loggers as pl_loggers from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint +from pytorch_lightning.callbacks.progress import TQDMProgressBar from syn_net.models.common import get_args, xy_to_dataloader from syn_net.models.mlp import MLP, cosine_distance @@ -98,16 +99,17 @@ def _fetch_molembedder(): filename="ckpts.{epoch}-{val_loss:.2f}", save_weights_only=False, ) - earlystop_callback = EarlyStopping(monitor="val_loss", patience=10) + earlystop_callback = EarlyStopping(monitor="val_loss", patience=3) + tqdm_callback = TQDMProgressBar(refresh_rate=int(len(train_dataloader) * 0.05)) - max_epochs = args.epoch if not args.debug else 300 + max_epochs = args.epoch if not args.debug else 100 # Create trainer trainer = pl.Trainer( gpus=[0], max_epochs=max_epochs, - progress_bar_refresh_rate=int(len(train_dataloader) * 0.05), - callbacks=[checkpoint_callback], + callbacks=[checkpoint_callback, tqdm_callback], logger=[tb_logger, csv_logger], + fast_dev_run=args.fast_dev_run, ) logger.info(f"Start training") diff --git a/src/syn_net/models/rt2.py b/src/syn_net/models/rt2.py index 30f5140f..91b55c5d 100644 --- a/src/syn_net/models/rt2.py +++ b/src/syn_net/models/rt2.py @@ -8,6 +8,7 @@ from pytorch_lightning import loggers as pl_loggers from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint +from pytorch_lightning.callbacks.progress import TQDMProgressBar from syn_net.models.common import get_args, xy_to_dataloader from syn_net.models.mlp import MLP, cosine_distance @@ -103,16 +104,17 @@ def _fetch_molembedder(): filename="ckpts.{epoch}-{val_loss:.2f}", save_weights_only=False, ) - earlystop_callback = EarlyStopping(monitor="val_loss", patience=10) + earlystop_callback = EarlyStopping(monitor="val_loss", patience=3) + tqdm_callback = TQDMProgressBar(refresh_rate=int(len(train_dataloader) * 0.05)) - max_epochs = args.epoch if not args.debug else 300 + max_epochs = args.epoch if not args.debug else 100 # Create trainer trainer = pl.Trainer( gpus=[0], max_epochs=max_epochs, - progress_bar_refresh_rate=int(len(train_dataloader) * 0.05), - callbacks=[checkpoint_callback], + callbacks=[checkpoint_callback, tqdm_callback], logger=[tb_logger, csv_logger], + fast_dev_run=args.fast_dev_run, ) logger.info(f"Start training") diff --git a/src/syn_net/models/rxn.py b/src/syn_net/models/rxn.py index 25528e72..a5771b67 100644 --- a/src/syn_net/models/rxn.py +++ b/src/syn_net/models/rxn.py @@ -9,6 +9,7 @@ from pytorch_lightning import loggers as pl_loggers from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint +from pytorch_lightning.callbacks.progress import TQDMProgressBar from syn_net.config import CHECKPOINTS_DIR from syn_net.models.common import get_args, xy_to_dataloader @@ -135,16 +136,17 @@ filename="ckpts.{epoch}-{val_loss:.2f}", save_weights_only=False, ) - earlystop_callback = EarlyStopping(monitor="val_loss", patience=10) + earlystop_callback = EarlyStopping(monitor="val_loss", patience=3) + tqdm_callback = TQDMProgressBar(refresh_rate=int(len(train_dataloader) * 0.05)) - max_epochs = args.epoch if not args.debug else 300 + max_epochs = args.epoch if not args.debug else 100 # Create trainer trainer = pl.Trainer( gpus=[0], max_epochs=max_epochs, - progress_bar_refresh_rate=int(len(train_dataloader) * 0.05), - callbacks=[checkpoint_callback], + callbacks=[checkpoint_callback, tqdm_callback], logger=[tb_logger, csv_logger], + fast_dev_run=args.fast_dev_run, ) logger.info(f"Start training") From db6882c689ebd700c0d4471650c68ecd569bd7b0 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 5 Oct 2022 08:42:44 -0400 Subject: [PATCH 267/302] clean up --- src/syn_net/models/mlp.py | 95 ++++++++++++++++++--------------------- 1 file changed, 44 insertions(+), 51 deletions(-) diff --git a/src/syn_net/models/mlp.py b/src/syn_net/models/mlp.py index 539c71f5..de6f3265 100644 --- a/src/syn_net/models/mlp.py +++ b/src/syn_net/models/mlp.py @@ -2,7 +2,6 @@ Multi-layer perceptron (MLP) class. """ import logging -import time import numpy as np import pytorch_lightning as pl @@ -18,19 +17,19 @@ class MLP(pl.LightningModule): def __init__( self, - input_dim=3072, - output_dim=4, - hidden_dim=1000, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task="classification", - loss="cross_entropy", - valid_loss="accuracy", - optimizer="adam", - learning_rate=1e-4, - val_freq=10, - ncpu=16, + input_dim: int, + output_dim: int, + hidden_dim: int, + num_layers: int, + dropout: float, + num_dropout_layers: int = 1, + task: str = "classification", + loss: str = "cross_entropy", + valid_loss: str = "accuracy", + optimizer: str = "adam", + learning_rate: float = 1e-4, + val_freq: int = 10, + ncpu: int = 16, molembedder: MolEmbedder = None, ): super().__init__() @@ -78,44 +77,42 @@ def training_step(self, batch, batch_idx): elif self.loss == "huber": loss = F.huber_loss(y_hat, y) else: - raise ValueError("Not specified loss function: % s" % self.loss) + raise ValueError("Unsupported loss function '%s'" % self.loss) self.log(f"train_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) return loss def validation_step(self, batch, batch_idx): """The complete validation loop.""" - if self.trainer.current_epoch % self.val_freq == 0: - x, y = batch - y_hat = self.layers(x) - if self.valid_loss == "cross_entropy": - loss = F.cross_entropy(y_hat, y.long()) - elif self.valid_loss == "accuracy": - y_hat = torch.argmax(y_hat, axis=1) - accuracy = (y_hat == y).sum() / len(y) - loss = 1 - accuracy - elif self.valid_loss[:11] == "nn_accuracy": - # NOTE: Very slow! - # Performing the knn-search can easily take a couple of minutes, - # even for small datasets. - kdtree = self.molembedder.kdtree - y = nn_search_list(y.detach().cpu().numpy(), None, kdtree) - y_hat = nn_search_list(y_hat.detach().cpu().numpy(), None, kdtree) - - accuracy = (y_hat == y).sum() / len(y) - loss = 1 - accuracy - elif self.valid_loss == "mse": - loss = F.mse_loss(y_hat, y) - elif self.valid_loss == "l1": - loss = F.l1_loss(y_hat, y) - elif self.valid_loss == "huber": - loss = F.huber_loss(y_hat, y) - else: - raise ValueError( - "Not specified validation loss function for '%s'" % self.valid_loss - ) - self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) + if self.trainer.current_epoch % self.val_freq != 0: + return None + + x, y = batch + y_hat = self.layers(x) + if self.valid_loss == "cross_entropy": + loss = F.cross_entropy(y_hat, y.long()) + elif self.valid_loss == "accuracy": + y_hat = torch.argmax(y_hat, axis=1) + accuracy = (y_hat == y).sum() / len(y) + loss = 1 - accuracy + elif self.valid_loss[:11] == "nn_accuracy": + # NOTE: Very slow! + # Performing the knn-search can easily take a couple of minutes, + # even for small datasets. + kdtree = self.molembedder.kdtree + y = nn_search_list(y.detach().cpu().numpy(), kdtree) + y_hat = nn_search_list(y_hat.detach().cpu().numpy(), kdtree) + + accuracy = (y_hat == y).sum() / len(y) + loss = 1 - accuracy + elif self.valid_loss == "mse": + loss = F.mse_loss(y_hat, y) + elif self.valid_loss == "l1": + loss = F.l1_loss(y_hat, y) + elif self.valid_loss == "huber": + loss = F.huber_loss(y_hat, y) else: - pass + raise ValueError("Unsupported loss function '%s'" % self.valid_loss) + self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) def configure_optimizers(self): """Define Optimerzers and LR schedulers.""" @@ -131,11 +128,7 @@ def load_array(data_arrays, batch_size, is_train=True, ncpu=-1): return torch.utils.data.DataLoader(dataset, batch_size, shuffle=is_train, num_workers=ncpu) -def cosine_distance(v1, v2, eps=1e-15): - return 1 - np.dot(v1, v2) / (np.linalg.norm(v1, ord=2) * np.linalg.norm(v2, ord=2) + eps) - - -def nn_search_list(y, out_feat, kdtree): +def nn_search_list(y, kdtree): y = np.atleast_2d(y) # (n_samples, n_features) ind = kdtree.query(y, k=1, return_distance=False) # (n_samples, 1) return ind From 05e7bd645c4514b5eab156de39147431ac14b649 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 5 Oct 2022 09:08:18 -0400 Subject: [PATCH 268/302] clean up: softmax explicitly during inference only --- src/syn_net/models/act.py | 2 +- src/syn_net/models/mlp.py | 7 ++++--- src/syn_net/models/rxn.py | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/syn_net/models/act.py b/src/syn_net/models/act.py index 02a54414..ffd56a4b 100644 --- a/src/syn_net/models/act.py +++ b/src/syn_net/models/act.py @@ -63,7 +63,7 @@ num_layers=5, dropout=0.5, num_dropout_layers=1, - task="classification-w/o-softmax", + task="classification", loss="cross_entropy", valid_loss="accuracy", optimizer="adam", diff --git a/src/syn_net/models/mlp.py b/src/syn_net/models/mlp.py index de6f3265..48da1c1e 100644 --- a/src/syn_net/models/mlp.py +++ b/src/syn_net/models/mlp.py @@ -55,14 +55,15 @@ def __init__( modules.append(nn.Dropout(dropout)) modules.append(nn.Linear(hidden_dim, output_dim)) - if task == "classification": - modules.append(nn.Softmax(dim=1)) self.layers = nn.Sequential(*modules) def forward(self, x): """Forward step for inference only.""" - return self.layers(x) + y_hat = self.layers(x) + if self.task == "classification": # during training, `cross_entropy` loss expexts raw logits + y_hat = F.softmax(y_hat,dim=-1) + return y_hat def training_step(self, batch, batch_idx): """The complete training loop.""" diff --git a/src/syn_net/models/rxn.py b/src/syn_net/models/rxn.py index a5771b67..9dffbd39 100644 --- a/src/syn_net/models/rxn.py +++ b/src/syn_net/models/rxn.py @@ -94,7 +94,7 @@ num_layers=5, dropout=0.5, num_dropout_layers=1, - task="classification-w/o-softmax", + task="classification", loss="cross_entropy", valid_loss="accuracy", optimizer="adam", From 29ebd8e484b31ed9a4bec8671a6882ed6826a841 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 5 Oct 2022 09:18:24 -0400 Subject: [PATCH 269/302] clean up resuming training from ckpt --- src/syn_net/models/rxn.py | 58 +++++++++++++-------------------------- 1 file changed, 19 insertions(+), 39 deletions(-) diff --git a/src/syn_net/models/rxn.py b/src/syn_net/models/rxn.py index 9dffbd39..d6ca1426 100644 --- a/src/syn_net/models/rxn.py +++ b/src/syn_net/models/rxn.py @@ -50,11 +50,6 @@ shuffle=True if dataset == "train" else False, ) logger.info(f"Set up dataloaders.") - param_path = ( - Path(CHECKPOINTS_DIR) - / f"{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_v{args.version}/" - ) - path_to_rxn = f"{param_path}rxn.ckpt" INPUT_DIMS = { "fp": { @@ -86,39 +81,23 @@ } output_dim = OUTPUT_DIMS[args.rxn_template] - if not args.restart: - mlp = MLP( - input_dim=input_dim, - output_dim=output_dim, - hidden_dim=hidden_dim, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task="classification", - loss="cross_entropy", - valid_loss="accuracy", - optimizer="adam", - learning_rate=3e-4, - val_freq=10, - ncpu=args.ncpu, - ) - else: # load from checkpt -> only for fp, not gin - # TODO: Use `ckpt_path`, c.f. https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html#pytorch_lightning.trainer.trainer.Trainer.fit - mlp = MLP.load_from_checkpoint( - path_to_rxn, - input_dim=input_dim, - output_dim=output_dim, - hidden_dim=hidden_dim, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task="classification", - loss="cross_entropy", - valid_loss="accuracy", - optimizer="adam", - learning_rate=1e-4, - ncpu=args.ncpu, - ) + path_to_rxn = "placeholder-path-for-checkpoint-for-resuming-training" + ckpt_path = path_to_rxn if args.restart else None # TODO: Unify for all networks + mlp = MLP( + input_dim=input_dim, + output_dim=output_dim, + hidden_dim=hidden_dim, + num_layers=5, + dropout=0.5, + num_dropout_layers=1, + task="classification", + loss="cross_entropy", + valid_loss="accuracy", + optimizer="adam", + learning_rate=3e-4, + val_freq=10, + ncpu=args.ncpu, + ) # Set up Trainer save_dir = Path("results/logs/") / MODEL_ID @@ -147,8 +126,9 @@ callbacks=[checkpoint_callback, tqdm_callback], logger=[tb_logger, csv_logger], fast_dev_run=args.fast_dev_run, + ) logger.info(f"Start training") - trainer.fit(mlp, train_dataloader, valid_dataloader) + trainer.fit(mlp, train_dataloader, valid_dataloader,ckpt_path=ckpt_path) logger.info(f"Training completed.") From 2e39916ce51d090796b7f95c5cda3e34ee784680 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 5 Oct 2022 09:56:44 -0400 Subject: [PATCH 270/302] clean up --- src/syn_net/models/common.py | 13 ++++--------- src/syn_net/models/rxn.py | 3 +-- 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/src/syn_net/models/common.py b/src/syn_net/models/common.py index 7fb1265a..8cd2848a 100644 --- a/src/syn_net/models/common.py +++ b/src/syn_net/models/common.py @@ -1,20 +1,12 @@ """Common methods and params shared by all models. """ -# Helper to select validation func based on output dim from typing import Union import numpy as np import torch from scipy import sparse -VALIDATION_OPTS = { - 300: "nn_accuracy_gin", - 4096: "nn_accuracy_fp_4096", - 256: "nn_accuracy_fp_256", - 200: "nn_accuracy_rdkit2d", -} - def get_args(): import argparse @@ -38,7 +30,10 @@ def get_args(): parser.add_argument("--batch_size", type=int, default=64, help="Batch size") parser.add_argument("--epoch", type=int, default=2000, help="Maximum number of epoches.") parser.add_argument( - "--restart", type=bool, default=False, help="Indicates whether to restart training." + "--ckpt-file", + type=str, + default=None, + help="Checkpoint file. If provided, load and resume training.", ) parser.add_argument("-v", "--version", type=int, default=1, help="Version") parser.add_argument("--debug", default=False, action="store_true") diff --git a/src/syn_net/models/rxn.py b/src/syn_net/models/rxn.py index d6ca1426..32a74ace 100644 --- a/src/syn_net/models/rxn.py +++ b/src/syn_net/models/rxn.py @@ -81,8 +81,7 @@ } output_dim = OUTPUT_DIMS[args.rxn_template] - path_to_rxn = "placeholder-path-for-checkpoint-for-resuming-training" - ckpt_path = path_to_rxn if args.restart else None # TODO: Unify for all networks + ckpt_path = args.ckpt_file # TODO: Unify for all networks mlp = MLP( input_dim=input_dim, output_dim=output_dim, From b7ef8683233d72743c0bbb6f12a7384a361f4121 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 5 Oct 2022 13:25:30 -0400 Subject: [PATCH 271/302] use hparams for attrs to fix loading from ckpts --- src/syn_net/models/mlp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/syn_net/models/mlp.py b/src/syn_net/models/mlp.py index 48da1c1e..509d56e1 100644 --- a/src/syn_net/models/mlp.py +++ b/src/syn_net/models/mlp.py @@ -61,7 +61,7 @@ def __init__( def forward(self, x): """Forward step for inference only.""" y_hat = self.layers(x) - if self.task == "classification": # during training, `cross_entropy` loss expexts raw logits + if self.hparams.task == "classification": # during training, `cross_entropy` loss expexts raw logits y_hat = F.softmax(y_hat,dim=-1) return y_hat From a0f432e8a52d3406540f7c3a3b768faec59e4763 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Thu, 6 Oct 2022 04:48:29 -0400 Subject: [PATCH 272/302] clean up and remove hard-coded paths --- scripts/20-predict-targets.py | 122 +++++++++++++++------------------- 1 file changed, 54 insertions(+), 68 deletions(-) diff --git a/scripts/20-predict-targets.py b/scripts/20-predict-targets.py index 2ddaeb98..8647e798 100644 --- a/scripts/20-predict-targets.py +++ b/scripts/20-predict-targets.py @@ -13,9 +13,9 @@ import numpy as np import pandas as pd -from syn_net.config import CHECKPOINTS_DIR, DATA_PREPROCESS_DIR, DATA_RESULT_DIR +from syn_net.config import DATA_PREPROCESS_DIR, DATA_RESULT_DIR, MAX_PROCESSES from syn_net.data_generation.preprocessing import BuildingBlockFileHandler -from syn_net.models.chkpt_loader import load_modules_from_checkpoint +from syn_net.models.chkpt_loader import load_mlp_from_ckpt from syn_net.utils.data_utils import ReactionSet, SyntheticTree, SyntheticTreeSet from syn_net.utils.predict_utils import mol_fp, synthetic_tree_decoder_greedy_search @@ -74,31 +74,24 @@ def find_best_model_ckpt(path: str) -> Union[Path, None]: # TODO: move to utils def _load_pretrained_model(path_to_checkpoints: list[Path]): """Wrapper to load modules from checkpoint.""" # Define paths to pretrained models. - path_to_act, path_to_rt1, path_to_rxn, path_to_rt2 = path_to_checkpoints + act_path, rt1_path, rxn_path, rt2_path = path_to_checkpoints # Load the pre-trained models. - act_net, rt1_net, rxn_net, rt2_net = load_modules_from_checkpoint( - path_to_act=path_to_act, - path_to_rt1=path_to_rt1, - path_to_rxn=path_to_rxn, - path_to_rt2=path_to_rt2, - featurize=args.featurize, - rxn_template=args.rxn_template, - out_dim=out_dim, - nbits=nbits, - ncpu=args.ncpu, - ) + act_net = load_mlp_from_ckpt(act_path) + rt1_net = load_mlp_from_ckpt(rt1_path) + rxn_net = load_mlp_from_ckpt(rxn_path) + rt2_net = load_mlp_from_ckpt(rt2_path) return act_net, rt1_net, rxn_net, rt2_net -def func(smiles: str) -> Tuple[str, float, SyntheticTree]: +def wrapper_decoder(smiles: str) -> Tuple[str, float, SyntheticTree]: """Generate a synthetic tree for the input molecular embedding.""" emb = mol_fp(smiles) try: smi, similarity, tree, action = synthetic_tree_decoder_greedy_search( z_target=emb, - building_blocks=building_blocks, - bb_dict=building_blocks_dict, + building_blocks=bblocks, + bb_dict=bblocks_dict, reaction_templates=rxns, mol_embedder=bblocks_molembedder.kdtree, # TODO: fix this, currently misused action_net=act_net, @@ -106,8 +99,8 @@ def func(smiles: str) -> Tuple[str, float, SyntheticTree]: rxn_net=rxn_net, reactant2_net=rt2_net, bb_emb=bb_emb, - rxn_template=args.rxn_template, - n_bits=nbits, + rxn_template="hb", # TODO: Do not hard code + n_bits=4096, # TODO: Do not hard code beam_width=3, max_step=15, ) @@ -127,44 +120,46 @@ def get_args(): import argparse parser = argparse.ArgumentParser() + # File I/O + parser.add_argument( + "--building-blocks-file", + type=str, + help="Input file with SMILES strings (First row `SMILES`, then one per line).", + ) parser.add_argument( - "-f", "--featurize", type=str, default="fp", help="Choose from ['fp', 'gin']" + "--rxns-collection-file", + type=str, + help="Input file for the collection of reactions matched with building-blocks.", + ) + parser.add_argument( + "--embeddings-knn-file", + type=str, + help="Input file for the pre-computed embeddings (*.npy).", ) - parser.add_argument("--radius", type=int, default=2, help="Radius for Morgan Fingerprint") parser.add_argument( - "-b", "--nbits", type=int, default=4096, help="Number of Bits for Morgan Fingerprint" + "--ckpt-dir", type=str, help="Directory with checkpoints for {act,rt1,rxn,rt2}-model." ) parser.add_argument( - "-r", "--rxn_template", type=str, default="hb", help="Choose from ['hb', 'pis']" + "--output-dir", type=str, default=DATA_RESULT_DIR, help="Directory to save output." ) - parser.add_argument("--ncpu", type=int, default=1, help="Number of cpus") - parser.add_argument("-n", "--num", type=int, default=-1, help="Number of molecules to predict.") + # Parameters + parser.add_argument("--num", type=int, default=-1, help="Number of molecules to predict.") parser.add_argument( - "-d", "--data", type=str, default="test", help="Choose from ['train', 'valid', 'test', 'chembl'] or provide a file with one SMILES per line.", ) - parser.add_argument( - "-o", - "--outputembedding", - type=str, - default="fp_256", - help="Choose from ['fp_4096', 'fp_256', 'gin', 'rdkit2d']", - ) - parser.add_argument("--output-dir", type=str, default=None, help="Directory to save output.") + # Processing + parser.add_argument("--ncpu", type=int, default=MAX_PROCESSES, help="Number of cpus") + parser.add_argument("--verbose", default=False, action="store_true") return parser.parse_args() if __name__ == "__main__": args = get_args() - logger.info(f"Arguments: {json.dumps(vars(args),indent=2)}") - - nbits = args.nbits - out_dim = args.outputembedding.split("_")[-1] # <=> morgan fingerprint with 256 bits - param_dir = f"{args.rxn_template}_{args.featurize}_{args.radius}_{nbits}_{out_dim}" + logger.info(f"Arguments: {json.dumps(vars(args),indent=2)}") # Load data ... logger.info("Start loading data...") # ... query molecules (i.e. molecules to decode) @@ -173,56 +168,45 @@ def get_args(): smiles_queries = smiles_queries[: args.num] # ... building blocks - file = ( - Path(DATA_PREPROCESS_DIR) / "building-blocks-rxns" / f"enamine-us-smiles.csv.gz" - ) # TODO: Do not hardcode - building_blocks = BuildingBlockFileHandler().load(file) - building_blocks_dict = { - block: i for i, block in enumerate(building_blocks) + bblocks = BuildingBlockFileHandler().load(args.building_blocks_file) + bblocks_dict = { + block: i for i, block in enumerate(bblocks) } # dict is used as lookup table for 2nd reactant during inference - logger.info("...loading building blocks completed.") + logger.info(f"Successfully read {args.building_blocks_file}.") # ... reaction templates - file = ( - Path(DATA_PREPROCESS_DIR) / "building-blocks-rxns" / "hb-enamine-us.json.gz" - ) # TODO: Do not hardcode - rxns = ReactionSet().load(file).rxns - logger.info("...loading reaction collection completed.") + rxns = ReactionSet().load(args.output_rxns_collection_file).rxns + logger.info(f"Successfully read {args.output_rxns_collection_file}.") # ... building block embedding - file = ( - Path(DATA_PREPROCESS_DIR) / "embeddings" / f"hb-enamine-embeddings.npy" - ) # TODO: Do not hardcode - bblocks_molembedder = MolEmbedder().load_precomputed(file).init_balltree(cosine_distance) + bblocks_molembedder = ( + MolEmbedder().load_precomputed(args.embeddings_knn_file).init_balltree(cosine_distance) + ) bb_emb = bblocks_molembedder.get_embeddings() - - logger.info("...loading building block embeddings completed.") + logger.info(f"Successfully read {args.embeddings_knn_file} and initialized BallTree.") logger.info("...loading data completed.") # ... models logger.info("Start loading models from checkpoints...") - path = Path(CHECKPOINTS_DIR) / f"{param_dir}" - paths = [ - find_best_model_ckpt("results/logs/hb_fp_2_4096/" + model) # TODO: Do not hardcode - for model in "act rt1 rxn rt2".split() - ] + path = Path(args.ckpt_dir) + paths = [find_best_model_ckpt(path / model) for model in "act rt1 rxn rt2".split()] act_net, rt1_net, rxn_net, rt2_net = _load_pretrained_model(paths) logger.info("...loading models completed.") # Decode queries, i.e. the target molecules. logger.info(f"Start to decode {len(smiles_queries)} target molecules.") if args.ncpu == 1: - results = [func(smi) for smi in smiles_queries] + results = [wrapper_decoder(smi) for smi in smiles_queries] else: with mp.Pool(processes=args.ncpu) as pool: logger.info(f"Starting MP with ncpu={args.ncpu}") - results = pool.map(func, smiles_queries) + results = pool.map(wrapper_decoder, smiles_queries) logger.info("Finished decoding.") # Print some results from the prediction - smis_decoded = [r[0] for r in results] - similarities = [r[1] for r in results] - trees = [r[2] for r in results] + smis_decoded = [smi for smi, _, tree in results if tree is not None] + similarities = [sim for _, sim, tree in results if tree is not None] + trees = [tree for _, _, tree in results if tree is not None] recovery_rate = (np.asfarray(similarities) == 1.0).sum() / len(similarities) avg_similarity = np.mean(similarities) @@ -233,7 +217,9 @@ def get_args(): logger.info(f" {avg_similarity=}") # Save to local dir - output_dir = DATA_RESULT_DIR if args.output_dir is None else args.output_dir + # 1. Dataframe with targets, decoded, smilarities + # 2. Synthetic trees of the decoded SMILES + output_dir = Path(args.output_dir) logger.info(f"Saving results to {output_dir} ...") df = pd.DataFrame( {"query SMILES": smiles_queries, "decode SMILES": smis_decoded, "similarity": similarities} From ddfc0a06a4c7ebed2134945aca4822cf379aed67 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Thu, 6 Oct 2022 04:53:51 -0400 Subject: [PATCH 273/302] update instructions --- INSTRUCTIONS.md | 3 +-- scripts/01-filter-building-blocks.py | 4 ++-- scripts/02-compute-embeddings.py | 5 ----- 3 files changed, 3 insertions(+), 9 deletions(-) diff --git a/INSTRUCTIONS.md b/INSTRUCTIONS.md index 40eb606e..3e477099 100644 --- a/INSTRUCTIONS.md +++ b/INSTRUCTIONS.md @@ -37,7 +37,7 @@ Let's start. --building-blocks-file "data/assets/building-blocks/enamine-us-smiles.csv.gz" \ --rxn-templates-file "data/assets/reaction-templates/hb.txt" \ --output-bblock-file "data/pre-process/building-blocks-rxns/bblocks-enamine-us.csv.gz" \ - --output-rxns-file "data/pre-process/building-blocks-rxns/rxns-hb-enamine-us.json.gz" --verbose + --output-rxns-collection-file "data/pre-process/building-blocks-rxns/rxns-hb-enamine-us.json.gz" --verbose ``` > :bulb: All following steps use this matched building blocks <-> reaction template data. You have to specify the correct files for every script to that it can load the right data. It can save some time to store these as environment variables. @@ -50,7 +50,6 @@ Let's start. ```bash python scripts/02-compute-embeddings.py \ --building-blocks-file "data/pre-process/building-blocks/enamine-us-smiles.csv.gz" \ - --rxn-templates-file "data/assets/reaction-templates/hb.txt" --output-file "data/pre-process/embeddings/hb-enamine-embeddings.npy" ``` diff --git a/scripts/01-filter-building-blocks.py b/scripts/01-filter-building-blocks.py index 8068b96c..7fe73cc1 100644 --- a/scripts/01-filter-building-blocks.py +++ b/scripts/01-filter-building-blocks.py @@ -38,7 +38,7 @@ def get_args(): help="Output file for the filtered building-blocks.", ) parser.add_argument( - "--output-rxns-file", + "--output-rxns-collection-file", type=str, help="Output file for the collection of reactions matched with building-blocks.", ) @@ -74,7 +74,7 @@ def get_args(): # Save collection of reactions which have "available reactants" set (for convenience) rxn_collection = ReactionSet(bbf.rxns) - rxn_collection.save(args.output_rxns_file) + rxn_collection.save(args.output_rxns_collection_file) logger.info(f"Total number of building blocks {len(bblocks):d}") logger.info(f"Matched number of building blocks {len(bblocks_filtered):d}") diff --git a/scripts/02-compute-embeddings.py b/scripts/02-compute-embeddings.py index e274800e..740a6e3a 100644 --- a/scripts/02-compute-embeddings.py +++ b/scripts/02-compute-embeddings.py @@ -35,11 +35,6 @@ def get_args(): type=str, help="Input file with SMILES strings (First row `SMILES`, then one per line).", ) - parser.add_argument( - "--rxn-templates-file", - type=str, - help="Input file with reaction templates as SMARTS (No header, one per line).", - ) parser.add_argument( "--output-file", type=str, From dc0a6e0a04bee64de34882c0454b3091221a8229 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Thu, 6 Oct 2022 04:54:02 -0400 Subject: [PATCH 274/302] clean up loading from ckpt --- scripts/_mp_decode.py | 30 +-- src/syn_net/models/chkpt_loader.py | 295 +---------------------------- 2 files changed, 14 insertions(+), 311 deletions(-) diff --git a/scripts/_mp_decode.py b/scripts/_mp_decode.py index d8bdf150..9fb3cd6b 100644 --- a/scripts/_mp_decode.py +++ b/scripts/_mp_decode.py @@ -7,11 +7,8 @@ from dgllife.model import load_pretrained from syn_net.utils.data_utils import ReactionSet -from syn_net.utils.predict_utils import ( - load_modules_from_checkpoint, - synthetic_tree_decoder, - tanimoto_similarity, -) +from syn_net.utils.predict_utils import synthetic_tree_decoder, tanimoto_similarity +from syn_net.models.chkpt_loader import load_mlp_from_ckpt # define some constants (here, for the Hartenfeller-Button test set) nbits = 4096 @@ -40,10 +37,10 @@ # define paths to pretrained modules param_path = f"/home/whgao/synth_net/synth_net/params/{param_dir}/" -path_to_act = f"{param_path}act.ckpt" -path_to_rt1 = f"{param_path}rt1.ckpt" -path_to_rxn = f"{param_path}rxn.ckpt" -path_to_rt2 = f"{param_path}rt2.ckpt" +act_path = f"{param_path}act.ckpt" +rt1_path = f"{param_path}rt1.ckpt" +rxn_path = f"{param_path}rxn.ckpt" +rt2_path = f"{param_path}rt2.ckpt" # load the purchasable building block SMILES to a dictionary building_blocks = pd.read_csv(path_to_building_blocks, compression="gzip")["SMILES"].tolist() @@ -54,17 +51,10 @@ rxns = rxn_set.rxns # load the pre-trained modules -act_net, rt1_net, rxn_net, rt2_net = load_modules_from_checkpoint( - path_to_act=path_to_act, - path_to_rt1=path_to_rt1, - path_to_rxn=path_to_rxn, - path_to_rt2=path_to_rt2, - featurize=featurize, - rxn_template=rxn_template, - out_dim=out_dim, - nbits=nbits, - ncpu=ncpu, -) +act_net = load_mlp_from_ckpt(act_path) +rt1_net = load_mlp_from_ckpt(rt1_path) +rxn_net = load_mlp_from_ckpt(rxn_path) +rt2_net = load_mlp_from_ckpt(rt2_path) def func(emb): diff --git a/src/syn_net/models/chkpt_loader.py b/src/syn_net/models/chkpt_loader.py index eea75b76..79784b70 100644 --- a/src/syn_net/models/chkpt_loader.py +++ b/src/syn_net/models/chkpt_loader.py @@ -1,294 +1,7 @@ -from typing import List, Tuple - -import pytorch_lightning as pl - from syn_net.models.mlp import MLP -def load_modules_from_checkpoint( - path_to_act: str, - path_to_rt1: str, - path_to_rxn: str, - path_to_rt2: str, - featurize: str, - rxn_template: str, - out_dim: int, - nbits: int, - ncpu: int, -) -> List[pl.LightningModule]: - - if rxn_template == "unittest": - - act_net = MLP.load_from_checkpoint( - path_to_act, - input_dim=int(3 * nbits), - output_dim=4, - hidden_dim=100, - num_layers=3, - dropout=0.5, - num_dropout_layers=1, - task="classification", - loss="cross_entropy", - valid_loss="accuracy", - optimizer="adam", - learning_rate=1e-4, - ncpu=ncpu, - ) - - rt1_net = MLP.load_from_checkpoint( - path_to_rt1, - input_dim=int(3 * nbits), - output_dim=out_dim, - hidden_dim=100, - num_layers=3, - dropout=0.5, - num_dropout_layers=1, - task="regression", - loss="mse", - valid_loss="mse", - optimizer="adam", - learning_rate=1e-4, - ncpu=ncpu, - ) - - rxn_net = MLP.load_from_checkpoint( - path_to_rxn, - input_dim=int(4 * nbits), - output_dim=3, - hidden_dim=100, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task="classification", - loss="cross_entropy", - valid_loss="accuracy", - optimizer="adam", - learning_rate=1e-4, - ncpu=ncpu, - ) - - rt2_net = MLP.load_from_checkpoint( - path_to_rt2, - input_dim=int(4 * nbits + 3), - output_dim=out_dim, - hidden_dim=100, - num_layers=3, - dropout=0.5, - num_dropout_layers=1, - task="regression", - loss="mse", - valid_loss="mse", - optimizer="adam", - learning_rate=1e-4, - ncpu=ncpu, - ) - elif featurize == "fp": - - act_net = MLP.load_from_checkpoint( - path_to_act, - input_dim=int(3 * nbits), - output_dim=4, - hidden_dim=1000, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task="classification", - loss="cross_entropy", - valid_loss="accuracy", - optimizer="adam", - learning_rate=1e-4, - ncpu=ncpu, - ) - - rt1_net = MLP.load_from_checkpoint( - path_to_rt1, - input_dim=int(3 * nbits), - output_dim=int(out_dim), - hidden_dim=1200, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task="regression", - loss="mse", - valid_loss="mse", - optimizer="adam", - learning_rate=1e-4, - ncpu=ncpu, - ) - - if rxn_template == "hb": - - rxn_net = MLP.load_from_checkpoint( - path_to_rxn, - input_dim=int(4 * nbits), - output_dim=91, - hidden_dim=3000, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task="classification", - loss="cross_entropy", - valid_loss="accuracy", - optimizer="adam", - learning_rate=1e-4, - ncpu=ncpu, - ) - - rt2_net = MLP.load_from_checkpoint( - path_to_rt2, - input_dim=int(4 * nbits + 91), - output_dim=int(out_dim), - hidden_dim=3000, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task="regression", - loss="mse", - valid_loss="mse", - optimizer="adam", - learning_rate=1e-4, - ncpu=ncpu, - ) - - elif rxn_template == "pis": - - rxn_net = MLP.load_from_checkpoint( - path_to_rxn, - input_dim=int(4 * nbits), - output_dim=4700, - hidden_dim=4500, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task="classification", - loss="cross_entropy", - valid_loss="accuracy", - optimizer="adam", - learning_rate=1e-4, - ncpu=ncpu, - ) - - rt2_net = MLP.load_from_checkpoint( - path_to_rt2, - input_dim=int(4 * nbits + 4700), - output_dim=out_dim, - hidden_dim=3000, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task="regression", - loss="mse", - valid_loss="mse", - optimizer="adam", - learning_rate=1e-4, - ncpu=ncpu, - ) - - elif featurize == "gin": - - act_net = MLP.load_from_checkpoint( - path_to_act, - input_dim=int(2 * nbits + out_dim), - output_dim=4, - hidden_dim=1000, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task="classification", - loss="cross_entropy", - valid_loss="accuracy", - optimizer="adam", - learning_rate=1e-4, - ncpu=ncpu, - ) - - rt1_net = MLP.load_from_checkpoint( - path_to_rt1, - input_dim=int(2 * nbits + out_dim), - output_dim=out_dim, - hidden_dim=1200, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task="regression", - loss="mse", - valid_loss="mse", - optimizer="adam", - learning_rate=1e-4, - ncpu=ncpu, - ) - - if rxn_template == "hb": - - rxn_net = MLP.load_from_checkpoint( - path_to_rxn, - input_dim=int(3 * nbits + out_dim), - output_dim=91, - hidden_dim=3000, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task="classification", - loss="cross_entropy", - valid_loss="accuracy", - optimizer="adam", - learning_rate=1e-4, - ncpu=ncpu, - ) - - rt2_net = MLP.load_from_checkpoint( - path_to_rt2, - input_dim=int(3 * nbits + out_dim + 91), - output_dim=out_dim, - hidden_dim=3000, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task="regression", - loss="mse", - valid_loss="mse", - optimizer="adam", - learning_rate=1e-4, - ncpu=ncpu, - ) - - elif rxn_template == "pis": - - rxn_net = MLP.load_from_checkpoint( - path_to_rxn, - input_dim=int(3 * nbits + out_dim), - output_dim=4700, - hidden_dim=3000, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task="classification", - loss="cross_entropy", - valid_loss="accuracy", - optimizer="adam", - learning_rate=1e-4, - ncpu=ncpu, - ) - - rt2_net = MLP.load_from_checkpoint( - path_to_rt2, - input_dim=int(3 * nbits + out_dim + 4700), - output_dim=out_dim, - hidden_dim=3000, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task="regression", - loss="mse", - valid_loss="mse", - optimizer="adam", - learning_rate=1e-4, - ncpu=ncpu, - ) - - act_net.eval() - rt1_net.eval() - rxn_net.eval() - rt2_net.eval() - - return act_net, rt1_net, rxn_net, rt2_net +def load_mlp_from_ckpt(ckpt_file: str): + """Load a model from a checkpoint for inference.""" + model = MLP.load_from_checkpoint(ckpt_file) + return model.eval() From b1f5b4d6875bd4b5da00db25251781bb0664456d Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Thu, 6 Oct 2022 04:57:01 -0400 Subject: [PATCH 275/302] move fct into `mlp.py` --- scripts/20-predict-targets.py | 2 +- scripts/_mp_decode.py | 2 +- src/syn_net/models/chkpt_loader.py | 7 ------- src/syn_net/models/mlp.py | 10 ++++++++-- 4 files changed, 10 insertions(+), 11 deletions(-) delete mode 100644 src/syn_net/models/chkpt_loader.py diff --git a/scripts/20-predict-targets.py b/scripts/20-predict-targets.py index 8647e798..823b036c 100644 --- a/scripts/20-predict-targets.py +++ b/scripts/20-predict-targets.py @@ -15,7 +15,7 @@ from syn_net.config import DATA_PREPROCESS_DIR, DATA_RESULT_DIR, MAX_PROCESSES from syn_net.data_generation.preprocessing import BuildingBlockFileHandler -from syn_net.models.chkpt_loader import load_mlp_from_ckpt +from syn_net.models.mlp import load_mlp_from_ckpt from syn_net.utils.data_utils import ReactionSet, SyntheticTree, SyntheticTreeSet from syn_net.utils.predict_utils import mol_fp, synthetic_tree_decoder_greedy_search diff --git a/scripts/_mp_decode.py b/scripts/_mp_decode.py index 9fb3cd6b..ae82b327 100644 --- a/scripts/_mp_decode.py +++ b/scripts/_mp_decode.py @@ -8,7 +8,7 @@ from syn_net.utils.data_utils import ReactionSet from syn_net.utils.predict_utils import synthetic_tree_decoder, tanimoto_similarity -from syn_net.models.chkpt_loader import load_mlp_from_ckpt +from syn_net.models.mlp import load_mlp_from_ckpt # define some constants (here, for the Hartenfeller-Button test set) nbits = 4096 diff --git a/src/syn_net/models/chkpt_loader.py b/src/syn_net/models/chkpt_loader.py deleted file mode 100644 index 79784b70..00000000 --- a/src/syn_net/models/chkpt_loader.py +++ /dev/null @@ -1,7 +0,0 @@ -from syn_net.models.mlp import MLP - - -def load_mlp_from_ckpt(ckpt_file: str): - """Load a model from a checkpoint for inference.""" - model = MLP.load_from_checkpoint(ckpt_file) - return model.eval() diff --git a/src/syn_net/models/mlp.py b/src/syn_net/models/mlp.py index 509d56e1..fee4aa1d 100644 --- a/src/syn_net/models/mlp.py +++ b/src/syn_net/models/mlp.py @@ -61,8 +61,10 @@ def __init__( def forward(self, x): """Forward step for inference only.""" y_hat = self.layers(x) - if self.hparams.task == "classification": # during training, `cross_entropy` loss expexts raw logits - y_hat = F.softmax(y_hat,dim=-1) + if ( + self.hparams.task == "classification" + ): # during training, `cross_entropy` loss expects raw logits + y_hat = F.softmax(y_hat, dim=-1) return y_hat def training_step(self, batch, batch_idx): @@ -134,6 +136,10 @@ def nn_search_list(y, kdtree): ind = kdtree.query(y, k=1, return_distance=False) # (n_samples, 1) return ind +def load_mlp_from_ckpt(ckpt_file: str): + """Load a model from a checkpoint for inference.""" + model = MLP.load_from_checkpoint(ckpt_file) + return model.eval() if __name__ == "__main__": pass From a9a240c067aabce7ee3d388cec20c4ccd26c060e Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Fri, 7 Oct 2022 06:29:11 -0400 Subject: [PATCH 276/302] avoid python loop for ~20x speedup --- scripts/21-identify-similar-fps.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/scripts/21-identify-similar-fps.py b/scripts/21-identify-similar-fps.py index f661397a..6aeb6977 100644 --- a/scripts/21-identify-similar-fps.py +++ b/scripts/21-identify-similar-fps.py @@ -57,12 +57,7 @@ def find_similar_fp(fp: np.ndarray, fps_reference: np.ndarray): """Finds most similar fingerprint in a reference set for `fp`. Uses Tanimoto Similarity. """ - dists = np.array( - [ - DataStructs.FingerprintSimilarity(fp, fp_, metric=DataStructs.TanimotoSimilarity) - for fp_ in fps_train - ] - ) + dists = np.asarray(DataStructs.BulkTanimotoSimilarity(fp, fps_reference)) similarity_score, idx = dists.max(), dists.argmax() return similarity_score, idx From 11a00a08cb1e5ca5d1e6e8c2e8c3d763491620e5 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Fri, 7 Oct 2022 07:20:10 -0400 Subject: [PATCH 277/302] clean up code --- scripts/21-identify-similar-fps.py | 49 +++++++++++++++++------------- 1 file changed, 28 insertions(+), 21 deletions(-) diff --git a/scripts/21-identify-similar-fps.py b/scripts/21-identify-similar-fps.py index 6aeb6977..3d55a8bf 100644 --- a/scripts/21-identify-similar-fps.py +++ b/scripts/21-identify-similar-fps.py @@ -1,6 +1,4 @@ -""" -Computes the fingerprint similarity of molecules in the validation and test set to -molecules in the training set. +"""Computes the fingerprint similarity of molecules in {valid,test}-set to molecules in the training set. """ # TODO: clean up, un-nest a couple of fcts import json import logging @@ -35,7 +33,7 @@ def get_args(): "--output-file", type=str, default=None, - help="Optional: File to save similarity-values for test,valid-synthetic trees.", + help="File to save similarity-values for test,valid-synthetic trees. (*csv.gz)", ) # Processing parser.add_argument("--ncpu", type=int, default=MAX_PROCESSES, help="Number of cpus") @@ -77,24 +75,28 @@ def get_smiles_and_fps(dataset: str) -> Tuple[list[str], list[np.ndarray]]: return smiles, fps -def _save_df(file: str, df): - if file is None: - return - df.to_csv(file, index=False) +def compute_most_similar_smiles( + split: str, + fps: np.ndarray, + smiles: list[str], + /, + fps_reference: np.ndarray, + smiles_reference: list[str], +) -> pd.DataFrame: - -def compute_most_similar_smiles(split: str, fps: np.ndarray, smiles: list[str]) -> pd.DataFrame: + func = partial(find_similar_fp, fps_reference=fps_reference) with mp.Pool(processes=args.ncpu) as pool: results = pool.map(func, fps) - similarities, idx = np.asfarray(results).T - most_similiar_ref_smiles = np.asarray(smiles_train)[idx.astype(int)] # use numpy for slicin' + similarities, idx = zip(*results) + most_similiar_ref_smiles = np.asarray(smiles_reference)[np.asarray(idx, dtype=int)] + # ^ Use numpy for slicing... df = pd.DataFrame( { - "smiles": smiles, "split": split, - "most similar": most_similiar_ref_smiles, + "smiles": smiles, + "most_similar_smiles": most_similiar_ref_smiles, "similarity": similarities, } ) @@ -107,20 +109,25 @@ def compute_most_similar_smiles(split: str, fps: np.ndarray, smiles: list[str]) # Parse input args args = get_args() logger.info(f"Arguments: {json.dumps(vars(args),indent=2)}") - args.input_dir = "/home/ulmer/SynNet/data/pre-process/syntrees" - # Load data smiles_train, fps_train = get_smiles_and_fps("train") smiles_valid, fps_valid = get_smiles_and_fps("valid") smiles_test, fps_test = get_smiles_and_fps("test") # Compute (mp) - func = partial(find_similar_fp, fps_reference=fps_train) - df_valid = compute_most_similar_smiles("valid", fps_valid, smiles_valid) - df_test = compute_most_similar_smiles("test", fps_test, smiles_test) + logger.info("Start computing most similar smiles...") + df_valid = compute_most_similar_smiles( + "valid", fps_valid, smiles_valid, fps_reference=fps_train, smiles_reference=smiles_train + ) + df_test = compute_most_similar_smiles( + "test", fps_test, smiles_test, fps_reference=fps_train, smiles_reference=smiles_train + ) + logger.info("Computed most similar smiles for {valid,test}-set.") # Save - outfile = "data_similarity.csv" - _save_df(outfile, pd.concat([df_valid, df_test], axis=0, ignore_index=True)) + Path(args.output_file).parent.mkdir(parents=True, exist_ok=True) + df = pd.concat([df_valid, df_test], axis=0, ignore_index=True) + df.to_csv(args.output_file, index=False, compression="gzip") + logger.info(f"Successfully saved output to {args.output_file}.") logger.info("Completed.") From dcae32c062bde179337199a44bd589e869b0d624 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Sat, 8 Oct 2022 08:18:28 -0400 Subject: [PATCH 278/302] deprecate `load_array()` --- scripts/22-compute-mrr.py | 17 ++++++++++------- src/syn_net/models/mlp.py | 5 ----- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/scripts/22-compute-mrr.py b/scripts/22-compute-mrr.py index be2ab55a..2f4c5621 100644 --- a/scripts/22-compute-mrr.py +++ b/scripts/22-compute-mrr.py @@ -6,6 +6,7 @@ import torch from scipy import sparse from sklearn.neighbors import BallTree +from syn_net.models.common import xy_to_dataloader from syn_net.encoding.distances import ce_distance, cosine_distance from syn_net.models.mlp import MLP, load_array @@ -58,13 +59,15 @@ def get_args(): batch_size = args.batch_size ncpu = args.ncpu - X = sparse.load_npz(args.X_data_file) - y = sparse.load_npz(args.y_data_file) - X = torch.Tensor(X.A) - y = torch.Tensor(y.A) - _idx = np.random.choice(list(range(X.shape[0])), size=int(X.shape[0] / 10), replace=False) - test_data_iter = load_array((X[_idx], y[_idx]), batch_size, ncpu=ncpu, is_train=False) - data_iter = test_data_iter + # Load data + dataloader = xy_to_dataloader( + X_file = args.X_data_file, + y_file = args.y_data_file, + n=None if not args.debug else 128, + batch_size=args.batch_size, + num_workers=args.ncpu, + shuffle=False, + ) rt1_net = MLP.load_from_checkpoint( path_to_rt1, diff --git a/src/syn_net/models/mlp.py b/src/syn_net/models/mlp.py index fee4aa1d..d29ba3dc 100644 --- a/src/syn_net/models/mlp.py +++ b/src/syn_net/models/mlp.py @@ -126,11 +126,6 @@ def configure_optimizers(self): return optimizer -def load_array(data_arrays, batch_size, is_train=True, ncpu=-1): - dataset = torch.utils.data.TensorDataset(*data_arrays) - return torch.utils.data.DataLoader(dataset, batch_size, shuffle=is_train, num_workers=ncpu) - - def nn_search_list(y, kdtree): y = np.atleast_2d(y) # (n_samples, n_features) ind = kdtree.query(y, k=1, return_distance=False) # (n_samples, 1) From ffe21dc08e1661a10fa0e509c15b02a206f1b179 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Sat, 8 Oct 2022 09:03:41 -0400 Subject: [PATCH 279/302] clean up --- scripts/22-compute-mrr.py | 109 ++++++++++++++++---------------------- 1 file changed, 45 insertions(+), 64 deletions(-) diff --git a/scripts/22-compute-mrr.py b/scripts/22-compute-mrr.py index 2f4c5621..1ca6a53b 100644 --- a/scripts/22-compute-mrr.py +++ b/scripts/22-compute-mrr.py @@ -1,15 +1,19 @@ -""" -This function is used to compute the mean reciprocal ranking for reactant 1 +"""Compute the mean reciprocal ranking for reactant 1 selection using the different distance metrics in the k-NN search. """ +import json +import logging + import numpy as np -import torch -from scipy import sparse -from sklearn.neighbors import BallTree -from syn_net.models.common import xy_to_dataloader +from tqdm import tqdm +from syn_net.config import MAX_PROCESSES from syn_net.encoding.distances import ce_distance, cosine_distance -from syn_net.models.mlp import MLP, load_array +from syn_net.models.common import xy_to_dataloader +from syn_net.models.mlp import load_mlp_from_ckpt +from syn_net.MolEmbedder import MolEmbedder + +logger = logging.getLogger(__name__) def get_args(): @@ -27,7 +31,7 @@ def get_args(): parser.add_argument( "--nbits", type=int, default=4096, help="Number of Bits for Morgan fingerprint." ) - parser.add_argument("--ncpu", type=int, default=8, help="Number of cpus") + parser.add_argument("--ncpu", type=int, default=MAX_PROCESSES, help="Number of cpus") parser.add_argument("--batch_size", type=int, default=64, help="Batch size") parser.add_argument("--device", type=str, default="cuda:0", help="") parser.add_argument( @@ -37,87 +41,64 @@ def get_args(): choices=["euclidean", "manhattan", "chebyshev", "cross_entropy", "cosine"], help="Distance function for `BallTree`.", ) + parser.add_argument("--debug", default=False, action="store_true") return parser.parse_args() if __name__ == "__main__": + logger.info("Start.") + # Parse input args args = get_args() + logger.info(f"Arguments: {json.dumps(vars(args),indent=2)}") - bb_emb_fp_256 = np.load(args.embeddings_file) - n, d = bb_emb_fp_256.shape - - metric = args.distance - if metric == "cross_entropy": + # Init BallTree for kNN-search + if args.distance == "cross_entropy": metric = ce_distance - elif metric == "cosine": + elif args.distance == "cosine": metric = cosine_distance + else: + metric = args.distance - kdtree_fp_256 = BallTree(bb_emb_fp_256, metric=metric) - - path_to_rt1 = args.ckpt_file - batch_size = args.batch_size - ncpu = args.ncpu + # Recall default: Morgan fingerprint with radius=2, nbits=256 + mol_embedder = MolEmbedder().load_precomputed(args.embeddings_file) + mol_embedder.init_balltree(metric=metric) + n, d = mol_embedder.embeddings.shape # Load data dataloader = xy_to_dataloader( - X_file = args.X_data_file, - y_file = args.y_data_file, + X_file=args.X_data_file, + y_file=args.y_data_file, n=None if not args.debug else 128, batch_size=args.batch_size, num_workers=args.ncpu, shuffle=False, ) - rt1_net = MLP.load_from_checkpoint( - path_to_rt1, - input_dim=int(3 * args.nbits), - output_dim=d, - hidden_dim=1200, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task="regression", - loss="mse", - valid_loss="mse", - optimizer="adam", - learning_rate=1e-4, - ncpu=ncpu, - ) - rt1_net.eval() + # Load MLP + rt1_net = load_mlp_from_ckpt(args.ckpt_file) rt1_net.to(args.device) ranks = [] - for X, y in data_iter: + for X, y in tqdm(dataloader): X, y = X.to(args.device), y.to(args.device) - y_hat = rt1_net(X) - dist_true, ind_true = kdtree_fp_256.query(y.detach().cpu().numpy(), k=1) - dist, ind = kdtree_fp_256.query(y_hat.detach().cpu().numpy(), k=n) - ranks = ranks + [np.where(ind[i] == ind_true[i])[0][0] for i in range(len(ind_true))] + y_hat = rt1_net(X) # (batch_size,nbits) - ranks = np.array(ranks) - rrs = 1 / (ranks + 1) + ind_true = mol_embedder.kdtree.query(y.detach().cpu().numpy(), k=1, return_distance=False) + ind = mol_embedder.kdtree.query(y_hat.detach().cpu().numpy(), k=n, return_distance=False) - np.save("ranks_" + metric + ".npy", ranks) # TODO: do not hard code + irows, icols = np.nonzero(ind == ind_true) # irows = range(batch_size), icols = ranks + ranks.append(icols) + + ranks = np.asarray(ranks, dtype=int).flatten() # (nSamples,) + rrs = 1 / (ranks + 1) # +1 for offset 0-based indexing + + # np.save("ranks_" + metric + ".npy", ranks) # TODO: do not hard code print(f"Result using metric: {metric}") print(f"The mean reciprocal ranking is: {rrs.mean():.3f}") - print( - f"The Top-1 recovery rate is: {sum(ranks < 1) / len(ranks) :.3f}, {sum(ranks < 1)} / {len(ranks)}" - ) - print( - f"The Top-3 recovery rate is: {sum(ranks < 3) / len(ranks) :.3f}, {sum(ranks < 3)} / {len(ranks)}" - ) - print( - f"The Top-5 recovery rate is: {sum(ranks < 5) / len(ranks) :.3f}, {sum(ranks < 5)} / {len(ranks)}" - ) - print( - f"The Top-10 recovery rate is: {sum(ranks < 10) / len(ranks) :.3f}, {sum(ranks < 10)} / {len(ranks)}" - ) - print( - f"The Top-15 recovery rate is: {sum(ranks < 15) / len(ranks) :.3f}, {sum(ranks < 15)} / {len(ranks)}" - ) - print( - f"The Top-30 recovery rate is: {sum(ranks < 30) / len(ranks) :.3f}, {sum(ranks < 30)} / {len(ranks)}" - ) - print() + TOP_N_RANKS = (1, 3, 5, 10, 15, 30) + for i in TOP_N_RANKS: + n_recovered = sum(ranks < i) + n = len(ranks) + print(f"The Top-{i:<2d} recovery rate is: {n_recovered/n:.3f} ({n_recovered}/{n})") From 459181790ce0ff007c69ed2223035a565dba5a98 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Sun, 9 Oct 2022 06:16:01 -0400 Subject: [PATCH 280/302] clean up some paths --- scripts/20-predict-targets.py | 68 ++++++++++++++++++----------------- 1 file changed, 35 insertions(+), 33 deletions(-) diff --git a/scripts/20-predict-targets.py b/scripts/20-predict-targets.py index 823b036c..0753331c 100644 --- a/scripts/20-predict-targets.py +++ b/scripts/20-predict-targets.py @@ -7,20 +7,18 @@ from pathlib import Path from typing import Tuple, Union -from syn_net.encoding.distances import cosine_distance - -logger = logging.getLogger(__name__) import numpy as np import pandas as pd from syn_net.config import DATA_PREPROCESS_DIR, DATA_RESULT_DIR, MAX_PROCESSES from syn_net.data_generation.preprocessing import BuildingBlockFileHandler +from syn_net.encoding.distances import cosine_distance from syn_net.models.mlp import load_mlp_from_ckpt +from syn_net.MolEmbedder import MolEmbedder from syn_net.utils.data_utils import ReactionSet, SyntheticTree, SyntheticTreeSet from syn_net.utils.predict_utils import mol_fp, synthetic_tree_decoder_greedy_search -Path(DATA_RESULT_DIR).mkdir(exist_ok=True) -from syn_net.MolEmbedder import MolEmbedder +logger = logging.getLogger(__name__) def _fetch_data_chembl(name: str) -> list[str]: @@ -42,14 +40,13 @@ def _fetch_data(name: str) -> list[str]: Path(DATA_PREPROCESS_DIR) / "syntrees" / f"synthetic-trees-filtered-{args.data}.json.gz" ) logger.info(f"Reading data from {file}") - sts = SyntheticTreeSet() - sts.load(file) - smis_query = [st.root.smiles for st in sts.sts] + syntree_collection = SyntheticTreeSet().load(file) + smiles = [syntree.root.smiles for syntree in syntree_collection] elif args.data in ["chembl"]: - smis_query = _fetch_data_chembl(name) + smiles = _fetch_data_chembl(name) else: # Hopefully got a filename instead - smis_query = _fetch_data_from_file(name) - return smis_query + smiles = _fetch_data_from_file(name) + return smiles def find_best_model_ckpt(path: str) -> Union[Path, None]: # TODO: move to utils.py @@ -61,7 +58,7 @@ def find_best_model_ckpt(path: str) -> Union[Path, None]: # TODO: move to utils """ ckpts = Path(path).rglob("*.ckpt") best_model_ckpt = None - lowest_loss = 10_000 + lowest_loss = 10_000 # ~ math.inf for file in ckpts: stem = file.stem val_loss = float(stem.split("val_loss=")[-1]) @@ -157,26 +154,28 @@ def get_args(): if __name__ == "__main__": - args = get_args() + logger.info("Start.") + # Parse input args + args = get_args() logger.info(f"Arguments: {json.dumps(vars(args),indent=2)}") + # Load data ... logger.info("Start loading data...") # ... query molecules (i.e. molecules to decode) - smiles_queries = _fetch_data(args.data) + targets = _fetch_data(args.data) if args.num > 0: # Select only n queries - smiles_queries = smiles_queries[: args.num] + targets = targets[: args.num] # ... building blocks bblocks = BuildingBlockFileHandler().load(args.building_blocks_file) - bblocks_dict = { - block: i for i, block in enumerate(bblocks) - } # dict is used as lookup table for 2nd reactant during inference + # A dict is used as lookup table for 2nd reactant during inference: + bblocks_dict = {block: i for i, block in enumerate(bblocks)} logger.info(f"Successfully read {args.building_blocks_file}.") # ... reaction templates - rxns = ReactionSet().load(args.output_rxns_collection_file).rxns - logger.info(f"Successfully read {args.output_rxns_collection_file}.") + rxns = ReactionSet().load(args.rxns_collection_file).rxns + logger.info(f"Successfully read {args.rxns_collection_file}.") # ... building block embedding bblocks_molembedder = ( @@ -194,25 +193,28 @@ def get_args(): logger.info("...loading models completed.") # Decode queries, i.e. the target molecules. - logger.info(f"Start to decode {len(smiles_queries)} target molecules.") + logger.info(f"Start to decode {len(targets)} target molecules.") if args.ncpu == 1: - results = [wrapper_decoder(smi) for smi in smiles_queries] + results = [wrapper_decoder(smi) for smi in targets] else: with mp.Pool(processes=args.ncpu) as pool: logger.info(f"Starting MP with ncpu={args.ncpu}") - results = pool.map(wrapper_decoder, smiles_queries) + results = pool.map(wrapper_decoder, targets) logger.info("Finished decoding.") # Print some results from the prediction - smis_decoded = [smi for smi, _, tree in results if tree is not None] - similarities = [sim for _, sim, tree in results if tree is not None] - trees = [tree for _, _, tree in results if tree is not None] + # Note: If a syntree cannot be decoded within `max_depth` steps (15), + # we will count it as unsuccessful. The similarity will be 0. + decoded = [smi for smi, _, _ in results ] + similarities = [sim for _, sim, _ in results ] + trees = [tree for _, _, tree in results ] recovery_rate = (np.asfarray(similarities) == 1.0).sum() / len(similarities) avg_similarity = np.mean(similarities) + n_successful = sum([syntree is not None for syntree in trees]) logger.info(f"For {args.data}:") - logger.info(f" Total number of attempted reconstructions: {len(smiles_queries)}") - logger.info(f" Total number of successful reconstructions: {len(smis_decoded)}") + logger.info(f" Total number of attempted reconstructions: {len(targets)}") + logger.info(f" Total number of successful reconstructions: {n_successful}") logger.info(f" {recovery_rate=}") logger.info(f" {avg_similarity=}") @@ -220,13 +222,13 @@ def get_args(): # 1. Dataframe with targets, decoded, smilarities # 2. Synthetic trees of the decoded SMILES output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) logger.info(f"Saving results to {output_dir} ...") - df = pd.DataFrame( - {"query SMILES": smiles_queries, "decode SMILES": smis_decoded, "similarity": similarities} - ) - df.to_csv(f"{output_dir}/decode_result_{args.data}.csv.gz", compression="gzip", index=False) + + df = pd.DataFrame({"targets": targets, "decoded": decoded, "similarity": similarities}) + df.to_csv(f"{output_dir}/decoded_results.csv.gz", compression="gzip", index=False) synthetic_tree_set = SyntheticTreeSet(sts=trees) - synthetic_tree_set.save(f"{output_dir}/decoded_st_{args.data}.json.gz") + synthetic_tree_set.save(f"{output_dir}/decoded_syntrees.json.gz") logger.info("Completed.") From 16d22696d82b11f5c5a0a990dafb10b19687c992 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Tue, 11 Oct 2022 10:18:57 -0400 Subject: [PATCH 281/302] squeeze instead of reshape to satisfy black --- src/syn_net/utils/prep_utils.py | 48 +++++---------------------------- 1 file changed, 6 insertions(+), 42 deletions(-) diff --git a/src/syn_net/utils/prep_utils.py b/src/syn_net/utils/prep_utils.py index 7f1214c1..791e564b 100644 --- a/src/syn_net/utils/prep_utils.py +++ b/src/syn_net/utils/prep_utils.py @@ -83,20 +83,8 @@ def split_data_into_Xy( # Delete all data where tree was ended (i.e. tree expansion did not trigger reaction) # TODO: Look into simpler slicing with boolean indices, perhabs consider CSR for row slicing - states = sparse.csc_matrix( - states.A[ - (steps[:, 0].A != 3).reshape( - -1, - ) - ] - ) - steps = sparse.csc_matrix( - steps.A[ - (steps[:, 0].A != 3).reshape( - -1, - ) - ] - ) + states = sparse.csc_matrix(states.A[(steps[:, 0].A != 3).squeeze()]) + steps = sparse.csc_matrix(steps.A[(steps[:, 0].A != 3).squeeze()]) # ... reaction data # X: [state, z_reactant_1] @@ -107,20 +95,8 @@ def split_data_into_Xy( sparse.save_npz(output_dir / f"y_rxn_{dataset_type}.npz", y) logger.info(f' saved data for "Reaction" to {output_dir}') - states = sparse.csc_matrix( - states.A[ - (steps[:, 0].A != 2).reshape( - -1, - ) - ] - ) - steps = sparse.csc_matrix( - steps.A[ - (steps[:, 0].A != 2).reshape( - -1, - ) - ] - ) + states = sparse.csc_matrix(states.A[(steps[:, 0].A != 2).squeeze()]) + steps = sparse.csc_matrix(steps.A[(steps[:, 0].A != 2).squeeze()]) enc = OneHotEncoder(handle_unknown="ignore") enc.fit([[i] for i in range(num_rxn)]) @@ -140,20 +116,8 @@ def split_data_into_Xy( sparse.save_npz(output_dir / f"y_rt2_{dataset_type}.npz", y) logger.info(f' saved data for "Reactant 2" to {output_dir}') - states = sparse.csc_matrix( - states.A[ - (steps[:, 0].A != 1).reshape( - -1, - ) - ] - ) - steps = sparse.csc_matrix( - steps.A[ - (steps[:, 0].A != 1).reshape( - -1, - ) - ] - ) + states = sparse.csc_matrix(states.A[(steps[:, 0].A != 1).squeeze()]) + steps = sparse.csc_matrix(steps.A[(steps[:, 0].A != 1).squeeze()]) # ... reactant 1 data # X: [z_state] From 83dc9ce64b57889ca5451fc182f6c59ba121e6c8 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Tue, 11 Oct 2022 14:28:16 -0400 Subject: [PATCH 282/302] speedup + more comments for splitting data --- src/syn_net/utils/prep_utils.py | 58 ++++++++++++++++----------------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/src/syn_net/utils/prep_utils.py b/src/syn_net/utils/prep_utils.py index 791e564b..d9298b87 100644 --- a/src/syn_net/utils/prep_utils.py +++ b/src/syn_net/utils/prep_utils.py @@ -67,8 +67,8 @@ def split_data_into_Xy( output_dir.mkdir(exist_ok=True, parents=True) # Load data # TODO: separate functionality? - states = sparse.load_npz(states_file) - steps = sparse.load_npz(steps_file) + states = sparse.load_npz(states_file) # (n,3*4096) + steps = sparse.load_npz(steps_file) # (n,1+256+91+256+4096) # Extract data for each network... @@ -81,49 +81,49 @@ def split_data_into_Xy( sparse.save_npz(output_dir / f"y_act_{dataset_type}.npz", y) logger.info(f' saved data for "Action" to {output_dir}') - # Delete all data where tree was ended (i.e. tree expansion did not trigger reaction) - # TODO: Look into simpler slicing with boolean indices, perhabs consider CSR for row slicing - states = sparse.csc_matrix(states.A[(steps[:, 0].A != 3).squeeze()]) - steps = sparse.csc_matrix(steps.A[(steps[:, 0].A != 3).squeeze()]) - # ... reaction data # X: [state, z_reactant_1] # y: [reaction_id] (int) - X = sparse.hstack([states, steps[:, (2 * out_dim + 2) :]]) - y = steps[:, out_dim + 1] + # but: delete all steps where we *end* syntrees, as that will not be followed by a reaction + actions = steps[:, 0].A # (n,1) as array to allow boolean + isActionEnd = (actions == 3).squeeze() # (n,) + states = states[~isActionEnd] + steps = steps[~isActionEnd] + X = sparse.hstack([states, steps[:, (2 * out_dim + 2) :]]) # (n,4*4096) + y = steps[:, out_dim + 1] # (n,1) sparse.save_npz(output_dir / f"X_rxn_{dataset_type}.npz", X) sparse.save_npz(output_dir / f"y_rxn_{dataset_type}.npz", y) logger.info(f' saved data for "Reaction" to {output_dir}') - states = sparse.csc_matrix(states.A[(steps[:, 0].A != 2).squeeze()]) - steps = sparse.csc_matrix(steps.A[(steps[:, 0].A != 2).squeeze()]) - - enc = OneHotEncoder(handle_unknown="ignore") - enc.fit([[i] for i in range(num_rxn)]) - # ... reactant 2 data - # X: [z_state, z_reactant_1, reaction_id] - # y: [z'_reactant_2] - X = sparse.hstack( - [ - states, - steps[:, (2 * out_dim + 2) :], - sparse.csc_matrix(enc.transform(steps[:, out_dim + 1].A.reshape((-1, 1))).toarray()), - ] - ) - y = steps[:, (out_dim + 2) : (2 * out_dim + 2)] + # X: [state,z_mol1,OneHotEnc(rxn_id)] + # y: [z_mol2] + # but: delete all steps where we *merge* syntrees, as in that case we already have reactant1+2 + actions = steps[:, 0].A # (n',1) as array to allow boolean + isActionMerge = (actions == 2).squeeze() # (n',) + steps = steps[~isActionMerge] + states = states[~isActionMerge] + z_mol1 = steps[:, (2 * out_dim + 2) :] + rxn_ids = steps[:, (1 + out_dim)] + z_rxn_id = OneHotEncoder().fit(np.arange(num_rxn)[:, None]).transform(rxn_ids.A) + X = sparse.hstack((states, z_mol1, z_rxn_id)) # (n,3*4096+4096+91) + y = steps[:, (2 + out_dim) : (2 * out_dim + 2)] sparse.save_npz(output_dir / f"X_rt2_{dataset_type}.npz", X) sparse.save_npz(output_dir / f"y_rt2_{dataset_type}.npz", y) logger.info(f' saved data for "Reactant 2" to {output_dir}') - states = sparse.csc_matrix(states.A[(steps[:, 0].A != 1).squeeze()]) - steps = sparse.csc_matrix(steps.A[(steps[:, 0].A != 1).squeeze()]) - # ... reactant 1 data # X: [z_state] # y: [z'_reactant_1] + # but: delete all steps where we expand syntrees, as in that case we already have a reactant1 + actions = steps[:, 0].A # (n',1) as array to allow boolean + isActionExpand = (actions == 1).squeeze() # (n',) + steps = steps[~isActionExpand] + states = states[~isActionExpand] + zprime_mol1 = steps[:, 1 : (out_dim + 1)] + X = states - y = steps[:, 1 : (out_dim + 1)] + y = zprime_mol1 sparse.save_npz(output_dir / f"X_rt1_{dataset_type}.npz", X) sparse.save_npz(output_dir / f"y_rt1_{dataset_type}.npz", y) logger.info(f' saved data for "Reactant 1" to {output_dir}') From c3f3c7f7fdaf5c6d33a8a546e55e913a1203024a Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Tue, 11 Oct 2022 14:28:40 -0400 Subject: [PATCH 283/302] fix imports --- src/syn_net/models/rt1.py | 3 ++- src/syn_net/models/rt2.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/syn_net/models/rt1.py b/src/syn_net/models/rt1.py index 6eeed0b0..690847c7 100644 --- a/src/syn_net/models/rt1.py +++ b/src/syn_net/models/rt1.py @@ -11,7 +11,8 @@ from pytorch_lightning.callbacks.progress import TQDMProgressBar from syn_net.models.common import get_args, xy_to_dataloader -from syn_net.models.mlp import MLP, cosine_distance +from syn_net.models.mlp import MLP +from syn_net.encoding.distances import cosine_distance from syn_net.MolEmbedder import MolEmbedder logger = logging.getLogger(__name__) diff --git a/src/syn_net/models/rt2.py b/src/syn_net/models/rt2.py index 91b55c5d..95e849b6 100644 --- a/src/syn_net/models/rt2.py +++ b/src/syn_net/models/rt2.py @@ -11,7 +11,8 @@ from pytorch_lightning.callbacks.progress import TQDMProgressBar from syn_net.models.common import get_args, xy_to_dataloader -from syn_net.models.mlp import MLP, cosine_distance +from syn_net.models.mlp import MLP +from syn_net.encoding.distances import cosine_distance from syn_net.MolEmbedder import MolEmbedder logger = logging.getLogger(__name__) From ff8fd1dfa6f40f8c2f9f372990bf284249e8381f Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Tue, 11 Oct 2022 14:39:12 -0400 Subject: [PATCH 284/302] disimprove evaluation script --- scripts/23-evaluate-predictions.py | 91 +++++++++++++----------------- 1 file changed, 38 insertions(+), 53 deletions(-) diff --git a/scripts/23-evaluate-predictions.py b/scripts/23-evaluate-predictions.py index 9946eed1..5cd25b57 100644 --- a/scripts/23-evaluate-predictions.py +++ b/scripts/23-evaluate-predictions.py @@ -1,15 +1,16 @@ """Evaluate a batch of predictions on different metrics. The predictions are generated in `20-predict-targets.py`. """ +import json +import logging + import numpy as np import pandas as pd from tdc import Evaluator -kl_divergence = Evaluator(name="KL_Divergence") -fcd_distance = Evaluator(name="FCD_Distance") -novelty = Evaluator(name="Novelty") -validity = Evaluator(name="Validity") -uniqueness = Evaluator(name="Uniqueness") +from syn_net.config import MAX_PROCESSES + +logger = logging.getLogger(__name__) def get_args(): @@ -22,28 +23,35 @@ def get_args(): type=str, help="Dataframe with target- and prediction smiles and similarities (*.csv.gz).", ) + # Processing + parser.add_argument("--ncpu", type=int, default=MAX_PROCESSES, help="Number of cpus") + parser.add_argument("--verbose", default=False, action="store_true") return parser.parse_args() if __name__ == "__main__": - args = get_args() + logger.info("Start.") - files = [args.file] # TODO: not sure why the loop but let's keep it for now + # Parse input args + args = get_args() + logger.info(f"Arguments: {json.dumps(vars(args),indent=2)}") # Keep track of successfully and unsuccessfully recovered molecules in 2 df's - recovered = pd.DataFrame({"query SMILES": [], "decode SMILES": [], "similarity": []}) - unrecovered = pd.DataFrame({"query SMILES": [], "decode SMILES": [], "similarity": []}) + # NOTE: column names must match input dataframe... + recovered = pd.DataFrame({"targets": [], "decoded": [], "similarity": []}) + unrecovered = pd.DataFrame({"targets": [], "decoded": [], "similarity": []}) # load each file containing the predictions similarity = [] n_recovered = 0 n_unrecovered = 0 n_total = 0 + files = [args.input_file] # TODO: not sure why the loop but let's keep it for now for file in files: - print(f"File currently being evaluated: {file}") + print(f"Evaluating file: {file}") result_df = pd.read_csv(file) - n_total += len(result_df["decode SMILES"]) + n_total += len(result_df["decoded"]) # Split smiles, discard NaNs is_recovered = result_df["similarity"] == 1.0 @@ -54,34 +62,7 @@ def get_args(): n_unrecovered += len(unrecovered) similarity += unrecovered["similarity"].tolist() - # compute the following properties, using the TDC, for the succesfully recovered molecules - recovered_novelty_all = novelty( - recovered["query SMILES"].tolist(), - recovered["decode SMILES"].tolist(), - ) - recovered_validity_decode_all = validity(recovered["decode SMILES"].tolist()) - recovered_uniqueness_decode_all = uniqueness(recovered["decode SMILES"].tolist()) - recovered_fcd_distance_all = fcd_distance( - recovered["query SMILES"].tolist(), recovered["decode SMILES"].tolist() - ) - recovered_kl_divergence_all = kl_divergence( - recovered["query SMILES"].tolist(), recovered["decode SMILES"].tolist() - ) - - # compute the following properties, using the TDC, for the unrecovered molecules - unrecovered_novelty_all = novelty( - unrecovered["query SMILES"].tolist(), unrecovered["decode SMILES"].tolist() - ) - unrecovered_validity_decode_all = validity(unrecovered["decode SMILES"].tolist()) - unrecovered_uniqueness_decode_all = uniqueness(unrecovered["decode SMILES"].tolist()) - unrecovered_fcd_distance_all = fcd_distance( - unrecovered["query SMILES"].tolist(), unrecovered["decode SMILES"].tolist() - ) - unrecovered_kl_divergence_all = kl_divergence( - unrecovered["query SMILES"].tolist(), unrecovered["decode SMILES"].tolist() - ) - - # Print info + # Print general info print(f"N total {n_total}") print(f"N recovered {n_recovered} ({n_recovered/n_total:.2f})") print(f"N unrecovered {n_unrecovered} ({n_recovered/n_total:.2f})") @@ -92,17 +73,21 @@ def get_args(): print(f"N unfinished trees (NaN) {n_unfinished} ({n_unfinished/n_total:.2f})") print(f"Average similarity (unrecovered only) {np.mean(similarity)}") - print("Novelty, recovered:", recovered_novelty_all) - print("Novelty, unrecovered:", unrecovered_novelty_all) - - print("Validity, decode molecules, recovered:", recovered_validity_decode_all) - print("Validity, decode molecules, unrecovered:", unrecovered_validity_decode_all) - - print("Uniqueness, decode molecules, recovered:", recovered_uniqueness_decode_all) - print("Uniqueness, decode molecules, unrecovered:", unrecovered_uniqueness_decode_all) - - print("FCD distance, recovered:", recovered_fcd_distance_all) - print("FCD distance, unrecovered:", unrecovered_fcd_distance_all) - - print("KL divergence, recovered:", recovered_kl_divergence_all) - print("KL divergence, unrecovered:", unrecovered_kl_divergence_all) + # Evaluate on TDC evaluators + for metric in "KL_divergence FCD_Distance Novelty Validity Uniqueness".split(): + evaluator = Evaluator(name=metric) + try: + score_recovered = evaluator(recovered["targets"], recovered["decoded"]) + score_unrecovered = evaluator(unrecovered["targets"], unrecovered["decoded"]) + except TypeError: + # Some evaluators only take 1 input args, try that. + score_recovered = evaluator(recovered["decoded"]) + score_unrecovered = evaluator(unrecovered["decoded"]) + except Exception as e: + logger.error(f"{e.__class__.__name__}: {str(e)}") + logger.error(e) + score_recovered, score_unrecovered = np.nan, np.nan + + print(f"Evaluation metric for {evaluator.name}:") + print(f" Recovered score: {score_recovered:.2f}") + print(f" Unrecovered score: {score_unrecovered:.2f}") From 9b705ac67b02d7bcd38275858d35fa33067f1eb1 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Tue, 11 Oct 2022 14:40:17 -0400 Subject: [PATCH 285/302] allow to resume training from ckpt --- src/syn_net/models/rxn.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/syn_net/models/rxn.py b/src/syn_net/models/rxn.py index 32a74ace..0ed6601c 100644 --- a/src/syn_net/models/rxn.py +++ b/src/syn_net/models/rxn.py @@ -81,7 +81,7 @@ } output_dim = OUTPUT_DIMS[args.rxn_template] - ckpt_path = args.ckpt_file # TODO: Unify for all networks + ckpt_path = args.ckpt_file # TODO: Unify for all networks mlp = MLP( input_dim=input_dim, output_dim=output_dim, @@ -125,9 +125,8 @@ callbacks=[checkpoint_callback, tqdm_callback], logger=[tb_logger, csv_logger], fast_dev_run=args.fast_dev_run, - ) logger.info(f"Start training") - trainer.fit(mlp, train_dataloader, valid_dataloader,ckpt_path=ckpt_path) + trainer.fit(mlp, train_dataloader, valid_dataloader, ckpt_path=ckpt_path) logger.info(f"Training completed.") From a6a55c813673da56c7e50fcffebace3725b264e1 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Tue, 11 Oct 2022 14:42:31 -0400 Subject: [PATCH 286/302] add a TODO --- scripts/_mp_decode.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/scripts/_mp_decode.py b/scripts/_mp_decode.py index ae82b327..2c1824c5 100644 --- a/scripts/_mp_decode.py +++ b/scripts/_mp_decode.py @@ -1,10 +1,9 @@ """ This file contains a function to decode a single synthetic tree. -TODO: Ussed in `scripts/optimize_ga.py`, refactor. +TODO: Used in `scripts/optimize_ga.py`, refactor. """ import numpy as np import pandas as pd -from dgllife.model import load_pretrained from syn_net.utils.data_utils import ReactionSet from syn_net.utils.predict_utils import synthetic_tree_decoder, tanimoto_similarity @@ -18,11 +17,23 @@ param_dir = "hb_fp_2_4096_256" ncpu = 16 -# define model to use for molecular embedding -model_type = "gin_supervised_contextpred" -device = "cpu" -mol_embedder = load_pretrained(model_type).to(device) -mol_embedder.eval() +def _fetch_gin_molembedder(): + from dgllife.model import load_pretrained + # define model to use for molecular embedding + model_type = "gin_supervised_contextpred" + device = "cpu" + mol_embedder = load_pretrained(model_type).to(device) + return mol_embedder.eval() + +def _fetch_molembedder(featurize:str): + """Fetch molembedder.""" + if featurize=="fp": + return None # not in use + else: + raise NotImplementedError + return _fetch_gin_molembedder() + +mol_embedder = _fetch_molembedder(featurize) # load the purchasable building block embeddings bb_emb = np.load("/pool001/whgao/data/synth_net/st_hb/enamine_us_emb_fp_256.npy") From a15d3d5379fe8859f34bf78427210581c9f55cbb Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Tue, 11 Oct 2022 14:45:09 -0400 Subject: [PATCH 287/302] silence syntree generator by default --- src/syn_net/data_generation/syntrees.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/syn_net/data_generation/syntrees.py b/src/syn_net/data_generation/syntrees.py index bf28a416..c73c1098 100644 --- a/src/syn_net/data_generation/syntrees.py +++ b/src/syn_net/data_generation/syntrees.py @@ -75,8 +75,9 @@ def __init__( self.IDX_RXNS = np.arange(len(self.rxns)) self.processes = processes self.verbose = verbose - if verbose: - logger.setLevel(logging.DEBUG) + if not verbose: + logger.setLevel('CRITICAL') # dont show error msgs + # Time intensive tasks self._init_rxns_with_reactants() From bfdebcdabde5226e74839a209dec0fbf44d8e4e5 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Tue, 11 Oct 2022 15:10:40 -0400 Subject: [PATCH 288/302] "fix" unittests --- README.md | 8 -------- tests/README.md | 19 ++++++------------- ...reparation.py => _test_DataPreparation.py} | 0 tests/{test_Predict.py => _test_Predict.py} | 0 tests/{test_Training.py => _test_Training.py} | 0 5 files changed, 6 insertions(+), 21 deletions(-) rename tests/{test_DataPreparation.py => _test_DataPreparation.py} (100%) rename tests/{test_Predict.py => _test_Predict.py} (100%) rename tests/{test_Training.py => _test_Training.py} (100%) diff --git a/README.md b/README.md index bff53731..dff421f3 100644 --- a/README.md +++ b/README.md @@ -64,14 +64,6 @@ source activate synnet pip install -e . ``` -### Unit tests - -To check that everything has been set-up correctly, you can run the unit tests. If starting in the main directory, you can run the unit tests as follows: - -```python -python -m unittest -``` - ### Data SynNet relies on two datasources: diff --git a/tests/README.md b/tests/README.md index 1fcb317f..28fe8f2f 100644 --- a/tests/README.md +++ b/tests/README.md @@ -1,19 +1,12 @@ # Unit tests -## Instructions -To run the unit tests, start from the main SynNet directory and run: +Sadly, the only working unittests are for the genetic algorithm for molecular optimization. -``` -export PYTHONPATH=`pwd`:$PYTHONPATH -``` - -Then, activate the SynNet conda environment, and from the current unit tests directory, run: - -``` -python -m unittest -``` +> :warning: **TODO**: write/fix unittests and remove this todo (old tests prefixed with `_test*`) ## Dataset + The data used for unit testing consists of: -* 3 randomly sampled reaction templates from the Hartenfeller-Button dataset (*rxn_set_hb_test.txt*) -* 100 randomly sampled matching building blocks from Enamine (*building_blocks_matched.csv.gz*) + +- 3 randomly sampled reaction templates from the Hartenfeller-Button dataset (*rxn_set_hb_test.txt*) +- 100 randomly sampled matching building blocks from Enamine (*building_blocks_matched.csv.gz*) diff --git a/tests/test_DataPreparation.py b/tests/_test_DataPreparation.py similarity index 100% rename from tests/test_DataPreparation.py rename to tests/_test_DataPreparation.py diff --git a/tests/test_Predict.py b/tests/_test_Predict.py similarity index 100% rename from tests/test_Predict.py rename to tests/_test_Predict.py diff --git a/tests/test_Training.py b/tests/_test_Training.py similarity index 100% rename from tests/test_Training.py rename to tests/_test_Training.py From 2fb36b92b4055d253ded511a95ec89c210f6818c Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Tue, 11 Oct 2022 15:11:29 -0400 Subject: [PATCH 289/302] add comments --- src/syn_net/utils/predict_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/syn_net/utils/predict_utils.py b/src/syn_net/utils/predict_utils.py index 8437b86c..d6fffe6a 100644 --- a/src/syn_net/utils/predict_utils.py +++ b/src/syn_net/utils/predict_utils.py @@ -227,6 +227,8 @@ def synthetic_tree_decoder( kdtree = mol_embedder # TODO: dont mis-use this arg # Start iteration + # TODO: tree decoder can exceed this an still return a tree, but action is not equal to 3 + # Raise error instead like in syntree generation? for i in range(max_step): # Encode current state state = tree.get_state() # a list From a6ac242f47846c7d79f115eee9ae8a216726d591 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Tue, 11 Oct 2022 15:54:53 -0400 Subject: [PATCH 290/302] explicitly install `fcd_torch` --- environment.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/environment.yml b/environment.yml index 6858de93..6979cff9 100644 --- a/environment.yml +++ b/environment.yml @@ -20,7 +20,8 @@ dependencies: - pip - pip: - setuptools==59.5.0 # https://github.com/pytorch/pytorch/issues/69894 -# - dgllife # only needed fro gin, will force scikit-learn < 1.0 +# - dgllife # only needed for gin, will force scikit-learn < 1.0 - pathos - rich - pyyaml + - fcd_torch # for evaluators in pytdc From f4247a7e4bd66c7440b8f798ff896a95458130e1 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Tue, 11 Oct 2022 15:55:01 -0400 Subject: [PATCH 291/302] fix typo --- scripts/01-filter-building-blocks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/01-filter-building-blocks.py b/scripts/01-filter-building-blocks.py index 7fe73cc1..18027172 100644 --- a/scripts/01-filter-building-blocks.py +++ b/scripts/01-filter-building-blocks.py @@ -79,7 +79,7 @@ def get_args(): logger.info(f"Total number of building blocks {len(bblocks):d}") logger.info(f"Matched number of building blocks {len(bblocks_filtered):d}") logger.info( - f"{len(bblocks_filtered)/len(bblocks):.2%} of building blocks applicable for the reaction template." + f"{len(bblocks_filtered)/len(bblocks):.2%} of building blocks applicable for the reaction templates." ) logger.info("Completed.") From 4cd9ecf764b634f14425b0925e012244272c1b30 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Tue, 11 Oct 2022 15:59:32 -0400 Subject: [PATCH 292/302] clean --- .gitignore | 154 +++++++++++++++++++++++++++++++---------------------- 1 file changed, 90 insertions(+), 64 deletions(-) diff --git a/.gitignore b/.gitignore index 9b6cec10..5eb7087f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,22 +1,36 @@ -# Code TODOs -TODOs - -# Certain unittest files -tests/data/states_0_train.npz -tests/data/steps_0_train.npz -tests/data/rxns_hb.json.gz -tests/data/st_data.json.gz -tests/data/X_act_train.npz -tests/data/y_act_train.npz -tests/data/X_rt1_train.npz -tests/data/y_rt1_train.npz -tests/data/X_rxn_train.npz -tests/data/y_rxn_train.npz -tests/data/X_rt2_train.npz -tests/data/y_rt2_train.npz -tests/gin_supervised_contextpred_pre_trained.pth -tests/backup/ +# === custom === +data/ +figures/syntrees/ +results/ +logs/ +tmp/ +.dev/ +.old/ +.notes/ +.aliases +*.sh + +# === template === + +# Created by https://www.toptal.com/developers/gitignore/api/visualstudiocode,python,jupyternotebooks +# Edit at https://www.toptal.com/developers/gitignore?templates=visualstudiocode,python,jupyternotebooks + +### JupyterNotebooks ### +# gitignore template for Jupyter Notebooks +# website: http://jupyter.org/ + +.ipynb_checkpoints +*/.ipynb_checkpoints/* + +# IPython +profile_default/ +ipython_config.py + +# Remove previous ipynb_checkpoints +# git rm -r .ipynb_checkpoints/ + +### Python ### # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] @@ -39,7 +53,6 @@ parts/ sdist/ var/ wheels/ -pip-wheel-metadata/ share/python-wheels/ *.egg-info/ .installed.cfg @@ -69,6 +82,7 @@ coverage.xml *.py,cover .hypothesis/ .pytest_cache/ +cover/ # Translations *.mo @@ -91,17 +105,17 @@ instance/ docs/_build/ # PyBuilder +.pybuilder/ target/ # Jupyter Notebook -.ipynb_checkpoints # IPython -profile_default/ -ipython_config.py # pyenv -.python-version +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. @@ -110,7 +124,22 @@ ipython_config.py # install all needed dependencies. #Pipfile.lock -# PEP 582; used by e.g. github.com/David-OConnor/pyflow +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm __pypackages__/ # Celery stuff @@ -147,45 +176,42 @@ dmypy.json # Pyre type checker .pyre/ -# Vim -*~ - -# Data -# data/* -.DS_Store -oracle/* -*.json* -*.npy -*logs* -*.gz -*.csv - -# test Jupyter Notebook -*.ipynb - -# Output files -nohup.out -*.output -*.o -*.out -*.swp -*slurm* -*.sh -*.pth -*.ckpt -*_old* -results -synth_net/params +# pytype static type analyzer +.pytype/ -# Old files set to be deleted -tmp/ -scripts/oracle -temp.py +# Cython debug symbols +cython_debug/ -.dev/ -.old/ -.notes/ -.aliases -figures/ -*.html -*.data*/ \ No newline at end of file +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +### VisualStudioCode ### +.vscode/* +# !.vscode/settings.json +# !.vscode/launch.json +!.vscode/tasks.json +!.vscode/extensions.json +!.vscode/*.code-snippets + +# Local History for Visual Studio Code +.history/ + +# Built Visual Studio Code Extensions +*.vsix + +### VisualStudioCode Patch ###a +# Ignore all local history of files +.history +.ionide + +# Support for Project snippet scope +.vscode/*.code-snippets + +# Ignore code-workspaces +*.code-workspace + +# End of https://www.toptal.com/developers/gitignore/api/visualstudiocode,python,jupyternotebooks From 762023d6f26f0f2dc6673bc4625d6eecd8ac9311 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Tue, 11 Oct 2022 16:03:19 -0400 Subject: [PATCH 293/302] rename --- tests/{filter_unmatch_tests.py => _filter_unmatch_tests.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/{filter_unmatch_tests.py => _filter_unmatch_tests.py} (100%) diff --git a/tests/filter_unmatch_tests.py b/tests/_filter_unmatch_tests.py similarity index 100% rename from tests/filter_unmatch_tests.py rename to tests/_filter_unmatch_tests.py From b530328fff65f3172fe370b3d18842c70a265a44 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Tue, 11 Oct 2022 16:11:15 -0400 Subject: [PATCH 294/302] rename to `synnet` --- scripts/00-extract-smiles-from-sdf.py | 2 +- scripts/01-filter-building-blocks.py | 6 +++--- scripts/02-compute-embeddings.py | 8 ++++---- scripts/03-generate-syntrees.py | 8 ++++---- scripts/04-filter-syntrees.py | 6 +++--- scripts/05-split-syntrees.py | 4 ++-- scripts/06-featurize-syntrees.py | 6 +++--- scripts/07-split-data-for-networks.py | 2 +- scripts/20-predict-targets.py | 14 +++++++------- scripts/21-identify-similar-fps.py | 4 ++-- scripts/22-compute-mrr.py | 10 +++++----- scripts/23-evaluate-predictions.py | 2 +- scripts/_mp_decode.py | 6 +++--- scripts/optimize_ga.py | 4 ++-- setup.py | 4 ++-- src/{syn_net => synnet}/MolEmbedder.py | 2 +- src/{syn_net => synnet}/__init__.py | 0 src/{syn_net => synnet}/config.py | 0 .../data_generation/__init__.py | 0 .../data_generation/check_all_template.py | 0 .../data_generation/preprocessing.py | 4 ++-- .../data_generation/syntrees.py | 4 ++-- src/{syn_net => synnet}/encoding/distances.py | 2 +- src/{syn_net => synnet}/encoding/fingerprints.py | 0 src/{syn_net => synnet}/encoding/gins.py | 0 src/{syn_net => synnet}/encoding/utils.py | 0 src/{syn_net => synnet}/models/act.py | 4 ++-- src/{syn_net => synnet}/models/common.py | 0 src/{syn_net => synnet}/models/mlp.py | 2 +- src/{syn_net => synnet}/models/rt1.py | 8 ++++---- src/{syn_net => synnet}/models/rt2.py | 8 ++++---- src/{syn_net => synnet}/models/rxn.py | 6 +++--- src/{syn_net => synnet}/utils/__init__.py | 0 src/{syn_net => synnet}/utils/data_utils.py | 0 src/{syn_net => synnet}/utils/ga_utils.py | 0 src/{syn_net => synnet}/utils/predict_utils.py | 8 ++++---- src/{syn_net => synnet}/utils/prep_utils.py | 0 src/{syn_net => synnet}/visualize/drawers.py | 0 src/{syn_net => synnet}/visualize/visualizer.py | 12 ++++++------ src/{syn_net => synnet}/visualize/writers.py | 0 tests/_filter_unmatch_tests.py | 2 +- tests/_test_DataPreparation.py | 6 +++--- tests/_test_Predict.py | 4 ++-- tests/_test_Training.py | 4 ++-- tests/test_Optimization.py | 2 +- 45 files changed, 82 insertions(+), 82 deletions(-) rename src/{syn_net => synnet}/MolEmbedder.py (98%) rename src/{syn_net => synnet}/__init__.py (100%) rename src/{syn_net => synnet}/config.py (100%) rename src/{syn_net => synnet}/data_generation/__init__.py (100%) rename src/{syn_net => synnet}/data_generation/check_all_template.py (100%) rename src/{syn_net => synnet}/data_generation/preprocessing.py (98%) rename src/{syn_net => synnet}/data_generation/syntrees.py (99%) rename src/{syn_net => synnet}/encoding/distances.py (97%) rename src/{syn_net => synnet}/encoding/fingerprints.py (100%) rename src/{syn_net => synnet}/encoding/gins.py (100%) rename src/{syn_net => synnet}/encoding/utils.py (100%) rename src/{syn_net => synnet}/models/act.py (97%) rename src/{syn_net => synnet}/models/common.py (100%) rename src/{syn_net => synnet}/models/mlp.py (99%) rename src/{syn_net => synnet}/models/rt1.py (94%) rename src/{syn_net => synnet}/models/rt2.py (95%) rename src/{syn_net => synnet}/models/rxn.py (96%) rename src/{syn_net => synnet}/utils/__init__.py (100%) rename src/{syn_net => synnet}/utils/data_utils.py (100%) rename src/{syn_net => synnet}/utils/ga_utils.py (100%) rename src/{syn_net => synnet}/utils/predict_utils.py (98%) rename src/{syn_net => synnet}/utils/prep_utils.py (100%) rename src/{syn_net => synnet}/visualize/drawers.py (100%) rename src/{syn_net => synnet}/visualize/visualizer.py (94%) rename src/{syn_net => synnet}/visualize/writers.py (100%) diff --git a/scripts/00-extract-smiles-from-sdf.py b/scripts/00-extract-smiles-from-sdf.py index c6ac4ee9..a68ae2fc 100644 --- a/scripts/00-extract-smiles-from-sdf.py +++ b/scripts/00-extract-smiles-from-sdf.py @@ -4,7 +4,7 @@ import logging from pathlib import Path -from syn_net.utils.prep_utils import Sdf2SmilesExtractor +from synnet.utils.prep_utils import Sdf2SmilesExtractor logger = logging.getLogger(__name__) diff --git a/scripts/01-filter-building-blocks.py b/scripts/01-filter-building-blocks.py index 18027172..823cd411 100644 --- a/scripts/01-filter-building-blocks.py +++ b/scripts/01-filter-building-blocks.py @@ -4,13 +4,13 @@ from rdkit import RDLogger -from syn_net.config import MAX_PROCESSES -from syn_net.data_generation.preprocessing import ( +from synnet.config import MAX_PROCESSES +from synnet.data_generation.preprocessing import ( BuildingBlockFileHandler, BuildingBlockFilter, ReactionTemplateFileHandler, ) -from syn_net.utils.data_utils import ReactionSet +from synnet.utils.data_utils import ReactionSet RDLogger.DisableLog("rdApp.*") logger = logging.getLogger(__name__) diff --git a/scripts/02-compute-embeddings.py b/scripts/02-compute-embeddings.py index 740a6e3a..78ee4af9 100644 --- a/scripts/02-compute-embeddings.py +++ b/scripts/02-compute-embeddings.py @@ -9,10 +9,10 @@ import logging from functools import partial -from syn_net.config import MAX_PROCESSES -from syn_net.data_generation.preprocessing import BuildingBlockFileHandler -from syn_net.encoding.fingerprints import mol_fp -from syn_net.MolEmbedder import MolEmbedder +from synnet.config import MAX_PROCESSES +from synnet.data_generation.preprocessing import BuildingBlockFileHandler +from synnet.encoding.fingerprints import mol_fp +from synnet.MolEmbedder import MolEmbedder logger = logging.getLogger(__file__) diff --git a/scripts/03-generate-syntrees.py b/scripts/03-generate-syntrees.py index 53986643..66aaa8e8 100644 --- a/scripts/03-generate-syntrees.py +++ b/scripts/03-generate-syntrees.py @@ -8,13 +8,13 @@ from rdkit import RDLogger from tqdm import tqdm -from syn_net.config import MAX_PROCESSES -from syn_net.data_generation.preprocessing import ( +from synnet.config import MAX_PROCESSES +from synnet.data_generation.preprocessing import ( BuildingBlockFileHandler, ReactionTemplateFileHandler, ) -from syn_net.data_generation.syntrees import SynTreeGenerator, wraps_syntreegenerator_generate -from syn_net.utils.data_utils import SyntheticTree, SyntheticTreeSet +from synnet.data_generation.syntrees import SynTreeGenerator, wraps_syntreegenerator_generate +from synnet.utils.data_utils import SyntheticTree, SyntheticTreeSet logger = logging.getLogger(__name__) from typing import Tuple, Union diff --git a/scripts/04-filter-syntrees.py b/scripts/04-filter-syntrees.py index 012a45ff..97fb195b 100644 --- a/scripts/04-filter-syntrees.py +++ b/scripts/04-filter-syntrees.py @@ -8,8 +8,8 @@ from rdkit import Chem, RDLogger from tqdm import tqdm -from syn_net.config import MAX_PROCESSES -from syn_net.utils.data_utils import SyntheticTree, SyntheticTreeSet +from synnet.config import MAX_PROCESSES +from synnet.utils.data_utils import SyntheticTree, SyntheticTreeSet logger = logging.getLogger(__name__) @@ -88,7 +88,7 @@ def get_args(): logger.info(f"Successfully loaded '{args.input_file}' with {len(syntree_collection)} syntrees.") # Filter trees - # TODO: Move to src/syn_net/data_generation/filters.py ? + # TODO: Move to src/synnet/data_generation/filters.py ? valid_root_mol_filter = ValidRootMolFilter() interesting_mol_filter = OracleFilter(threshold=0.5, rng=np.random.default_rng()) diff --git a/scripts/05-split-syntrees.py b/scripts/05-split-syntrees.py index 9fbef7b5..542d0afb 100644 --- a/scripts/05-split-syntrees.py +++ b/scripts/05-split-syntrees.py @@ -4,8 +4,8 @@ import logging from pathlib import Path -from syn_net.config import MAX_PROCESSES -from syn_net.utils.data_utils import SyntheticTreeSet +from synnet.config import MAX_PROCESSES +from synnet.utils.data_utils import SyntheticTreeSet logger = logging.getLogger(__name__) diff --git a/scripts/06-featurize-syntrees.py b/scripts/06-featurize-syntrees.py index 4a641df0..d83c88a4 100644 --- a/scripts/06-featurize-syntrees.py +++ b/scripts/06-featurize-syntrees.py @@ -7,16 +7,16 @@ from scipy import sparse from tqdm import tqdm -from syn_net.data_generation.syntrees import ( +from synnet.data_generation.syntrees import ( IdentityIntEncoder, MorganFingerprintEncoder, SynTreeFeaturizer, ) -from syn_net.utils.data_utils import SyntheticTreeSet +from synnet.utils.data_utils import SyntheticTreeSet logger = logging.getLogger(__file__) -from syn_net.config import MAX_PROCESSES +from synnet.config import MAX_PROCESSES def get_args(): diff --git a/scripts/07-split-data-for-networks.py b/scripts/07-split-data-for-networks.py index f429644c..c9220879 100644 --- a/scripts/07-split-data-for-networks.py +++ b/scripts/07-split-data-for-networks.py @@ -4,7 +4,7 @@ import logging from pathlib import Path -from syn_net.utils.prep_utils import split_data_into_Xy +from synnet.utils.prep_utils import split_data_into_Xy logger = logging.getLogger(__file__) diff --git a/scripts/20-predict-targets.py b/scripts/20-predict-targets.py index 0753331c..86d91b2d 100644 --- a/scripts/20-predict-targets.py +++ b/scripts/20-predict-targets.py @@ -10,13 +10,13 @@ import numpy as np import pandas as pd -from syn_net.config import DATA_PREPROCESS_DIR, DATA_RESULT_DIR, MAX_PROCESSES -from syn_net.data_generation.preprocessing import BuildingBlockFileHandler -from syn_net.encoding.distances import cosine_distance -from syn_net.models.mlp import load_mlp_from_ckpt -from syn_net.MolEmbedder import MolEmbedder -from syn_net.utils.data_utils import ReactionSet, SyntheticTree, SyntheticTreeSet -from syn_net.utils.predict_utils import mol_fp, synthetic_tree_decoder_greedy_search +from synnet.config import DATA_PREPROCESS_DIR, DATA_RESULT_DIR, MAX_PROCESSES +from synnet.data_generation.preprocessing import BuildingBlockFileHandler +from synnet.encoding.distances import cosine_distance +from synnet.models.mlp import load_mlp_from_ckpt +from synnet.MolEmbedder import MolEmbedder +from synnet.utils.data_utils import ReactionSet, SyntheticTree, SyntheticTreeSet +from synnet.utils.predict_utils import mol_fp, synthetic_tree_decoder_greedy_search logger = logging.getLogger(__name__) diff --git a/scripts/21-identify-similar-fps.py b/scripts/21-identify-similar-fps.py index 3d55a8bf..746fd812 100644 --- a/scripts/21-identify-similar-fps.py +++ b/scripts/21-identify-similar-fps.py @@ -12,11 +12,11 @@ from rdkit import Chem, DataStructs from rdkit.Chem import AllChem -from syn_net.utils.data_utils import SyntheticTreeSet +from synnet.utils.data_utils import SyntheticTreeSet logger = logging.getLogger(__file__) -from syn_net.config import MAX_PROCESSES +from synnet.config import MAX_PROCESSES def get_args(): diff --git a/scripts/22-compute-mrr.py b/scripts/22-compute-mrr.py index 1ca6a53b..0e589738 100644 --- a/scripts/22-compute-mrr.py +++ b/scripts/22-compute-mrr.py @@ -7,11 +7,11 @@ import numpy as np from tqdm import tqdm -from syn_net.config import MAX_PROCESSES -from syn_net.encoding.distances import ce_distance, cosine_distance -from syn_net.models.common import xy_to_dataloader -from syn_net.models.mlp import load_mlp_from_ckpt -from syn_net.MolEmbedder import MolEmbedder +from synnet.config import MAX_PROCESSES +from synnet.encoding.distances import ce_distance, cosine_distance +from synnet.models.common import xy_to_dataloader +from synnet.models.mlp import load_mlp_from_ckpt +from synnet.MolEmbedder import MolEmbedder logger = logging.getLogger(__name__) diff --git a/scripts/23-evaluate-predictions.py b/scripts/23-evaluate-predictions.py index 5cd25b57..415f00b6 100644 --- a/scripts/23-evaluate-predictions.py +++ b/scripts/23-evaluate-predictions.py @@ -8,7 +8,7 @@ import pandas as pd from tdc import Evaluator -from syn_net.config import MAX_PROCESSES +from synnet.config import MAX_PROCESSES logger = logging.getLogger(__name__) diff --git a/scripts/_mp_decode.py b/scripts/_mp_decode.py index 2c1824c5..b551c35c 100644 --- a/scripts/_mp_decode.py +++ b/scripts/_mp_decode.py @@ -5,9 +5,9 @@ import numpy as np import pandas as pd -from syn_net.utils.data_utils import ReactionSet -from syn_net.utils.predict_utils import synthetic_tree_decoder, tanimoto_similarity -from syn_net.models.mlp import load_mlp_from_ckpt +from synnet.utils.data_utils import ReactionSet +from synnet.utils.predict_utils import synthetic_tree_decoder, tanimoto_similarity +from synnet.models.mlp import load_mlp_from_ckpt # define some constants (here, for the Hartenfeller-Button test set) nbits = 4096 diff --git a/scripts/optimize_ga.py b/scripts/optimize_ga.py index 4d3c6342..d0e909b7 100644 --- a/scripts/optimize_ga.py +++ b/scripts/optimize_ga.py @@ -12,8 +12,8 @@ from tdc import Oracle import scripts._mp_decode as decode -from syn_net.utils.ga_utils import crossover, mutation -from syn_net.utils.predict_utils import mol_fp +from synnet.utils.ga_utils import crossover, mutation +from synnet.utils.predict_utils import mol_fp def dock_drd3(smi): diff --git a/setup.py b/setup.py index f68f64da..f8480447 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ long_description = fh.read() setuptools.setup( - name="syn_net", + name="synnet", version="0.1.0", description="Synthetic tree generation using neural networks.", long_description=long_description, @@ -15,6 +15,6 @@ "Operating System :: OS Independent", ], package_dir={"": "src"}, - packages=setuptools.find_packages(where="src",exclude=["src/syn_net/encoding/gins.py"]), + packages=setuptools.find_packages(where="src",exclude=["src/synnet/encoding/gins.py"]), python_requires=">=3.9", ) \ No newline at end of file diff --git a/src/syn_net/MolEmbedder.py b/src/synnet/MolEmbedder.py similarity index 98% rename from src/syn_net/MolEmbedder.py rename to src/synnet/MolEmbedder.py index 2298b9fc..105962c3 100644 --- a/src/syn_net/MolEmbedder.py +++ b/src/synnet/MolEmbedder.py @@ -5,7 +5,7 @@ import numpy as np from sklearn.neighbors import BallTree -from syn_net.config import MAX_PROCESSES +from synnet.config import MAX_PROCESSES logger = logging.getLogger(__name__) diff --git a/src/syn_net/__init__.py b/src/synnet/__init__.py similarity index 100% rename from src/syn_net/__init__.py rename to src/synnet/__init__.py diff --git a/src/syn_net/config.py b/src/synnet/config.py similarity index 100% rename from src/syn_net/config.py rename to src/synnet/config.py diff --git a/src/syn_net/data_generation/__init__.py b/src/synnet/data_generation/__init__.py similarity index 100% rename from src/syn_net/data_generation/__init__.py rename to src/synnet/data_generation/__init__.py diff --git a/src/syn_net/data_generation/check_all_template.py b/src/synnet/data_generation/check_all_template.py similarity index 100% rename from src/syn_net/data_generation/check_all_template.py rename to src/synnet/data_generation/check_all_template.py diff --git a/src/syn_net/data_generation/preprocessing.py b/src/synnet/data_generation/preprocessing.py similarity index 98% rename from src/syn_net/data_generation/preprocessing.py rename to src/synnet/data_generation/preprocessing.py index 311508e8..e800a749 100644 --- a/src/syn_net/data_generation/preprocessing.py +++ b/src/synnet/data_generation/preprocessing.py @@ -2,8 +2,8 @@ from tqdm import tqdm -from syn_net.config import MAX_PROCESSES -from syn_net.utils.data_utils import Reaction +from synnet.config import MAX_PROCESSES +from synnet.utils.data_utils import Reaction class BuildingBlockFilter: diff --git a/src/syn_net/data_generation/syntrees.py b/src/synnet/data_generation/syntrees.py similarity index 99% rename from src/syn_net/data_generation/syntrees.py rename to src/synnet/data_generation/syntrees.py index c73c1098..d8f31931 100644 --- a/src/syn_net/data_generation/syntrees.py +++ b/src/synnet/data_generation/syntrees.py @@ -8,11 +8,11 @@ from scipy import sparse from tqdm import tqdm -from syn_net.config import MAX_PROCESSES +from synnet.config import MAX_PROCESSES logger = logging.getLogger(__name__) -from syn_net.utils.data_utils import Reaction, SyntheticTree +from synnet.utils.data_utils import Reaction, SyntheticTree class NoReactantAvailableError(Exception): diff --git a/src/syn_net/encoding/distances.py b/src/synnet/encoding/distances.py similarity index 97% rename from src/syn_net/encoding/distances.py rename to src/synnet/encoding/distances.py index 41d34429..fd5b5e92 100644 --- a/src/syn_net/encoding/distances.py +++ b/src/synnet/encoding/distances.py @@ -1,6 +1,6 @@ import numpy as np -from syn_net.encoding.fingerprints import mol_fp +from synnet.encoding.fingerprints import mol_fp def cosine_distance(v1, v2): diff --git a/src/syn_net/encoding/fingerprints.py b/src/synnet/encoding/fingerprints.py similarity index 100% rename from src/syn_net/encoding/fingerprints.py rename to src/synnet/encoding/fingerprints.py diff --git a/src/syn_net/encoding/gins.py b/src/synnet/encoding/gins.py similarity index 100% rename from src/syn_net/encoding/gins.py rename to src/synnet/encoding/gins.py diff --git a/src/syn_net/encoding/utils.py b/src/synnet/encoding/utils.py similarity index 100% rename from src/syn_net/encoding/utils.py rename to src/synnet/encoding/utils.py diff --git a/src/syn_net/models/act.py b/src/synnet/models/act.py similarity index 97% rename from src/syn_net/models/act.py rename to src/synnet/models/act.py index ffd56a4b..fbf6da3d 100644 --- a/src/syn_net/models/act.py +++ b/src/synnet/models/act.py @@ -10,8 +10,8 @@ from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.callbacks.progress import TQDMProgressBar -from syn_net.models.common import get_args, xy_to_dataloader -from syn_net.models.mlp import MLP +from synnet.models.common import get_args, xy_to_dataloader +from synnet.models.mlp import MLP logger = logging.getLogger(__name__) MODEL_ID = Path(__file__).stem diff --git a/src/syn_net/models/common.py b/src/synnet/models/common.py similarity index 100% rename from src/syn_net/models/common.py rename to src/synnet/models/common.py diff --git a/src/syn_net/models/mlp.py b/src/synnet/models/mlp.py similarity index 99% rename from src/syn_net/models/mlp.py rename to src/synnet/models/mlp.py index d29ba3dc..b31b24c4 100644 --- a/src/syn_net/models/mlp.py +++ b/src/synnet/models/mlp.py @@ -9,7 +9,7 @@ import torch.nn.functional as F from torch import nn -from syn_net.MolEmbedder import MolEmbedder +from synnet.MolEmbedder import MolEmbedder logger = logging.getLogger(__name__) diff --git a/src/syn_net/models/rt1.py b/src/synnet/models/rt1.py similarity index 94% rename from src/syn_net/models/rt1.py rename to src/synnet/models/rt1.py index 690847c7..cbe97ecf 100644 --- a/src/syn_net/models/rt1.py +++ b/src/synnet/models/rt1.py @@ -10,10 +10,10 @@ from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.callbacks.progress import TQDMProgressBar -from syn_net.models.common import get_args, xy_to_dataloader -from syn_net.models.mlp import MLP -from syn_net.encoding.distances import cosine_distance -from syn_net.MolEmbedder import MolEmbedder +from synnet.models.common import get_args, xy_to_dataloader +from synnet.models.mlp import MLP +from synnet.encoding.distances import cosine_distance +from synnet.MolEmbedder import MolEmbedder logger = logging.getLogger(__name__) MODEL_ID = Path(__file__).stem diff --git a/src/syn_net/models/rt2.py b/src/synnet/models/rt2.py similarity index 95% rename from src/syn_net/models/rt2.py rename to src/synnet/models/rt2.py index 95e849b6..51db0732 100644 --- a/src/syn_net/models/rt2.py +++ b/src/synnet/models/rt2.py @@ -10,10 +10,10 @@ from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.callbacks.progress import TQDMProgressBar -from syn_net.models.common import get_args, xy_to_dataloader -from syn_net.models.mlp import MLP -from syn_net.encoding.distances import cosine_distance -from syn_net.MolEmbedder import MolEmbedder +from synnet.models.common import get_args, xy_to_dataloader +from synnet.models.mlp import MLP +from synnet.encoding.distances import cosine_distance +from synnet.MolEmbedder import MolEmbedder logger = logging.getLogger(__name__) MODEL_ID = Path(__file__).stem diff --git a/src/syn_net/models/rxn.py b/src/synnet/models/rxn.py similarity index 96% rename from src/syn_net/models/rxn.py rename to src/synnet/models/rxn.py index 0ed6601c..d4ded03c 100644 --- a/src/syn_net/models/rxn.py +++ b/src/synnet/models/rxn.py @@ -11,9 +11,9 @@ from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.callbacks.progress import TQDMProgressBar -from syn_net.config import CHECKPOINTS_DIR -from syn_net.models.common import get_args, xy_to_dataloader -from syn_net.models.mlp import MLP +from synnet.config import CHECKPOINTS_DIR +from synnet.models.common import get_args, xy_to_dataloader +from synnet.models.mlp import MLP logger = logging.getLogger(__name__) MODEL_ID = Path(__file__).stem diff --git a/src/syn_net/utils/__init__.py b/src/synnet/utils/__init__.py similarity index 100% rename from src/syn_net/utils/__init__.py rename to src/synnet/utils/__init__.py diff --git a/src/syn_net/utils/data_utils.py b/src/synnet/utils/data_utils.py similarity index 100% rename from src/syn_net/utils/data_utils.py rename to src/synnet/utils/data_utils.py diff --git a/src/syn_net/utils/ga_utils.py b/src/synnet/utils/ga_utils.py similarity index 100% rename from src/syn_net/utils/ga_utils.py rename to src/synnet/utils/ga_utils.py diff --git a/src/syn_net/utils/predict_utils.py b/src/synnet/utils/predict_utils.py similarity index 98% rename from src/syn_net/utils/predict_utils.py rename to src/synnet/utils/predict_utils.py index d6fffe6a..147b10cb 100644 --- a/src/syn_net/utils/predict_utils.py +++ b/src/synnet/utils/predict_utils.py @@ -11,10 +11,10 @@ from rdkit import Chem from sklearn.neighbors import BallTree -from syn_net.encoding.distances import cosine_distance, tanimoto_similarity -from syn_net.encoding.fingerprints import mol_fp -from syn_net.encoding.utils import one_hot_encoder -from syn_net.utils.data_utils import Reaction, SyntheticTree +from synnet.encoding.distances import cosine_distance, tanimoto_similarity +from synnet.encoding.fingerprints import mol_fp +from synnet.encoding.utils import one_hot_encoder +from synnet.utils.data_utils import Reaction, SyntheticTree # create a random seed for NumPy np.random.seed(6) diff --git a/src/syn_net/utils/prep_utils.py b/src/synnet/utils/prep_utils.py similarity index 100% rename from src/syn_net/utils/prep_utils.py rename to src/synnet/utils/prep_utils.py diff --git a/src/syn_net/visualize/drawers.py b/src/synnet/visualize/drawers.py similarity index 100% rename from src/syn_net/visualize/drawers.py rename to src/synnet/visualize/drawers.py diff --git a/src/syn_net/visualize/visualizer.py b/src/synnet/visualize/visualizer.py similarity index 94% rename from src/syn_net/visualize/visualizer.py rename to src/synnet/visualize/visualizer.py index df9079aa..7d5e9593 100644 --- a/src/syn_net/visualize/visualizer.py +++ b/src/synnet/visualize/visualizer.py @@ -1,9 +1,9 @@ from pathlib import Path from typing import Union -from syn_net.utils.data_utils import NodeChemical, NodeRxn, SyntheticTree -from syn_net.visualize.drawers import MolDrawer -from syn_net.visualize.writers import subgraph +from synnet.utils.data_utils import NodeChemical, NodeRxn, SyntheticTree +from synnet.visualize.drawers import MolDrawer +from synnet.visualize.writers import subgraph class SynTreeVisualizer: @@ -154,9 +154,9 @@ def demo(): st = SyntheticTree() st.read(data) - from syn_net.visualize.drawers import MolDrawer - from syn_net.visualize.visualizer import SynTreeVisualizer - from syn_net.visualize.writers import SynTreeWriter + from synnet.visualize.drawers import MolDrawer + from synnet.visualize.visualizer import SynTreeVisualizer + from synnet.visualize.writers import SynTreeWriter outpath = Path("./figures/syntrees/generation/st") outpath.mkdir(parents=True, exist_ok=True) diff --git a/src/syn_net/visualize/writers.py b/src/synnet/visualize/writers.py similarity index 100% rename from src/syn_net/visualize/writers.py rename to src/synnet/visualize/writers.py diff --git a/tests/_filter_unmatch_tests.py b/tests/_filter_unmatch_tests.py index e77f18db..ca968cc4 100644 --- a/tests/_filter_unmatch_tests.py +++ b/tests/_filter_unmatch_tests.py @@ -4,7 +4,7 @@ """ import pandas as pd from tqdm import tqdm -from syn_net.utils.data_utils import * +from synnet.utils.data_utils import * if __name__ == '__main__': diff --git a/tests/_test_DataPreparation.py b/tests/_test_DataPreparation.py index 27342fc4..0697064e 100644 --- a/tests/_test_DataPreparation.py +++ b/tests/_test_DataPreparation.py @@ -12,9 +12,9 @@ from scipy import sparse from tqdm import tqdm -from syn_net.encoding.gins import get_mol_embedding -from syn_net.utils.prep_utils import organize, synthetic_tree_generator, prep_data -from syn_net.utils.data_utils import SyntheticTreeSet, Reaction, ReactionSet +from synnet.encoding.gins import get_mol_embedding +from synnet.utils.prep_utils import organize, synthetic_tree_generator, prep_data +from synnet.utils.data_utils import SyntheticTreeSet, Reaction, ReactionSet TEST_DIR = Path(__file__).parent diff --git a/tests/_test_Predict.py b/tests/_test_Predict.py index 83854e66..626fe49a 100644 --- a/tests/_test_Predict.py +++ b/tests/_test_Predict.py @@ -7,11 +7,11 @@ import numpy as np import pandas as pd -from syn_net.utils.predict_utils import ( +from synnet.utils.predict_utils import ( synthetic_tree_decoder_greedy_search, mol_fp, ) -from syn_net.utils.data_utils import SyntheticTreeSet, ReactionSet +from synnet.utils.data_utils import SyntheticTreeSet, ReactionSet from syn_net.models.chkpt_loader import load_modules_from_checkpoint TEST_DIR = Path(__file__).parent diff --git a/tests/_test_Training.py b/tests/_test_Training.py index 3765f5a6..0cfbc039 100644 --- a/tests/_test_Training.py +++ b/tests/_test_Training.py @@ -10,8 +10,8 @@ from scipy import sparse import torch -from syn_net.models.mlp import MLP, load_array -from syn_net.MolEmbedder import MolEmbedder +from synnet.models.mlp import MLP, load_array +from synnet.MolEmbedder import MolEmbedder TEST_DIR = Path(__file__).parent diff --git a/tests/test_Optimization.py b/tests/test_Optimization.py index fabdd14f..9d69b139 100644 --- a/tests/test_Optimization.py +++ b/tests/test_Optimization.py @@ -3,7 +3,7 @@ """ import unittest import numpy as np -from syn_net.utils.ga_utils import crossover, mutation, fitness_sum +from synnet.utils.ga_utils import crossover, mutation, fitness_sum class TestOptimization(unittest.TestCase): From 0f3a8322a024be312d011125b96335caf37ed650 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 12 Oct 2022 08:51:04 -0400 Subject: [PATCH 295/302] fix loading pre-trained models --- src/synnet/models/mlp.py | 58 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 57 insertions(+), 1 deletion(-) diff --git a/src/synnet/models/mlp.py b/src/synnet/models/mlp.py index b31b24c4..4ad28bd9 100644 --- a/src/synnet/models/mlp.py +++ b/src/synnet/models/mlp.py @@ -2,6 +2,7 @@ Multi-layer perceptron (MLP) class. """ import logging +from pathlib import Path import numpy as np import pytorch_lightning as pl @@ -131,10 +132,65 @@ def nn_search_list(y, kdtree): ind = kdtree.query(y, k=1, return_distance=False) # (n_samples, 1) return ind + def load_mlp_from_ckpt(ckpt_file: str): """Load a model from a checkpoint for inference.""" - model = MLP.load_from_checkpoint(ckpt_file) + try: + model = MLP.load_from_checkpoint(ckpt_file) + except TypeError: + model = _load_mlp_from_iclr_ckpt(ckpt_file) return model.eval() + +def _load_mlp_from_iclr_ckpt(ckpt_file: str): + """Load a model from a checkpoint for inference. + Info: hparams were not saved, so we specify the ones needed for inference again.""" + model = Path(ckpt_file).parent.name # assume "//.ckpt" + if model == "act": + model = MLP.load_from_checkpoint( + ckpt_file, + input_dim=3*4096, + output_dim=4, + hidden_dim=1000, + num_layers=5, + task="classification", + dropout=0.5, + ) + elif model == "rt1": + model = MLP.load_from_checkpoint( + ckpt_file, + input_dim=3 * 4096, + output_dim=256, + hidden_dim=1200, + num_layers=5, + task="regression", + dropout=0.5, + ) + elif model == "rxn": + model = MLP.load_from_checkpoint( + ckpt_file, + input_dim=4 * 4096, + output_dim=91, + hidden_dim=3000, + num_layers=5, + task="classification", + dropout=0.5, + ) + elif model == "rt2": + model = MLP.load_from_checkpoint( + ckpt_file, + input_dim=4 * 4096 + 91, + output_dim=256, + hidden_dim=3000, + num_layers=5, + task="regression", + dropout=0.5, + ) + + else: + raise ValueError + return model.eval() + + if __name__ == "__main__": pass From 7158ed6dd02978e550d00575f4e8c6f66e7e1f4c Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 12 Oct 2022 08:57:54 -0400 Subject: [PATCH 296/302] move all code into single file --- scripts/_mp_decode.py | 107 ---------------------------------------- scripts/optimize_ga.py | 108 ++++++++++++++++++++++++++++++++++++++++- 2 files changed, 106 insertions(+), 109 deletions(-) delete mode 100644 scripts/_mp_decode.py diff --git a/scripts/_mp_decode.py b/scripts/_mp_decode.py deleted file mode 100644 index b551c35c..00000000 --- a/scripts/_mp_decode.py +++ /dev/null @@ -1,107 +0,0 @@ -""" -This file contains a function to decode a single synthetic tree. -TODO: Used in `scripts/optimize_ga.py`, refactor. -""" -import numpy as np -import pandas as pd - -from synnet.utils.data_utils import ReactionSet -from synnet.utils.predict_utils import synthetic_tree_decoder, tanimoto_similarity -from synnet.models.mlp import load_mlp_from_ckpt - -# define some constants (here, for the Hartenfeller-Button test set) -nbits = 4096 -out_dim = 256 -rxn_template = "hb" -featurize = "fp" -param_dir = "hb_fp_2_4096_256" -ncpu = 16 - -def _fetch_gin_molembedder(): - from dgllife.model import load_pretrained - # define model to use for molecular embedding - model_type = "gin_supervised_contextpred" - device = "cpu" - mol_embedder = load_pretrained(model_type).to(device) - return mol_embedder.eval() - -def _fetch_molembedder(featurize:str): - """Fetch molembedder.""" - if featurize=="fp": - return None # not in use - else: - raise NotImplementedError - return _fetch_gin_molembedder() - -mol_embedder = _fetch_molembedder(featurize) - -# load the purchasable building block embeddings -bb_emb = np.load("/pool001/whgao/data/synth_net/st_hb/enamine_us_emb_fp_256.npy") - -# define path to the reaction templates and purchasable building blocks -path_to_reaction_file = ( - f"/pool001/whgao/data/synth_net/st_{rxn_template}/reactions_{rxn_template}.json.gz" -) -path_to_building_blocks = ( - f"/pool001/whgao/data/synth_net/st_{rxn_template}/enamine_us_matched.csv.gz" -) - -# define paths to pretrained modules -param_path = f"/home/whgao/synth_net/synth_net/params/{param_dir}/" -act_path = f"{param_path}act.ckpt" -rt1_path = f"{param_path}rt1.ckpt" -rxn_path = f"{param_path}rxn.ckpt" -rt2_path = f"{param_path}rt2.ckpt" - -# load the purchasable building block SMILES to a dictionary -building_blocks = pd.read_csv(path_to_building_blocks, compression="gzip")["SMILES"].tolist() -bb_dict = {building_blocks[i]: i for i in range(len(building_blocks))} - -# load the reaction templates as a ReactionSet object -rxn_set = ReactionSet().load(path_to_reaction_file) -rxns = rxn_set.rxns - -# load the pre-trained modules -act_net = load_mlp_from_ckpt(act_path) -rt1_net = load_mlp_from_ckpt(rt1_path) -rxn_net = load_mlp_from_ckpt(rxn_path) -rt2_net = load_mlp_from_ckpt(rt2_path) - - -def func(emb): - """ - Generates the synthetic tree for the input molecular embedding. - - Args: - emb (np.ndarray): Molecular embedding to decode. - - Returns: - str: SMILES for the final chemical node in the tree. - SyntheticTree: The generated synthetic tree. - """ - emb = emb.reshape((1, -1)) - try: - tree, action = synthetic_tree_decoder( - z_target=emb, - building_blocks=building_blocks, - bb_dict=bb_dict, - reaction_templates=rxns, - mol_embedder=mol_embedder, - action_net=act_net, - reactant1_net=rt1_net, - rxn_net=rxn_net, - reactant2_net=rt2_net, - bb_emb=bb_emb, - rxn_template=rxn_template, - n_bits=nbits, - max_step=15, - ) - except Exception as e: - print(e) - action = -1 - if action != 3: - return None, None - else: - scores = np.array(tanimoto_similarity(emb, [node.smiles for node in tree.chemicals])) - max_score_idx = np.where(scores == np.max(scores))[0][0] - return tree.chemicals[max_score_idx].smiles, tree diff --git a/scripts/optimize_ga.py b/scripts/optimize_ga.py index d0e909b7..c36d0026 100644 --- a/scripts/optimize_ga.py +++ b/scripts/optimize_ga.py @@ -11,11 +11,115 @@ import pandas as pd from tdc import Oracle -import scripts._mp_decode as decode from synnet.utils.ga_utils import crossover, mutation from synnet.utils.predict_utils import mol_fp +import numpy as np +import pandas as pd + +from synnet.utils.data_utils import ReactionSet +from synnet.utils.predict_utils import synthetic_tree_decoder, tanimoto_similarity +from synnet.models.mlp import load_mlp_from_ckpt + +# define some constants (here, for the Hartenfeller-Button test set) +nbits = 4096 +out_dim = 256 +rxn_template = "hb" +featurize = "fp" +param_dir = "hb_fp_2_4096_256" +ncpu = 16 + +def _fetch_gin_molembedder(): + from dgllife.model import load_pretrained + # define model to use for molecular embedding + model_type = "gin_supervised_contextpred" + device = "cpu" + mol_embedder = load_pretrained(model_type).to(device) + return mol_embedder.eval() + +def _fetch_molembedder(featurize:str): + """Fetch molembedder.""" + if featurize=="fp": + return None # not in use + else: + raise NotImplementedError + return _fetch_gin_molembedder() + +mol_embedder = _fetch_molembedder(featurize) + +# load the purchasable building block embeddings +bb_emb = np.load("/pool001/whgao/data/synth_net/st_hb/enamine_us_emb_fp_256.npy") + +# define path to the reaction templates and purchasable building blocks +path_to_reaction_file = ( + f"/pool001/whgao/data/synth_net/st_{rxn_template}/reactions_{rxn_template}.json.gz" +) +path_to_building_blocks = ( + f"/pool001/whgao/data/synth_net/st_{rxn_template}/enamine_us_matched.csv.gz" +) + +# define paths to pretrained modules +param_path = f"/home/whgao/synth_net/synth_net/params/{param_dir}/" +act_path = f"{param_path}act.ckpt" +rt1_path = f"{param_path}rt1.ckpt" +rxn_path = f"{param_path}rxn.ckpt" +rt2_path = f"{param_path}rt2.ckpt" + +# load the purchasable building block SMILES to a dictionary +building_blocks = pd.read_csv(path_to_building_blocks, compression="gzip")["SMILES"].tolist() +bb_dict = {building_blocks[i]: i for i in range(len(building_blocks))} + +# load the reaction templates as a ReactionSet object +rxn_set = ReactionSet().load(path_to_reaction_file) +rxns = rxn_set.rxns + +# load the pre-trained modules +act_net = load_mlp_from_ckpt(act_path) +rt1_net = load_mlp_from_ckpt(rt1_path) +rxn_net = load_mlp_from_ckpt(rxn_path) +rt2_net = load_mlp_from_ckpt(rt2_path) + + +def func(emb): + """ + Generates the synthetic tree for the input molecular embedding. + + Args: + emb (np.ndarray): Molecular embedding to decode. + + Returns: + str: SMILES for the final chemical node in the tree. + SyntheticTree: The generated synthetic tree. + """ + emb = emb.reshape((1, -1)) + try: + tree, action = synthetic_tree_decoder( + z_target=emb, + building_blocks=building_blocks, + bb_dict=bb_dict, + reaction_templates=rxns, + mol_embedder=mol_embedder, + action_net=act_net, + reactant1_net=rt1_net, + rxn_net=rxn_net, + reactant2_net=rt2_net, + bb_emb=bb_emb, + rxn_template=rxn_template, + n_bits=nbits, + max_step=15, + ) + except Exception as e: + print(e) + action = -1 + if action != 3: + return None, None + else: + scores = np.array(tanimoto_similarity(emb, [node.smiles for node in tree.chemicals])) + max_score_idx = np.where(scores == np.max(scores))[0][0] + return tree.chemicals[max_score_idx].smiles, tree + + def dock_drd3(smi): """ Returns the docking score for the DRD3 target. @@ -80,7 +184,7 @@ def fitness(embs, _pool, obj): trees (list): Contains the synthetic trees generated from the input embeddings. """ - results = _pool.map(decode.func, embs) + results = _pool.map(func, embs) smiles = [r[0] for r in results] trees = [r[1] for r in results] From 0ecebf56108ce52c912aaa02a7ad2d6af7a393e1 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 12 Oct 2022 09:39:11 -0400 Subject: [PATCH 297/302] move code to load from ckpts to module --- scripts/20-predict-targets.py | 39 ++--------------- src/synnet/models/common.py | 81 +++++++++++++++++++++++++++++++++++ src/synnet/models/mlp.py | 59 ------------------------- 3 files changed, 84 insertions(+), 95 deletions(-) diff --git a/scripts/20-predict-targets.py b/scripts/20-predict-targets.py index 86d91b2d..d74c729b 100644 --- a/scripts/20-predict-targets.py +++ b/scripts/20-predict-targets.py @@ -13,7 +13,7 @@ from synnet.config import DATA_PREPROCESS_DIR, DATA_RESULT_DIR, MAX_PROCESSES from synnet.data_generation.preprocessing import BuildingBlockFileHandler from synnet.encoding.distances import cosine_distance -from synnet.models.mlp import load_mlp_from_ckpt +from synnet.models.common import load_mlp_from_ckpt, find_best_model_ckpt from synnet.MolEmbedder import MolEmbedder from synnet.utils.data_utils import ReactionSet, SyntheticTree, SyntheticTreeSet from synnet.utils.predict_utils import mol_fp, synthetic_tree_decoder_greedy_search @@ -48,39 +48,6 @@ def _fetch_data(name: str) -> list[str]: smiles = _fetch_data_from_file(name) return smiles - -def find_best_model_ckpt(path: str) -> Union[Path, None]: # TODO: move to utils.py - """Find checkpoint with lowest val_loss. - - Poor man's regex: - somepath/act/ckpts.epoch=70-val_loss=0.03.ckpt - ^^^^--extract this as float - """ - ckpts = Path(path).rglob("*.ckpt") - best_model_ckpt = None - lowest_loss = 10_000 # ~ math.inf - for file in ckpts: - stem = file.stem - val_loss = float(stem.split("val_loss=")[-1]) - if val_loss < lowest_loss: - best_model_ckpt = file - lowest_loss = val_loss - return best_model_ckpt - - -def _load_pretrained_model(path_to_checkpoints: list[Path]): - """Wrapper to load modules from checkpoint.""" - # Define paths to pretrained models. - act_path, rt1_path, rxn_path, rt2_path = path_to_checkpoints - - # Load the pre-trained models. - act_net = load_mlp_from_ckpt(act_path) - rt1_net = load_mlp_from_ckpt(rt1_path) - rxn_net = load_mlp_from_ckpt(rxn_path) - rt2_net = load_mlp_from_ckpt(rt2_path) - return act_net, rt1_net, rxn_net, rt2_net - - def wrapper_decoder(smiles: str) -> Tuple[str, float, SyntheticTree]: """Generate a synthetic tree for the input molecular embedding.""" emb = mol_fp(smiles) @@ -188,8 +155,8 @@ def get_args(): # ... models logger.info("Start loading models from checkpoints...") path = Path(args.ckpt_dir) - paths = [find_best_model_ckpt(path / model) for model in "act rt1 rxn rt2".split()] - act_net, rt1_net, rxn_net, rt2_net = _load_pretrained_model(paths) + ckpt_files = [find_best_model_ckpt(path / model) for model in "act rt1 rxn rt2".split()] + act_net, rt1_net, rxn_net, rt2_net = [load_mlp_from_ckpt(file) for file in ckpt_files] logger.info("...loading models completed.") # Decode queries, i.e. the target molecules. diff --git a/src/synnet/models/common.py b/src/synnet/models/common.py index 8cd2848a..301be528 100644 --- a/src/synnet/models/common.py +++ b/src/synnet/models/common.py @@ -1,12 +1,15 @@ """Common methods and params shared by all models. """ +from pathlib import Path from typing import Union import numpy as np import torch from scipy import sparse +from synnet.models.mlp import MLP + def get_args(): import argparse @@ -72,6 +75,84 @@ def xy_to_dataloader( return torch.utils.data.DataLoader(dataset, **kwargs) +def load_mlp_from_ckpt(ckpt_file: str): + """Load a model from a checkpoint for inference.""" + try: + model = MLP.load_from_checkpoint(ckpt_file) + except TypeError: + model = _load_mlp_from_iclr_ckpt(ckpt_file) + return model.eval() + + +def find_best_model_ckpt(path: str) -> Union[Path, None]: + """Find checkpoint with lowest val_loss. + + Poor man's regex: + somepath/act/ckpts.epoch=70-val_loss=0.03.ckpt + ^^^^--extract this as float + """ + ckpts = Path(path).rglob("*.ckpt") + best_model_ckpt = None + lowest_loss = 10_000 # ~ math.inf + for file in ckpts: + stem = file.stem + val_loss = float(stem.split("val_loss=")[-1]) + if val_loss < lowest_loss: + best_model_ckpt = file + lowest_loss = val_loss + return best_model_ckpt + + +def _load_mlp_from_iclr_ckpt(ckpt_file: str): + """Load a model from a checkpoint for inference. + Info: hparams were not saved, so we specify the ones needed for inference again.""" + model = Path(ckpt_file).parent.name # assume "//.ckpt" + if model == "act": + model = MLP.load_from_checkpoint( + ckpt_file, + input_dim=3 * 4096, + output_dim=4, + hidden_dim=1000, + num_layers=5, + task="classification", + dropout=0.5, + ) + elif model == "rt1": + model = MLP.load_from_checkpoint( + ckpt_file, + input_dim=3 * 4096, + output_dim=256, + hidden_dim=1200, + num_layers=5, + task="regression", + dropout=0.5, + ) + elif model == "rxn": + model = MLP.load_from_checkpoint( + ckpt_file, + input_dim=4 * 4096, + output_dim=91, + hidden_dim=3000, + num_layers=5, + task="classification", + dropout=0.5, + ) + elif model == "rt2": + model = MLP.load_from_checkpoint( + ckpt_file, + input_dim=4 * 4096 + 91, + output_dim=256, + hidden_dim=3000, + num_layers=5, + task="regression", + dropout=0.5, + ) + + else: + raise ValueError + return model.eval() + + if __name__ == "__main__": import json diff --git a/src/synnet/models/mlp.py b/src/synnet/models/mlp.py index 4ad28bd9..a946e2a9 100644 --- a/src/synnet/models/mlp.py +++ b/src/synnet/models/mlp.py @@ -133,64 +133,5 @@ def nn_search_list(y, kdtree): return ind -def load_mlp_from_ckpt(ckpt_file: str): - """Load a model from a checkpoint for inference.""" - try: - model = MLP.load_from_checkpoint(ckpt_file) - except TypeError: - model = _load_mlp_from_iclr_ckpt(ckpt_file) - return model.eval() - - -def _load_mlp_from_iclr_ckpt(ckpt_file: str): - """Load a model from a checkpoint for inference. - Info: hparams were not saved, so we specify the ones needed for inference again.""" - model = Path(ckpt_file).parent.name # assume "//.ckpt" - if model == "act": - model = MLP.load_from_checkpoint( - ckpt_file, - input_dim=3*4096, - output_dim=4, - hidden_dim=1000, - num_layers=5, - task="classification", - dropout=0.5, - ) - elif model == "rt1": - model = MLP.load_from_checkpoint( - ckpt_file, - input_dim=3 * 4096, - output_dim=256, - hidden_dim=1200, - num_layers=5, - task="regression", - dropout=0.5, - ) - elif model == "rxn": - model = MLP.load_from_checkpoint( - ckpt_file, - input_dim=4 * 4096, - output_dim=91, - hidden_dim=3000, - num_layers=5, - task="classification", - dropout=0.5, - ) - elif model == "rt2": - model = MLP.load_from_checkpoint( - ckpt_file, - input_dim=4 * 4096 + 91, - output_dim=256, - hidden_dim=3000, - num_layers=5, - task="regression", - dropout=0.5, - ) - - else: - raise ValueError - return model.eval() - - if __name__ == "__main__": pass From 8a0a105cb67908434634614aa67087b336773f2e Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 12 Oct 2022 10:57:42 -0400 Subject: [PATCH 298/302] fix imports & paths --- scripts/optimize_ga.py | 139 +++++++++++++++++++++++------------------ 1 file changed, 77 insertions(+), 62 deletions(-) diff --git a/scripts/optimize_ga.py b/scripts/optimize_ga.py index c36d0026..183e90b9 100644 --- a/scripts/optimize_ga.py +++ b/scripts/optimize_ga.py @@ -1,34 +1,29 @@ """ Generates synthetic trees where the root molecule optimizes for a specific objective -based on Therapeutic Data Commons (TDC) oracle functions. Uses a genetic algorithm -to optimize embeddings before decoding. -""" +based on Therapeutics Data Commons (TDC) oracle functions. +Uses a genetic algorithm to optimize embeddings before decoding. +""" # TODO: Refactor/Consolidate with generic inference script import json import multiprocessing as mp import time - +from pathlib import Path import numpy as np import pandas as pd from tdc import Oracle - +from synnet.config import MAX_PROCESSES +from synnet.encoding.distances import cosine_distance +from synnet.MolEmbedder import MolEmbedder from synnet.utils.ga_utils import crossover, mutation from synnet.utils.predict_utils import mol_fp - +from pathlib import Path +from synnet.data_generation.preprocessing import BuildingBlockFileHandler import numpy as np import pandas as pd from synnet.utils.data_utils import ReactionSet from synnet.utils.predict_utils import synthetic_tree_decoder, tanimoto_similarity -from synnet.models.mlp import load_mlp_from_ckpt - -# define some constants (here, for the Hartenfeller-Button test set) -nbits = 4096 -out_dim = 256 -rxn_template = "hb" -featurize = "fp" -param_dir = "hb_fp_2_4096_256" -ncpu = 16 +from synnet.models.common import load_mlp_from_ckpt, find_best_model_ckpt def _fetch_gin_molembedder(): from dgllife.model import load_pretrained @@ -46,41 +41,6 @@ def _fetch_molembedder(featurize:str): raise NotImplementedError return _fetch_gin_molembedder() -mol_embedder = _fetch_molembedder(featurize) - -# load the purchasable building block embeddings -bb_emb = np.load("/pool001/whgao/data/synth_net/st_hb/enamine_us_emb_fp_256.npy") - -# define path to the reaction templates and purchasable building blocks -path_to_reaction_file = ( - f"/pool001/whgao/data/synth_net/st_{rxn_template}/reactions_{rxn_template}.json.gz" -) -path_to_building_blocks = ( - f"/pool001/whgao/data/synth_net/st_{rxn_template}/enamine_us_matched.csv.gz" -) - -# define paths to pretrained modules -param_path = f"/home/whgao/synth_net/synth_net/params/{param_dir}/" -act_path = f"{param_path}act.ckpt" -rt1_path = f"{param_path}rt1.ckpt" -rxn_path = f"{param_path}rxn.ckpt" -rt2_path = f"{param_path}rt2.ckpt" - -# load the purchasable building block SMILES to a dictionary -building_blocks = pd.read_csv(path_to_building_blocks, compression="gzip")["SMILES"].tolist() -bb_dict = {building_blocks[i]: i for i in range(len(building_blocks))} - -# load the reaction templates as a ReactionSet object -rxn_set = ReactionSet().load(path_to_reaction_file) -rxns = rxn_set.rxns - -# load the pre-trained modules -act_net = load_mlp_from_ckpt(act_path) -rt1_net = load_mlp_from_ckpt(rt1_path) -rxn_net = load_mlp_from_ckpt(rxn_path) -rt2_net = load_mlp_from_ckpt(rt2_path) - - def func(emb): """ Generates the synthetic tree for the input molecular embedding. @@ -99,7 +59,7 @@ def func(emb): building_blocks=building_blocks, bb_dict=bb_dict, reaction_templates=rxns, - mol_embedder=mol_embedder, + mol_embedder=bblocks_molembedder.kdtree, # TODO: fix this, currently misused, action_net=act_net, reactant1_net=rt1_net, rxn_net=rxn_net, @@ -271,15 +231,31 @@ def mut_probability_scheduler(n, total): else: return 0.5 - -if __name__ == "__main__": - +def get_args(): import argparse parser = argparse.ArgumentParser() + # File I/O + parser.add_argument( + "--building-blocks-file", + type=str, + help="Input file with SMILES strings (First row `SMILES`, then one per line).", + ) + parser.add_argument( + "--rxns-collection-file", + type=str, + help="Input file for the collection of reactions matched with building-blocks.", + ) parser.add_argument( - "-i", - "--input_file", + "--embeddings-knn-file", + type=str, + help="Input file for the pre-computed embeddings (*.npy).", + ) + parser.add_argument( + "--ckpt-dir", type=str, help="Directory with checkpoints for {act,rt1,rxn,rt2}-model." + ) + parser.add_argument( + "--input-file", type=str, default=None, help="A file contains the starting mating pool.", @@ -301,7 +277,7 @@ def mut_probability_scheduler(n, total): help="Number of offsprings to generate each iteration.", ) parser.add_argument("--num_gen", type=int, default=30, help="Number of generations to proceed.") - parser.add_argument("--ncpu", type=int, default=16, help="Number of cpus") + parser.add_argument("--ncpu", type=int, default=MAX_PROCESSES, help="Number of cpus") parser.add_argument( "--mut_probability", type=float, @@ -316,10 +292,9 @@ def mut_probability_scheduler(n, total): ) parser.add_argument("--restart", action="store_true") parser.add_argument("--seed", type=int, default=1, help="Random seed.") - args = parser.parse_args() - - np.random.seed(args.seed) + return parser.parse_args() +def fetch_population(args) -> np.ndarray: if args.restart: population = np.load(args.input_file) print(f"Starting with {len(population)} fps from {args.input_file}") @@ -332,9 +307,49 @@ def mut_probability_scheduler(n, total): starting_smiles = starting_smiles["smiles"].tolist() population = np.array([mol_fp(smi, args.radius, args.nbits) for smi in starting_smiles]) print(f"Starting with {len(starting_smiles)} fps from {args.input_file}") + return population + +if __name__ == "__main__": + + args = get_args() + np.random.seed(args.seed) + # define some constants (here, for the Hartenfeller-Button test set) + nbits = 4096 + out_dim = 256 + rxn_template = "hb" + featurize = "fp" + param_dir = "hb_fp_2_4096_256" + + # Load data + mol_embedder = _fetch_molembedder(featurize) + + # load the purchasable building block embeddings + bblocks_molembedder = ( + MolEmbedder().load_precomputed(args.embeddings_knn_file).init_balltree(cosine_distance) + ) + bb_emb = bblocks_molembedder.get_embeddings() + + # load the purchasable building block SMILES to a dictionary + building_blocks = BuildingBlockFileHandler().load(args.building_blocks_file) + # A dict is used as lookup table for 2nd reactant during inference: + bb_dict = {block: i for i, block in enumerate(building_blocks)} + + # load the reaction templates as a ReactionSet object + rxns = ReactionSet().load(args.rxns_collection_file).rxns + + # load the pre-trained modules + path = Path(args.ckpt_dir) + ckpt_files = [find_best_model_ckpt(path / model) for model in "act rt1 rxn rt2".split()] + act_net, rt1_net, rxn_net, rt2_net = [load_mlp_from_ckpt(file) for file in ckpt_files] + + # Get initial population + population = fetch_population(args) + + # Evaluation initial population with mp.Pool(processes=args.ncpu) as pool: scores, mols, trees = fitness(embs=population, _pool=pool, obj=args.objective) + scores = np.array(scores) score_x = np.argsort(scores) population = population[score_x[::-1]] @@ -344,10 +359,9 @@ def mut_probability_scheduler(n, total): print(f"Scores: {scores}") print(f"Top-3 Smiles: {mols[:3]}") + # Genetic Algorithm: loop over generations recent_scores = [] - for n in range(args.num_gen): - t = time.time() dist_ = distribution_schedule(n, args.num_gen) @@ -414,6 +428,7 @@ def mut_probability_scheduler(n, total): print("Early Stop!") break + # Save results data = { "objective": args.objective, "top1": np.mean(scores[:1]), From 0f9164dea85d31ccb5af85babb2c797bbdb966cf Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 12 Oct 2022 10:58:57 -0400 Subject: [PATCH 299/302] format --- scripts/20-predict-targets.py | 13 ++++++----- scripts/optimize_ga.py | 31 ++++++++++++++------------ src/synnet/data_generation/syntrees.py | 3 +-- src/synnet/models/rt1.py | 2 +- src/synnet/models/rt2.py | 2 +- 5 files changed, 27 insertions(+), 24 deletions(-) diff --git a/scripts/20-predict-targets.py b/scripts/20-predict-targets.py index d74c729b..ed88f98c 100644 --- a/scripts/20-predict-targets.py +++ b/scripts/20-predict-targets.py @@ -5,7 +5,7 @@ import logging import multiprocessing as mp from pathlib import Path -from typing import Tuple, Union +from typing import Tuple import numpy as np import pandas as pd @@ -13,7 +13,7 @@ from synnet.config import DATA_PREPROCESS_DIR, DATA_RESULT_DIR, MAX_PROCESSES from synnet.data_generation.preprocessing import BuildingBlockFileHandler from synnet.encoding.distances import cosine_distance -from synnet.models.common import load_mlp_from_ckpt, find_best_model_ckpt +from synnet.models.common import find_best_model_ckpt, load_mlp_from_ckpt from synnet.MolEmbedder import MolEmbedder from synnet.utils.data_utils import ReactionSet, SyntheticTree, SyntheticTreeSet from synnet.utils.predict_utils import mol_fp, synthetic_tree_decoder_greedy_search @@ -48,6 +48,7 @@ def _fetch_data(name: str) -> list[str]: smiles = _fetch_data_from_file(name) return smiles + def wrapper_decoder(smiles: str) -> Tuple[str, float, SyntheticTree]: """Generate a synthetic tree for the input molecular embedding.""" emb = mol_fp(smiles) @@ -156,7 +157,7 @@ def get_args(): logger.info("Start loading models from checkpoints...") path = Path(args.ckpt_dir) ckpt_files = [find_best_model_ckpt(path / model) for model in "act rt1 rxn rt2".split()] - act_net, rt1_net, rxn_net, rt2_net = [load_mlp_from_ckpt(file) for file in ckpt_files] + act_net, rt1_net, rxn_net, rt2_net = [load_mlp_from_ckpt(file) for file in ckpt_files] logger.info("...loading models completed.") # Decode queries, i.e. the target molecules. @@ -172,9 +173,9 @@ def get_args(): # Print some results from the prediction # Note: If a syntree cannot be decoded within `max_depth` steps (15), # we will count it as unsuccessful. The similarity will be 0. - decoded = [smi for smi, _, _ in results ] - similarities = [sim for _, sim, _ in results ] - trees = [tree for _, _, tree in results ] + decoded = [smi for smi, _, _ in results] + similarities = [sim for _, sim, _ in results] + trees = [tree for _, _, tree in results] recovery_rate = (np.asfarray(similarities) == 1.0).sum() / len(similarities) avg_similarity = np.mean(similarities) diff --git a/scripts/optimize_ga.py b/scripts/optimize_ga.py index 183e90b9..76f5271d 100644 --- a/scripts/optimize_ga.py +++ b/scripts/optimize_ga.py @@ -2,45 +2,45 @@ Generates synthetic trees where the root molecule optimizes for a specific objective based on Therapeutics Data Commons (TDC) oracle functions. Uses a genetic algorithm to optimize embeddings before decoding. -""" # TODO: Refactor/Consolidate with generic inference script +""" # TODO: Refactor/Consolidate with generic inference script import json import multiprocessing as mp import time from pathlib import Path + import numpy as np import pandas as pd from tdc import Oracle + from synnet.config import MAX_PROCESSES +from synnet.data_generation.preprocessing import BuildingBlockFileHandler from synnet.encoding.distances import cosine_distance +from synnet.models.common import find_best_model_ckpt, load_mlp_from_ckpt from synnet.MolEmbedder import MolEmbedder +from synnet.utils.data_utils import ReactionSet from synnet.utils.ga_utils import crossover, mutation -from synnet.utils.predict_utils import mol_fp -from pathlib import Path -from synnet.data_generation.preprocessing import BuildingBlockFileHandler - -import numpy as np -import pandas as pd +from synnet.utils.predict_utils import mol_fp, synthetic_tree_decoder, tanimoto_similarity -from synnet.utils.data_utils import ReactionSet -from synnet.utils.predict_utils import synthetic_tree_decoder, tanimoto_similarity -from synnet.models.common import load_mlp_from_ckpt, find_best_model_ckpt def _fetch_gin_molembedder(): from dgllife.model import load_pretrained + # define model to use for molecular embedding model_type = "gin_supervised_contextpred" device = "cpu" mol_embedder = load_pretrained(model_type).to(device) return mol_embedder.eval() -def _fetch_molembedder(featurize:str): + +def _fetch_molembedder(featurize: str): """Fetch molembedder.""" - if featurize=="fp": - return None # not in use + if featurize == "fp": + return None # not in use else: raise NotImplementedError return _fetch_gin_molembedder() + def func(emb): """ Generates the synthetic tree for the input molecular embedding. @@ -231,6 +231,7 @@ def mut_probability_scheduler(n, total): else: return 0.5 + def get_args(): import argparse @@ -294,6 +295,7 @@ def get_args(): parser.add_argument("--seed", type=int, default=1, help="Random seed.") return parser.parse_args() + def fetch_population(args) -> np.ndarray: if args.restart: population = np.load(args.input_file) @@ -309,6 +311,7 @@ def fetch_population(args) -> np.ndarray: print(f"Starting with {len(starting_smiles)} fps from {args.input_file}") return population + if __name__ == "__main__": args = get_args() @@ -341,7 +344,7 @@ def fetch_population(args) -> np.ndarray: # load the pre-trained modules path = Path(args.ckpt_dir) ckpt_files = [find_best_model_ckpt(path / model) for model in "act rt1 rxn rt2".split()] - act_net, rt1_net, rxn_net, rt2_net = [load_mlp_from_ckpt(file) for file in ckpt_files] + act_net, rt1_net, rxn_net, rt2_net = [load_mlp_from_ckpt(file) for file in ckpt_files] # Get initial population population = fetch_population(args) diff --git a/src/synnet/data_generation/syntrees.py b/src/synnet/data_generation/syntrees.py index d8f31931..5e071aea 100644 --- a/src/synnet/data_generation/syntrees.py +++ b/src/synnet/data_generation/syntrees.py @@ -76,8 +76,7 @@ def __init__( self.processes = processes self.verbose = verbose if not verbose: - logger.setLevel('CRITICAL') # dont show error msgs - + logger.setLevel("CRITICAL") # dont show error msgs # Time intensive tasks self._init_rxns_with_reactants() diff --git a/src/synnet/models/rt1.py b/src/synnet/models/rt1.py index cbe97ecf..8bb3e9b9 100644 --- a/src/synnet/models/rt1.py +++ b/src/synnet/models/rt1.py @@ -10,9 +10,9 @@ from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.callbacks.progress import TQDMProgressBar +from synnet.encoding.distances import cosine_distance from synnet.models.common import get_args, xy_to_dataloader from synnet.models.mlp import MLP -from synnet.encoding.distances import cosine_distance from synnet.MolEmbedder import MolEmbedder logger = logging.getLogger(__name__) diff --git a/src/synnet/models/rt2.py b/src/synnet/models/rt2.py index 51db0732..2ea69453 100644 --- a/src/synnet/models/rt2.py +++ b/src/synnet/models/rt2.py @@ -10,9 +10,9 @@ from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.callbacks.progress import TQDMProgressBar +from synnet.encoding.distances import cosine_distance from synnet.models.common import get_args, xy_to_dataloader from synnet.models.mlp import MLP -from synnet.encoding.distances import cosine_distance from synnet.MolEmbedder import MolEmbedder logger = logging.getLogger(__name__) From 80fc464ca512e04efc236d07a165ea148d366979 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 12 Oct 2022 10:59:16 -0400 Subject: [PATCH 300/302] update gitignore --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index 5eb7087f..e2f1b9d8 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,8 @@ data/ figures/syntrees/ results/ +checkpoints/ +oracle/ logs/ tmp/ .dev/ From e88551f1ffda8575645b6133a570a5667b93953b Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 12 Oct 2022 11:01:17 -0400 Subject: [PATCH 301/302] update readme --- README.md | 85 ++++++++++++++++++++++++++++++------------------------- 1 file changed, 46 insertions(+), 39 deletions(-) diff --git a/README.md b/README.md index dff421f3..ba642508 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,7 @@ # SynNet -This repo contains the code and analysis scripts for our amortized approach to synthetic tree generation using neural networks. Our model can serve as both a synthesis planning tool and as a tool for synthesizable molecular design. +This repo contains the code and analysis scripts for our amortized approach to synthetic tree generation using neural networks. +Our model can serve as both a synthesis planning tool and as a tool for synthesizable molecular design. The method is described in detail in the publication "Amortized tree generation for bottom-up synthesis planning and synthesizable molecular design" available on the [arXiv](https://arxiv.org/abs/2110.06389) and summarized below. @@ -30,25 +31,31 @@ The model consists of four modules, each containing a multi-layer perceptron (ML ![the model](./figures/network.png "model scheme") -These four modules predict the probability distributions of actions to be taken within a single reaction step, and determine the nodes to be added to the synthetic tree under construction. All of these networks are conditioned on the target molecule embedding. +These four modules predict the probability distributions of actions to be taken within a single reaction step, and determine the nodes to be added to the synthetic tree under construction. +All of these networks are conditioned on the target molecule embedding. ### Synthesis planning -This task is to infer the synthetic pathway to a given target molecule. We formulate this problem as generating a synthetic tree such that the product molecule it produces (i.e., the molecule at the root node) matches the desired target molecule. +This task is to infer the synthetic pathway to a given target molecule. +We formulate this problem as generating a synthetic tree such that the product molecule it produces (i.e., the molecule at the root node) matches the desired target molecule. -For this task, we can take a molecular embedding for the desired product, and use it as input to our model to produce a synthetic tree. If the desired product is successfully recovered, then the final root molecule will match the desired molecule used to create the input embedding. If the desired product is not successully recovered, it is possible the final root molecule may still be *similar* to the desired molecule used to create the input embedding, and thus our tool can also be used for *synthesizable analog recommendation*. +For this task, we can take a molecular embedding for the desired product, and use it as input to our model to produce a synthetic tree. +If the desired product is successfully recovered, then the final root molecule will match the desired molecule used to create the input embedding. +If the desired product is not successully recovered, it is possible the final root molecule may still be *similar* to the desired molecule used to create the input embedding, and thus our tool can also be used for *synthesizable analog recommendation*. ![the generation process](./figures/generation_process.png "generation process") ### Synthesizable molecular design -This task is to optimize a molecular structure with respect to an oracle function (e.g. bioactivity), while ensuring the synthetic accessibility of the molecules. We formulate this problem as optimizing the structure of a synthetic tree with respect to the desired properties of the product molecule it produces. +This task is to optimize a molecular structure with respect to an oracle function (e.g. bioactivity), while ensuring the synthetic accessibility of the molecules. +We formulate this problem as optimizing the structure of a synthetic tree with respect to the desired properties of the product molecule it produces. -To do this, we optimize the molecular embedding of the molecule using a genetic algorithm and the desired oracle function. The optimized molecule embedding can then be used as input to our model to produce a synthetic tree, where the final root molecule corresponds to the optimized molecule. +To do this, we optimize the molecular embedding of the molecule using a genetic algorithm and the desired oracle function. +The optimized molecule embedding can then be used as input to our model to produce a synthetic tree, where the final root molecule corresponds to the optimized molecule. ## Setup instructions -### Setting up the environment +### Environment Conda is used to create the environment for running SynNet. @@ -57,13 +64,22 @@ Conda is used to create the environment for running SynNet. conda env create -f environment.yml ``` -Before running any SynNet code, activate the environment and install this package in development mode. This ensures the scripts can find the right files. You can do this by typing: +Before running any SynNet code, activate the environment and install this package in development mode: ```bash source activate synnet pip install -e . ``` +The model implementations can be found in `src/syn_net/models/`. + +The pre-processing and analysis scripts are in `scripts/`. + +### Train the model from scratch + +Before training any models, you will first need to some data preprocessing. +Please see [INSTRUCTIONS.md](INSTRUCTIONS.md) for a complete guide. + ### Data SynNet relies on two datasources: @@ -77,11 +93,6 @@ The building blocks are not freely available. To obtain the data, go to [https://enamine.net/building-blocks/building-blocks-catalog](https://enamine.net/building-blocks/building-blocks-catalog). We used the "Building Blocks, US Stock" data. You need to first register and then request access to download the dataset. The people from enamine.net manually approve you, so please be nice and patient. -## Code Structure - -The model implementations can be found in [src/syn_net/models/](src/syn_net/models/). -The pre-processing and analysis scripts are in [scripts/](scripts/). - ## Reproducing results Before running anything, set up the environment as decribed above. @@ -95,11 +106,18 @@ For further details, please see the publication. To download the pre-trained model to `./checkpoints`: ```bash -mkdir -p checkpoints && cd checkpoints # Download wget -O hb_fp_2_4096_256.tar.gz https://figshare.com/ndownloader/files/31067692 # Extract tar -vxf hb_fp_2_4096_256.tar.gz +# Rename files to match new scripts (...) +mv hb_fp_2_4096_256/ checkpoints/ +for model in "act" "rt1" "rxn" "rt2" +do + mkdir checkpoints/$model + mv "checkpoints/$model.ckpt" "checkpoints/$model/ckpts.dummy-val_loss=0.00.ckpt" +done +rm -f hb_fp_2_4096_256.tar.gz ``` The following scripts are run from the command line. @@ -109,23 +127,23 @@ Use `python some_script.py --help` or check the source code to see the instructi In addition to the necessary data, we will need to pre-compute an embedding of the building blocks. To do so, please follow steps 0-2 from the [INSTRUCTIONS.md](INSTRUCTIONS.md). +Then, replace the environment variables in the commands below. #### Synthesis Planning To perform synthesis planning described in the main text: ```bash -python scripts/predict_multireactant_mp.py \ - -n -1 \ +python scripts/20-predict-targets.py \ + --building-blocks-file $BUILDING_BLOCKS_FILE \ + --rxns-collection-file $RXN_COLLECTION_FILE \ + --embeddings-knn-file $EMBEDDINGS_KNN_FILE \ --data "data/assets/molecules/sample-targets.txt" \ - --ncpu 10 + --ckpt-dir "checkpoints/" \ + --output-dir "results/demo-inference/" ``` -This script will feed a list of ten randomly selected molecules from the validation to SynNet. -The decoded results, i.e. the predicted synthesis trees, are saved to `DATA_RESULT_DIR`. -(Paths are defined in [src/syn_net/config.py](src/syn_net/config.py).) - -*Note*: To do synthesis planning, you will need a list of target molecules (provided), building blocks (need to download) and embeddings (need to compute). +This script will feed a list of ten molecules to SynNet. #### Synthesizable Molecular Design @@ -133,7 +151,11 @@ To perform synthesizable molecular design, run: ```bash python scripts/optimize_ga.py \ - -i path/to/zinc.csv \ + --ckpt-dir "checkpoints/" \ + --building-blocks-file $BUILDING_BLOCKS_FILE \ + --rxns-collection-file $RXN_COLLECTION_FILE \ + --embeddings-knn-file $EMBEDDINGS_KNN_FILE \ + --input-file path/to/zinc.csv \ --radius 2 --nbits 4096 \ --num_population 128 --num_offspring 512 --num_gen 200 --objective gsk \ --ncpu 32 @@ -141,19 +163,4 @@ python scripts/optimize_ga.py \ This script uses a genetic algorithm to optimize molecular embeddings and returns the predicted synthetic trees for the optimized molecular embedding. -If user wants to start from a checkpoint of previous run, run: - -```bash -python scripts/optimize_ga.py \ - -i path/to/population.npy \ - --radius 2 --nbits 4096 \ - --num_population 128 --num_offspring 512 --num_gen 200 --objective gsk --restart \ - --ncpu 32 -``` - -Note: the input file indicated by `-i` contains the seed molecules in CSV format for an initial run, and as a pre-saved numpy array of the population for restarting the run. - -### Train the model from scratch - -Before training any models, you will first need to some data preprocessing. -Please see [INSTRUCTIONS.md](INSTRUCTIONS.md) for a complete guide. +Note: `input-file` contains the seed molecules in CSV format for an initial run, and as a pre-saved numpy array of the population for restarting the run. If omitted, a random fingerprint will be chosen. From a12744ea7a6b0003fef8fb25e48d89d99ae7f1e8 Mon Sep 17 00:00:00 2001 From: Christian Ulmer <39857842+chrulm@users.noreply.github.com> Date: Wed, 12 Oct 2022 11:11:55 -0400 Subject: [PATCH 302/302] update instructions --- INSTRUCTIONS.md | 51 ++++++++++++++++++++++++------------------------- 1 file changed, 25 insertions(+), 26 deletions(-) diff --git a/INSTRUCTIONS.md b/INSTRUCTIONS.md index 3e477099..9f2032f7 100644 --- a/INSTRUCTIONS.md +++ b/INSTRUCTIONS.md @@ -2,13 +2,13 @@ This documents outlines the process to train SynNet from scratch step-by-step. -> :warning: It is still a WIP to match the filenames of the scripts to the instructions here and to simplify the dependency on parameters/filenames. +> :warning: It is still a WIP. You can use any set of reaction templates and building blocks, but we will illustrate the process with the *Hartenfeller-Button* reaction templates and *Enamine building blocks*. *Note*: This project depends on a lot of exact filenames. For example, one script will save to file, the next will read that file for further processing. -It is not a perfect approach - we are open to feedback - and advise to revise the parameters defined in each script. +It is not a perfect approach - we are open to feedback. Let's start. @@ -20,7 +20,8 @@ Let's start. ```shell python scripts/00-extract-smiles-from-sdf.py \ - --input-file="data/assets/building-blocks/enamine-us.sdf" + --input-file="data/assets/building-blocks/enamine-us.sdf" \ + --output-file="data/assets/building-blocks/enamine-us-smiles.csv.gz" ``` 1. Filter building blocks. @@ -49,8 +50,9 @@ Let's start. ```bash python scripts/02-compute-embeddings.py \ - --building-blocks-file "data/pre-process/building-blocks/enamine-us-smiles.csv.gz" \ - --output-file "data/pre-process/embeddings/hb-enamine-embeddings.npy" + --building-blocks-file "data/pre-process/building-blocks-rxns/bblocks-enamine-us.csv.gz" \ + --output-file "data/pre-process/embeddings/hb-enamine-embeddings.npy" \ + --featurization-fct "fp_256" ``` 3. Generate *synthetic trees* @@ -61,10 +63,10 @@ Let's start. ```bash # Generate synthetic trees python scripts/03-generate-syntrees.py \ - --building-blocks-file "data/pre-process/building-blocks/enamine-us-smiles.csv.gz" \ - --rxn-templates-file "data/assets/reaction-templates/hb.txt" \ - --output-file "data/pre-process/synthetic-trees.json.gz" \ - --number-syntrees 600000 + --building-blocks-file "data/pre-process/building-blocks-rxns/bblocks-enamine-us.csv.gz" \ + --rxn-templates-file "data/assets/reaction-templates/hb.txt" \ + --output-file "data/pre-process/syntrees/synthetic-trees.json.gz" \ + --number-syntrees "600000" ``` In a second step, we filter out some synthetic trees to make the data pharmaceutically more interesting. @@ -73,13 +75,14 @@ Let's start. ```bash # Filter python scripts/04-filter-syntrees.py \ - --input-file "data/pre-process/synthetic-trees.json.gz" \ - --output-file "data/pre-process/synthetic-trees-filtered.json.gz" + --input-file "data/pre-process/syntrees/synthetic-trees.json.gz" \ + --output-file "data/pre-process/syntrees/synthetic-trees-filtered.json.gz" \ + --verbose ``` Each *synthetic tree* is serializable and so we save all trees in a compressed `.json` file. -4. Split *synthetic trees* into train,valid,test-data +5. Split *synthetic trees* into train,valid,test-data We load the `.json`-file with all *synthetic trees* and straightforward split it into three files: `{train,test,valid}.json`. @@ -87,11 +90,11 @@ Let's start. ```bash python scripts/05-split-syntrees.py \ - --input-file "data/pre-process/syntrees/synthetic-trees-filtered.json.gz" \ - --output-dir "data/pre-process/syntrees/" + --input-file "data/pre-process/syntrees/synthetic-trees-filtered.json.gz" \ + --output-dir "data/pre-process/syntrees/" --verbose ``` -5. Featurization +6. Featurization We featurize each *synthetic tree*. That is, we break down each tree to each iteration step ("Add", "Expand", "Extend", "End") and featurize it. @@ -100,8 +103,8 @@ Let's start. ```bash python scripts/06-featurize-syntrees.py \ - --input-dir "data/pre-process/syntrees/" - --output-dir "data/featurized" --verbose + --input-dir "data/pre-process/syntrees/" \ + --output-dir "data/featurized/" --verbose ``` This script will load the `{train,valid,test}` data, featurize it, and save it in @@ -111,7 +114,7 @@ Let's start. The encoders for the molecules must be provided in the script. A short text summary of the encoders will be saved as well. -6. Split features +7. Split features Up to this point, we worked with a (featurized) *synthetic tree* as a whole, now we split it up to into "consumable" input/output data for each of the four networks. @@ -125,12 +128,12 @@ Let's start. This will create 24 new files (3 splits, 4 networks, X + y). All new files will be saved in `/Xy`. -7. Train the networks +8. Train the networks - Finally, we can train each of the four networks in `src/syn_net/models/` separately: + Finally, we can train each of the four networks in `src/synnet/models/` separately. For example: ```bash - python src/syn_net/models/act.py + python src/synnet/models/act.py ``` After training a new model, you can then use the trained model to make predictions and construct synthetic trees for a list given set of molecules. @@ -148,7 +151,7 @@ To visualize trees, there is a hacky script that represents *Synthetic Trees* as To demo it: ```bash -python src/syn_net/visualize/visualizer.py +python src/synnet/visualize/visualizer.py ``` Still to be implemented: i) target molecule, ii) "end" action @@ -156,7 +159,3 @@ Still to be implemented: i) target molecule, ii) "end" action To render the markdown file incl. the diagram directly in VS Code, install the extension [vscode-markdown-mermaid](https://github.com/mjbvz/vscode-markdown-mermaid) and use the built-in markdown preview. *Info*: If the images of the molecules do not load, edit + save the markdown file anywhere. For example add and delete a character with the preview open. Not sure why this happens. - -### Mean reciprocal rank - -To be added.