diff --git a/.gitignore b/.gitignore index 9f6319f1..e2f1b9d8 100644 --- a/.gitignore +++ b/.gitignore @@ -1,22 +1,38 @@ -# Code TODOs -TODOs - -# Certain unittest files -tests/data/states_0_train.npz -tests/data/steps_0_train.npz -tests/data/rxns_hb.json.gz -tests/data/st_data.json.gz -tests/data/X_act_train.npz -tests/data/y_act_train.npz -tests/data/X_rt1_train.npz -tests/data/y_rt1_train.npz -tests/data/X_rxn_train.npz -tests/data/y_rxn_train.npz -tests/data/X_rt2_train.npz -tests/data/y_rt2_train.npz -tests/gin_supervised_contextpred_pre_trained.pth -tests/backup/ +# === custom === + +data/ +figures/syntrees/ +results/ +checkpoints/ +oracle/ +logs/ +tmp/ +.dev/ +.old/ +.notes/ +.aliases +*.sh + +# === template === + +# Created by https://www.toptal.com/developers/gitignore/api/visualstudiocode,python,jupyternotebooks +# Edit at https://www.toptal.com/developers/gitignore?templates=visualstudiocode,python,jupyternotebooks + +### JupyterNotebooks ### +# gitignore template for Jupyter Notebooks +# website: http://jupyter.org/ + +.ipynb_checkpoints +*/.ipynb_checkpoints/* +# IPython +profile_default/ +ipython_config.py + +# Remove previous ipynb_checkpoints +# git rm -r .ipynb_checkpoints/ + +### Python ### # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] @@ -39,7 +55,6 @@ parts/ sdist/ var/ wheels/ -pip-wheel-metadata/ share/python-wheels/ *.egg-info/ .installed.cfg @@ -69,6 +84,7 @@ coverage.xml *.py,cover .hypothesis/ .pytest_cache/ +cover/ # Translations *.mo @@ -91,17 +107,17 @@ instance/ docs/_build/ # PyBuilder +.pybuilder/ target/ # Jupyter Notebook -.ipynb_checkpoints # IPython -profile_default/ -ipython_config.py # pyenv -.python-version +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. @@ -110,7 +126,22 @@ ipython_config.py # install all needed dependencies. #Pipfile.lock -# PEP 582; used by e.g. github.com/David-OConnor/pyflow +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm __pypackages__/ # Celery stuff @@ -147,37 +178,42 @@ dmypy.json # Pyre type checker .pyre/ -# Vim -*~ - -# Data -# data/* -.DS_Store -oracle/* -*.json* -*.npy -*logs* -*.gz -*.csv - -# test Jupyter Notebook -*.ipynb - -# Output files -nohup.out -*.output -*.o -*.out -*.swp -*slurm* -*.sh -*.pth -*.ckpt -*_old* -results -synth_net/params +# pytype static type analyzer +.pytype/ -# Old files set to be deleted -tmp/ -scripts/oracle -temp.py +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +### VisualStudioCode ### +.vscode/* +# !.vscode/settings.json +# !.vscode/launch.json +!.vscode/tasks.json +!.vscode/extensions.json +!.vscode/*.code-snippets + +# Local History for Visual Studio Code +.history/ + +# Built Visual Studio Code Extensions +*.vsix + +### VisualStudioCode Patch ###a +# Ignore all local history of files +.history +.ionide + +# Support for Project snippet scope +.vscode/*.code-snippets + +# Ignore code-workspaces +*.code-workspace + +# End of https://www.toptal.com/developers/gitignore/api/visualstudiocode,python,jupyternotebooks diff --git a/INSTRUCTIONS.md b/INSTRUCTIONS.md new file mode 100644 index 00000000..9f2032f7 --- /dev/null +++ b/INSTRUCTIONS.md @@ -0,0 +1,161 @@ +# Instructions + +This documents outlines the process to train SynNet from scratch step-by-step. + +> :warning: It is still a WIP. + +You can use any set of reaction templates and building blocks, but we will illustrate the process with the *Hartenfeller-Button* reaction templates and *Enamine building blocks*. + +*Note*: This project depends on a lot of exact filenames. +For example, one script will save to file, the next will read that file for further processing. +It is not a perfect approach - we are open to feedback. + +Let's start. + +## Step-by-Step + +0. Prepare reaction templates and building blocks. + + Extract SMILES from the `.sdf` file from enamine.net. + + ```shell + python scripts/00-extract-smiles-from-sdf.py \ + --input-file="data/assets/building-blocks/enamine-us.sdf" \ + --output-file="data/assets/building-blocks/enamine-us-smiles.csv.gz" + ``` + +1. Filter building blocks. + + We proprocess the building blocks to identify applicable reactants for each reaction template. + In other words, filter out all building blocks that do not match any reaction template. + There is no need to keep them, as they cannot act as reactant. + In a first step, we match all building blocks with each reaction template. + In a second step, we save all matched building blocks + and a collection of `Reaction`s with their available building blocks. + + ```bash + python scripts/01-filter-building-blocks.py \ + --building-blocks-file "data/assets/building-blocks/enamine-us-smiles.csv.gz" \ + --rxn-templates-file "data/assets/reaction-templates/hb.txt" \ + --output-bblock-file "data/pre-process/building-blocks-rxns/bblocks-enamine-us.csv.gz" \ + --output-rxns-collection-file "data/pre-process/building-blocks-rxns/rxns-hb-enamine-us.json.gz" --verbose + ``` + + > :bulb: All following steps use this matched building blocks <-> reaction template data. You have to specify the correct files for every script to that it can load the right data. It can save some time to store these as environment variables. + +2. Pre-compute embeddings + + We use the embedding space for the building blocks a lot. + Hence, we pre-compute and store the building blocks. + + ```bash + python scripts/02-compute-embeddings.py \ + --building-blocks-file "data/pre-process/building-blocks-rxns/bblocks-enamine-us.csv.gz" \ + --output-file "data/pre-process/embeddings/hb-enamine-embeddings.npy" \ + --featurization-fct "fp_256" + ``` + +3. Generate *synthetic trees* + + Herein we generate the data used for training the networks. + The data is generated by randomly selecting building blocks, reaction templates and directives to grow a synthetic tree. + + ```bash + # Generate synthetic trees + python scripts/03-generate-syntrees.py \ + --building-blocks-file "data/pre-process/building-blocks-rxns/bblocks-enamine-us.csv.gz" \ + --rxn-templates-file "data/assets/reaction-templates/hb.txt" \ + --output-file "data/pre-process/syntrees/synthetic-trees.json.gz" \ + --number-syntrees "600000" + ``` + + In a second step, we filter out some synthetic trees to make the data pharmaceutically more interesting. + That is, we filter out trees, whose root node molecule has a QED < 0.5, or randomly with a probability less than 1 - QED/0.5. + + ```bash + # Filter + python scripts/04-filter-syntrees.py \ + --input-file "data/pre-process/syntrees/synthetic-trees.json.gz" \ + --output-file "data/pre-process/syntrees/synthetic-trees-filtered.json.gz" \ + --verbose + ``` + + Each *synthetic tree* is serializable and so we save all trees in a compressed `.json` file. + +5. Split *synthetic trees* into train,valid,test-data + + We load the `.json`-file with all *synthetic trees* and + straightforward split it into three files: `{train,test,valid}.json`. + The default split ratio is 6:2:2. + + ```bash + python scripts/05-split-syntrees.py \ + --input-file "data/pre-process/syntrees/synthetic-trees-filtered.json.gz" \ + --output-dir "data/pre-process/syntrees/" --verbose + ``` + +6. Featurization + + We featurize each *synthetic tree*. + That is, we break down each tree to each iteration step ("Add", "Expand", "Extend", "End") and featurize it. + This results in a "state" vector and a a corresponding "super step" vector. + We call it "super step" here, as it contains all featurized data for all networks. + + ```bash + python scripts/06-featurize-syntrees.py \ + --input-dir "data/pre-process/syntrees/" \ + --output-dir "data/featurized/" --verbose + ``` + + This script will load the `{train,valid,test}` data, featurize it, and save it in + - `/{train,valid,test}_states.npz` and + - `/{train,valid,test}_steps.npz`. + + The encoders for the molecules must be provided in the script. + A short text summary of the encoders will be saved as well. + +7. Split features + + Up to this point, we worked with a (featurized) *synthetic tree* as a whole, + now we split it up to into "consumable" input/output data for each of the four networks. + This includes picking the right featurized data from the "super step" vector from the previous step. + + ```bash + python scripts/07-split-data-for-networks.py \ + --input-dir "data/featurized/" + ``` + + This will create 24 new files (3 splits, 4 networks, X + y). + All new files will be saved in `/Xy`. + +8. Train the networks + + Finally, we can train each of the four networks in `src/synnet/models/` separately. For example: + + ```bash + python src/synnet/models/act.py + ``` + +After training a new model, you can then use the trained model to make predictions and construct synthetic trees for a list given set of molecules. + +You can also perform molecular optimization using a genetic algorithm. + +Please refer to the [README.md](./README.md) for inference instructions. + +## Auxiallary Scripts + +### Visualizing trees + +To visualize trees, there is a hacky script that represents *Synthetic Trees* as [mermaid](https://github.com/mermaid-js/mermaid) diagrams. + +To demo it: + +```bash +python src/synnet/visualize/visualizer.py +``` + +Still to be implemented: i) target molecule, ii) "end" action + +To render the markdown file incl. the diagram directly in VS Code, install the extension [vscode-markdown-mermaid](https://github.com/mjbvz/vscode-markdown-mermaid) and use the built-in markdown preview. + +*Info*: If the images of the molecules do not load, edit + save the markdown file anywhere. For example add and delete a character with the preview open. Not sure why this happens. diff --git a/README.md b/README.md index 325f1b90..ba642508 100644 --- a/README.md +++ b/README.md @@ -1,277 +1,166 @@ # SynNet -This repo contains the code and analysis scripts for our amortized approach to synthetic tree generation using neural networks. Our model can serve as both a synthesis planning tool and as a tool for synthesizable molecular design. + +This repo contains the code and analysis scripts for our amortized approach to synthetic tree generation using neural networks. +Our model can serve as both a synthesis planning tool and as a tool for synthesizable molecular design. The method is described in detail in the publication "Amortized tree generation for bottom-up synthesis planning and synthesizable molecular design" available on the [arXiv](https://arxiv.org/abs/2110.06389) and summarized below. ## Summary -### Overview -We model synthetic pathways as tree structures called *synthetic trees*. A valid synthetic tree has one root node (the final product molecule) linked to purchasable building blocks (encoded as SMILES strings) via feasible reactions according to a list of discrete reaction templates (examples of templates encoded as SMARTS strings in [data/rxn_set_hb.txt](./data/rxn_set_hb.txt)). At a high level, each synthetic tree is constructed one reaction step at a time in a bottom-up manner, starting from purchasable building blocks. -The model consists of four modules, each containing a multi-layer perceptron (MLP): +We model synthetic pathways as tree structures called *synthetic trees*. +A synthetic tree has a single root node and one or more child nodes. +Every node is chemical molecule: -1. An *Action Type* selection function that classifies action types among the four possible actions (“Add”, “Expand”, “Merge”, and “End”) in building the synthetic tree. -2. A *First Reactant* selection function that predicts an embedding for the first reactant. A candidate molecule is identified for the first reactant through a k-nearest neighbors (k-NN) search from the list of potential building blocks. -3. A *Reaction* selection function whose output is a probability distribution over available reaction templates, from which inapplicable reactions are masked (based on reactant 1) and a suitable template is then sampled using a greedy search. -4. A *Second Reactant* selection function that identifies the second reactant if the sampled template is bi-molecular. The model predicts an embedding for the second reactant, and a candidate is then sampled via a k-NN search from the masked set of building blocks. +- The root node is the final product molecule +- The leaf nodes consist of purchasable building blocks. +- All other inner nodes are constrained to be a product of allowed chemical reactions. -![the model](./figures/network.png "model scheme") +At a high level, each synthetic tree is constructed one reaction step at a time in a bottom-up manner, that is starting from purchasable building blocks. -These four modules predict the probability distributions of actions to be taken within a single reaction step, and determine the nodes to be added to the synthetic tree under construction. All of these networks are conditioned on the target molecule embedding. +### Overview -### Synthesis planning -This task is to infer the synthetic pathway to a given target molecule. We formulate this problem as generating a synthetic tree such that the product molecule it produces (i.e., the molecule at the root node) matches the desired target molecule. +The model consists of four modules, each containing a multi-layer perceptron (MLP): -For this task, we can take a molecular embedding for the desired product, and use it as input to our model to produce a synthetic tree. If the desired product is successfully recovered, then the final root molecule will match the desired molecule used to create the input embedding. If the desired product is not successully recovered, it is possible the final root molecule may still be *similar* to the desired molecule used to create the input embedding, and thus our tool can also be used for *synthesizable analog recommendation*. +1. An *Action Type* selection function that classifies action types among the four possible actions (“Add”, “Expand”, “Merge”, and “End”) in building the synthetic tree. Each action increases the depth of the synthetic tree by one. -![the generation process](./figures/generation_process.png "generation process") +2. A *First Reactant* selection function that selects the first reactant. A MLP predicts a molecular embedding and a first reactant is identified from the pool of building blocks through a k-nearest neighbors (k-NN) search. -### Synthesizable molecular design -This task is to optimize a molecular structure with respect to an oracle function (e.g. bioactivity), while ensuring the synthetic accessibility of the molecules. We formulate this problem as optimizing the structure of a synthetic tree with respect to the desired properties of the product molecule it produces. +3. A *Reaction* selection function whose output is a probability distribution over available reaction templates. Inapplicable reactions are masked based on reactant 1. A suitable template is then sampled using a greedy search. -To do this, we optimize the molecular embedding of the molecule using a genetic algorithm and the desired oracle function. The optimized molecule embedding can then be used as input to our model to produce a synthetic tree, where the final root molecule corresponds to the optimized molecule. +4. A *Second Reactant* selection function that identifies the second reactant if the sampled template is bi-molecular. The model predicts an embedding for the second reactant, and a candidate is then sampled via a k-NN search from the masked set of building blocks. -## Setup instructions +![the model](./figures/network.png "model scheme") -### Setting up the environment -You can use conda to create an environment containing the necessary packages and dependencies for running SynNet by using the provided YAML file: +These four modules predict the probability distributions of actions to be taken within a single reaction step, and determine the nodes to be added to the synthetic tree under construction. +All of these networks are conditioned on the target molecule embedding. -``` -conda env create -f environment.yml -``` +### Synthesis planning -If you update the environment and would like to save the updated environment as a new YAML file using conda, use: +This task is to infer the synthetic pathway to a given target molecule. +We formulate this problem as generating a synthetic tree such that the product molecule it produces (i.e., the molecule at the root node) matches the desired target molecule. -``` -conda env export > path/to/env.yml -``` +For this task, we can take a molecular embedding for the desired product, and use it as input to our model to produce a synthetic tree. +If the desired product is successfully recovered, then the final root molecule will match the desired molecule used to create the input embedding. +If the desired product is not successully recovered, it is possible the final root molecule may still be *similar* to the desired molecule used to create the input embedding, and thus our tool can also be used for *synthesizable analog recommendation*. -Before running any SynNet code, activate the environment and update the Python path so that the scripts can find the right files. You can do this by typing: +![the generation process](./figures/generation_process.png "generation process") -``` -source activate synthenv -export PYTHONPATH=`pwd`:$PYTHONPATH -``` +### Synthesizable molecular design -### Unit tests -To check that everything has been set-up correctly, you can run the unit tests from within the [tests/](./tests/). If starting in the main SynNet/ directory, you can run the unit tests as follows: +This task is to optimize a molecular structure with respect to an oracle function (e.g. bioactivity), while ensuring the synthetic accessibility of the molecules. +We formulate this problem as optimizing the structure of a synthetic tree with respect to the desired properties of the product molecule it produces. -``` -source activate synthenv -export PYTHONPATH=`pwd`:$PYTHONPATH -cd tests/ -python -m unittest -``` - -You should get no errors if everything ran correctly. +To do this, we optimize the molecular embedding of the molecule using a genetic algorithm and the desired oracle function. +The optimized molecule embedding can then be used as input to our model to produce a synthetic tree, where the final root molecule corresponds to the optimized molecule. -### Data +## Setup instructions -#### Templates -The Hartenfeller-Button templates are available in the [./data/](./data/) directory. -#### Building blocks -The Enamine data can be freely downloaded from https://enamine.net/building-blocks/building-blocks-catalog for academic purposes. After downloading the Enamine building blocks, you will need to replace the paths to the Enamine building blocks in the code. This can be done by searching for the string "enamine". +### Environment -## Code Structure -The code is structured as follows: +Conda is used to create the environment for running SynNet. +```bash +# Install environment from file +conda env create -f environment.yml ``` -SynNet/ -├── data -│ └── rxn_set_hb.txt -├── environment.yml -├── LICENSE -├── README.md -├── scripts -│ ├── compute_embedding_mp.py -│ ├── compute_embedding.py -│ ├── generation_fp.py -│ ├── generation.py -│ ├── gin_supervised_contextpred_pre_trained.pth -│ ├── _mp_decode.py -│ ├── _mp_predict_beam.py -│ ├── _mp_predict_multireactant.py -│ ├── _mp_predict.py -│ ├── _mp_search_similar.py -│ ├── _mp_sum.py -│ ├── mrr.py -│ ├── optimize_ga.py -│ ├── predict-beam-fullTree.py -│ ├── predict_beam_mp.py -│ ├── predict-beam-reactantOnly.py -│ ├── predict_mp.py -│ ├── predict_multireactant_mp.py -│ ├── predict.py -│ ├── read_st_data.py -│ ├── sample_from_original.py -│ ├── search_similar.py -│ ├── sketch-synthetic-trees.py -│ ├── st2steps.py -│ ├── st_split.py -│ └── temp.py -├── syn_net -│ ├── data_generation -│ │ ├── check_all_template.py -│ │ ├── filter_unmatch.py -│ │ ├── __init__.py -│ │ ├── make_dataset_mp.py -│ │ ├── make_dataset.py -│ │ ├── _mp_make.py -│ │ ├── _mp_process.py -│ │ └── process_rxn_mp.py -│ ├── __init__.py -│ ├── models -│ │ ├── act.py -│ │ ├── mlp.py -│ │ ├── prepare_data.py -│ │ ├── rt1.py -│ │ ├── rt2.py -│ │ └── rxn.py -│ └── utils -│ ├── data_utils.py -│ ├── ga_utils.py -│ ├── predict_beam_utils.py -│ ├── predict_utils.py -│ └── __init__.py -└── tests - ├── create-unittest-data.py - └── test_DataPreparation.py -``` - -The model implementations can be found in [syn_net/models/](syn_net/models/), with processing and analysis scripts located in [scripts/](./scripts/). -## Instructions -Before running anything, you need to add the root directory to the Python path. One option for doing this is to run the following command in the root `SynNet` directory: +Before running any SynNet code, activate the environment and install this package in development mode: -``` -export PYTHONPATH=`pwd`:$PYTHONPATH +```bash +source activate synnet +pip install -e . ``` -## Using pre-trained models -We have made available a set of pre-trained models at the following [link](https://figshare.com/articles/software/Trained_model_parameters_for_SynNet/16799413). The pretrained models correspond to the Action, Reactant 1, Reaction, and Reactant 2 networks, trained on the Hartenfeller-Button dataset using radius 2, length 4096 Morgan fingerprints for the molecular node embeddings, and length 256 fingerprints for the k-NN search. For further details, please see the publication. +The model implementations can be found in `src/syn_net/models/`. -The models can be uncompressed with: -``` -tar -zxvf hb_fp_2_4096_256.tar.gz -``` - -### Synthesis Planning -To perform synthesis planning described in the main text: -``` -python predict_multireactant_mp.py -n -1 --ncpu 36 --data test -``` -This script will feed a list of molecules from the test data and save the decoded results (predicted synthesis trees) to [./results/](./results/). -One can use --help to see the instruction of each argument. -Note: this file reads parameters from a directory, please specify the path to parameters previously. - -### Synthesizable Molecular Design -To perform synthesizable molecular design, under [./scripts/](./scripts/), run: -``` -optimize_ga.py -i path/to/zinc.csv --radius 2 --nbits 4096 --num_population 128 --num_offspring 512 --num_gen 200 --ncpu 32 --objective gsk -``` -This script uses a genetic algorithm to optimize molecular embeddings and returns the predicted synthetic trees for the optimized molecular embedding. -One can use --help to see the instruction of each argument. -If user wants to start from a checkpoint of previous run, run: -``` -optimize_ga.py -i path/to/population.npy --radius 2 --nbits 4096 --num_population 128 --num_offspring 512 --num_gen 200 --ncpu 32 --objective gsk --restart -``` -Note: the input file indicated by -i contains the seed molecules in CSV format for an initial run, and as a pre-saved numpy array of the population for restarting the run. +The pre-processing and analysis scripts are in `scripts/`. ### Train the model from scratch -Before training any models, you will first need to preprocess the set of reaction templates which you would like to use. You can use either a new set of reaction templates, or the provided Hartenfeller-Button (HB) set of reaction templates (see [data/rxn_set_hb.txt](data/rxn_set_hb.txt)). To preprocess a new dataset, you will need to: -1. Preprocess the data to identify applicable reactants for each reaction template -2. Generate the synthetic trees by random selection -3. Split the synthetic trees into training, testing, and validation splits -4. Featurize the nodes in the synthetic trees using molecular fingerprints -5. Prepare the training data for each of the four networks - -Once you have preprocessed a training set, you can begin to train a model by training each of the four networks separately (the *Action*, *First Reactant*, *Reaction*, and *Second Reactant* networks). - -After training a new model, you can then use the trained model to make predictions and construct synthetic trees for a list given set of molecules. -You can also perform molecular optimization using a genetic algorithm. +Before training any models, you will first need to some data preprocessing. +Please see [INSTRUCTIONS.md](INSTRUCTIONS.md) for a complete guide. -Instructions for all of the aforementioned steps are described in detail below. - -In addition to the aforementioned types of jobs, we have also provide below instructions for (1) sketching synthetic trees and (2) calculating the mean reciprocal rank of reactant 1. - -### Processing the data: reaction templates and applicable reactants - -Given a set of reaction templates and a list of buyable building blocks, we first need to assign applicable reactants for each template. Under [./syn_net/data_generation/](./syn_net/data_generation/), run: - -``` -python process_rxn_mp.py -``` +### Data -This will save the reaction templates and their corresponding building blocks in a JSON file. Then, run: +SynNet relies on two datasources: -``` -python filter_unmatch.py -``` +1. reaction templates and +2. building blocks. -This will filter out buyable building blocks which didn't match a single template. +The data used for the publication are 1) the *Hartenfeller-Button* reaction templates, which are available under [data/assets/reaction-templates/hb.txt](data/assets/reaction-templates/hb.txt) and 2) *Enamine building blocks*. +The building blocks are not freely available. -### Generating the synthetic path data by random selection -Under [./syn_net/data_generation/](./syn_net/data_generation/), run: +To obtain the data, go to [https://enamine.net/building-blocks/building-blocks-catalog](https://enamine.net/building-blocks/building-blocks-catalog). +We used the "Building Blocks, US Stock" data. You need to first register and then request access to download the dataset. The people from enamine.net manually approve you, so please be nice and patient. -``` -python make_dataset_mp.py -``` +## Reproducing results -This will generate synthetic path data saved in a JSON file. Then, to make the dataset more pharmaceutically revelant, we can change to [./scripts/](./scripts/) and run: +Before running anything, set up the environment as decribed above. -``` -python sample_from_original.py -``` +### Using pre-trained models -This will filter out the samples where the root node QED is less than 0.5, or randomly with a probability less than 1 - QED/0.5. +We have made available a set of pre-trained models at the following [link](https://figshare.com/articles/software/Trained_model_parameters_for_SynNet/16799413). +The pretrained models correspond to the Action, Reactant 1, Reaction, and Reactant 2 networks, trained on the *Hartenfeller-Button* dataset and *Enamine* building blocks using radius 2, length 4096 Morgan fingerprints for the molecular node embeddings, and length 256 fingerprints for the k-NN search. +For further details, please see the publication. -### Splitting data into training, validation, and testing sets, and removing duplicates -Under [./scripts/](./scripts/), run: +To download the pre-trained model to `./checkpoints`: -``` -python st_split.py +```bash +# Download +wget -O hb_fp_2_4096_256.tar.gz https://figshare.com/ndownloader/files/31067692 +# Extract +tar -vxf hb_fp_2_4096_256.tar.gz +# Rename files to match new scripts (...) +mv hb_fp_2_4096_256/ checkpoints/ +for model in "act" "rt1" "rxn" "rt2" +do + mkdir checkpoints/$model + mv "checkpoints/$model.ckpt" "checkpoints/$model/ckpts.dummy-val_loss=0.00.ckpt" +done +rm -f hb_fp_2_4096_256.tar.gz ``` -The default split ratio is 6:2:2 for training, validation, and testing sets. +The following scripts are run from the command line. +Use `python some_script.py --help` or check the source code to see the instructions of each argument. -### Featurizing data -Under [./scripts/](./scripts/), run: +### Prerequisites -``` -python st2steps.py -r 2 -b 4096 -d train -``` +In addition to the necessary data, we will need to pre-compute an embedding of the building blocks. +To do so, please follow steps 0-2 from the [INSTRUCTIONS.md](INSTRUCTIONS.md). +Then, replace the environment variables in the commands below. -This will featurize the synthetic tree data into step-by-step data which can be used for training. The flag *-r* indicates the fingerprint radius, *-b* indicates the number of bits to use for the fingerprints, and *-d* indicates which dataset split to featurize. +#### Synthesis Planning -### Preparing training data for each network -Under [./syn_net/models/](./syn_net/models/), run: +To perform synthesis planning described in the main text: -``` -python prepare_data.py --radius 2 --nbits 4096 +```bash +python scripts/20-predict-targets.py \ + --building-blocks-file $BUILDING_BLOCKS_FILE \ + --rxns-collection-file $RXN_COLLECTION_FILE \ + --embeddings-knn-file $EMBEDDINGS_KNN_FILE \ + --data "data/assets/molecules/sample-targets.txt" \ + --ckpt-dir "checkpoints/" \ + --output-dir "results/demo-inference/" ``` -This will prepare the training data for the networks. +This script will feed a list of ten molecules to SynNet. -Each is a training script and can be used as follows (using the action network as an example): +#### Synthesizable Molecular Design -``` -python act.py --radius 2 --nbits 4096 -``` - -This will train the network and save the model parameters at the state with the best validation loss in a logging directory, e.g., **`act_hb_fp_2_4096_logs`**. One can use tensorboard to monitor the training and validation loss. +To perform synthesizable molecular design, run: -### Sketching synthetic trees -To visualize the synthetic trees, run: - -``` -python scripts/sketch-synthetic-trees.py --file /path/to/st_hb/st_train.json.gz --saveto ./ --nsketches 5 --actions 3 +```bash +python scripts/optimize_ga.py \ + --ckpt-dir "checkpoints/" \ + --building-blocks-file $BUILDING_BLOCKS_FILE \ + --rxns-collection-file $RXN_COLLECTION_FILE \ + --embeddings-knn-file $EMBEDDINGS_KNN_FILE \ + --input-file path/to/zinc.csv \ + --radius 2 --nbits 4096 \ + --num_population 128 --num_offspring 512 --num_gen 200 --objective gsk \ + --ncpu 32 ``` -This will sketch 5 synthetic trees with 3 or more actions to the current ("./") directory (you can play around with these variables or just also leave them out to use the defaults). - -### Testing the mean reciprocal rank (MRR) of reactant 1 -Under [./scripts/](./scripts/), run: +This script uses a genetic algorithm to optimize molecular embeddings and returns the predicted synthetic trees for the optimized molecular embedding. -``` -python mrr.py --distance cosine -``` +Note: `input-file` contains the seed molecules in CSV format for an initial run, and as a pre-saved numpy array of the population for restarting the run. If omitted, a random fingerprint will be chosen. diff --git a/data/assets/building-blocks/.gitkeep b/data/assets/building-blocks/.gitkeep new file mode 100644 index 00000000..ef7c1b53 --- /dev/null +++ b/data/assets/building-blocks/.gitkeep @@ -0,0 +1 @@ +Placeholder for building block molecules. diff --git a/data/assets/molecules/sample-targets.txt b/data/assets/molecules/sample-targets.txt new file mode 100644 index 00000000..4d4aa219 --- /dev/null +++ b/data/assets/molecules/sample-targets.txt @@ -0,0 +1,10 @@ +COc1cc(Cn2c(C)c(Cc3ccccc3)c3c2CCCC3)ccc1OCC(=O)N(C)C +CCC1CCCC(Nc2cc(C(F)(F)F)c(Cl)cc2SC)CC1 +Clc1cc(Cl)c(C2=NC(c3cccc4c(Br)cccc34)=NN2)nn1 +COc1ccc(S(=O)(=O)c2ccc(-c3nc(-c4cc(B(O)O)ccc4O)no3)cn2)cc1 +CNS(=O)(=O)c1ccc(-c2cc3c4c(ccc3[nH]2)CCCN4C(N)=O)cc1 +CC(NC(=O)C1Cn2c(O)nnc2CN1)c1cc(F)ccc1N1CCC(n2nnn(-c3ccc(Br)cc3)c2=S)CC1 +COc1cc(-c2nc(-c3ccccc3)c(-c3ccccc3)s2)ccn1 +CCCn1c(C)nnc1CC(C)(O)C(=C(C)C)c1nccnc1S(=O)(=O)F +CN(c1ccccc1)c1ccc(-c2nc3ncccc3s2)cn1 +COc1cc(-c2nc(-c3ccc(F)cc3)c(-c3ccc(F)cc3)n2c2cc(Cl)ccc2Cl)ccc1Oc1ccc(S(=O)(=O)N2CCCCC2)cc1[N+](=O)[O-] diff --git a/data/assets/reaction-templates/hb.txt b/data/assets/reaction-templates/hb.txt new file mode 100644 index 00000000..ff4b4727 --- /dev/null +++ b/data/assets/reaction-templates/hb.txt @@ -0,0 +1,91 @@ +[cH1:1]1:[c:2](-[CH2:7]-[CH2:8]-[NH2:9]):[c:3]:[c:4]:[c:5]:[c:6]:1.[#6:11]-[CH1;R0:10]=[OD1]>>[c:1]12:[c:2](-[CH2:7]-[CH2:8]-[NH1:9]-[C:10]-2(-[#6:11])):[c:3]:[c:4]:[c:5]:[c:6]:1 +[c;r6:1](-[NH1;$(N-[#6]):2]):[c;r6:3](-[NH2:4]).[#6:6]-[C;R0:5](=[OD1])-[#8;H1,$(O-[CH3])]>>[c:3]2:[c:1]:[n:2]:[c:5](-[#6:6]):[n:4]2 +[c;r6:1](-[NH1;$(N-[#6]):2]):[c;r6:3](-[NH2:4]).[#6:6]-[CH1;R0:5](=[OD1])>>[c:3]2:[c:1]:[n:2]:[c:5](-[#6:6]):[n:4]2 +[c;r6:1](-[SH1:2]):[c;r6:3](-[NH2:4]).[#6:6]-[CH1;R0:5](=[OD1])>>[c:3]2:[c:1]:[s:2]:[c:5](-[#6:6]):[n:4]2 +[c:1](-[OH1;$(Oc1ccccc1):2]):[c;r6:3](-[NH2:4]).[c:6]-[CH1;R0:5](=[OD1])>>[c:3]2:[c:1]:[o:2]:[c:5](-[c:6]):[n:4]2 +[c;r6:1](-[OH1:2]):[c;r6:3](-[NH2:4]).[#6:6]-[C;R0:5](=[OD1])-[OH1]>>[c:3]2:[c:1]:[o:2]:[c:5](-[#6:6]):[n:4]2 +[#6:6]-[C;R0:1](=[OD1])-[CH1;R0:5](-[#6:7])-[*;#17,#35,#53].[NH2:2]-[C:3]=[SD1:4]>>[c:1]2(-[#6:6]):[n:2]:[c:3]:[s:4][c:5]([#6:7]):2 +[c:1](-[C;$(C-c1ccccc1):2](=[OD1:3])-[OH1]):[c:4](-[NH2:5]).[N;!H0;!$(N-N);!$(N-C=N);!$(N(-C=O)-C=O):6]-[C;H1,$(C-[#6]):7]=[OD1]>>[c:4]2:[c:1]-[C:2](=[O:3])-[N:6]-[C:7]=[N:5]-2 +[CH0;$(C-[#6]):1]#[NH0:2]>>[C:1]1=[N:2]-N-N=N-1 +[CH0;$(C-[#6]):1]#[NH0:2].[C;A;!$(C=O):3]-[*;#17,#35,#53]>>[C:1]1=[N:2]-N(-[C:3])-N=N-1 +[CH0;$(C-[#6]):1]#[NH0:2].[C;A;!$(C=O):3]-[*;#17,#35,#53]>>[C:1]1=[N:2]-N=N-N-1(-[C:3]) +[CH0;$(C-[#6]):1]#[CH1:2].[C;H1,H2;A;!$(C=O):3]-[*;#17,#35,#53,OH1]>>[C:1]1=[C:2]-N(-[C:3])-N=N-1 +[CH0;$(C-[#6]):1]#[CH1:2].[C;H1,H2;A;!$(C=O):3]-[*;#17,#35,#53,OH1]>>[C:1]1=[C:2]-N=NN(-[C:3])-1 +[CH0;$(C-[#6]):1]#[CH0;$(C-[#6]):2].[C;H1,H2;A;!$(C=O):3]-[*;#17,#35,#53,OH1]>>[C:1]1=[C:2]-N=NN(-[C:3])-1 +[CH0;$(C-[#6]):1]#[NH0:2].[NH2:3]-[NH1:4]-[CH0;$(C-[#6]);R0:5]=[OD1]>>[N:2]1-[C:1]=[N:3]-[N:4]-[C:5]=1 +[CH0;$(C-[#6]):1]#[NH0:2].[CH0;$(C-[#6]);R0:5](=[OD1])-[#8;H1,$(O-[CH3]),$(O-[CH2]-[CH3])]>>[N:2]1-[C:1]=N-N-[C:5]=1 +[c:1](-[C;$(C-c1ccccc1):2](=[OD1:3])-[CH3:4]):[c:5](-[OH1:6]).[C;$(C1-[CH2]-[CH2]-[N,C]-[CH2]-[CH2]-1):7](=[OD1])>>[O:6]1-[c:5]:[c:1]-[C:2](=[OD1:3])-[C:4]-[C:7]-1 +[c;r6:1](-[C;$(C=O):6]-[OH1]):[c;r6:2]-[C;H1,$(C-C):3]=[OD1].[NH2:4]-[NH1;$(N-[#6]);!$(NC=[O,S,N]):5]>>[c:1]1:[c:2]-[C:3]=[N:4]-[N:5]-[C:6]-1 +[C;$(C-c1ccccc1):1](=[OD1])-[C;D3;$(C-c1ccccc1):2]~[O;D1,H1].[CH1;$(C-c):3]=[OD1]>>[C:1]1-N=[C:3]-[NH1]-[C:2]=1 +[NH1;$(N-c1ccccc1):1](-[NH2])-[c:5]:[cH1:4].[C;$(C([#6])[#6]):2](=[OD1])-[CH2;$(C([#6])[#6]);!$(C(C=O)C=O):3]>>[C:5]1-[N:1]-[C:2]=[C:3]-[C:4]:1 +[NH2;$(N-c1ccccc1):1]-[c:2]:[c:3]-[CH1:4]=[OD1].[C;$(C([#6])[#6]):6](=[OD1])-[CH2;$(C([#6])[#6]);!$(C(C=O)C=O):5]>>[N:1]1-[c:2]:[c:3]-[C:4]=[C:5]-[C:6]:1 +[*;Br,I;$(*c1ccccc1)]-[c:1]:[c:2]-[OH1:3].[CH1:5]#[C;$(C-[#6]):4]>>[c:1]1:[c:2]-[O:3]-[C:4]=[C:5]-1 +[*;Br,I;$(*c1ccccc1)]-[c:1]:[c:2]-[SD2:3]-[CH3].[CH1:5]#[C;$(C-[#6]):4]>>[c:1]1:[c:2]-[S:3]-[C:4]=[C:5]-1 +[*;Br,I;$(*c1ccccc1)]-[c:1]:[c:2]-[NH2:3].[CH1:5]#[C;$(C-[#6]):4]>>[c:1]1:[c:2]-[N:3]-[C:4]=[C:5]-1 +[#6:6][C:5]#[#7;D1:4].[#6:1][C:2](=[OD1:3])[OH1]>>[#6:6][c:5]1[n:4][o:3][c:2]([#6:1])n1 +[#6;$([#6]~[#6]);!$([#6]=O):2][#8;H1:3].[Cl,Br,I][#6;H2;$([#6]~[#6]):4]>>[CH2:4][O:3][#6:2] +[#6;H0;D3;$([#6](~[#6])~[#6]):1]B(O)O.[#6;H0;D3;$([#6](~[#6])~[#6]):2][Cl,Br,I]>>[#6:2][#6:1] +[c;H1:3]1:[c:4]:[c:5]:[c;H1:6]:[c:7]2:[nH:8]:[c:9]:[c;H1:1]:[c:2]:1:2.O=[C:10]1[#6;H2:11][#6;H2:12][N:13][#6;H2:14][#6;H2:15]1>>[#6;H2:12]3[#6;H1:11]=[C:10]([c:1]1:[c:9]:[n:8]:[c:7]2:[c:6]:[c:5]:[c:4]:[c:3]:[c:2]:1:2)[#6;H2:15][#6;H2:14][N:13]3 +[C;H1&$(C([#6])[#6]),H2&$(C[#6]):1][OH1].[NH1;$(N(C=O)C=O):2]>>[C:1][N:2] +[C;H1&$(C([#6])[#6]),H2&$(C[#6]):1][OH1].[OH1;$(Oc1ccccc1):2]>>[C:1][O:2] +[C;H1&$(C([#6])[#6]),H2&$(C[#6]):1][OH1].[NH1;$(N([#6])S(=O)=O):2]>>[C:1][N:2] +[C;H1&$(C([#6])[#6]),H2&$(C[#6]):1][OH1].[#7H1:2]1~[#7:3]~[#7:4]~[#7:5]~[#6:6]~1>>[C:1][#7:2]1:[#7:3]:[#7:4]:[#7:5]:[#6:6]:1 +[C;H1&$(C([#6])[#6]),H2&$(C[#6]):1][OH1].[#7H1:2]1~[#7:3]~[#7:4]~[#7:5]~[#6:6]~1>>[#7H0:2]1:[#7:3]:[#7H0:4]([C:1]):[#7:5]:[#6:6]:1 +[C;H1&$(C([#6])[#6]),H2&$(C[#6]):1][OH1].[#7:2]1~[#7:3]~[#7H1:4]~[#7:5]~[#6:6]~1>>[C:1][#7H0:2]1:[#7:3]:[#7H0:4]:[#7:5]:[#6:6]:1 +[C;H1&$(C([#6])[#6]),H2&$(C[#6]):1][OH1].[#7:2]1~[#7:3]~[#7H1:4]~[#7:5]~[#6:6]~1>>[#7:2]1:[#7:3]:[#7:4]([C:1]):[#7:5]:[#6:6]:1 +[#6;$(C=C-[#6]),$(c:c):1][Br,I].[Cl,Br,I][c:2]>>[c:2][#6:1] +[#6:1][C:2]#[#7;D1].[Cl,Br,I][#6;$([#6]~[#6]);!$([#6]([Cl,Br,I])[Cl,Br,I]);!$([#6]=O):3]>>[#6:1][C:2](=O)[#6:3] +[#6:1][C;H1,$([C]([#6])[#6]):2]=[OD1:3].[Cl,Br,I][#6;$([#6]~[#6]);!$([#6]([Cl,Br,I])[Cl,Br,I]);!$([#6]=O):4]>>[C:1][#6:2]([OH1:3])[#6:4] +[S;$(S(=O)(=O)[C,N]):1][Cl].[N;$(NC);!$(N=*);!$([N-]);!$(N#*);!$([ND3]);!$([ND4]);!$(N[c,O]);!$(N[C,S]=[S,O,N]):2]>>[S:1][N+0:2] +[c:1]B(O)O.[nH1;+0;r5;!$(n[#6]=[O,S,N]);!$(n~n~n);!$(n~n~c~n);!$(n~c~n~n):2]>>[c:1][n:2] +[#6:3]-[C;H1,$([CH0](-[#6])[#6]);!$(CC=O):1]=[OD1].[Cl,Br,I][C;H2;$(C-[#6]);!$(CC[I,Br]);!$(CCO[CH3]):2]>>[C:3][C:1]=[C:2] +[Cl,Br,I][c;$(c1:[c,n]:[c,n]:[c,n]:[c,n]:[c,n]:1):1].[N;$(NC)&!$(N=*)&!$([N-])&!$(N#*)&!$([ND3])&!$([ND4])&!$(N[c,O])&!$(N[C,S]=[S,O,N]),H2&$(Nc1:[c,n]:[c,n]:[c,n]:[c,n]:[c,n]:1):2]>>[c:1][N:2] +[C;$(C([#6])[#6;!$([#6]Br)]):4](=[OD1])[CH;$(C([#6])[#6]):5]Br.[#7;H2:3][C;$(C(=N)(N)[c,#7]):2]=[#7;H1;D1:1]>>[C:4]1=[CH0:5][NH:3][C:2]=[N:1]1 +[c;$(c1[c;$(c[C,S,N](=[OD1])[*;R0;!OH1])]cccc1):1][C;$(C(=O)[O;H1])].[c;$(c1aaccc1):2][Cl,Br,I]>>[c:1][c:2] +[c;!$(c1ccccc1);$(c1[n,c]c[n,c]c[n,c]1):1][Cl,F].[N;$(NC);!$(N=*);!$([N-]);!$(N#*);!$([ND3]);!$([ND4]);!$(N[c,O]);!$(N[C,S]=[S,O,N]):2]>>[c:1][N:2] +[c;$(c1c(N(~O)~O)cccc1):1][Cl,F].[N;$(NC);!$(N=*);!$([N-]);!$(N#*);!$([ND3]);!$([ND4]);!$(N[c,O]);!$(N[C,S]=[S,O,N]):2]>>[c:1][N:2] +[c;$(c1ccc(N(~O)~O)cc1):1][Cl,F].[N;$(NC);!$(N=*);!$([N-]);!$(N#*);!$([ND3]);!$([ND4]);!$(N[c,O]);!$(N[C,S]=[S,O,N]):2]>>[c:1][N:2] +[N;$(N-[#6]):3]=[C;$(C=O):1].[N;$(N[#6]);!$(N=*);!$([N-]);!$(N#*);!$([ND3]);!$([ND4]);!$(N[O,N]);!$(N[C,S]=[S,O,N]):2]>>[N:3]-[C:1]-[N+0:2] +[N;$(N-[#6]):3]=[C;$(C=S):1].[N;$(N[#6]);!$(N=*);!$([N-]);!$(N#*);!$([ND3]);!$([ND4]);!$(N[O,N]);!$(N[C,S]=[S,O,N]):2]>>[N:3]-[C:1]-[N+0:2] +[$(C([CH2,CH3])),CH:10](=[O:11])-[NH+0:9]-[C$(C(N)(C)(C)(C)),C$([CH](N)(C)(C)),C$([CH2](N)(C)):8]-[C$(C(c)(C)(C)(C)),C$([CH](c)(C)(C)),C$([CH2](c)(C)):7]-[c:6]1[cH:1][c:2][c:3][c:4][c:5]1>>[C:10]-1=[N+0:9]-[C:8]-[C:7]-[c:6]2[c:5][c:4][c:3][c:2][c:1]-12 +[$(C([CH2,CH3])),CH:10](=[O:11])-[NH+0:9]-[C$([CH](N)(C)(C)),C$([CH2](N)(C)):8]-[C$([C](c)(C)(C)),C$([CH](c)(C)):7]([O$(OC),OH])-[c:6]1[cH:1][c:2][c:3][c:4][c:5]1>>[c:10]-1[n:9][c:8][c:7][c:6]2[c:5][c:4][c:3][c:2][c:1]-12 +[NH3+,NH2]-[C$(C(N)(C)(C)(C)),C$([CH](N)(C)(C)),C$([CH2](N)(C)):8]-[C$(C(c)(C)(C)(C)),C$([CH](c)(C)(C)),C$([CH2](c)(C)):7]-[c:6]1[c:1][c:2][nH:3][cH:5]1.[CH:10](-[CX4:12])=[O:11]>>[c,C:12]-[CH:10]-1-[N]-[C:8]-[C:7]-[c:6]2[c:1][c:2][nH:3][c:5]-12 +[NH2,NH3+1:8]-[c:5]1[cH:4][c:3][c:2][c:1][c:6]1.[Br:18][C$([CH2](C)(Br)),C$([CH](C)(C)(Br)):17]-[C:15](=[O:16])-[c:10]1[c:11][c:12][c:13][c:14][c:9]1>>[c:13]1[c:12][c:11][c:10]([c:9][c:14]1)-[c:15]1[c:17][c:4]2[c:3][c:2][c:1][c:6][c:5]2[nH+0:8]1 +[Cl:1][CH2:2]-[C$([CH](C)),C$(C(C)(C)):3]=[O:4].[OH:12]-[c:11]1[c:6][c:7][c:8][c:9][c:10]1-[CH:13]=[O:14]>>[C:3](=[O:4])-[c:2]1[c:13][c:10]2[c:9][c:8][c:7][c:6][c:11]2[o:12]1 +[NH2,NH3+]-[C$([CX4](N)([c,C])([c,C])([c,C])),C$([CH](N)([c,C])([c,C])),C$([CH2](N)([c,C])),C$([CH3](N)):2].[NH2:12]-[c:7]1[c:6][c:5][c:4][c:3][c:8]1-[C:9](-[OH,O-:11])=[O:10]>>[C:2]-[n+0]-1[c:13][n:12][c:7]2[c:6][c:5][c:4][c:3][c:8]2[c:9]-1=[O:10] +[N$([NH2]([CX4])),N$([NH3+1]([CX4])):1].[O:5]-[C$([CH]([CX4])(C)(O)),C$([CH2]([CX4])(O)):3][C$(C([CX4])(=O)([CX4])),C$([CH]([CX4])(=O)):4]=[O:6]>[O:15]=[C:9]-1-[CH2:10]-[CH2:11]-[CH2:12]-[CH2:13]-[CH2:14]-1>[c:4]1[c:3][n+0:1][c:10]2-[C:11]-[C:12]-[C:13]-[C:14]-[c:9]12 +[C$(C(=O)([CX4])([CX4])),C$([CH](=O)([CX4])):2](=[O:6])-[C$([CH]([CX4])),C$([CH2]):3]-[C$(C(=O)([CX4])([CX4])),C$([CH](=O)([CX4])):4]=[O:7].[NH2:8]-[C:9](=[O:10])-[CH2:11][C:12]#[N:13]>>[OH:10]-[c:9]1[n:8][c:4][c:3][c:2][c:11]1[C:12]#[N:13] +[C$(C(#C)([CX4])):2]#[C$(C(#C)([CX4])):1].[N$(N(~N)([CX4])):5]~[N]~[N]>>[c:2]1[c:1][n:5][n][n]1 +[C$(C(=C)([CX4])):2]=[C$(C(=C)([CX4])):1].[N$(N(~N)([CX4])):5]~[N]~[N]>>[C:2]1[C:1][N:5][N]=[N]1 +[C$(C(=C)([CX4,OX2,NX3])([CX4,OX2,NX3])),C$([CH](=C)([CX4,OX2,NX3])),C$([CH2](=C)):1]=[C$(C(=C)([CX4,OX2,NX3])([CX4,OX2,NX3])),C$([CH](=C)([CX4,OX2,NX3])),C$([CH2](=C)):2].[C$(C(=C)([CX4,OX2,NX3])([CX4,OX2,NX3])),C$([CH](=C)([CX4,OX2,NX3])),C$([CH2](=C)):3]=[C$([C](=C)(C)([CX4,OX2,NX3])),C$([CH](=C)(C)):4]-[C$([C](=C)(C)([CX4,OX2,NX3])),C$([CH](=C)(C)):5]=[C$(C(=C)([CX4,OX2,NX3])([CX4,OX2,NX3])),C$([CH](=C)([CX4,OX2,NX3])),C$([CH2](=C)):6]>>[C:1]1[C:2][C:3][C:4]=[C:5][C:6]1 +[C$(C(#C)([CX4,OX2,NX3])),C$([CH](#C)):1]#[C$(C(#C)([CX4,OX2,NX3])),C$([CH](#C)):2].[C$(C(=C)([CX4,OX2,NX3])([CX4,OX2,NX3])),C$([CH](=C)([CX4,OX2,NX3])),C$([CH2](=C)):3]=[C$([C](=C)(C)([CX4,OX2,NX3])),C$([CH](=C)(C)):4]-[C$([C](=C)(C)([CX4,OX2,NX3])),C$([CH](=C)(C)):5]=[C$(C(=C)([CX4,OX2,NX3])([CX4,OX2,NX3])),C$([CH](=C)([CX4,OX2,NX3])),C$([CH2](=C)):6]>>[C:1]1=[C:2][C:3][C:4]=[C:5][C:6]1 +[NH2,NH3+:3]-[N$([NH](N)([CX4])):2].[C$([CH](C)(C)([CX4])),C$([CH2](C)(C)):6](-[C$(C(=O)(C)([CX4])),C$([CH](=O)(C)):5]=[O:9])-[C$(C(=O)(C)([CX4])),C$([CH](=O)(C)):7]=[O:10]>>[c:7]1[n:3][n:2][c:5][c:6]1 +[C$(C(=O)(C)([CX4])),C$(C[H](=O)(C)):1](=[O:2])-[$([CH](C)(C)([CX4])),$([CH2](C)(C)):3]-[$([CH](C)(C)([CX4])),$([CH2](C)(C)):4]-[C$(C(=O)(C)([CX4])),C$(C[H](=O)(C)):5]=[O:6].[N$([NH2,NH3+1]([CX4])):7]>>[c:5]1[c:4][c:3][c:1][n+0:7]1 +[CH:7](=[O:8])-[c:1]1[c:2][c:3][c:4][c:5][c:6]1.[O:24]=[C:23](-[C:22](=[O:25])-[c:15]1[c:10][c:11][c:12][c:13][c:14]1)-[c:20]1[c:21][c:16][c:17][c:18][c:19]1>[NH4].[O-]C(=O)C>[nH:27]-1[c:7]([n:26][c:23]([c:22]-1[c:15]1[c:10][c:11][c:12][c:13][c:14]1)-[c:20]1[c:21][c:16][c:17][c:18][c:19]1)-[c:1]1[c:2][c:3][c:4][c:5][c:6]1 +[OH:7]-[c:6]1[cH:1][c:2][c:3][c:4][c:5]1.[O$(O(C)([CX4])):12]-[C:11](=[O:15])-[C$([CH](C)(C)([CX4])),C$([CH2](C)(C)):10]-[C:8]=[O:16]>>[C:8]-1=[C:10]-[C:11](=[O:15])-[O]-[c:6]2[c:5][c:4][c:3][c:2][c:1]-12 +[O$(O(C)([CX4])):8][C:7](=[O:9])[CH:6][C:5][C:4][C:3][C:2]([O$(O(C)([CX4])):10])=[O:1]>>[O:8][C:7](=[O:9])[C:6]1[C:5][C:4][C:3][C:2]1=[O:1] +[O$(O(C)([CX4])):8][C:7](=[O:9])[CH:6][C:5][C:11][C:4][C:3][C:2]([O$(O(C)([CX4])):10])=[O:1]>>[O:8][C:7](=[O:9])[C:6]1[C:5][C:11][C:4][C:3][C:2]1=[O:1] +[Cl:9][C:7](=[O:8])-[c:3]1[c:2][c:1][c:6][c:5][c:4]1.[C$([CH2](C)([CX4])),C$([CH3](C)):18]-[C:16](=[O:17])-[c:14]1[c:13][c:12][c:11][c:10][c:15]1-[OH:19]>>[O:17]=[C:16]-1-[C:18]=[C:7](-[O:8]-[c:15]2[c:10][c:11][c:12][c:13][c:14]-12)-[c:3]1[c:2][c:1][c:6][c:5][c:4]1 +[C$(C(C)(=O)([CX4,OX2&H0])),C$(C(C)(#N)),N$([N+1](C)(=O)([O-1])):1][C$([CH]([C,N])([C,N])([CX4])),C$([CH2]([C,N])([C,N])):2][C$(C(C)(=O)([CX4,OX2&H0])),C$(C(C)(#N)),N$([N+1](C)(=O)([O-1])):3].[C$(C(C)(#N)),C$(C(C)([CX4,OX2&H0])([CX4,OX2&H0])([OX2&H0])),C$([CH](C)([CX4,OX2&H0])([OX2&H0])),C$([CH2](C)([OX2&H0])),C$(C(C)(=O)([OX2&H0])):6][CH:5]=[C$(C(=C)([CX4])([CX4])),C$([CH](=C)([CX4])),C$([CH2](=C)):4]>>[C:6][C:5][C:4][C:2]([C:1])[C:3] +[C$([C](O)([CX4])([CX4])([CX4])),C$([CH](O)([CX4])([CX4])),C$([CH2](O)([CX4])):4]-[O:3]-[C$(C(=O)([CX4])),C$([CH](=O)):2]=[O:5].[C$([CH](C)([CX4])([CX4])),C$([CH2](C)([CX4])),C$([CH3](C)):7]-[C$(C(=O)([CX4])),C$([CH](=O)):8]=[O:9]>>[C:7](-[C:2]=[O:5])-[C:8]=[O:9] +[Cl,OH,O-:3][C$(C(=O)([CX4,c])),C$([CH](=O)):2]=[O:4].[O$([OH]([CX4,c])),O$([OH]([CX4,c])([CX4,c])),S$([SH]([CX4,c])),S$([SH]([CX4,c])([CX4,c])):6]>>[*:6]-[C:2]=[O:4] +[C$(C(=O)([CX4,c])([CX4,c])),C$([CH](=O)([CX4,c])):1]=[O:2].[N$([NH2,NH3+1]([CX4,c])),N$([NH]([CX4,c])([CX4,c])):3]>>[N+0:3][C:1] +[Br:1][c$(c(Br)),n$(n(Br)),o$(o(Br)),C$([CH](Br)(=C)):2].[C$(C(B)([CX4])([CX4])([CX4])),C$([CH](B)([CX4])([CX4])),C$([CH2](B)([CX4])),C$([CH2](B)),C$(C(B)(=C)),c$(c(B)),o$(o(B)),n$(n(B)):3][B$(B([C,c,n,o])([OH,$(OC)])([OH,$(OC)])),B$([B-1]([C,c,n,o])(N)([OH,$(OC)])([OH,$(OC)])):4]>>[C,c,n,o:2][C,c,n,o:3] +[Br,I:1][C$(C([Br,I])([CX4])([CX4])([CX4])),C$([CH]([Br,I])([CX4])([CX4])),C$([CH2]([Br,I])([CX4])),C$([CH3]([Br,I])),C$([C]([Br,I])(=C)([CX4])),C$([CH]([Br,I])(=C)),C$(C([Br,I])(#C)),c$(c([Br,I])):2].[Br,I:3][C$(C([Br,I])([CX4])([CX4])([CX4])),C$([CH]([Br,I])([CX4])([CX4])),C$([CH2]([Br,I])([CX4])),C$([CH3]([Br,I])),C$([C]([Br,I])(=C)([CX4])),C$([CH]([Br,I])(=C)),C$(C([Br,I])(#C)),c$(c([Br,I])):4]>>[C,c:2][C,c:4] +[OH,O-]-[C$(C(=O)(O)([CX4,c])):2]=[O:3].[OH:8]-[C$([CH](O)([CX4,c])([CX4,c])),C$([CH2](O)([CX4,c])),C$([CH3](O)):6]>>[C:6][O]-[C:2]=[O:3] +[C$([CH](=C)([CX4])),C$([CH2](=C)):2]=[C$(C(=C)([CX4])([CX4])),C$([CH](=C)([CX4])),C$([CH2](=C)):3].[Br,I:7][C$([CX4]([Br,I])),c$([c]([Br,I])):4]>>[C,c:4][C:2]=[C:3] +[Cl,OH,O-:3][C$(C(=O)([CX4,c])),C$([CH](=O)):2]=[O:4].[N$([NH2,NH3+1]([CX4,c])),N$([NH]([CX4,c])([CX4,c])):6]>>[N+0:6]-[C:2]=[O:4] +[C$(C(=C)([CX4])([CX4])),C$([CH](=C)([CX4])),C$([CH2](=C)):1]=[C$(C(=C)([CX4])([CX4])),C$([CH](=C)([CX4])),C$([CH2](=C)):2].[SH:4]-[CX4:5][Br,Cl,I]>>[C:1]-[C:2]-[S:4][C:5] +[C$([C](=O)([CX4])),C$([CH](=O)):2](=[O:1])[OH,Cl,O-:6].[SH:4]-[CX4:5][Br,Cl,I]>>[CH2:2]-[S:4][C:5] +[I:1][C$(C(I)([CX4,c])([CX4,c])([CX4,c])),C$([CH](I)([CX4,c])([CX4,c])),C$([CH2](I)([CX4,c])),C$([CH3](I)):2].[C$(C(=O)([Cl,OH,O-])([CX4,c])),C$([CH]([Cl,OH,O-])(=O)):3](=[O:6])[Cl,OH,O-:5]>>[C:2]-[C:3]=[O:6] +[Cl:5][S$(S(=O)(=O)(Cl)([CX4])):2](=[O:3])=[O:4].[NH2+0,NH3+:6]-[C$(C(N)([CX4,c])([CX4,c])([CX4,c])),C$([CH](N)([CX4,c])([CX4,c])),C$([CH2](N)([CX4,c])),C$([CH3](N)),c$(c(N)):7]>>[C,c:7]-[NH+0:6][S:2](=[O:4])=[O:3] +[*:1][C:2]#[CH:3].[Br,I:4][C$(C([CX4,c])([CX4,c])([CX4,c])),C$([CH]([CX4,c])([CX4,c])),C$([CH2]([CX4,c])),C$([CH3]),c$(c):5]>>[C,c:5][C:3]#[C:2][*:1] +[C$(C(C)([CX4])([CX4])([CX4])),C$([CH](C)([CX4])([CX4])),C$([CH2](C)([CX4])),C$([CH3](C)):1][C:2]#[CH:3].[Br,I:4][C$(C(=O)([Br,I])([CX4])),C$([CH](=O)([Br,I])):5]=[O:6]>>[C:1][C:2]#[C:3][C:5]=[O:6] +[OH,O-:4]-[C$(C(=O)([OH,O-])([CX4])),C$([CH](=O)([OH,O-])):2]=[O:3]>>[Cl:5][C:2]=[O:3] +[OH:2]-[$([CX4]),c:1]>>[Br:3][C,c:1] +[OH:2]-[$([CX4]),c:1]>>[Cl:3][C,c:1] +[OH,O-:3][S$(S([CX4])):2](=[O:4])=[O:5]>>[Cl:6][S:2](=[O:5])=[O:4] +[OH+0,O-:5]-[C:3](=[O:4])-[C$([CH]([CX4])),C$([CH2]):2]>>[OH+0,O-:5]-[C:3](=[O:4])-[C:2]([Br:6]) +[OH+0,O-:5]-[C:3](=[O:4])-[C$([CH]([CX4])),C$([CH2]):2]>>[OH+0,O-:5]-[C:3](=[O:4])-[C:2]([Cl:6]) +[Cl,I,Br:7][c:1]1[c:2][c:3][c:4][c:5][c:6]1>>[N:9]#[C:8][c:1]1[c:2][c:3][c:4][c:5][c:6]1 +[OH,NH2,NH3+:3]-[CH2:2]-[C$(C([CX4,c])([CX4,c])([CX4,c])),C$([CH]([CX4,c])([CX4,c])),C$([CH2]([CX4,c])),C$([CH3]),c$(c):1]>>[C,c:1][C:2]#[N:4] diff --git a/data/rxn_set_hb.txt b/data/rxn_set_hb.txt deleted file mode 100644 index fa917eff..00000000 --- a/data/rxn_set_hb.txt +++ /dev/null @@ -1,91 +0,0 @@ -|[cH1:1]1:[c:2](-[CH2:7]-[CH2:8]-[NH2:9]):[c:3]:[c:4]:[c:5]:[c:6]:1.[#6:11]-[CH1;R0:10]=[OD1]>>[c:1]12:[c:2](-[CH2:7]-[CH2:8]-[NH1:9]-[C:10]-2(-[#6:11])):[c:3]:[c:4]:[c:5]:[c:6]:1 -|[c;r6:1](-[NH1;$(N-[#6]):2]):[c;r6:3](-[NH2:4]).[#6:6]-[C;R0:5](=[OD1])-[#8;H1,$(O-[CH3])]>>[c:3]2:[c:1]:[n:2]:[c:5](-[#6:6]):[n:4]2 -|[c;r6:1](-[NH1;$(N-[#6]):2]):[c;r6:3](-[NH2:4]).[#6:6]-[CH1;R0:5](=[OD1])>>[c:3]2:[c:1]:[n:2]:[c:5](-[#6:6]):[n:4]2 -|[c;r6:1](-[SH1:2]):[c;r6:3](-[NH2:4]).[#6:6]-[CH1;R0:5](=[OD1])>>[c:3]2:[c:1]:[s:2]:[c:5](-[#6:6]):[n:4]2 -|[c:1](-[OH1;$(Oc1ccccc1):2]):[c;r6:3](-[NH2:4]).[c:6]-[CH1;R0:5](=[OD1])>>[c:3]2:[c:1]:[o:2]:[c:5](-[c:6]):[n:4]2 -|[c;r6:1](-[OH1:2]):[c;r6:3](-[NH2:4]).[#6:6]-[C;R0:5](=[OD1])-[OH1]>>[c:3]2:[c:1]:[o:2]:[c:5](-[#6:6]):[n:4]2 -|[#6:6]-[C;R0:1](=[OD1])-[CH1;R0:5](-[#6:7])-[*;#17,#35,#53].[NH2:2]-[C:3]=[SD1:4]>>[c:1]2(-[#6:6]):[n:2]:[c:3]:[s:4][c:5]([#6:7]):2 -|[c:1](-[C;$(C-c1ccccc1):2](=[OD1:3])-[OH1]):[c:4](-[NH2:5]).[N;!H0;!$(N-N);!$(N-C=N);!$(N(-C=O)-C=O):6]-[C;H1,$(C-[#6]):7]=[OD1]>>[c:4]2:[c:1]-[C:2](=[O:3])-[N:6]-[C:7]=[N:5]-2 -|[CH0;$(C-[#6]):1]#[NH0:2]>>[C:1]1=[N:2]-N-N=N-1 -|[CH0;$(C-[#6]):1]#[NH0:2].[C;A;!$(C=O):3]-[*;#17,#35,#53]>>[C:1]1=[N:2]-N(-[C:3])-N=N-1 -|[CH0;$(C-[#6]):1]#[NH0:2].[C;A;!$(C=O):3]-[*;#17,#35,#53]>>[C:1]1=[N:2]-N=N-N-1(-[C:3]) -|[CH0;$(C-[#6]):1]#[CH1:2].[C;H1,H2;A;!$(C=O):3]-[*;#17,#35,#53,OH1]>>[C:1]1=[C:2]-N(-[C:3])-N=N-1 -|[CH0;$(C-[#6]):1]#[CH1:2].[C;H1,H2;A;!$(C=O):3]-[*;#17,#35,#53,OH1]>>[C:1]1=[C:2]-N=NN(-[C:3])-1 -|[CH0;$(C-[#6]):1]#[CH0;$(C-[#6]):2].[C;H1,H2;A;!$(C=O):3]-[*;#17,#35,#53,OH1]>>[C:1]1=[C:2]-N=NN(-[C:3])-1 -|[CH0;$(C-[#6]):1]#[NH0:2].[NH2:3]-[NH1:4]-[CH0;$(C-[#6]);R0:5]=[OD1]>>[N:2]1-[C:1]=[N:3]-[N:4]-[C:5]=1 -|[CH0;$(C-[#6]):1]#[NH0:2].[CH0;$(C-[#6]);R0:5](=[OD1])-[#8;H1,$(O-[CH3]),$(O-[CH2]-[CH3])]>>[N:2]1-[C:1]=N-N-[C:5]=1 -|[c:1](-[C;$(C-c1ccccc1):2](=[OD1:3])-[CH3:4]):[c:5](-[OH1:6]).[C;$(C1-[CH2]-[CH2]-[N,C]-[CH2]-[CH2]-1):7](=[OD1])>>[O:6]1-[c:5]:[c:1]-[C:2](=[OD1:3])-[C:4]-[C:7]-1 -|[c;r6:1](-[C;$(C=O):6]-[OH1]):[c;r6:2]-[C;H1,$(C-C):3]=[OD1].[NH2:4]-[NH1;$(N-[#6]);!$(NC=[O,S,N]):5]>>[c:1]1:[c:2]-[C:3]=[N:4]-[N:5]-[C:6]-1 -|[C;$(C-c1ccccc1):1](=[OD1])-[C;D3;$(C-c1ccccc1):2]~[O;D1,H1].[CH1;$(C-c):3]=[OD1]>>[C:1]1-N=[C:3]-[NH1]-[C:2]=1 -|[NH1;$(N-c1ccccc1):1](-[NH2])-[c:5]:[cH1:4].[C;$(C([#6])[#6]):2](=[OD1])-[CH2;$(C([#6])[#6]);!$(C(C=O)C=O):3]>>[C:5]1-[N:1]-[C:2]=[C:3]-[C:4]:1 -|[NH2;$(N-c1ccccc1):1]-[c:2]:[c:3]-[CH1:4]=[OD1].[C;$(C([#6])[#6]):6](=[OD1])-[CH2;$(C([#6])[#6]);!$(C(C=O)C=O):5]>>[N:1]1-[c:2]:[c:3]-[C:4]=[C:5]-[C:6]:1 -|[*;Br,I;$(*c1ccccc1)]-[c:1]:[c:2]-[OH1:3].[CH1:5]#[C;$(C-[#6]):4]>>[c:1]1:[c:2]-[O:3]-[C:4]=[C:5]-1 -|[*;Br,I;$(*c1ccccc1)]-[c:1]:[c:2]-[SD2:3]-[CH3].[CH1:5]#[C;$(C-[#6]):4]>>[c:1]1:[c:2]-[S:3]-[C:4]=[C:5]-1 -|[*;Br,I;$(*c1ccccc1)]-[c:1]:[c:2]-[NH2:3].[CH1:5]#[C;$(C-[#6]):4]>>[c:1]1:[c:2]-[N:3]-[C:4]=[C:5]-1 -|[#6:6][C:5]#[#7;D1:4].[#6:1][C:2](=[OD1:3])[OH1]>>[#6:6][c:5]1[n:4][o:3][c:2]([#6:1])n1 -|[#6;$([#6]~[#6]);!$([#6]=O):2][#8;H1:3].[Cl,Br,I][#6;H2;$([#6]~[#6]):4]>>[CH2:4][O:3][#6:2] -|[#6;H0;D3;$([#6](~[#6])~[#6]):1]B(O)O.[#6;H0;D3;$([#6](~[#6])~[#6]):2][Cl,Br,I]>>[#6:2][#6:1] -|[c;H1:3]1:[c:4]:[c:5]:[c;H1:6]:[c:7]2:[nH:8]:[c:9]:[c;H1:1]:[c:2]:1:2.O=[C:10]1[#6;H2:11][#6;H2:12][N:13][#6;H2:14][#6;H2:15]1>>[#6;H2:12]3[#6;H1:11]=[C:10]([c:1]1:[c:9]:[n:8]:[c:7]2:[c:6]:[c:5]:[c:4]:[c:3]:[c:2]:1:2)[#6;H2:15][#6;H2:14][N:13]3 -|[C;H1&$(C([#6])[#6]),H2&$(C[#6]):1][OH1].[NH1;$(N(C=O)C=O):2]>>[C:1][N:2] -|[C;H1&$(C([#6])[#6]),H2&$(C[#6]):1][OH1].[OH1;$(Oc1ccccc1):2]>>[C:1][O:2] -|[C;H1&$(C([#6])[#6]),H2&$(C[#6]):1][OH1].[NH1;$(N([#6])S(=O)=O):2]>>[C:1][N:2] -|[C;H1&$(C([#6])[#6]),H2&$(C[#6]):1][OH1].[#7H1:2]1~[#7:3]~[#7:4]~[#7:5]~[#6:6]~1>>[C:1][#7:2]1:[#7:3]:[#7:4]:[#7:5]:[#6:6]:1 -|[C;H1&$(C([#6])[#6]),H2&$(C[#6]):1][OH1].[#7H1:2]1~[#7:3]~[#7:4]~[#7:5]~[#6:6]~1>>[#7H0:2]1:[#7:3]:[#7H0:4]([C:1]):[#7:5]:[#6:6]:1 -|[C;H1&$(C([#6])[#6]),H2&$(C[#6]):1][OH1].[#7:2]1~[#7:3]~[#7H1:4]~[#7:5]~[#6:6]~1>>[C:1][#7H0:2]1:[#7:3]:[#7H0:4]:[#7:5]:[#6:6]:1 -|[C;H1&$(C([#6])[#6]),H2&$(C[#6]):1][OH1].[#7:2]1~[#7:3]~[#7H1:4]~[#7:5]~[#6:6]~1>>[#7:2]1:[#7:3]:[#7:4]([C:1]):[#7:5]:[#6:6]:1 -|[#6;$(C=C-[#6]),$(c:c):1][Br,I].[Cl,Br,I][c:2]>>[c:2][#6:1] -|[#6:1][C:2]#[#7;D1].[Cl,Br,I][#6;$([#6]~[#6]);!$([#6]([Cl,Br,I])[Cl,Br,I]);!$([#6]=O):3]>>[#6:1][C:2](=O)[#6:3] -|[#6:1][C;H1,$([C]([#6])[#6]):2]=[OD1:3].[Cl,Br,I][#6;$([#6]~[#6]);!$([#6]([Cl,Br,I])[Cl,Br,I]);!$([#6]=O):4]>>[C:1][#6:2]([OH1:3])[#6:4] -|[S;$(S(=O)(=O)[C,N]):1][Cl].[N;$(NC);!$(N=*);!$([N-]);!$(N#*);!$([ND3]);!$([ND4]);!$(N[c,O]);!$(N[C,S]=[S,O,N]):2]>>[S:1][N+0:2] -|[c:1]B(O)O.[nH1;+0;r5;!$(n[#6]=[O,S,N]);!$(n~n~n);!$(n~n~c~n);!$(n~c~n~n):2]>>[c:1][n:2] -|[#6:3]-[C;H1,$([CH0](-[#6])[#6]);!$(CC=O):1]=[OD1].[Cl,Br,I][C;H2;$(C-[#6]);!$(CC[I,Br]);!$(CCO[CH3]):2]>>[C:3][C:1]=[C:2] -|[Cl,Br,I][c;$(c1:[c,n]:[c,n]:[c,n]:[c,n]:[c,n]:1):1].[N;$(NC)&!$(N=*)&!$([N-])&!$(N#*)&!$([ND3])&!$([ND4])&!$(N[c,O])&!$(N[C,S]=[S,O,N]),H2&$(Nc1:[c,n]:[c,n]:[c,n]:[c,n]:[c,n]:1):2]>>[c:1][N:2] -|[C;$(C([#6])[#6;!$([#6]Br)]):4](=[OD1])[CH;$(C([#6])[#6]):5]Br.[#7;H2:3][C;$(C(=N)(N)[c,#7]):2]=[#7;H1;D1:1]>>[C:4]1=[CH0:5][NH:3][C:2]=[N:1]1 -|[c;$(c1[c;$(c[C,S,N](=[OD1])[*;R0;!OH1])]cccc1):1][C;$(C(=O)[O;H1])].[c;$(c1aaccc1):2][Cl,Br,I]>>[c:1][c:2] -|[c;!$(c1ccccc1);$(c1[n,c]c[n,c]c[n,c]1):1][Cl,F].[N;$(NC);!$(N=*);!$([N-]);!$(N#*);!$([ND3]);!$([ND4]);!$(N[c,O]);!$(N[C,S]=[S,O,N]):2]>>[c:1][N:2] -|[c;$(c1c(N(~O)~O)cccc1):1][Cl,F].[N;$(NC);!$(N=*);!$([N-]);!$(N#*);!$([ND3]);!$([ND4]);!$(N[c,O]);!$(N[C,S]=[S,O,N]):2]>>[c:1][N:2] -|[c;$(c1ccc(N(~O)~O)cc1):1][Cl,F].[N;$(NC);!$(N=*);!$([N-]);!$(N#*);!$([ND3]);!$([ND4]);!$(N[c,O]);!$(N[C,S]=[S,O,N]):2]>>[c:1][N:2] -|[N;$(N-[#6]):3]=[C;$(C=O):1].[N;$(N[#6]);!$(N=*);!$([N-]);!$(N#*);!$([ND3]);!$([ND4]);!$(N[O,N]);!$(N[C,S]=[S,O,N]):2]>>[N:3]-[C:1]-[N+0:2] -|[N;$(N-[#6]):3]=[C;$(C=S):1].[N;$(N[#6]);!$(N=*);!$([N-]);!$(N#*);!$([ND3]);!$([ND4]);!$(N[O,N]);!$(N[C,S]=[S,O,N]):2]>>[N:3]-[C:1]-[N+0:2] -|[$(C([CH2,CH3])),CH:10](=[O:11])-[NH+0:9]-[C$(C(N)(C)(C)(C)),C$([CH](N)(C)(C)),C$([CH2](N)(C)):8]-[C$(C(c)(C)(C)(C)),C$([CH](c)(C)(C)),C$([CH2](c)(C)):7]-[c:6]1[cH:1][c:2][c:3][c:4][c:5]1>>[C:10]-1=[N+0:9]-[C:8]-[C:7]-[c:6]2[c:5][c:4][c:3][c:2][c:1]-12 -|[$(C([CH2,CH3])),CH:10](=[O:11])-[NH+0:9]-[C$([CH](N)(C)(C)),C$([CH2](N)(C)):8]-[C$([C](c)(C)(C)),C$([CH](c)(C)):7]([O$(OC),OH])-[c:6]1[cH:1][c:2][c:3][c:4][c:5]1>>[c:10]-1[n:9][c:8][c:7][c:6]2[c:5][c:4][c:3][c:2][c:1]-12 -|[NH3+,NH2]-[C$(C(N)(C)(C)(C)),C$([CH](N)(C)(C)),C$([CH2](N)(C)):8]-[C$(C(c)(C)(C)(C)),C$([CH](c)(C)(C)),C$([CH2](c)(C)):7]-[c:6]1[c:1][c:2][nH:3][cH:5]1.[CH:10](-[CX4:12])=[O:11]>>[c,C:12]-[CH:10]-1-[N]-[C:8]-[C:7]-[c:6]2[c:1][c:2][nH:3][c:5]-12 -|[NH2,NH3+1:8]-[c:5]1[cH:4][c:3][c:2][c:1][c:6]1.[Br:18][C$([CH2](C)(Br)),C$([CH](C)(C)(Br)):17]-[C:15](=[O:16])-[c:10]1[c:11][c:12][c:13][c:14][c:9]1>>[c:13]1[c:12][c:11][c:10]([c:9][c:14]1)-[c:15]1[c:17][c:4]2[c:3][c:2][c:1][c:6][c:5]2[nH+0:8]1 -|[Cl:1][CH2:2]-[C$([CH](C)),C$(C(C)(C)):3]=[O:4].[OH:12]-[c:11]1[c:6][c:7][c:8][c:9][c:10]1-[CH:13]=[O:14]>>[C:3](=[O:4])-[c:2]1[c:13][c:10]2[c:9][c:8][c:7][c:6][c:11]2[o:12]1 -|[NH2,NH3+]-[C$([CX4](N)([c,C])([c,C])([c,C])),C$([CH](N)([c,C])([c,C])),C$([CH2](N)([c,C])),C$([CH3](N)):2].[NH2:12]-[c:7]1[c:6][c:5][c:4][c:3][c:8]1-[C:9](-[OH,O-:11])=[O:10]>>[C:2]-[n+0]-1[c:13][n:12][c:7]2[c:6][c:5][c:4][c:3][c:8]2[c:9]-1=[O:10] -|[N$([NH2]([CX4])),N$([NH3+1]([CX4])):1].[O:5]-[C$([CH]([CX4])(C)(O)),C$([CH2]([CX4])(O)):3][C$(C([CX4])(=O)([CX4])),C$([CH]([CX4])(=O)):4]=[O:6]>[O:15]=[C:9]-1-[CH2:10]-[CH2:11]-[CH2:12]-[CH2:13]-[CH2:14]-1>[c:4]1[c:3][n+0:1][c:10]2-[C:11]-[C:12]-[C:13]-[C:14]-[c:9]12 -|[C$(C(=O)([CX4])([CX4])),C$([CH](=O)([CX4])):2](=[O:6])-[C$([CH]([CX4])),C$([CH2]):3]-[C$(C(=O)([CX4])([CX4])),C$([CH](=O)([CX4])):4]=[O:7].[NH2:8]-[C:9](=[O:10])-[CH2:11][C:12]#[N:13]>>[OH:10]-[c:9]1[n:8][c:4][c:3][c:2][c:11]1[C:12]#[N:13] -|[C$(C(#C)([CX4])):2]#[C$(C(#C)([CX4])):1].[N$(N(~N)([CX4])):5]~[N]~[N]>>[c:2]1[c:1][n:5][n][n]1 -|[C$(C(=C)([CX4])):2]=[C$(C(=C)([CX4])):1].[N$(N(~N)([CX4])):5]~[N]~[N]>>[C:2]1[C:1][N:5][N]=[N]1 -|[C$(C(=C)([CX4,OX2,NX3])([CX4,OX2,NX3])),C$([CH](=C)([CX4,OX2,NX3])),C$([CH2](=C)):1]=[C$(C(=C)([CX4,OX2,NX3])([CX4,OX2,NX3])),C$([CH](=C)([CX4,OX2,NX3])),C$([CH2](=C)):2].[C$(C(=C)([CX4,OX2,NX3])([CX4,OX2,NX3])),C$([CH](=C)([CX4,OX2,NX3])),C$([CH2](=C)):3]=[C$([C](=C)(C)([CX4,OX2,NX3])),C$([CH](=C)(C)):4]-[C$([C](=C)(C)([CX4,OX2,NX3])),C$([CH](=C)(C)):5]=[C$(C(=C)([CX4,OX2,NX3])([CX4,OX2,NX3])),C$([CH](=C)([CX4,OX2,NX3])),C$([CH2](=C)):6]>>[C:1]1[C:2][C:3][C:4]=[C:5][C:6]1 -|[C$(C(#C)([CX4,OX2,NX3])),C$([CH](#C)):1]#[C$(C(#C)([CX4,OX2,NX3])),C$([CH](#C)):2].[C$(C(=C)([CX4,OX2,NX3])([CX4,OX2,NX3])),C$([CH](=C)([CX4,OX2,NX3])),C$([CH2](=C)):3]=[C$([C](=C)(C)([CX4,OX2,NX3])),C$([CH](=C)(C)):4]-[C$([C](=C)(C)([CX4,OX2,NX3])),C$([CH](=C)(C)):5]=[C$(C(=C)([CX4,OX2,NX3])([CX4,OX2,NX3])),C$([CH](=C)([CX4,OX2,NX3])),C$([CH2](=C)):6]>>[C:1]1=[C:2][C:3][C:4]=[C:5][C:6]1 -|[NH2,NH3+:3]-[N$([NH](N)([CX4])):2].[C$([CH](C)(C)([CX4])),C$([CH2](C)(C)):6](-[C$(C(=O)(C)([CX4])),C$([CH](=O)(C)):5]=[O:9])-[C$(C(=O)(C)([CX4])),C$([CH](=O)(C)):7]=[O:10]>>[c:7]1[n:3][n:2][c:5][c:6]1 -|[C$(C(=O)(C)([CX4])),C$(C[H](=O)(C)):1](=[O:2])-[$([CH](C)(C)([CX4])),$([CH2](C)(C)):3]-[$([CH](C)(C)([CX4])),$([CH2](C)(C)):4]-[C$(C(=O)(C)([CX4])),C$(C[H](=O)(C)):5]=[O:6].[N$([NH2,NH3+1]([CX4])):7]>>[c:5]1[c:4][c:3][c:1][n+0:7]1 -|[CH:7](=[O:8])-[c:1]1[c:2][c:3][c:4][c:5][c:6]1.[O:24]=[C:23](-[C:22](=[O:25])-[c:15]1[c:10][c:11][c:12][c:13][c:14]1)-[c:20]1[c:21][c:16][c:17][c:18][c:19]1>[NH4].[O-]C(=O)C>[nH:27]-1[c:7]([n:26][c:23]([c:22]-1[c:15]1[c:10][c:11][c:12][c:13][c:14]1)-[c:20]1[c:21][c:16][c:17][c:18][c:19]1)-[c:1]1[c:2][c:3][c:4][c:5][c:6]1 -|[OH:7]-[c:6]1[cH:1][c:2][c:3][c:4][c:5]1.[O$(O(C)([CX4])):12]-[C:11](=[O:15])-[C$([CH](C)(C)([CX4])),C$([CH2](C)(C)):10]-[C:8]=[O:16]>>[C:8]-1=[C:10]-[C:11](=[O:15])-[O]-[c:6]2[c:5][c:4][c:3][c:2][c:1]-12 -|[O$(O(C)([CX4])):8][C:7](=[O:9])[CH:6][C:5][C:4][C:3][C:2]([O$(O(C)([CX4])):10])=[O:1]>>[O:8][C:7](=[O:9])[C:6]1[C:5][C:4][C:3][C:2]1=[O:1] -|[O$(O(C)([CX4])):8][C:7](=[O:9])[CH:6][C:5][C:11][C:4][C:3][C:2]([O$(O(C)([CX4])):10])=[O:1]>>[O:8][C:7](=[O:9])[C:6]1[C:5][C:11][C:4][C:3][C:2]1=[O:1] -|[Cl:9][C:7](=[O:8])-[c:3]1[c:2][c:1][c:6][c:5][c:4]1.[C$([CH2](C)([CX4])),C$([CH3](C)):18]-[C:16](=[O:17])-[c:14]1[c:13][c:12][c:11][c:10][c:15]1-[OH:19]>>[O:17]=[C:16]-1-[C:18]=[C:7](-[O:8]-[c:15]2[c:10][c:11][c:12][c:13][c:14]-12)-[c:3]1[c:2][c:1][c:6][c:5][c:4]1 -|[C$(C(C)(=O)([CX4,OX2&H0])),C$(C(C)(#N)),N$([N+1](C)(=O)([O-1])):1][C$([CH]([C,N])([C,N])([CX4])),C$([CH2]([C,N])([C,N])):2][C$(C(C)(=O)([CX4,OX2&H0])),C$(C(C)(#N)),N$([N+1](C)(=O)([O-1])):3].[C$(C(C)(#N)),C$(C(C)([CX4,OX2&H0])([CX4,OX2&H0])([OX2&H0])),C$([CH](C)([CX4,OX2&H0])([OX2&H0])),C$([CH2](C)([OX2&H0])),C$(C(C)(=O)([OX2&H0])):6][CH:5]=[C$(C(=C)([CX4])([CX4])),C$([CH](=C)([CX4])),C$([CH2](=C)):4]>>[C:6][C:5][C:4][C:2]([C:1])[C:3] -|[C$([C](O)([CX4])([CX4])([CX4])),C$([CH](O)([CX4])([CX4])),C$([CH2](O)([CX4])):4]-[O:3]-[C$(C(=O)([CX4])),C$([CH](=O)):2]=[O:5].[C$([CH](C)([CX4])([CX4])),C$([CH2](C)([CX4])),C$([CH3](C)):7]-[C$(C(=O)([CX4])),C$([CH](=O)):8]=[O:9]>>[C:7](-[C:2]=[O:5])-[C:8]=[O:9] -|[Cl,OH,O-:3][C$(C(=O)([CX4,c])),C$([CH](=O)):2]=[O:4].[O$([OH]([CX4,c])),O$([OH]([CX4,c])([CX4,c])),S$([SH]([CX4,c])),S$([SH]([CX4,c])([CX4,c])):6]>>[*:6]-[C:2]=[O:4] -|[C$(C(=O)([CX4,c])([CX4,c])),C$([CH](=O)([CX4,c])):1]=[O:2].[N$([NH2,NH3+1]([CX4,c])),N$([NH]([CX4,c])([CX4,c])):3]>>[N+0:3][C:1] -|[Br:1][c$(c(Br)),n$(n(Br)),o$(o(Br)),C$([CH](Br)(=C)):2].[C$(C(B)([CX4])([CX4])([CX4])),C$([CH](B)([CX4])([CX4])),C$([CH2](B)([CX4])),C$([CH2](B)),C$(C(B)(=C)),c$(c(B)),o$(o(B)),n$(n(B)):3][B$(B([C,c,n,o])([OH,$(OC)])([OH,$(OC)])),B$([B-1]([C,c,n,o])(N)([OH,$(OC)])([OH,$(OC)])):4]>>[C,c,n,o:2][C,c,n,o:3] -|[Br,I:1][C$(C([Br,I])([CX4])([CX4])([CX4])),C$([CH]([Br,I])([CX4])([CX4])),C$([CH2]([Br,I])([CX4])),C$([CH3]([Br,I])),C$([C]([Br,I])(=C)([CX4])),C$([CH]([Br,I])(=C)),C$(C([Br,I])(#C)),c$(c([Br,I])):2].[Br,I:3][C$(C([Br,I])([CX4])([CX4])([CX4])),C$([CH]([Br,I])([CX4])([CX4])),C$([CH2]([Br,I])([CX4])),C$([CH3]([Br,I])),C$([C]([Br,I])(=C)([CX4])),C$([CH]([Br,I])(=C)),C$(C([Br,I])(#C)),c$(c([Br,I])):4]>>[C,c:2][C,c:4] -|[OH,O-]-[C$(C(=O)(O)([CX4,c])):2]=[O:3].[OH:8]-[C$([CH](O)([CX4,c])([CX4,c])),C$([CH2](O)([CX4,c])),C$([CH3](O)):6]>>[C:6][O]-[C:2]=[O:3] -|[C$([CH](=C)([CX4])),C$([CH2](=C)):2]=[C$(C(=C)([CX4])([CX4])),C$([CH](=C)([CX4])),C$([CH2](=C)):3].[Br,I:7][C$([CX4]([Br,I])),c$([c]([Br,I])):4]>>[C,c:4][C:2]=[C:3] -|[Cl,OH,O-:3][C$(C(=O)([CX4,c])),C$([CH](=O)):2]=[O:4].[N$([NH2,NH3+1]([CX4,c])),N$([NH]([CX4,c])([CX4,c])):6]>>[N+0:6]-[C:2]=[O:4] -|[C$(C(=C)([CX4])([CX4])),C$([CH](=C)([CX4])),C$([CH2](=C)):1]=[C$(C(=C)([CX4])([CX4])),C$([CH](=C)([CX4])),C$([CH2](=C)):2].[SH:4]-[CX4:5][Br,Cl,I]>>[C:1]-[C:2]-[S:4][C:5] -|[C$([C](=O)([CX4])),C$([CH](=O)):2](=[O:1])[OH,Cl,O-:6].[SH:4]-[CX4:5][Br,Cl,I]>>[CH2:2]-[S:4][C:5] -|[I:1][C$(C(I)([CX4,c])([CX4,c])([CX4,c])),C$([CH](I)([CX4,c])([CX4,c])),C$([CH2](I)([CX4,c])),C$([CH3](I)):2].[C$(C(=O)([Cl,OH,O-])([CX4,c])),C$([CH]([Cl,OH,O-])(=O)):3](=[O:6])[Cl,OH,O-:5]>>[C:2]-[C:3]=[O:6] -|[Cl:5][S$(S(=O)(=O)(Cl)([CX4])):2](=[O:3])=[O:4].[NH2+0,NH3+:6]-[C$(C(N)([CX4,c])([CX4,c])([CX4,c])),C$([CH](N)([CX4,c])([CX4,c])),C$([CH2](N)([CX4,c])),C$([CH3](N)),c$(c(N)):7]>>[C,c:7]-[NH+0:6][S:2](=[O:4])=[O:3] -|[*:1][C:2]#[CH:3].[Br,I:4][C$(C([CX4,c])([CX4,c])([CX4,c])),C$([CH]([CX4,c])([CX4,c])),C$([CH2]([CX4,c])),C$([CH3]),c$(c):5]>>[C,c:5][C:3]#[C:2][*:1] -|[C$(C(C)([CX4])([CX4])([CX4])),C$([CH](C)([CX4])([CX4])),C$([CH2](C)([CX4])),C$([CH3](C)):1][C:2]#[CH:3].[Br,I:4][C$(C(=O)([Br,I])([CX4])),C$([CH](=O)([Br,I])):5]=[O:6]>>[C:1][C:2]#[C:3][C:5]=[O:6] -|[OH,O-:4]-[C$(C(=O)([OH,O-])([CX4])),C$([CH](=O)([OH,O-])):2]=[O:3]>>[Cl:5][C:2]=[O:3] -|[OH:2]-[$([CX4]),c:1]>>[Br:3][C,c:1] -|[OH:2]-[$([CX4]),c:1]>>[Cl:3][C,c:1] -|[OH,O-:3][S$(S([CX4])):2](=[O:4])=[O:5]>>[Cl:6][S:2](=[O:5])=[O:4] -|[OH+0,O-:5]-[C:3](=[O:4])-[C$([CH]([CX4])),C$([CH2]):2]>>[OH+0,O-:5]-[C:3](=[O:4])-[C:2]([Br:6]) -|[OH+0,O-:5]-[C:3](=[O:4])-[C$([CH]([CX4])),C$([CH2]):2]>>[OH+0,O-:5]-[C:3](=[O:4])-[C:2]([Cl:6]) -|[Cl,I,Br:7][c:1]1[c:2][c:3][c:4][c:5][c:6]1>>[N:9]#[C:8][c:1]1[c:2][c:3][c:4][c:5][c:6]1 -|[OH,NH2,NH3+:3]-[CH2:2]-[C$(C([CX4,c])([CX4,c])([CX4,c])),C$([CH]([CX4,c])([CX4,c])),C$([CH2]([CX4,c])),C$([CH3]),c$(c):1]>>[C,c:1][C:2]#[N:4] diff --git a/environment.yml b/environment.yml index d4e5fcc7..6979cff9 100644 --- a/environment.yml +++ b/environment.yml @@ -1,204 +1,27 @@ -name: synthenv +name: synnet channels: - pytorch - - nvidia + # - dglteam # only needed for gin - conda-forge - - defaults dependencies: - - _libgcc_mutex=0.1=conda_forge - - _openmp_mutex=4.5=1_gnu - - blas=1.0=mkl - - boost=1.74.0=py39h5472131_3 - - boost-cpp=1.74.0=h312852a_4 - - bzip2=1.0.8=h7f98852_4 - - ca-certificates=2021.5.30=ha878542_0 - - cairo=1.16.0=h6cf1ce9_1008 - - certifi=2021.5.30=py39hf3d152e_0 - - cudatoolkit=11.1.74=h6bb024c_0 - - cycler=0.10.0=py_2 - - fontconfig=2.13.1=hba837de_1005 - - freetype=2.10.4=h0708190_1 - - gettext=0.19.8.1=h0b5b191_1005 - - greenlet=1.1.0=py39he80948d_0 - - icu=68.1=h58526e2_0 - - intel-openmp=2021.3.0=h06a4308_3350 - - jbig=2.1=h7f98852_2003 - - jpeg=9d=h36c2ea0_0 - - kiwisolver=1.3.1=py39h1a9c180_1 - - lcms2=2.12=hddcbb42_0 - - ld_impl_linux-64=2.36.1=hea4e1c9_1 - - lerc=2.2.1=h9c3ff4c_0 - - libdeflate=1.7=h7f98852_5 - - libffi=3.3=h58526e2_2 - - libgcc-ng=9.3.0=h2828fa1_19 - - libgfortran-ng=9.3.0=hff62375_19 - - libgfortran5=9.3.0=hff62375_19 - - libglib=2.68.3=h3e27bee_0 - - libgomp=9.3.0=h2828fa1_19 - - libiconv=1.16=h516909a_0 - - libopenblas=0.3.15=pthreads_h8fe5266_1 - - libpng=1.6.37=h21135ba_2 - - libstdcxx-ng=9.3.0=h6de172a_19 - - libtiff=4.3.0=hf544144_1 - - libuuid=2.32.1=h7f98852_1000 - - libuv=1.40.0=h7b6447c_0 - - libwebp-base=1.2.0=h7f98852_2 - - libxcb=1.13=h7f98852_1003 - - libxml2=2.9.12=h72842e0_0 - - lz4-c=1.9.3=h9c3ff4c_0 - - matplotlib-base=3.4.2=py39h2fa2bec_0 - - mkl=2021.3.0=h06a4308_520 - - mkl-service=2.4.0=py39h7f8727e_0 - - mkl_fft=1.3.0=py39h42c9631_2 - - mkl_random=1.2.2=py39h51133e4_0 - - ncurses=6.2=h58526e2_4 - - ninja=1.10.2=hff7bd54_1 - - numpy=1.20.3=py39hf144106_0 - - numpy-base=1.20.3=py39h74d4b33_0 - - olefile=0.46=pyh9f0ad1d_1 - - openjpeg=2.4.0=hb52868f_1 - - openssl=1.1.1k=h7f98852_0 - - pandas=1.3.0=py39hde0f152_0 - - pcre=8.45=h9c3ff4c_0 - - pillow=8.3.1=py39ha612740_0 - - pip=21.1.3=pyhd8ed1ab_0 - - pixman=0.40.0=h36c2ea0_0 - - pthread-stubs=0.4=h36c2ea0_1001 - - pycairo=1.20.1=py39hedcb9fc_0 - - pyparsing=2.4.7=pyh9f0ad1d_0 - - python=3.9.6=h49503c6_1_cpython - - python-dateutil=2.8.2=pyhd8ed1ab_0 - - python_abi=3.9=2_cp39 - - pytorch=1.9.0=py3.9_cuda11.1_cudnn8.0.5_0 - - pytz=2021.1=pyhd8ed1ab_0 - - rdkit=2021.03.4=py39hccf6a74_0 - - readline=8.1=h46c0cb4_0 - - reportlab=3.5.68=py39he59360d_0 - - setuptools=49.6.0=py39hf3d152e_3 - - six=1.16.0=pyh6c4a22f_0 - - sqlalchemy=1.4.21=py39h3811e60_0 - - sqlite=3.36.0=h9cd32fc_0 - - tk=8.6.10=h21135ba_1 - - torchaudio=0.9.0=py39 - - torchvision=0.2.2=py_3 - - tornado=6.1=py39h3811e60_1 - - typing_extensions=3.10.0.0=pyh06a4308_0 - - tzdata=2021a=he74cb21_1 - - wheel=0.36.2=pyhd3deb0d_0 - - xorg-kbproto=1.0.7=h7f98852_1002 - - xorg-libice=1.0.10=h7f98852_0 - - xorg-libsm=1.2.3=hd9c2040_1000 - - xorg-libx11=1.7.2=h7f98852_0 - - xorg-libxau=1.0.9=h7f98852_0 - - xorg-libxdmcp=1.1.3=h7f98852_0 - - xorg-libxext=1.3.4=h7f98852_1 - - xorg-libxrender=0.9.10=h7f98852_1003 - - xorg-renderproto=0.11.1=h7f98852_1002 - - xorg-xextproto=7.3.0=h7f98852_1002 - - xorg-xproto=7.0.31=h7f98852_1007 - - xz=5.2.5=h516909a_1 - - zlib=1.2.11=h516909a_1010 - - zstd=1.5.0=ha95c52a_0 + - python=3.9.* + - pytorch::torchvision + - pytorch::pytorch=1.9.* + - pytorch-lightning + - rdkit=2021.03.* + # - dglteam::dgl-cuda11.1 # only needed for gin + - pytdc + - scikit-learn>=1.1.* + - ipykernel=6.15.* + - nb_conda_kernels + - black=22.6.* + - black-jupyter=22.6.* + - isort=5.10.* + - pip - pip: - - absl-py==0.13.0 - - aiohttp==3.7.4.post0 - - anyio==3.3.0 - - argon2-cffi==20.1.0 - - async-generator==1.10 - - async-timeout==3.0.1 - - attrs==21.2.0 - - babel==2.9.1 - - backcall==0.2.0 - - bleach==4.0.0 - - cachetools==4.2.2 - - cffi==1.14.6 - - chardet==4.0.0 - - charset-normalizer==2.0.3 - - cloudpickle==1.6.0 - - debugpy==1.4.1 - - decorator==4.4.2 - - defusedxml==0.7.1 - - dgl-cu110==0.6.1 - - dgllife==0.2.8 - - dill==0.3.4 - - entrypoints==0.3 - - fsspec==2021.7.0 - - future==0.18.2 - - fuzzywuzzy==0.18.0 - - google-auth==1.33.1 - - google-auth-oauthlib==0.4.4 - - grpcio==1.38.1 - - hyperopt==0.2.5 - - idna==3.2 - - ipdb==0.13.9 - - ipykernel==6.1.0 - - ipython==7.25.0 - - ipython-genutils==0.2.0 - - jedi==0.18.0 - - jinja2==3.0.1 - - joblib==1.0.1 - - json5==0.9.6 - - jsonschema==3.2.0 - - jupyter-client==6.1.12 - - jupyter-core==4.7.1 - - jupyter-server==1.10.2 - - jupyterlab==3.1.6 - - jupyterlab-pygments==0.1.2 - - jupyterlab-server==2.7.0 - - markdown==3.3.4 - - markupsafe==2.0.1 - - matplotlib-inline==0.1.2 - - mistune==0.8.4 - - multidict==5.1.0 - - nbclassic==0.3.1 - - nbclient==0.5.3 - - nbconvert==6.1.0 - - nbformat==5.1.3 - - nest-asyncio==1.5.1 - - networkx==2.5.1 - - notebook==6.4.3 - - oauthlib==3.1.1 - - packaging==21.0 - - pandocfilters==1.4.3 - - parso==0.8.2 - - pexpect==4.8.0 - - pickleshare==0.7.5 - - prometheus-client==0.11.0 - - prompt-toolkit==3.0.19 - - protobuf==3.17.3 - - ptyprocess==0.7.0 - - pyasn1==0.4.8 - - pyasn1-modules==0.2.8 - - pycparser==2.20 - - pydeprecate==0.3.0 - - pygments==2.9.0 - - pyrsistent==0.18.0 - - pytdc==0.2.0 - - pytorch-lightning==1.3.8 - - pyyaml==5.4.1 - - pyzmq==22.2.1 - - requests==2.26.0 - - requests-oauthlib==1.3.0 - - requests-unixsocket==0.2.0 - - rsa==4.7.2 - - scikit-learn==0.24.2 - - scipy==1.7.0 - - seaborn==0.11.1 - - send2trash==1.8.0 - - shutup==0.1.1 - - sniffio==1.2.0 - - tensorboard==2.4.1 - - tensorboard-plugin-wit==1.8.0 - - terminado==0.11.0 - - testpath==0.5.0 - - threadpoolctl==2.2.0 - - toml==0.10.2 - - torchmetrics==0.4.1 - - tqdm==4.61.2 - - traitlets==5.0.5 - - urllib3==1.26.6 - - wcwidth==0.2.5 - - webencodings==0.5.1 - - websocket-client==1.2.0 - - werkzeug==2.0.1 - - yarl==1.6.3 + - setuptools==59.5.0 # https://github.com/pytorch/pytorch/issues/69894 +# - dgllife # only needed for gin, will force scikit-learn < 1.0 + - pathos + - rich + - pyyaml + - fcd_torch # for evaluators in pytdc diff --git a/scripts/00-extract-smiles-from-sdf.py b/scripts/00-extract-smiles-from-sdf.py new file mode 100644 index 00000000..a68ae2fc --- /dev/null +++ b/scripts/00-extract-smiles-from-sdf.py @@ -0,0 +1,44 @@ +"""Extract chemicals as SMILES from a downloaded `*.sdf*` file. +""" +import json +import logging +from pathlib import Path + +from synnet.utils.prep_utils import Sdf2SmilesExtractor + +logger = logging.getLogger(__name__) + + +def main(): + if not input_file.exists(): + raise FileNotFoundError(input_file) + logger.info(f"Start parsing {input_file}") + Sdf2SmilesExtractor().from_sdf(input_file).to_file(outfile) + logger.info(f"Parsed file. Output written to {outfile}.") + + +def get_args(): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--input-file", type=str, help="An *.sdf file") + parser.add_argument( + "--output-file", + type=str, + help="File with SMILES strings (First row `SMILES`, then one per line).", + ) + return parser.parse_args() + + +if __name__ == "__main__": + logger.info("Start.") + + # Parse input args + args = get_args() + logger.info(f"Arguments: {json.dumps(vars(args),indent=2)}") + + input_file = Path(args.input_file) + outfile = Path(args.output_file) + main() + + logger.info(f"Complete.") diff --git a/scripts/01-filter-building-blocks.py b/scripts/01-filter-building-blocks.py new file mode 100644 index 00000000..823cd411 --- /dev/null +++ b/scripts/01-filter-building-blocks.py @@ -0,0 +1,85 @@ +"""Filter out building blocks that cannot react with any template. +""" +import logging + +from rdkit import RDLogger + +from synnet.config import MAX_PROCESSES +from synnet.data_generation.preprocessing import ( + BuildingBlockFileHandler, + BuildingBlockFilter, + ReactionTemplateFileHandler, +) +from synnet.utils.data_utils import ReactionSet + +RDLogger.DisableLog("rdApp.*") +logger = logging.getLogger(__name__) +import json + + +def get_args(): + import argparse + + parser = argparse.ArgumentParser() + # File I/O + parser.add_argument( + "--building-blocks-file", + type=str, + help="File with SMILES strings (First row `SMILES`, then one per line).", + ) + parser.add_argument( + "--rxn-templates-file", + type=str, + help="Input file with reaction templates as SMARTS(No header, one per line).", + ) + parser.add_argument( + "--output-bblock-file", + type=str, + help="Output file for the filtered building-blocks.", + ) + parser.add_argument( + "--output-rxns-collection-file", + type=str, + help="Output file for the collection of reactions matched with building-blocks.", + ) + # Processing + parser.add_argument("--ncpu", type=int, default=MAX_PROCESSES, help="Number of cpus") + parser.add_argument("--verbose", default=False, action="store_true") + return parser.parse_args() + + +if __name__ == "__main__": + logger.info("Start.") + + # Parse input args + args = get_args() + logger.info(f"Arguments: {json.dumps(vars(args),indent=2)}") + + # Load assets + bblocks = BuildingBlockFileHandler().load(args.building_blocks_file) + rxn_templates = ReactionTemplateFileHandler().load(args.rxn_templates_file) + + bbf = BuildingBlockFilter( + building_blocks=bblocks, + rxn_templates=rxn_templates, + verbose=args.verbose, + processes=args.ncpu, + ) + # Time intensive task... + bbf.filter() + + # ... and save to disk + bblocks_filtered = bbf.building_blocks_filtered + BuildingBlockFileHandler().save(args.output_bblock_file, bblocks_filtered) + + # Save collection of reactions which have "available reactants" set (for convenience) + rxn_collection = ReactionSet(bbf.rxns) + rxn_collection.save(args.output_rxns_collection_file) + + logger.info(f"Total number of building blocks {len(bblocks):d}") + logger.info(f"Matched number of building blocks {len(bblocks_filtered):d}") + logger.info( + f"{len(bblocks_filtered)/len(bblocks):.2%} of building blocks applicable for the reaction templates." + ) + + logger.info("Completed.") diff --git a/scripts/02-compute-embeddings.py b/scripts/02-compute-embeddings.py new file mode 100644 index 00000000..78ee4af9 --- /dev/null +++ b/scripts/02-compute-embeddings.py @@ -0,0 +1,74 @@ +""" +Computes the molecular embeddings of the purchasable building blocks. + +The embeddings are also referred to as "output embedding". +In the embedding space, a kNN-search will identify the 1st or 2nd reactant. +""" + +import json +import logging +from functools import partial + +from synnet.config import MAX_PROCESSES +from synnet.data_generation.preprocessing import BuildingBlockFileHandler +from synnet.encoding.fingerprints import mol_fp +from synnet.MolEmbedder import MolEmbedder + +logger = logging.getLogger(__file__) + +FUNCTIONS = { + "fp_4096": partial(mol_fp, _radius=2, _nBits=4096), + "fp_2048": partial(mol_fp, _radius=2, _nBits=2048), + "fp_1024": partial(mol_fp, _radius=2, _nBits=1024), + "fp_512": partial(mol_fp, _radius=2, _nBits=512), + "fp_256": partial(mol_fp, _radius=2, _nBits=256), +} # TODO: think about refactor/merge with `MorganFingerprintEncoder` + + +def get_args(): + import argparse + + parser = argparse.ArgumentParser() + # File I/O + parser.add_argument( + "--building-blocks-file", + type=str, + help="Input file with SMILES strings (First row `SMILES`, then one per line).", + ) + parser.add_argument( + "--output-file", + type=str, + help="Output file for the computed embeddings file. (*.npy)", + ) + parser.add_argument( + "--featurization-fct", + type=str, + choices=FUNCTIONS.keys(), + help="Featurization function applied to each molecule.", + ) + # Processing + parser.add_argument("--ncpu", type=int, default=MAX_PROCESSES, help="Number of cpus") + parser.add_argument("--verbose", default=False, action="store_true") + return parser.parse_args() + + +if __name__ == "__main__": + logger.info("Start.") + + # Parse input args + args = get_args() + logger.info(f"Arguments: {json.dumps(vars(args),indent=2)}") + + # Load building blocks + bblocks = BuildingBlockFileHandler().load(args.building_blocks_file) + logger.info(f"Successfully read {args.building_blocks_file}.") + logger.info(f"Total number of building blocks: {len(bblocks)}.") + + # Compute embeddings + func = FUNCTIONS[args.featurization_fct] + molembedder = MolEmbedder(processes=args.ncpu).compute_embeddings(func, bblocks) + + # Save? + molembedder.save_precomputed(args.output_file) + + logger.info("Completed.") diff --git a/scripts/03-generate-syntrees.py b/scripts/03-generate-syntrees.py new file mode 100644 index 00000000..66aaa8e8 --- /dev/null +++ b/scripts/03-generate-syntrees.py @@ -0,0 +1,129 @@ +"""Generate synthetic trees. +""" # TODO: clean up this mess +import json +import logging +from collections import Counter +from pathlib import Path + +from rdkit import RDLogger +from tqdm import tqdm + +from synnet.config import MAX_PROCESSES +from synnet.data_generation.preprocessing import ( + BuildingBlockFileHandler, + ReactionTemplateFileHandler, +) +from synnet.data_generation.syntrees import SynTreeGenerator, wraps_syntreegenerator_generate +from synnet.utils.data_utils import SyntheticTree, SyntheticTreeSet + +logger = logging.getLogger(__name__) +from typing import Tuple, Union + +RDLogger.DisableLog("rdApp.*") + + +def get_args(): + import argparse + + parser = argparse.ArgumentParser() + # File I/O + parser.add_argument( + "--building-blocks-file", + type=str, + default="data/pre-process/building-blocks/enamine-us-smiles.csv.gz", # TODO: change + help="Input file with SMILES strings (First row `SMILES`, then one per line).", + ) + parser.add_argument( + "--rxn-templates-file", + type=str, + default="data/assets/reaction-templates/hb.txt", # TODO: change + help="Input file with reaction templates as SMARTS(No header, one per line).", + ) + parser.add_argument( + "--output-file", + type=str, + default="data/pre-precess/synthetic-trees.json.gz", + help="Output file for the generated synthetic trees (*.json.gz)", + ) + # Parameters + parser.add_argument( + "--number-syntrees", type=int, default=100, help="Number of SynTrees to generate." + ) + + # Processing + parser.add_argument("--ncpu", type=int, default=MAX_PROCESSES, help="Number of cpus") + parser.add_argument("--verbose", default=False, action="store_true") + return parser.parse_args() + + +def generate_mp() -> Tuple[dict[int, str], list[Union[SyntheticTree, None]]]: + from functools import partial + + import numpy as np + from pathos import multiprocessing as mp + + def wrapper(stgen, _): + stgen.rng = np.random.default_rng() # TODO: Think about this... + return wraps_syntreegenerator_generate(stgen) + + func = partial(wrapper, stgen) + + with mp.Pool(processes=args.ncpu) as pool: + results = pool.map(func, range(args.number_syntrees)) + + outcomes = { + i: e.__class__.__name__ if e is not None else "success" for i, (_, e) in enumerate(results) + } + syntrees = [st for (st, e) in results if e is None] + return outcomes, syntrees + + +def generate() -> Tuple[dict[int, str], list[Union[SyntheticTree, None]]]: + outcomes: dict[int, str] = dict() + syntrees: list[Union[SyntheticTree, None]] = [] + myrange = tqdm(range(args.number_syntrees)) if args.verbose else range(args.number_syntrees) + for i in myrange: + st, e = wraps_syntreegenerator_generate(stgen) + outcomes[i] = e.__class__.__name__ if e is not None else "success" + syntrees.append(st) + + return outcomes, syntrees + + +if __name__ == "__main__": + logger.info("Start.") + + # Parse input args + args = get_args() + logger.info(f"Arguments: {json.dumps(vars(args),indent=2)}") + + # Load assets + bblocks = BuildingBlockFileHandler().load(args.building_blocks_file) + rxn_templates = ReactionTemplateFileHandler().load(args.rxn_templates_file) + logger.info("Loaded building block & rxn-template assets.") + + # Init SynTree Generator + logger.info("Start initializing SynTreeGenerator...") + stgen = SynTreeGenerator( + building_blocks=bblocks, rxn_templates=rxn_templates, verbose=args.verbose + ) + logger.info("Successfully initialized SynTreeGenerator.") + + # Generate synthetic trees + logger.info(f"Start generation of {args.number_syntrees} SynTrees...") + if args.ncpu > 1: + outcomes, syntrees = generate_mp() + else: + outcomes, syntrees = generate() + result_summary = Counter(outcomes.values()) + logger.info(f"SynTree generation completed. Results: {result_summary}") + + summary_file = Path(args.output_file).parent / "results-summary.txt" + summary_file.parent.mkdir(parents=True, exist_ok=True) + summary_file.write_text(json.dumps(result_summary, indent=2)) + + # Save synthetic trees on disk + syntree_collection = SyntheticTreeSet(syntrees) + syntree_collection.save(args.output_file) + + logger.info(f"Completed.") diff --git a/scripts/04-filter-syntrees.py b/scripts/04-filter-syntrees.py new file mode 100644 index 00000000..97fb195b --- /dev/null +++ b/scripts/04-filter-syntrees.py @@ -0,0 +1,120 @@ +"""Filter Synthetic Trees. +""" + +import json +import logging + +import numpy as np +from rdkit import Chem, RDLogger +from tqdm import tqdm + +from synnet.config import MAX_PROCESSES +from synnet.utils.data_utils import SyntheticTree, SyntheticTreeSet + +logger = logging.getLogger(__name__) + +RDLogger.DisableLog("rdApp.*") + + +class Filter: + def filter(self, st: SyntheticTree, **kwargs) -> bool: + ... + + +class ValidRootMolFilter(Filter): + def filter(self, st: SyntheticTree, **kwargs) -> bool: + return Chem.MolFromSmiles(st.root.smiles) is not None + + +class OracleFilter(Filter): + def __init__( + self, + name: str = "qed", + threshold: float = 0.5, + rng=np.random.default_rng(42), + ) -> None: + super().__init__() + from tdc import Oracle + + self.oracle_fct = Oracle(name=name) + self.threshold = threshold + self.rng = rng + + def _qed(self, st: SyntheticTree): + """Filter for molecules with a high qed.""" + return self.oracle_fct(st.root.smiles) > self.threshold + + def _random(self, st: SyntheticTree): + """Filter molecules that fail the `_qed` filter; i.e. randomly select low qed molecules.""" + return self.rng.random() < (self.oracle_fct(st.root.smiles) / self.threshold) + + def filter(self, st: SyntheticTree) -> bool: + return self._qed(st) or self._random(st) + + +def get_args(): + import argparse + + parser = argparse.ArgumentParser() + # File I/O + parser.add_argument( + "--input-file", + type=str, + default="data/pre-process/synthetic-trees.json.gz", + help="Input file for the filtered generated synthetic trees (*.json.gz)", + ) + parser.add_argument( + "--output-file", + type=str, + default="data/pre-process/synthetic-trees-filtered.json.gz", + help="Output file for the filtered generated synthetic trees (*.json.gz)", + ) + + # Processing + parser.add_argument("--ncpu", type=int, default=MAX_PROCESSES, help="Number of cpus") + parser.add_argument("--verbose", default=False, action="store_true") + return parser.parse_args() + + +if __name__ == "__main__": + logger.info("Start.") + + # Parse input args + args = get_args() + logger.info(f"Arguments: {json.dumps(vars(args),indent=2)}") + + # Load previously generated synthetic trees + syntree_collection = SyntheticTreeSet().load(args.input_file) + logger.info(f"Successfully loaded '{args.input_file}' with {len(syntree_collection)} syntrees.") + + # Filter trees + # TODO: Move to src/synnet/data_generation/filters.py ? + valid_root_mol_filter = ValidRootMolFilter() + interesting_mol_filter = OracleFilter(threshold=0.5, rng=np.random.default_rng()) + + logger.info(f"Start filtering syntrees...") + syntrees = [] + syntree_collection = [s for s in syntree_collection if s is not None] + syntree_collection = tqdm(syntree_collection) if args.verbose else syntree_collection + outcomes: dict[int, str] = dict() # TODO: think about what metrics to track here + for i, st in enumerate(syntree_collection): + + # Filter 1: Is root molecule valid? + keep_tree = valid_root_mol_filter.filter(st) + if not keep_tree: + continue + + # Filter 2: Is root molecule "pharmaceutically interesting?" + keep_tree = interesting_mol_filter.filter(st) + if not keep_tree: + continue + + # We passed all filters. This tree ascended to our dataset + syntrees.append(st) + logger.info(f"Successfully filtered syntrees.") + + # Save filtered synthetic trees on disk + SyntheticTreeSet(syntrees).save(args.output_file) + logger.info(f"Successfully saved '{args.output_file}' with {len(syntrees)} syntrees.") + + logger.info(f"Completed.") diff --git a/scripts/05-split-syntrees.py b/scripts/05-split-syntrees.py new file mode 100644 index 00000000..542d0afb --- /dev/null +++ b/scripts/05-split-syntrees.py @@ -0,0 +1,74 @@ +"""Reads synthetic tree data and splits it into training, validation and testing sets. +""" +import json +import logging +from pathlib import Path + +from synnet.config import MAX_PROCESSES +from synnet.utils.data_utils import SyntheticTreeSet + +logger = logging.getLogger(__name__) + + +def get_args(): + import argparse + + parser = argparse.ArgumentParser() + # File I/O + parser.add_argument( + "--input-file", + type=str, + help="Input file for the filtered generated synthetic trees (*.json.gz)", + ) + parser.add_argument( + "--output-dir", + type=str, + help="Output directory for the splitted synthetic trees (*.json.gz)", + ) + + # Processing + parser.add_argument("--ncpu", type=int, default=MAX_PROCESSES, help="Number of cpus") + parser.add_argument("--verbose", default=False, action="store_true") + return parser.parse_args() + + +if __name__ == "__main__": + logger.info("Start.") + + # Parse input args + args = get_args() + logger.info(f"Arguments: {json.dumps(vars(args),indent=2)}") + + # Load filtered synthetic trees + logger.info(f"Reading data from {args.input_file}") + syntree_collection = SyntheticTreeSet().load(args.input_file) + syntrees = syntree_collection.sts + + num_total = len(syntrees) + logger.info(f"There are {len(syntrees)} synthetic trees.") + + # Split data + SPLIT_RATIO = [0.6, 0.2, 0.2] + + num_train = int(SPLIT_RATIO[0] * num_total) + num_valid = int(SPLIT_RATIO[1] * num_total) + num_test = num_total - num_train - num_valid + + data_train = syntrees[:num_train] + data_valid = syntrees[num_train : num_train + num_valid] + data_test = syntrees[num_train + num_valid :] + + # Save to local disk + out_dir = Path(args.output_dir) + out_dir.mkdir(parents=True, exist_ok=True) + + logger.info(f"Saving training dataset. Number of syntrees: {len(data_train)}") + SyntheticTreeSet(data_train).save(out_dir / "synthetic-trees-filtered-train.json.gz") + + logger.info(f"Saving validation dataset. Number of syntrees: {len(data_valid)}") + SyntheticTreeSet(data_valid).save(out_dir / "synthetic-trees-filtered-valid.json.gz") + + logger.info(f"Saving testing dataset. Number of syntrees: {len(data_test)}") + SyntheticTreeSet(data_test).save(out_dir / "synthetic-trees-filtered-test.json.gz") + + logger.info(f"Completed.") diff --git a/scripts/06-featurize-syntrees.py b/scripts/06-featurize-syntrees.py new file mode 100644 index 00000000..d83c88a4 --- /dev/null +++ b/scripts/06-featurize-syntrees.py @@ -0,0 +1,116 @@ +"""Splits a synthetic tree into states and steps. +""" +import json +import logging +from pathlib import Path + +from scipy import sparse +from tqdm import tqdm + +from synnet.data_generation.syntrees import ( + IdentityIntEncoder, + MorganFingerprintEncoder, + SynTreeFeaturizer, +) +from synnet.utils.data_utils import SyntheticTreeSet + +logger = logging.getLogger(__file__) + +from synnet.config import MAX_PROCESSES + + +def get_args(): + import argparse + + parser = argparse.ArgumentParser() + # File I/O + parser.add_argument( + "--input-dir", + type=str, + help="Directory with `*{train,valid,test}*.json.gz`-data of synthetic trees", + ) + parser.add_argument( + "--output-dir", + type=str, + help="Directory for the splitted synthetic trees ({train,valid,test}_{steps,states}.npz", + ) + # Processing + parser.add_argument("--ncpu", type=int, default=MAX_PROCESSES, help="Number of cpus") + parser.add_argument("--verbose", default=False, action="store_true") + return parser.parse_args() + + +def _match_dataset_filename(path: str, dataset_type: str) -> Path: + """Helper to find the exact filename for {train,valid,test} file.""" + files = list(Path(path).glob(f"*{dataset_type}*.json.gz")) + if len(files) != 1: + raise ValueError(f"Can not find unique '{dataset_type} 'file, got {files}") + return files[0] + + +def featurize_data( + syntree_featurizer: SynTreeFeaturizer, input_dir: str, output_dir: Path, verbose: bool = False +): + """Wrapper method to featurize synthetic tree data.""" + + # Load syntree data + logger.info(f"Start loading {input_dir}") + syntree_collection = SyntheticTreeSet().load(input_dir) + logger.info(f"Successfully loaded synthetic trees.") + logger.info(f" Number of trees: {len(syntree_collection.sts)}") + + # Start splitting synthetic trees in states and steps + states = [] + steps = [] + unsuccessfuls = [] + it = tqdm(syntree_collection) if verbose else syntree_collection + for i, syntree in enumerate(it): + try: + state, step = syntree_featurizer.featurize(syntree) + except Exception as e: + logger.exception(e, exc_info=e) + unsuccessfuls += [i] + continue + states.append(state) + steps.append(step) + logger.info(f"Completed featurizing syntrees.") + if len(unsuccessfuls) > 0: + logger.warning(f"Unsuccessfully attempted to featurize syntrees: {unsuccessfuls}.") + + # Finally, save. + logger.info(f"Saving to directory {output_dir}") + states = sparse.vstack(states) + steps = sparse.vstack(steps) + sparse.save_npz(output_dir / f"{dataset_type}_states.npz", states) + sparse.save_npz(output_dir / f"{dataset_type}_steps.npz", steps) + logger.info("Save successful.") + return None + + +if __name__ == "__main__": + logger.info("Start.") + + # Parse input args + args = get_args() + logger.info(f"Arguments: {json.dumps(vars(args),indent=2)}") + + stfeat = SynTreeFeaturizer( + reactant_embedder=MorganFingerprintEncoder(2, 256), + mol_embedder=MorganFingerprintEncoder(2, 4096), + rxn_embedder=IdentityIntEncoder(), + action_embedder=IdentityIntEncoder(), + ) + + # Ensure output dir exists + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=1, exist_ok=1) + + for dataset_type in "train valid test".split(): + + input_file = _match_dataset_filename(args.input_dir, dataset_type) + featurize_data(stfeat, input_file, output_dir=output_dir, verbose=args.verbose) + + # Save information + (output_dir / "summary.txt").write_text(f"{stfeat}") # TODO: Parse as proper json? + + logger.info("Completed.") diff --git a/scripts/07-split-data-for-networks.py b/scripts/07-split-data-for-networks.py new file mode 100644 index 00000000..c9220879 --- /dev/null +++ b/scripts/07-split-data-for-networks.py @@ -0,0 +1,48 @@ +"""Split the featurized data into X,y-chunks for the {act,rt1,rxn,rt2}-networks +""" +import json +import logging +from pathlib import Path + +from synnet.utils.prep_utils import split_data_into_Xy + +logger = logging.getLogger(__file__) + + +def get_args(): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--input-dir", + type=str, + help="Input directory for the featurized synthetic trees (with {train,valid,test}-data).", + ) + return parser.parse_args() + + +if __name__ == "__main__": + logger.info("Start.") + + # Parse input args + args = get_args() + logger.info(f"Arguments: {json.dumps(vars(args),indent=2)}") + + # Split datasets for each MLP + logger.info("Start splitting data.") + num_rxn = 91 # Auxiliary var for indexing TODO: Dont hardcode + out_dim = 256 # Auxiliary var for indexing TODO: Dont hardcode + input_dir = Path(args.input_dir) + output_dir = input_dir / "Xy" + for dataset_type in "train valid test".split(): + logger.info(f"Split {dataset_type}-data...") + split_data_into_Xy( + dataset_type=dataset_type, + steps_file=input_dir / f"{dataset_type}_steps.npz", + states_file=input_dir / f"{dataset_type}_states.npz", + output_dir=input_dir / "Xy", + num_rxn=num_rxn, + out_dim=out_dim, + ) + + logger.info(f"Completed.") diff --git a/scripts/20-predict-targets.py b/scripts/20-predict-targets.py new file mode 100644 index 00000000..ed88f98c --- /dev/null +++ b/scripts/20-predict-targets.py @@ -0,0 +1,202 @@ +""" +Generate synthetic trees for a set of specified query molecules. Multiprocessing. +""" # TODO: Clean up + dont hardcode file paths +import json +import logging +import multiprocessing as mp +from pathlib import Path +from typing import Tuple + +import numpy as np +import pandas as pd + +from synnet.config import DATA_PREPROCESS_DIR, DATA_RESULT_DIR, MAX_PROCESSES +from synnet.data_generation.preprocessing import BuildingBlockFileHandler +from synnet.encoding.distances import cosine_distance +from synnet.models.common import find_best_model_ckpt, load_mlp_from_ckpt +from synnet.MolEmbedder import MolEmbedder +from synnet.utils.data_utils import ReactionSet, SyntheticTree, SyntheticTreeSet +from synnet.utils.predict_utils import mol_fp, synthetic_tree_decoder_greedy_search + +logger = logging.getLogger(__name__) + + +def _fetch_data_chembl(name: str) -> list[str]: + raise NotImplementedError + df = pd.read_csv(f"{DATA_DIR}/chembl_20k.csv") + smis_query = df.smiles.to_list() + return smis_query + + +def _fetch_data_from_file(name: str) -> list[str]: + with open(name, "rt") as f: + smis_query = [line.strip() for line in f] + return smis_query + + +def _fetch_data(name: str) -> list[str]: + if args.data in ["train", "valid", "test"]: + file = ( + Path(DATA_PREPROCESS_DIR) / "syntrees" / f"synthetic-trees-filtered-{args.data}.json.gz" + ) + logger.info(f"Reading data from {file}") + syntree_collection = SyntheticTreeSet().load(file) + smiles = [syntree.root.smiles for syntree in syntree_collection] + elif args.data in ["chembl"]: + smiles = _fetch_data_chembl(name) + else: # Hopefully got a filename instead + smiles = _fetch_data_from_file(name) + return smiles + + +def wrapper_decoder(smiles: str) -> Tuple[str, float, SyntheticTree]: + """Generate a synthetic tree for the input molecular embedding.""" + emb = mol_fp(smiles) + try: + smi, similarity, tree, action = synthetic_tree_decoder_greedy_search( + z_target=emb, + building_blocks=bblocks, + bb_dict=bblocks_dict, + reaction_templates=rxns, + mol_embedder=bblocks_molembedder.kdtree, # TODO: fix this, currently misused + action_net=act_net, + reactant1_net=rt1_net, + rxn_net=rxn_net, + reactant2_net=rt2_net, + bb_emb=bb_emb, + rxn_template="hb", # TODO: Do not hard code + n_bits=4096, # TODO: Do not hard code + beam_width=3, + max_step=15, + ) + except Exception as e: + logger.error(e, exc_info=e) + action = -1 + + if action != 3: # aka tree has not been properly ended + smi = None + similarity = 0 + tree = None + + return smi, similarity, tree + + +def get_args(): + import argparse + + parser = argparse.ArgumentParser() + # File I/O + parser.add_argument( + "--building-blocks-file", + type=str, + help="Input file with SMILES strings (First row `SMILES`, then one per line).", + ) + parser.add_argument( + "--rxns-collection-file", + type=str, + help="Input file for the collection of reactions matched with building-blocks.", + ) + parser.add_argument( + "--embeddings-knn-file", + type=str, + help="Input file for the pre-computed embeddings (*.npy).", + ) + parser.add_argument( + "--ckpt-dir", type=str, help="Directory with checkpoints for {act,rt1,rxn,rt2}-model." + ) + parser.add_argument( + "--output-dir", type=str, default=DATA_RESULT_DIR, help="Directory to save output." + ) + # Parameters + parser.add_argument("--num", type=int, default=-1, help="Number of molecules to predict.") + parser.add_argument( + "--data", + type=str, + default="test", + help="Choose from ['train', 'valid', 'test', 'chembl'] or provide a file with one SMILES per line.", + ) + # Processing + parser.add_argument("--ncpu", type=int, default=MAX_PROCESSES, help="Number of cpus") + parser.add_argument("--verbose", default=False, action="store_true") + return parser.parse_args() + + +if __name__ == "__main__": + logger.info("Start.") + + # Parse input args + args = get_args() + logger.info(f"Arguments: {json.dumps(vars(args),indent=2)}") + + # Load data ... + logger.info("Start loading data...") + # ... query molecules (i.e. molecules to decode) + targets = _fetch_data(args.data) + if args.num > 0: # Select only n queries + targets = targets[: args.num] + + # ... building blocks + bblocks = BuildingBlockFileHandler().load(args.building_blocks_file) + # A dict is used as lookup table for 2nd reactant during inference: + bblocks_dict = {block: i for i, block in enumerate(bblocks)} + logger.info(f"Successfully read {args.building_blocks_file}.") + + # ... reaction templates + rxns = ReactionSet().load(args.rxns_collection_file).rxns + logger.info(f"Successfully read {args.rxns_collection_file}.") + + # ... building block embedding + bblocks_molembedder = ( + MolEmbedder().load_precomputed(args.embeddings_knn_file).init_balltree(cosine_distance) + ) + bb_emb = bblocks_molembedder.get_embeddings() + logger.info(f"Successfully read {args.embeddings_knn_file} and initialized BallTree.") + logger.info("...loading data completed.") + + # ... models + logger.info("Start loading models from checkpoints...") + path = Path(args.ckpt_dir) + ckpt_files = [find_best_model_ckpt(path / model) for model in "act rt1 rxn rt2".split()] + act_net, rt1_net, rxn_net, rt2_net = [load_mlp_from_ckpt(file) for file in ckpt_files] + logger.info("...loading models completed.") + + # Decode queries, i.e. the target molecules. + logger.info(f"Start to decode {len(targets)} target molecules.") + if args.ncpu == 1: + results = [wrapper_decoder(smi) for smi in targets] + else: + with mp.Pool(processes=args.ncpu) as pool: + logger.info(f"Starting MP with ncpu={args.ncpu}") + results = pool.map(wrapper_decoder, targets) + logger.info("Finished decoding.") + + # Print some results from the prediction + # Note: If a syntree cannot be decoded within `max_depth` steps (15), + # we will count it as unsuccessful. The similarity will be 0. + decoded = [smi for smi, _, _ in results] + similarities = [sim for _, sim, _ in results] + trees = [tree for _, _, tree in results] + + recovery_rate = (np.asfarray(similarities) == 1.0).sum() / len(similarities) + avg_similarity = np.mean(similarities) + n_successful = sum([syntree is not None for syntree in trees]) + logger.info(f"For {args.data}:") + logger.info(f" Total number of attempted reconstructions: {len(targets)}") + logger.info(f" Total number of successful reconstructions: {n_successful}") + logger.info(f" {recovery_rate=}") + logger.info(f" {avg_similarity=}") + + # Save to local dir + # 1. Dataframe with targets, decoded, smilarities + # 2. Synthetic trees of the decoded SMILES + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + logger.info(f"Saving results to {output_dir} ...") + + df = pd.DataFrame({"targets": targets, "decoded": decoded, "similarity": similarities}) + df.to_csv(f"{output_dir}/decoded_results.csv.gz", compression="gzip", index=False) + + synthetic_tree_set = SyntheticTreeSet(sts=trees) + synthetic_tree_set.save(f"{output_dir}/decoded_syntrees.json.gz") + + logger.info("Completed.") diff --git a/scripts/21-identify-similar-fps.py b/scripts/21-identify-similar-fps.py new file mode 100644 index 00000000..746fd812 --- /dev/null +++ b/scripts/21-identify-similar-fps.py @@ -0,0 +1,133 @@ +"""Computes the fingerprint similarity of molecules in {valid,test}-set to molecules in the training set. +""" # TODO: clean up, un-nest a couple of fcts +import json +import logging +import multiprocessing as mp +from functools import partial +from pathlib import Path +from typing import Tuple + +import numpy as np +import pandas as pd +from rdkit import Chem, DataStructs +from rdkit.Chem import AllChem + +from synnet.utils.data_utils import SyntheticTreeSet + +logger = logging.getLogger(__file__) + +from synnet.config import MAX_PROCESSES + + +def get_args(): + import argparse + + parser = argparse.ArgumentParser() + # File I/O + parser.add_argument( + "--input-dir", + type=str, + help="Directory with `*{train,valid,test}*.json.gz`-data of synthetic trees", + ) + parser.add_argument( + "--output-file", + type=str, + default=None, + help="File to save similarity-values for test,valid-synthetic trees. (*csv.gz)", + ) + # Processing + parser.add_argument("--ncpu", type=int, default=MAX_PROCESSES, help="Number of cpus") + parser.add_argument("--verbose", default=False, action="store_true") + return parser.parse_args() + + +def _match_dataset_filename( + path: str, dataset_type: str +) -> Path: # TODO: consolidate with code in script/05-* + """Helper to find the exact filename for {train,valid,test} file.""" + files = list(Path(path).glob(f"*{dataset_type}*.json.gz")) + if len(files) != 1: + raise ValueError(f"Can not find unique '{dataset_type} 'file, got {files}") + return files[0] + + +def find_similar_fp(fp: np.ndarray, fps_reference: np.ndarray): + """Finds most similar fingerprint in a reference set for `fp`. + Uses Tanimoto Similarity. + """ + dists = np.asarray(DataStructs.BulkTanimotoSimilarity(fp, fps_reference)) + similarity_score, idx = dists.max(), dists.argmax() + return similarity_score, idx + + +def _compute_fp_bitvector(smiles: list[str], radius: int = 2, nbits: int = 1024): + return [ + AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(smi), radius, nBits=nbits) + for smi in smiles + ] + + +def get_smiles_and_fps(dataset: str) -> Tuple[list[str], list[np.ndarray]]: + file = _match_dataset_filename(args.input_dir, dataset) + syntree_collection = SyntheticTreeSet().load(file) + smiles = [st.root.smiles for st in syntree_collection] + fps = _compute_fp_bitvector(smiles) + return smiles, fps + + +def compute_most_similar_smiles( + split: str, + fps: np.ndarray, + smiles: list[str], + /, + fps_reference: np.ndarray, + smiles_reference: list[str], +) -> pd.DataFrame: + + func = partial(find_similar_fp, fps_reference=fps_reference) + with mp.Pool(processes=args.ncpu) as pool: + results = pool.map(func, fps) + + similarities, idx = zip(*results) + most_similiar_ref_smiles = np.asarray(smiles_reference)[np.asarray(idx, dtype=int)] + # ^ Use numpy for slicing... + + df = pd.DataFrame( + { + "split": split, + "smiles": smiles, + "most_similar_smiles": most_similiar_ref_smiles, + "similarity": similarities, + } + ) + return df + + +if __name__ == "__main__": + logger.info("Start.") + + # Parse input args + args = get_args() + logger.info(f"Arguments: {json.dumps(vars(args),indent=2)}") + # Load data + smiles_train, fps_train = get_smiles_and_fps("train") + smiles_valid, fps_valid = get_smiles_and_fps("valid") + smiles_test, fps_test = get_smiles_and_fps("test") + + # Compute (mp) + logger.info("Start computing most similar smiles...") + df_valid = compute_most_similar_smiles( + "valid", fps_valid, smiles_valid, fps_reference=fps_train, smiles_reference=smiles_train + ) + df_test = compute_most_similar_smiles( + "test", fps_test, smiles_test, fps_reference=fps_train, smiles_reference=smiles_train + ) + logger.info("Computed most similar smiles for {valid,test}-set.") + + # Save + Path(args.output_file).parent.mkdir(parents=True, exist_ok=True) + df = pd.concat([df_valid, df_test], axis=0, ignore_index=True) + df.to_csv(args.output_file, index=False, compression="gzip") + logger.info(f"Successfully saved output to {args.output_file}.") + + logger.info("Completed.") diff --git a/scripts/22-compute-mrr.py b/scripts/22-compute-mrr.py new file mode 100644 index 00000000..0e589738 --- /dev/null +++ b/scripts/22-compute-mrr.py @@ -0,0 +1,104 @@ +"""Compute the mean reciprocal ranking for reactant 1 +selection using the different distance metrics in the k-NN search. +""" +import json +import logging + +import numpy as np +from tqdm import tqdm + +from synnet.config import MAX_PROCESSES +from synnet.encoding.distances import ce_distance, cosine_distance +from synnet.models.common import xy_to_dataloader +from synnet.models.mlp import load_mlp_from_ckpt +from synnet.MolEmbedder import MolEmbedder + +logger = logging.getLogger(__name__) + + +def get_args(): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--ckpt-file", type=str, help="Checkpoint to load trained reactant 1 network." + ) + parser.add_argument( + "--embeddings-file", type=str, help="Pre-computed molecular embeddings for kNN search." + ) + parser.add_argument("--X-data-file", type=str, help="Featurized X data for network.") + parser.add_argument("--y-data-file", type=str, help="Featurized y data for network.") + parser.add_argument( + "--nbits", type=int, default=4096, help="Number of Bits for Morgan fingerprint." + ) + parser.add_argument("--ncpu", type=int, default=MAX_PROCESSES, help="Number of cpus") + parser.add_argument("--batch_size", type=int, default=64, help="Batch size") + parser.add_argument("--device", type=str, default="cuda:0", help="") + parser.add_argument( + "--distance", + type=str, + default="euclidean", + choices=["euclidean", "manhattan", "chebyshev", "cross_entropy", "cosine"], + help="Distance function for `BallTree`.", + ) + parser.add_argument("--debug", default=False, action="store_true") + return parser.parse_args() + + +if __name__ == "__main__": + logger.info("Start.") + + # Parse input args + args = get_args() + logger.info(f"Arguments: {json.dumps(vars(args),indent=2)}") + + # Init BallTree for kNN-search + if args.distance == "cross_entropy": + metric = ce_distance + elif args.distance == "cosine": + metric = cosine_distance + else: + metric = args.distance + + # Recall default: Morgan fingerprint with radius=2, nbits=256 + mol_embedder = MolEmbedder().load_precomputed(args.embeddings_file) + mol_embedder.init_balltree(metric=metric) + n, d = mol_embedder.embeddings.shape + + # Load data + dataloader = xy_to_dataloader( + X_file=args.X_data_file, + y_file=args.y_data_file, + n=None if not args.debug else 128, + batch_size=args.batch_size, + num_workers=args.ncpu, + shuffle=False, + ) + + # Load MLP + rt1_net = load_mlp_from_ckpt(args.ckpt_file) + rt1_net.to(args.device) + + ranks = [] + for X, y in tqdm(dataloader): + X, y = X.to(args.device), y.to(args.device) + y_hat = rt1_net(X) # (batch_size,nbits) + + ind_true = mol_embedder.kdtree.query(y.detach().cpu().numpy(), k=1, return_distance=False) + ind = mol_embedder.kdtree.query(y_hat.detach().cpu().numpy(), k=n, return_distance=False) + + irows, icols = np.nonzero(ind == ind_true) # irows = range(batch_size), icols = ranks + ranks.append(icols) + + ranks = np.asarray(ranks, dtype=int).flatten() # (nSamples,) + rrs = 1 / (ranks + 1) # +1 for offset 0-based indexing + + # np.save("ranks_" + metric + ".npy", ranks) # TODO: do not hard code + + print(f"Result using metric: {metric}") + print(f"The mean reciprocal ranking is: {rrs.mean():.3f}") + TOP_N_RANKS = (1, 3, 5, 10, 15, 30) + for i in TOP_N_RANKS: + n_recovered = sum(ranks < i) + n = len(ranks) + print(f"The Top-{i:<2d} recovery rate is: {n_recovered/n:.3f} ({n_recovered}/{n})") diff --git a/scripts/23-evaluate-predictions.py b/scripts/23-evaluate-predictions.py new file mode 100644 index 00000000..415f00b6 --- /dev/null +++ b/scripts/23-evaluate-predictions.py @@ -0,0 +1,93 @@ +"""Evaluate a batch of predictions on different metrics. +The predictions are generated in `20-predict-targets.py`. +""" +import json +import logging + +import numpy as np +import pandas as pd +from tdc import Evaluator + +from synnet.config import MAX_PROCESSES + +logger = logging.getLogger(__name__) + + +def get_args(): + import argparse + + parser = argparse.ArgumentParser() + + parser.add_argument( + "--input-file", + type=str, + help="Dataframe with target- and prediction smiles and similarities (*.csv.gz).", + ) + # Processing + parser.add_argument("--ncpu", type=int, default=MAX_PROCESSES, help="Number of cpus") + parser.add_argument("--verbose", default=False, action="store_true") + return parser.parse_args() + + +if __name__ == "__main__": + logger.info("Start.") + + # Parse input args + args = get_args() + logger.info(f"Arguments: {json.dumps(vars(args),indent=2)}") + + # Keep track of successfully and unsuccessfully recovered molecules in 2 df's + # NOTE: column names must match input dataframe... + recovered = pd.DataFrame({"targets": [], "decoded": [], "similarity": []}) + unrecovered = pd.DataFrame({"targets": [], "decoded": [], "similarity": []}) + + # load each file containing the predictions + similarity = [] + n_recovered = 0 + n_unrecovered = 0 + n_total = 0 + files = [args.input_file] # TODO: not sure why the loop but let's keep it for now + for file in files: + print(f"Evaluating file: {file}") + + result_df = pd.read_csv(file) + n_total += len(result_df["decoded"]) + + # Split smiles, discard NaNs + is_recovered = result_df["similarity"] == 1.0 + unrecovered = pd.concat([unrecovered, result_df[~is_recovered].dropna()]) + recovered = pd.concat([recovered, result_df[is_recovered].dropna()]) + + n_recovered += len(recovered) + n_unrecovered += len(unrecovered) + similarity += unrecovered["similarity"].tolist() + + # Print general info + print(f"N total {n_total}") + print(f"N recovered {n_recovered} ({n_recovered/n_total:.2f})") + print(f"N unrecovered {n_unrecovered} ({n_recovered/n_total:.2f})") + + n_finished = n_recovered + n_unrecovered + n_unfinished = n_total - n_finished + print(f"N finished tree {n_finished} ({n_finished/n_total:.2f})") + print(f"N unfinished trees (NaN) {n_unfinished} ({n_unfinished/n_total:.2f})") + print(f"Average similarity (unrecovered only) {np.mean(similarity)}") + + # Evaluate on TDC evaluators + for metric in "KL_divergence FCD_Distance Novelty Validity Uniqueness".split(): + evaluator = Evaluator(name=metric) + try: + score_recovered = evaluator(recovered["targets"], recovered["decoded"]) + score_unrecovered = evaluator(unrecovered["targets"], unrecovered["decoded"]) + except TypeError: + # Some evaluators only take 1 input args, try that. + score_recovered = evaluator(recovered["decoded"]) + score_unrecovered = evaluator(unrecovered["decoded"]) + except Exception as e: + logger.error(f"{e.__class__.__name__}: {str(e)}") + logger.error(e) + score_recovered, score_unrecovered = np.nan, np.nan + + print(f"Evaluation metric for {evaluator.name}:") + print(f" Recovered score: {score_recovered:.2f}") + print(f" Unrecovered score: {score_unrecovered:.2f}") diff --git a/scripts/_mp_decode.py b/scripts/_mp_decode.py deleted file mode 100644 index 2bb1029f..00000000 --- a/scripts/_mp_decode.py +++ /dev/null @@ -1,97 +0,0 @@ -""" -This file contains a function to decode a single synthetic tree. -""" -import pandas as pd -import numpy as np -from dgllife.model import load_pretrained -from syn_net.utils.data_utils import ReactionSet -from syn_net.utils.predict_utils import synthetic_tree_decoder, tanimoto_similarity, load_modules_from_checkpoint - - -# define some constants (here, for the Hartenfeller-Button test set) -nbits = 4096 -out_dim = 256 -rxn_template = 'hb' -featurize = 'fp' -param_dir = 'hb_fp_2_4096_256' -ncpu = 16 - -# define model to use for molecular embedding -model_type = 'gin_supervised_contextpred' -device = 'cpu' -mol_embedder = load_pretrained(model_type).to(device) -mol_embedder.eval() - -# load the purchasable building block embeddings -bb_emb = np.load('/pool001/whgao/data/synth_net/st_hb/enamine_us_emb_fp_256.npy') - -# define path to the reaction templates and purchasable building blocks -path_to_reaction_file = f'/pool001/whgao/data/synth_net/st_{rxn_template}/reactions_{rxn_template}.json.gz' -path_to_building_blocks = f'/pool001/whgao/data/synth_net/st_{rxn_template}/enamine_us_matched.csv.gz' - -# define paths to pretrained modules -param_path = f'/home/whgao/synth_net/synth_net/params/{param_dir}/' -path_to_act = f'{param_path}act.ckpt' -path_to_rt1 = f'{param_path}rt1.ckpt' -path_to_rxn = f'{param_path}rxn.ckpt' -path_to_rt2 = f'{param_path}rt2.ckpt' - -# load the purchasable building block SMILES to a dictionary -building_blocks = pd.read_csv(path_to_building_blocks, compression='gzip')['SMILES'].tolist() -bb_dict = {building_blocks[i]: i for i in range(len(building_blocks))} - -# load the reaction templates as a ReactionSet object -rxn_set = ReactionSet() -rxn_set.load(path_to_reaction_file) -rxns = rxn_set.rxns - -# load the pre-trained modules -act_net, rt1_net, rxn_net, rt2_net = load_modules_from_checkpoint( - path_to_act=path_to_act, - path_to_rt1=path_to_rt1, - path_to_rxn=path_to_rxn, - path_to_rt2=path_to_rt2, - featurize=featurize, - rxn_template=rxn_template, - out_dim=out_dim, - nbits=nbits, - ncpu=ncpu, -) - -def func(emb): - """ - Generates the synthetic tree for the input molecular embedding. - - Args: - emb (np.ndarray): Molecular embedding to decode. - - Returns: - str: SMILES for the final chemical node in the tree. - SyntheticTree: The generated synthetic tree. - """ - emb = emb.reshape((1, -1)) - try: - tree, action = synthetic_tree_decoder(z_target=emb, - building_blocks=building_blocks, - bb_dict=bb_dict, - reaction_templates=rxns, - mol_embedder=mol_embedder, - action_net=act_net, - reactant1_net=rt1_net, - rxn_net=rxn_net, - reactant2_net=rt2_net, - bb_emb=bb_emb, - rxn_template=rxn_template, - n_bits=nbits, - max_step=15) - except Exception as e: - print(e) - action = -1 - if action != 3: - return None, None - else: - scores = np.array( - tanimoto_similarity(emb, [node.smiles for node in tree.chemicals]) - ) - max_score_idx = np.where(scores == np.max(scores))[0][0] - return tree.chemicals[max_score_idx].smiles, tree diff --git a/scripts/_mp_predict.py b/scripts/_mp_predict.py deleted file mode 100644 index 3ef1a7f1..00000000 --- a/scripts/_mp_predict.py +++ /dev/null @@ -1,92 +0,0 @@ -""" -This file contains a function to predict a single synthetic tree given a molecular SMILES. -""" -import pandas as pd -import numpy as np -from syn_net.utils.data_utils import ReactionSet -from dgllife.model import load_pretrained -from syn_net.utils.predict_utils import synthetic_tree_decoder, tanimoto_similarity, load_modules_from_checkpoint, mol_fp - -# define some constants (here, for the Hartenfeller-Button test set) -nbits = 4096 -out_dim = 256 -rxn_template = 'hb' -featurize = 'fp' -param_dir = 'hb_fp_2_4096_256' -ncpu = 32 - -# define model to use for molecular embedding -model_type = 'gin_supervised_contextpred' -device = 'cpu' -mol_embedder = load_pretrained(model_type).to(device) -mol_embedder.eval() - -# load the purchasable building block embeddings -bb_emb = np.load('/pool001/whgao/data/synth_net/st_hb/enamine_us_emb_fp_256.npy') - -# define path to the reaction templates and purchasable building blocks -path_to_reaction_file = '/pool001/whgao/data/synth_net/st_' + rxn_template + '/reactions_' + rxn_template + '.json.gz' -path_to_building_blocks = '/pool001/whgao/data/synth_net/st_' + rxn_template + '/enamine_us_matched.csv.gz' - -# define paths to pretrained modules -param_path = '/home/whgao/scGen/synth_net/synth_net/params/' + param_dir + '/' -path_to_act = param_path + 'act.ckpt' -path_to_rt1 = param_path + 'rt1.ckpt' -path_to_rxn = param_path + 'rxn.ckpt' -path_to_rt2 = param_path + 'rt2.ckpt' - -# load the purchasable building block SMILES to a dictionary -building_blocks = pd.read_csv(path_to_building_blocks, compression='gzip')['SMILES'].tolist() -bb_dict = {building_blocks[i]: i for i in range(len(building_blocks))} - -# load the reaction templates as a ReactionSet object -rxn_set = ReactionSet() -rxn_set.load(path_to_reaction_file) -rxns = rxn_set.rxns - -# load the pre-trained modules -act_net, rt1_net, rxn_net, rt2_net = load_modules_from_checkpoint( - path_to_act=path_to_act, - path_to_rt1=path_to_rt1, - path_to_rxn=path_to_rxn, - path_to_rt2=path_to_rt2, - featurize=featurize, - rxn_template=rxn_template, - out_dim=out_dim, - nbits=nbits, - ncpu=ncpu, -) - -def func(smi): - """ - Generates the synthetic tree for the input SMILES. - - Args: - smi (str): Molecular to reconstruct. - - Returns: - str: Final product SMILES. - float: Score of the best final product. - SyntheticTree: The generated synthetic tree. - """ - emb = mol_fp(smi) - try: - tree, action = synthetic_tree_decoder(emb, building_blocks, bb_dict, rxns, mol_embedder, act_net, rt1_net, rxn_net, rt2_net, bb_emb, rxn_template=rxn_template, n_bits=nbits, max_step=15) - except Exception as e: - print(e) - action = -1 - - # tree, action = synthetic_tree_decoder(emb, building_blocks, bb_dict, rxns, mol_embedder, act_net, rt1_net, rxn_net, rt2_net, max_step=15) - - # import ipdb; ipdb.set_trace(context=9) - # tree._print() - # print(action) - # print(np.max(oracle(tree.get_state()))) - # print() - - if action != 3: - return None, 0, None - else: - scores = tanimoto_similarity(emb, tree.get_state()) - max_score_idx = np.where(scores == np.max(scores))[0][0] - return tree.get_state()[max_score_idx], np.max(scores), tree diff --git a/scripts/_mp_predict_beam.py b/scripts/_mp_predict_beam.py deleted file mode 100644 index e415079f..00000000 --- a/scripts/_mp_predict_beam.py +++ /dev/null @@ -1,99 +0,0 @@ -""" -This file contains a function to decode a single synthetic tree. -""" -import pandas as pd -import numpy as np -from syn_net.utils.data_utils import ReactionSet -from dgllife.model import load_pretrained -from syn_net.utils.predict_utils import tanimoto_similarity, load_modules_from_checkpoint, mol_fp -from syn_net.utils.predict_beam_utils import synthetic_tree_decoder - - -# define some constants (here, for the Hartenfeller-Button test set) -nbits = 4096 -out_dim = 300 -rxn_template = 'hb' -featurize = 'fp' -param_dir = 'hb_fp_2_4096' -ncpu = 16 - -# define model to use for molecular embedding -model_type = 'gin_supervised_contextpred' -device = 'cpu' -mol_embedder = load_pretrained(model_type).to(device) -mol_embedder.eval() - -# load the purchasable building block embeddings -bb_emb = np.load('/pool001/whgao/data/synth_net/st_hb/enamine_us_emb_fp_256.npy') - -# define path to the reaction templates and purchasable building blocks -path_to_reaction_file = f'/pool001/whgao/data/synth_net/st_{rxn_template}/reactions_{rxn_template}.json.gz' -path_to_building_blocks = f'/pool001/whgao/data/synth_net/st_{rxn_template}/enamine_us_matched.csv.gz' - -# define paths to pretrained modules -param_path = f'/home/whgao/scGen/synth_net/synth_net/params/{param_dir}/' -path_to_act = f'{param_path}act.ckpt' -path_to_rt1 = f'{param_path}rt1.ckpt' -path_to_rxn = f'{param_path}rxn.ckpt' -path_to_rt2 = f'{param_path}rt2.ckpt' - -# load the purchasable building block SMILES to a dictionary -building_blocks = pd.read_csv(path_to_building_blocks, compression='gzip')['SMILES'].tolist() -bb_dict = {building_blocks[i]: i for i in range(len(building_blocks))} - -# load the reaction templates as a ReactionSet object -rxn_set = ReactionSet() -rxn_set.load(path_to_reaction_file) -rxns = rxn_set.rxns - -# load the pre-trained modules -act_net, rt1_net, rxn_net, rt2_net = load_modules_from_checkpoint( - path_to_act=path_to_act, - path_to_rt1=path_to_rt1, - path_to_rxn=path_to_rxn, - path_to_rt2=path_to_rt2, - featurize=featurize, - rxn_template=rxn_template, - out_dim=out_dim, - nbits=nbits, - ncpu=ncpu, -) - -def func(smi): - """ - Generates the synthetic tree for the input moleular string. - - Args: - smi (str): Molecule (SMILES) to decode. - - Returns: - np.ndarray or None: State of the generated synthetic tree. - float: The best score. - SyntheticTree: The generated synthetic tree. - """ - emb = mol_fp(smi) - try: - tree, action = synthetic_tree_decoder(z_target=emb, - building_blocks=building_blocks, - bb_dict=bb_dict, - reaction_templates=rxns, - mol_embedder=mol_embedder, - action_net=act_net, - reactant1_net=rt1_net, - rxn_net=rxn_net, - reactant2_net=rt2_net, - bb_emb=bb_emb, - beam_width=10, - rxn_template=rxn_template, - n_bits=nbits, - max_step=15) - except Exception as e: - print(e) - action = -1 - - if action != 3: - return None, 0, None - else: - scores = tanimoto_similarity(emb, tree.get_state()) - max_score_idx = np.where(scores == np.max(scores))[0][0] - return tree.get_state()[max_score_idx], np.max(scores), tree diff --git a/scripts/_mp_predict_multireactant.py b/scripts/_mp_predict_multireactant.py deleted file mode 100644 index e3ab8f13..00000000 --- a/scripts/_mp_predict_multireactant.py +++ /dev/null @@ -1,91 +0,0 @@ -""" -This file contains a function to decode a single synthetic tree. -""" -import pandas as pd -import numpy as np -from syn_net.utils.data_utils import ReactionSet -from syn_net.utils.predict_utils import synthetic_tree_decoder_multireactant, load_modules_from_checkpoint, mol_fp - - -# define some constants (here, for the Hartenfeller-Button test set) -nbits = 4096 -out_dim = 256 -rxn_template = 'hb' -featurize = 'fp' -param_dir = 'hb_fp_2_4096_256' -ncpu = 1 - -# load the purchasable building block embeddings -bb_emb = np.load('/pool001/whgao/data/synth_net/st_hb/enamine_us_emb_fp_256.npy') - -# define path to the reaction templates and purchasable building blocks -path_to_reaction_file = f'/pool001/whgao/data/synth_net/st_{rxn_template}/reactions_{rxn_template}.json.gz' -path_to_building_blocks = f'/pool001/whgao/data/synth_net/st_{rxn_template}/enamine_us_matched.csv.gz' - -# define paths to pretrained modules -param_path = f'/home/whgao/synth_net/synth_net/params/{param_dir}/' -path_to_act = f'{param_path}act.ckpt' -path_to_rt1 = f'{param_path}rt1.ckpt' -path_to_rxn = f'{param_path}rxn.ckpt' -path_to_rt2 = f'{param_path}rt2.ckpt' - -# load the purchasable building block SMILES to a dictionary -building_blocks = pd.read_csv(path_to_building_blocks, compression='gzip')['SMILES'].tolist() -bb_dict = {building_blocks[i]: i for i in range(len(building_blocks))} - -# load the reaction templates as a ReactionSet object -rxn_set = ReactionSet() -rxn_set.load(path_to_reaction_file) -rxns = rxn_set.rxns - -# load the pre-trained modules -act_net, rt1_net, rxn_net, rt2_net = load_modules_from_checkpoint( - path_to_act=path_to_act, - path_to_rt1=path_to_rt1, - path_to_rxn=path_to_rxn, - path_to_rt2=path_to_rt2, - featurize=featurize, - rxn_template=rxn_template, - out_dim=out_dim, - nbits=nbits, - ncpu=ncpu, -) - -def func(smi): - """ - Generates the synthetic tree for the input molecular embedding. - - Args: - smi (str): SMILES string corresponding to the molecule to decode. - - Returns: - smi (str): SMILES for the final chemical node in the tree. - similarity (float): Similarity measure between the final chemical node - and the input molecule. - tree (SyntheticTree): The generated synthetic tree. - """ - emb = mol_fp(smi) - try: - smi, similarity, tree, action = synthetic_tree_decoder_multireactant( - z_target=emb, - building_blocks=building_blocks, - bb_dict=bb_dict, - reaction_templates=rxns, - mol_embedder=mol_fp, - action_net=act_net, - reactant1_net=rt1_net, - rxn_net=rxn_net, - reactant2_net=rt2_net, - bb_emb=bb_emb, - rxn_template=rxn_template, - n_bits=nbits, - beam_width=3, - max_step=15) - except Exception as e: - print(e) - action = -1 - - if action != 3: - return None, 0, None - else: - return smi, similarity, tree diff --git a/scripts/_mp_search_similar.py b/scripts/_mp_search_similar.py deleted file mode 100644 index 41baaa74..00000000 --- a/scripts/_mp_search_similar.py +++ /dev/null @@ -1,34 +0,0 @@ -""" -This function is used to identify the most similar molecule in the training set -to a given molecular fingerprint. -""" -import numpy as np -from rdkit import Chem -from rdkit.Chem import AllChem -from rdkit import DataStructs -import pandas as pd -from syn_net.utils.data_utils import * - - -data_path = '/pool001/whgao/data/synth_net/st_hb/st_train.json.gz' -st_set = SyntheticTreeSet() -st_set.load(data_path) -data = st_set.sts -data_train = [t.root.smiles for t in data] -fps_train = [AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(smi), 2, nBits=1024) for smi in data_train] - - -def func(fp): - """ - Finds the most similar molecule in the training set to the input molecule - using the Tanimoto similarity. - - Args: - fp (np.ndarray): Morgan fingerprint to find similars to in the training set. - - Returns: - np.float: The maximum similarity found to the training set fingerprints. - np.ndarray: Fingerprint of the most similar training set molecule. - """ - dists = np.array([DataStructs.FingerprintSimilarity(fp, fp_, metric=DataStructs.TanimotoSimilarity) for fp_ in fps_train]) - return dists.max(), dists.argmax() diff --git a/scripts/_mp_sum.py b/scripts/_mp_sum.py deleted file mode 100644 index 9cb0591f..00000000 --- a/scripts/_mp_sum.py +++ /dev/null @@ -1,8 +0,0 @@ -""" -Computes the sum of a single molecular embedding. -""" -import numpy as np - - -def func(emb): - return np.sum(emb) diff --git a/scripts/compute_embedding.py b/scripts/compute_embedding.py deleted file mode 100644 index 3def3ffb..00000000 --- a/scripts/compute_embedding.py +++ /dev/null @@ -1,60 +0,0 @@ -""" -This file contains functions for generating molecular embeddings from SMILES using GIN. -""" -import pandas as pd -import numpy as np -from tqdm import tqdm -from syn_net.utils.predict_utils import mol_embedding, fp_embedding, rdkit2d_embedding - - -def get_mol_embedding_func(feature): - """ - Returns the molecular embedding function. - - Args: - feature (str): Indicates the type of featurization to use (GIN or Morgan - fingerprint), and the size. - - Returns: - Callable: The embedding function. - """ - if feature == 'gin': - embedding_func = lambda smi: mol_embedding(smi, device='cpu') - elif feature == 'fp_4096': - embedding_func = lambda smi: fp_embedding(smi, _nBits=4096) - elif feature == 'fp_2048': - embedding_func = lambda smi: fp_embedding(smi, _nBits=2048) - elif feature == 'fp_1024': - embedding_func = lambda smi: fp_embedding(smi, _nBits=1024) - elif feature == 'fp_512': - embedding_func = lambda smi: fp_embedding(smi, _nBits=512) - elif feature == 'fp_256': - embedding_func = lambda smi: fp_embedding(smi, _nBits=256) - elif feature == 'rdkit2d': - embedding_func = rdkit2d_embedding - return embedding_func - -if __name__ == '__main__': - - import argparse - parser = argparse.ArgumentParser() - parser.add_argument("--feature", type=str, default="gin", - help="Objective function to optimize") - parser.add_argument("--ncpu", type=int, default=16, - help="Number of cpus") - args = parser.parse_args() - - path = '/pool001/whgao/data/synth_net/st_hb/' - ## path = './tests/data/' ## for debugging - data = pd.read_csv(path + 'enamine_us_matched.csv.gz', compression='gzip')['SMILES'].tolist() - ## data = pd.read_csv(path + 'building_blocks_matched.csv.gz', compression='gzip')['SMILES'].tolist() ## for debugging - print('Total data: ', len(data)) - - embeddings = [] - for smi in tqdm(data): - embeddings.append(mol_embedding(smi)) - - embedding = np.array(embeddings) - np.save(path + 'enamine_us_emb_' + args.feature + '.npy', embeddings) - - print('Finish!') diff --git a/scripts/compute_embedding_mp.py b/scripts/compute_embedding_mp.py deleted file mode 100644 index 421503ec..00000000 --- a/scripts/compute_embedding_mp.py +++ /dev/null @@ -1,57 +0,0 @@ -""" -Computes the molecular embeddings of the purchasable building blocks. -""" -import multiprocessing as mp -from scripts.compute_embedding import * -from rdkit import RDLogger -from syn_net.utils.predict_utils import mol_embedding, fp_4096, fp_2048, fp_1024, fp_512, fp_256, rdkit2d_embedding -RDLogger.DisableLog('*') - - -if __name__ == '__main__': - - import argparse - parser = argparse.ArgumentParser() - parser.add_argument("--feature", type=str, default="gin", - help="Objective function to optimize") - parser.add_argument("--ncpu", type=int, default=16, - help="Number of cpus") - args = parser.parse_args() - - # define the path to which data will be saved - path = '/pool001/whgao/data/synth_net/st_hb/' - ## path = './tests/data/' ## for debugging - - # load the building blocks - data = pd.read_csv(path + 'enamine_us_matched.csv.gz', compression='gzip')['SMILES'].tolist() - ## data = pd.read_csv(path + 'building_blocks_matched.csv.gz', compression='gzip')['SMILES'].tolist() ## for debugging - print('Total data: ', len(data)) - - if args.feature == 'gin': - with mp.Pool(processes=args.ncpu) as pool: - embeddings = pool.map(mol_embedding, data) - elif args.feature == 'fp_4096': - with mp.Pool(processes=args.ncpu) as pool: - embeddings = pool.map(fp_4096, data) - elif args.feature == 'fp_2048': - with mp.Pool(processes=args.ncpu) as pool: - embeddings = pool.map(fp_2048, data) - elif args.feature == 'fp_1024': - with mp.Pool(processes=args.ncpu) as pool: - embeddings = pool.map(fp_1024, data) - elif args.feature == 'fp_512': - with mp.Pool(processes=args.ncpu) as pool: - embeddings = pool.map(fp_512, data) - elif args.feature == 'fp_256': - with mp.Pool(processes=args.ncpu) as pool: - embeddings = pool.map(fp_256, data) - elif args.feature == 'rdkit2d': - with mp.Pool(processes=args.ncpu) as pool: - embeddings = pool.map(rdkit2d_embedding, data) - - embedding = np.array(embeddings) - - # import ipdb; ipdb.set_trace(context=9) - np.save(path + 'enamine_us_emb_' + args.feature + '.npy', embeddings) - - print('Finish!') diff --git a/scripts/evaluate_batch.py b/scripts/evaluate_batch.py deleted file mode 100644 index 85283da2..00000000 --- a/scripts/evaluate_batch.py +++ /dev/null @@ -1,28 +0,0 @@ -""" -This function evaluates a batch of predictions by computing the (1) novelty, (2) validity, -(3) uniqueness, (4) Fréchet ChemNet distance, and (5) KL divergence for the final -root molecules which correspond to *unrecovered* molecules in all the generated trees. -""" -from tdc import Evaluator -import pandas as pd - -kl_divergence = Evaluator(name = 'KL_Divergence') -fcd_distance = Evaluator(name = 'FCD_Distance') -novelty = Evaluator(name = 'Novelty') -validity = Evaluator(name = 'Validity') -uniqueness = Evaluator(name = 'Uniqueness') - -if __name__ == '__main__': - # load the final root molecules generated by a prediction run using a pre-trained model - result_train = pd.read_csv('../results/decode_result_test_processed_property.csv.gz', compression='gzip') - - # get the unrecovered molecules only - # result_test_unrecover = result_train[result_train['recovered sa'] != -1][result_train['similarity'] != 1.0] - result_test_unrecover = result_train[result_train['recovered sa'] != -1] - - # compute the following properties, using the TDC - print(f"Novelty: {novelty(result_test_unrecover['query SMILES'].tolist(), result_test_unrecover['decode SMILES'].tolist())}") - print(f"Validity: {validity(result_test_unrecover['decode SMILES'].tolist())}") - print(f"Uniqueness: {uniqueness(result_test_unrecover['decode SMILES'].tolist())}") - print(f"FCD: {fcd_distance(result_test_unrecover['query SMILES'].tolist(), result_test_unrecover['decode SMILES'].tolist())}") - print(f"KL: {kl_divergence(result_test_unrecover['query SMILES'].tolist()[:10000], result_test_unrecover['decode SMILES'].tolist()[:10000])}") diff --git a/scripts/evaluate_batch_recovery.py b/scripts/evaluate_batch_recovery.py deleted file mode 100644 index 1eff5127..00000000 --- a/scripts/evaluate_batch_recovery.py +++ /dev/null @@ -1,92 +0,0 @@ -""" -This function evaluates a batch of predictions by computing the (1) novelty, (2) validity, -(3) uniqueness, (4) Fréchet ChemNet distance, and (5) KL divergence for the final -root molecules which correspond to *unrecovered* molecules in all the generated trees. -""" -from tdc import Evaluator -import pandas as pd -import glob -import numpy as np - -kl_divergence = Evaluator(name = 'KL_Divergence') -fcd_distance = Evaluator(name = 'FCD_Distance') -novelty = Evaluator(name = 'Novelty') -validity = Evaluator(name = 'Validity') -uniqueness = Evaluator(name = 'Uniqueness') - -if __name__ == '__main__': - # load the final root molecules generated by a prediction run using a - # pre-trained model, which were all saved to different files - generated_st_files = glob.glob('../../results_mp/pis_fp/decode_result_train*.csv.gz') - - # lists in which to collect all the results - recovered_molecules = pd.DataFrame({'query SMILES': [], 'decode SMILES': [], 'similarity':[]}) # molecules successfully recovered from query - unrecovered_molecules = pd.DataFrame({'query SMILES': [], 'decode SMILES': [], 'similarity':[]}) # unsuccessfully recovered - recovered_novelty_all = [] - recovered_validity_decode_all = [] - recovered_uniqueness_decode_all = [] - recovered_fcd_distance_all = [] - recovered_kl_divergence_all = [] - unrecovered_novelty_all = [] - unrecovered_validity_decode_all = [] - unrecovered_uniqueness_decode_all = [] - unrecovered_fcd_distance_all = [] - unrecovered_kl_divergence_all = [] - - similarity = [] - - n_recovered = 0 - n_unrecovered = 0 - n_total = 0 - - # load each file containing the predictions - for generate_st_file in generated_st_files: - - print(f'File currently being evaluated: {generate_st_file}') - - result_df = pd.read_csv(generate_st_file, compression='gzip') - n_total += len(result_df['decode SMILES']) - - # get the recovered and unrecovered molecules only (no NaNs) - unrecovered_molecules = pd.concat([unrecovered_molecules, result_df[result_df['similarity'] != 1.0].dropna()]) - recovered_molecules = pd.concat([recovered_molecules, result_df[result_df['similarity'] == 1.0].dropna()]) - - n_recovered += len(recovered_molecules['decode SMILES']) - n_unrecovered += len(unrecovered_molecules['decode SMILES']) - similarity += unrecovered_molecules['similarity'].tolist() - - # compute the following properties, using the TDC, for the succesfully recovered molecules - recovered_novelty_all = novelty(recovered_molecules['query SMILES'].tolist(), recovered_molecules['decode SMILES'].tolist()) - recovered_validity_decode_all = validity(recovered_molecules['decode SMILES'].tolist()) - recovered_uniqueness_decode_all = uniqueness(recovered_molecules['decode SMILES'].tolist()) - recovered_fcd_distance_all = fcd_distance(recovered_molecules['query SMILES'].tolist(), recovered_molecules['decode SMILES'].tolist()) - recovered_kl_divergence_all = kl_divergence(recovered_molecules['query SMILES'].tolist(), recovered_molecules['decode SMILES'].tolist()) - - # compute the following properties, using the TDC, for the unrecovered molecules - unrecovered_novelty_all = novelty(unrecovered_molecules['query SMILES'].tolist(), unrecovered_molecules['decode SMILES'].tolist()) - unrecovered_validity_decode_all = validity(unrecovered_molecules['decode SMILES'].tolist()) - unrecovered_uniqueness_decode_all = uniqueness(unrecovered_molecules['decode SMILES'].tolist()) - unrecovered_fcd_distance_all = fcd_distance(unrecovered_molecules['query SMILES'].tolist(), unrecovered_molecules['decode SMILES'].tolist()) - unrecovered_kl_divergence_all = kl_divergence(unrecovered_molecules['query SMILES'].tolist(), unrecovered_molecules['decode SMILES'].tolist()) - - print('N recovered, N unrecovered, N total (% recovered):', n_recovered, ',', n_unrecovered, ',', n_total, ', (', 100*n_recovered/n_total, '%)') - n_finished = n_recovered + n_unrecovered - n_unfinished = n_total - n_finished - print('N finished trees (%):', n_finished, '(', 100*n_finished/n_total,'%)') - print('N unfinished trees (NaN) (%):', n_unfinished, '(', 100*n_unfinished/n_total,'%)') - print('Average similarity (unrecovered only)', np.mean(similarity)) - - print('Novelty, recovered:', recovered_novelty_all) - print('Novelty, unrecovered:', unrecovered_novelty_all) - - print('Validity, decode molecules, recovered:', recovered_validity_decode_all) - print('Validity, decode molecules, unrecovered:', unrecovered_validity_decode_all) - - print('Uniqueness, decode molecules, recovered:', recovered_uniqueness_decode_all) - print('Uniqueness, decode molecules, unrecovered:', unrecovered_uniqueness_decode_all) - - print('FCD distance, recovered:', recovered_fcd_distance_all) - print('FCD distance, unrecovered:', unrecovered_fcd_distance_all) - - print('KL divergence, recovered:', recovered_kl_divergence_all) - print('KL divergence, unrecovered:', unrecovered_kl_divergence_all) diff --git a/scripts/mrr.py b/scripts/mrr.py deleted file mode 100644 index c77fd4a1..00000000 --- a/scripts/mrr.py +++ /dev/null @@ -1,126 +0,0 @@ -""" -This function is used to compute the mean reciprocal ranking for reactant 1 -selection using the different distance metrics in the k-NN search. -""" -from syn_net.models.mlp import MLP, load_array -from scipy import sparse -import numpy as np -from sklearn.neighbors import BallTree -import torch -from syn_net.utils.predict_utils import cosine_distance, ce_distance - - -if __name__ == '__main__': - - import argparse - parser = argparse.ArgumentParser() - parser.add_argument("-f", "--featurize", type=str, default='fp', - help="Choose from ['fp', 'gin']") - parser.add_argument("-r", "--rxn_template", type=str, default='hb', - help="Choose from ['hb', 'pis']") - parser.add_argument("--param_dir", type=str, default='hb_fp_2_4096_256', - help="") - parser.add_argument("--radius", type=int, default=2, - help="Radius for Morgan fingerprint.") - parser.add_argument("--nbits", type=int, default=4096, - help="Number of Bits for Morgan fingerprint.") - parser.add_argument("--out_dim", type=int, default=256, - help="Output dimension.") - parser.add_argument("--ncpu", type=int, default=8, - help="Number of cpus") - parser.add_argument("--batch_size", type=int, default=64, - help="Batch size") - parser.add_argument("--device", type=str, default="cuda:0", - help="") - parser.add_argument("--distance", type=str, default="euclidean", - help="Choose from ['euclidean', 'manhattan', 'chebyshev', 'cross_entropy', 'cosine']") - args = parser.parse_args() - - if args.out_dim == 300: - validation_option = 'nn_accuracy_gin' - elif args.out_dim == 4096: - validation_option = 'nn_accuracy_fp_4096' - elif args.out_dim == 256: - validation_option = 'nn_accuracy_fp_256' - elif args.out_dim == 200: - validation_option = 'nn_accuracy_rdkit2d' - else: - raise ValueError - - main_dir = '/pool001/whgao/data/synth_net/' + args.rxn_template + '_' + args.featurize + '_' + str(args.radius) + '_' + str(args.nbits) + '_' + validation_option[12:] + '/' - path_to_rt1 = '/home/whgao/scGen/synth_net/synth_net/params/' + args.param_dir + '/' + 'rt1.ckpt' - batch_size = args.batch_size - ncpu = args.ncpu - - # X = sparse.load_npz(main_dir + 'X_rt1_train.npz') - # y = sparse.load_npz(main_dir + 'y_rt1_train.npz') - # X = torch.Tensor(X.A) - # y = torch.Tensor(y.A) - # _idx = np.random.choice(list(range(X.shape[0])), size=int(X.shape[0]/100), replace=False) - # train_data_iter = load_array((X[_idx], y[_idx]), batch_size, ncpu=ncpu, is_train=False) - - # X = sparse.load_npz(main_dir + 'X_rt1_valid.npz') - # y = sparse.load_npz(main_dir + 'y_rt1_valid.npz') - # X = torch.Tensor(X.A) - # y = torch.Tensor(y.A) - # _idx = np.random.choice(list(range(X.shape[0])), size=int(X.shape[0]/10), replace=False) - # valid_data_iter = load_array((X[_idx], y[_idx]), batch_size, ncpu=ncpu, is_train=False) - - X = sparse.load_npz(main_dir + 'X_rt1_test.npz') - y = sparse.load_npz(main_dir + 'y_rt1_test.npz') - X = torch.Tensor(X.A) - y = torch.Tensor(y.A) - _idx = np.random.choice(list(range(X.shape[0])), size=int(X.shape[0]/10), replace=False) - test_data_iter = load_array((X[_idx], y[_idx]), batch_size, ncpu=ncpu, is_train=False) - data_iter = test_data_iter - - rt1_net = MLP.load_from_checkpoint(path_to_rt1, - input_dim=int(3 * args.nbits), - output_dim=args.out_dim, - hidden_dim=1200, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='regression', - loss='mse', - valid_loss='mse', - optimizer='adam', - learning_rate=1e-4, - ncpu=ncpu) - rt1_net.eval() - rt1_net.to(args.device) - - bb_emb_fp_256 = np.load('/pool001/whgao/data/synth_net/st_hb/enamine_us_emb_fp_256.npy') - - # for kw_metric_ in ['euclidean', 'manhattan', 'chebyshev', 'cross_entropy', 'cosine']: - kw_metric_ = args.distance - - if kw_metric_ == 'cross_entropy': - kw_metric = ce_distance - elif kw_metric_ == 'cosine': - kw_metric = cosine_distance - else: - kw_metric = kw_metric_ - - kdtree_fp_256 = BallTree(bb_emb_fp_256, metric=kw_metric) - - ranks = [] - for X, y in data_iter: - X, y = X.to(args.device), y.to(args.device) - y_hat = rt1_net(X) - dist_true, ind_true = kdtree_fp_256.query(y.detach().cpu().numpy(), k=1) - dist, ind = kdtree_fp_256.query(y_hat.detach().cpu().numpy(), k=bb_emb_fp_256.shape[0]) - ranks = ranks + [np.where(ind[i] == ind_true[i])[0][0] for i in range(len(ind_true))] - - ranks = np.array(ranks) - rrs = 1 / (ranks + 1) - np.save('ranks_' + kw_metric_ + '.npy', ranks) - print(f"Result using metric: {kw_metric_}") - print(f"The mean reciprocal ranking is: {rrs.mean():.3f}") - print(f"The Top-1 recovery rate is: {sum(ranks < 1) / len(ranks) :.3f}, {sum(ranks < 1)} / {len(ranks)}") - print(f"The Top-3 recovery rate is: {sum(ranks < 3) / len(ranks) :.3f}, {sum(ranks < 3)} / {len(ranks)}") - print(f"The Top-5 recovery rate is: {sum(ranks < 5) / len(ranks) :.3f}, {sum(ranks < 5)} / {len(ranks)}") - print(f"The Top-10 recovery rate is: {sum(ranks < 10) / len(ranks) :.3f}, {sum(ranks < 10)} / {len(ranks)}") - print(f"The Top-15 recovery rate is: {sum(ranks < 15) / len(ranks) :.3f}, {sum(ranks < 15)} / {len(ranks)}") - print(f"The Top-30 recovery rate is: {sum(ranks < 30) / len(ranks) :.3f}, {sum(ranks < 30)} / {len(ranks)}") - print() diff --git a/scripts/optimize_ga.py b/scripts/optimize_ga.py index 0aad875a..76f5271d 100644 --- a/scripts/optimize_ga.py +++ b/scripts/optimize_ga.py @@ -1,18 +1,84 @@ """ Generates synthetic trees where the root molecule optimizes for a specific objective -based on Therapeutic Data Commons (TDC) oracle functions. Uses a genetic algorithm -to optimize embeddings before decoding. -""" -from syn_net.utils.ga_utils import crossover, mutation +based on Therapeutics Data Commons (TDC) oracle functions. +Uses a genetic algorithm to optimize embeddings before decoding. +""" # TODO: Refactor/Consolidate with generic inference script +import json import multiprocessing as mp +import time +from pathlib import Path + import numpy as np import pandas as pd -import time -import json -import scripts._mp_decode as decode -from syn_net.utils.predict_utils import mol_fp from tdc import Oracle +from synnet.config import MAX_PROCESSES +from synnet.data_generation.preprocessing import BuildingBlockFileHandler +from synnet.encoding.distances import cosine_distance +from synnet.models.common import find_best_model_ckpt, load_mlp_from_ckpt +from synnet.MolEmbedder import MolEmbedder +from synnet.utils.data_utils import ReactionSet +from synnet.utils.ga_utils import crossover, mutation +from synnet.utils.predict_utils import mol_fp, synthetic_tree_decoder, tanimoto_similarity + + +def _fetch_gin_molembedder(): + from dgllife.model import load_pretrained + + # define model to use for molecular embedding + model_type = "gin_supervised_contextpred" + device = "cpu" + mol_embedder = load_pretrained(model_type).to(device) + return mol_embedder.eval() + + +def _fetch_molembedder(featurize: str): + """Fetch molembedder.""" + if featurize == "fp": + return None # not in use + else: + raise NotImplementedError + return _fetch_gin_molembedder() + + +def func(emb): + """ + Generates the synthetic tree for the input molecular embedding. + + Args: + emb (np.ndarray): Molecular embedding to decode. + + Returns: + str: SMILES for the final chemical node in the tree. + SyntheticTree: The generated synthetic tree. + """ + emb = emb.reshape((1, -1)) + try: + tree, action = synthetic_tree_decoder( + z_target=emb, + building_blocks=building_blocks, + bb_dict=bb_dict, + reaction_templates=rxns, + mol_embedder=bblocks_molembedder.kdtree, # TODO: fix this, currently misused, + action_net=act_net, + reactant1_net=rt1_net, + rxn_net=rxn_net, + reactant2_net=rt2_net, + bb_emb=bb_emb, + rxn_template=rxn_template, + n_bits=nbits, + max_step=15, + ) + except Exception as e: + print(e) + action = -1 + if action != 3: + return None, None + else: + scores = np.array(tanimoto_similarity(emb, [node.smiles for node in tree.chemicals])) + max_score_idx = np.where(scores == np.max(scores))[0][0] + return tree.chemicals[max_score_idx].smiles, tree + def dock_drd3(smi): """ @@ -25,16 +91,17 @@ def dock_drd3(smi): float: Predicted docking score against the DRD3 target. """ # define the oracle function from the TDC - _drd3 = Oracle(name = 'drd3_docking') + _drd3 = Oracle(name="drd3_docking") if smi is None: return 0.0 else: try: - return - _drd3(smi) + return -_drd3(smi) except: return 0.0 + def dock_7l11(smi): """ Returns the docking score for the 7L11 target. @@ -46,12 +113,12 @@ def dock_7l11(smi): float: Predicted docking score against the 7L11 target. """ # define the oracle function from the TDC - _7l11 = Oracle(name = '7l11_docking') + _7l11 = Oracle(name="7l11_docking") if smi is None: return 0.0 else: try: - return - _7l11(smi) + return -_7l11(smi) except: return 0.0 @@ -77,36 +144,36 @@ def fitness(embs, _pool, obj): trees (list): Contains the synthetic trees generated from the input embeddings. """ - results = _pool.map(decode.func, embs) - smiles = [r[0] for r in results] - trees = [r[1] for r in results] + results = _pool.map(func, embs) + smiles = [r[0] for r in results] + trees = [r[1] for r in results] - if obj == 'qed': + if obj == "qed": # define the oracle function from the TDC - qed = Oracle(name = 'QED') + qed = Oracle(name="QED") scores = [qed(smi) for smi in smiles] - elif obj == 'logp': + elif obj == "logp": # define the oracle function from the TDC - logp = Oracle(name = 'LogP') + logp = Oracle(name="LogP") scores = [logp(smi) for smi in smiles] - elif obj == 'jnk': + elif obj == "jnk": # define the oracle function from the TDC - jnk = Oracle(name = 'JNK3') + jnk = Oracle(name="JNK3") scores = [jnk(smi) if smi is not None else 0.0 for smi in smiles] - elif obj == 'gsk': + elif obj == "gsk": # define the oracle function from the TDC - gsk = Oracle(name = 'GSK3B') + gsk = Oracle(name="GSK3B") scores = [gsk(smi) if smi is not None else 0.0 for smi in smiles] - elif obj == 'drd2': + elif obj == "drd2": # define the oracle function from the TDC - drd2 = Oracle(name = 'DRD2') + drd2 = Oracle(name="DRD2") scores = [drd2(smi) if smi is not None else 0.0 for smi in smiles] - elif obj == '7l11': + elif obj == "7l11": scores = [dock_7l11(smi) for smi in smiles] - elif obj == 'drd3': + elif obj == "drd3": scores = [dock_drd3(smi) for smi in smiles] else: - raise ValueError('Objective function not implemneted') + raise ValueError("Objective function not implemneted") return scores, smiles, trees @@ -122,10 +189,11 @@ def distribution_schedule(n, total): Returns: str: Describes a type of probability distribution. """ - if n < 4 * total/5: - return 'linear' + if n < 4 * total / 5: + return "linear" else: - return 'softmax_linear' + return "softmax_linear" + def num_mut_per_ele_scheduler(n, total): """ @@ -145,6 +213,7 @@ def num_mut_per_ele_scheduler(n, total): # return 512 return 24 + def mut_probability_scheduler(n, total): """ Determines the probability of mutating a vector, based on the number of elapsed @@ -157,42 +226,77 @@ def mut_probability_scheduler(n, total): Returns: float: The probability of mutation. """ - if n < total/2: + if n < total / 2: return 0.5 else: return 0.5 -if __name__ == '__main__': +def get_args(): import argparse - parser = argparse.ArgumentParser() - parser.add_argument("-i", "--input_file", type=str, default=None, - help="A file contains the starting mating pool.") - parser.add_argument("--objective", type=str, default="qed", - help="Objective function to optimize") - parser.add_argument("--radius", type=int, default=2, - help="Radius for Morgan fingerprint.") - parser.add_argument("--nbits", type=int, default=4096, - help="Number of Bits for Morgan fingerprint.") - parser.add_argument("--num_population", type=int, default=100, - help="Number of parents sets to keep.") - parser.add_argument("--num_offspring", type=int, default=300, - help="Number of offsprings to generate each iteration.") - parser.add_argument("--num_gen", type=int, default=30, - help="Number of generations to proceed.") - parser.add_argument("--ncpu", type=int, default=16, - help="Number of cpus") - parser.add_argument("--mut_probability", type=float, default=0.5, - help="Probability to mutate for one offspring.") - parser.add_argument("--num_mut_per_ele", type=int, default=1, - help="Number of bits to mutate in one fingerprint.") - parser.add_argument('--restart', action='store_true') - parser.add_argument("--seed", type=int, default=1, - help="Random seed.") - args = parser.parse_args() - - np.random.seed(args.seed) + parser = argparse.ArgumentParser() + # File I/O + parser.add_argument( + "--building-blocks-file", + type=str, + help="Input file with SMILES strings (First row `SMILES`, then one per line).", + ) + parser.add_argument( + "--rxns-collection-file", + type=str, + help="Input file for the collection of reactions matched with building-blocks.", + ) + parser.add_argument( + "--embeddings-knn-file", + type=str, + help="Input file for the pre-computed embeddings (*.npy).", + ) + parser.add_argument( + "--ckpt-dir", type=str, help="Directory with checkpoints for {act,rt1,rxn,rt2}-model." + ) + parser.add_argument( + "--input-file", + type=str, + default=None, + help="A file contains the starting mating pool.", + ) + parser.add_argument( + "--objective", type=str, default="qed", help="Objective function to optimize" + ) + parser.add_argument("--radius", type=int, default=2, help="Radius for Morgan fingerprint.") + parser.add_argument( + "--nbits", type=int, default=4096, help="Number of Bits for Morgan fingerprint." + ) + parser.add_argument( + "--num_population", type=int, default=100, help="Number of parents sets to keep." + ) + parser.add_argument( + "--num_offspring", + type=int, + default=300, + help="Number of offsprings to generate each iteration.", + ) + parser.add_argument("--num_gen", type=int, default=30, help="Number of generations to proceed.") + parser.add_argument("--ncpu", type=int, default=MAX_PROCESSES, help="Number of cpus") + parser.add_argument( + "--mut_probability", + type=float, + default=0.5, + help="Probability to mutate for one offspring.", + ) + parser.add_argument( + "--num_mut_per_ele", + type=int, + default=1, + help="Number of bits to mutate in one fingerprint.", + ) + parser.add_argument("--restart", action="store_true") + parser.add_argument("--seed", type=int, default=1, help="Random seed.") + return parser.parse_args() + + +def fetch_population(args) -> np.ndarray: if args.restart: population = np.load(args.input_file) print(f"Starting with {len(population)} fps from {args.input_file}") @@ -202,41 +306,79 @@ def mut_probability_scheduler(n, total): print(f"Starting with {args.num_population} fps with {args.nbits} bits") else: starting_smiles = pd.read_csv(args.input_file).sample(args.num_population) - starting_smiles = starting_smiles['smiles'].tolist() - population = np.array( - [mol_fp(smi, args.radius, args.nbits) for smi in starting_smiles] - ) + starting_smiles = starting_smiles["smiles"].tolist() + population = np.array([mol_fp(smi, args.radius, args.nbits) for smi in starting_smiles]) print(f"Starting with {len(starting_smiles)} fps from {args.input_file}") + return population + + +if __name__ == "__main__": + + args = get_args() + np.random.seed(args.seed) + # define some constants (here, for the Hartenfeller-Button test set) + nbits = 4096 + out_dim = 256 + rxn_template = "hb" + featurize = "fp" + param_dir = "hb_fp_2_4096_256" + + # Load data + mol_embedder = _fetch_molembedder(featurize) + + # load the purchasable building block embeddings + bblocks_molembedder = ( + MolEmbedder().load_precomputed(args.embeddings_knn_file).init_balltree(cosine_distance) + ) + bb_emb = bblocks_molembedder.get_embeddings() + + # load the purchasable building block SMILES to a dictionary + building_blocks = BuildingBlockFileHandler().load(args.building_blocks_file) + # A dict is used as lookup table for 2nd reactant during inference: + bb_dict = {block: i for i, block in enumerate(building_blocks)} + + # load the reaction templates as a ReactionSet object + rxns = ReactionSet().load(args.rxns_collection_file).rxns + + # load the pre-trained modules + path = Path(args.ckpt_dir) + ckpt_files = [find_best_model_ckpt(path / model) for model in "act rt1 rxn rt2".split()] + act_net, rt1_net, rxn_net, rt2_net = [load_mlp_from_ckpt(file) for file in ckpt_files] + + # Get initial population + population = fetch_population(args) + + # Evaluation initial population with mp.Pool(processes=args.ncpu) as pool: - scores, mols, trees = fitness(embs=population, - _pool=pool, - obj=args.objective) - scores = np.array(scores) - score_x = np.argsort(scores) + scores, mols, trees = fitness(embs=population, _pool=pool, obj=args.objective) + + scores = np.array(scores) + score_x = np.argsort(scores) population = population[score_x[::-1]] - mols = [mols[i] for i in score_x[::-1]] - scores = scores[score_x[::-1]] + mols = [mols[i] for i in score_x[::-1]] + scores = scores[score_x[::-1]] print(f"Initial: {scores.mean():.3f} +/- {scores.std():.3f}") print(f"Scores: {scores}") print(f"Top-3 Smiles: {mols[:3]}") + # Genetic Algorithm: loop over generations recent_scores = [] - for n in range(args.num_gen): - t = time.time() - dist_ = distribution_schedule(n, args.num_gen) + dist_ = distribution_schedule(n, args.num_gen) num_mut_per_ele_ = num_mut_per_ele_scheduler(n, args.num_gen) mut_probability_ = mut_probability_scheduler(n, args.num_gen) - offspring = crossover(parents=population, - offspring_size=args.num_offspring, - distribution=dist_) - offspring = mutation(offspring_crossover=offspring, - num_mut_per_ele=num_mut_per_ele_, - mut_probability=mut_probability_) + offspring = crossover( + parents=population, offspring_size=args.num_offspring, distribution=dist_ + ) + offspring = mutation( + offspring_crossover=offspring, + num_mut_per_ele=num_mut_per_ele_, + mut_probability=mut_probability_, + ) new_population = np.unique(np.concatenate([population, offspring], axis=0), axis=0) with mp.Pool(processes=args.ncpu) as pool: new_scores, new_mols, trees = fitness(new_population, pool, args.objective) @@ -272,28 +414,33 @@ def mut_probability_scheduler(n, total): if len(recent_scores) > 10: del recent_scores[0] - np.save('population_' + args.objective + '_' + str(n+1) + '.npy', population) - - data = {'objective': args.objective, - 'top1' : np.mean(scores[:1]), - 'top10' : np.mean(scores[:10]), - 'top100' : np.mean(scores[:100]), - 'smiles' : mols, - 'scores' : scores.tolist()} - with open('opt_' + args.objective + '.json', 'w') as f: + np.save("population_" + args.objective + "_" + str(n + 1) + ".npy", population) + + data = { + "objective": args.objective, + "top1": np.mean(scores[:1]), + "top10": np.mean(scores[:10]), + "top100": np.mean(scores[:100]), + "smiles": mols, + "scores": scores.tolist(), + } + with open("opt_" + args.objective + ".json", "w") as f: json.dump(data, f) if n > 30 and recent_scores[-1] - recent_scores[0] < 0.01: print("Early Stop!") break - data = {'objective': args.objective, - 'top1' : np.mean(scores[:1]), - 'top10' : np.mean(scores[:10]), - 'top100' : np.mean(scores[:100]), - 'smiles' : mols, - 'scores' : scores.tolist()} - with open('opt_' + args.objective + '.json', 'w') as f: + # Save results + data = { + "objective": args.objective, + "top1": np.mean(scores[:1]), + "top10": np.mean(scores[:10]), + "top100": np.mean(scores[:100]), + "smiles": mols, + "scores": scores.tolist(), + } + with open("opt_" + args.objective + ".json", "w") as f: json.dump(data, f) - np.save('population_' + args.objective + '.npy', population) + np.save("population_" + args.objective + ".npy", population) diff --git a/scripts/predict-beam-fullTree.py b/scripts/predict-beam-fullTree.py deleted file mode 100644 index 7bf001c8..00000000 --- a/scripts/predict-beam-fullTree.py +++ /dev/null @@ -1,178 +0,0 @@ -""" -This file contains the code to decode synthetic trees using beam search at every -sampling step after the action network (i.e. reactant 1, reaction, and reactant 2 -sampling). -""" -import os -import pandas as pd -import numpy as np -from tqdm import tqdm -from rdkit import Chem -from rdkit import DataStructs - -from syn_net.utils.data_utils import ReactionSet, SyntheticTreeSet - -from dgl.nn.pytorch.glob import AvgPooling -from dgllife.model import load_pretrained -from syn_net.utils.predict_utils import mol_fp, get_mol_embedding -from syn_net.utils.predict_beam_utils import synthetic_tree_decoder_fullbeam, load_modules_from_checkpoint - -if __name__ == '__main__': - - import argparse - parser = argparse.ArgumentParser() - parser.add_argument("-f", "--featurize", type=str, default='fp', - help="Choose from ['fp', 'gin']") - parser.add_argument("-r", "--rxn_template", type=str, default='hb', - help="Choose from ['hb', 'pis']") - parser.add_argument("-v", "--version", type=int, default=1, - help="Version") - parser.add_argument("--radius", type=int, default=2, - help="Radius for Morgan fingerprint.") - parser.add_argument("--nbits", type=int, default=1024, - help="Number of Bits for Morgan fingerprint.") - parser.add_argument("--out_dim", type=int, default=300, - help="Output dimension.") - parser.add_argument("--ncpu", type=int, default=16, - help="Number of cpus") - parser.add_argument("--batch_size", type=int, default=64, - help="Batch size") - parser.add_argument("--beam_width", type=int, default=5, - help="Beam width to use for Reactant1 search") - parser.add_argument("-n", "--num", type=int, default=-1, - help="Number of molecules to decode.") - parser.add_argument("-d", "--data", type=str, default='test', - help="Choose from ['train', 'valid', 'test']") - args = parser.parse_args() - - # define model to use for molecular embedding - readout = AvgPooling() - model_type = 'gin_supervised_contextpred' - device = 'cuda:0' - mol_embedder = load_pretrained(model_type).to(device) - mol_embedder.eval() - - # load the purchasable building block embeddings - bb_emb = np.load(f'/pool001/whgao/data/synth_net/st_{args.rxn_template}/enamine_us_emb.npy') - - # define path to the reaction templates and purchasable building blocks - path_to_reaction_file = f'/pool001/whgao/data/synth_net/st_{args.rxn_template}/reactions_{args.rxn_template}.json.gz' - path_to_building_blocks = f'/pool001/whgao/data/synth_net/st_{args.rxn_template}/enamine_us_matched.csv.gz' - - # define paths to pretrained modules - param_path = f'/home/rociomer/SynthNet/pre-trained-models/{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_v{args.version}/' - path_to_act = f'{param_path}act.ckpt' - path_to_rt1 = f'{param_path}rt1.ckpt' - path_to_rxn = f'{param_path}rxn.ckpt' - path_to_rt2 = f'{param_path}rt2.ckpt' - - np.random.seed(6) - - # load the purchasable building block SMILES to a dictionary - building_blocks = pd.read_csv(path_to_building_blocks, compression='gzip')['SMILES'].tolist() - bb_dict = {building_blocks[i]: i for i in range(len(building_blocks))} - - # load the reaction templates as a ReactionSet object - rxn_set = ReactionSet() - rxn_set.load(path_to_reaction_file) - rxns = rxn_set.rxns - - # load the pre-trained modules - act_net, rt1_net, rxn_net, rt2_net = load_modules_from_checkpoint( - path_to_act=path_to_act, - path_to_rt1=path_to_rt1, - path_to_rxn=path_to_rxn, - path_to_rt2=path_to_rt2, - featurize=args.featurize, - rxn_template=args.rxn_template, - out_dim=args.out_dim, - nbits=args.nbits, - ncpu=args.ncpu, - ) - - def decode_one_molecule(query_smi): - """ - Generate a synthetic tree from a given query SMILES. - - Args: - query_smi (str): SMILES for molecule to decode. - - Returns: - tree (SyntheticTree): The final synthetic tree - act (int): The final action (to know if the tree was "properly" terminated) - """ - if args.featurize == 'fp': - z_target = mol_fp(query_smi, nBits=args.nbits) - elif args.featurize == 'gin': - z_target = get_mol_embedding(query_smi) - tree, action = synthetic_tree_decoder_fullbeam(z_target=z_target, - building_blocks=building_blocks, - bb_dict=bb_dict, - reaction_templates=rxns, - mol_embedder=mol_embedder, - action_net=act_net, - reactant1_net=rt1_net, - rxn_net=rxn_net, - reactant2_net=rt2_net, - bb_emb=bb_emb, - beam_width=args.beam_width, - rxn_template=args.rxn_template, - n_bits=args.nbits, - max_step=15) - return tree, action - - # load the purchasable building blocks to decode - path_to_data = f'/pool001/whgao/data/synth_net/st_{args.rxn_template}/st_{args.data}.json.gz' - print('Reading data from ', path_to_data) - sts = SyntheticTreeSet() - sts.load(path_to_data) - query_smis = [st.root.smiles for st in sts.sts] - if args.num == -1: - pass - else: - query_smis = query_smis[:args.num] - - output_smis = [] - similaritys = [] - trees = [] - num_finish = 0 - num_unfinish = 0 - - print('Start to decode!') - for smi in tqdm(query_smis): - - try: - tree, action = decode_one_molecule(smi) - except Exception as e: - print(e) - action = 1 - tree = None - - if action != 3: - num_unfinish += 1 - output_smis.append(None) - similaritys.append(None) - trees.append(None) - else: - num_finish += 1 - output_smis.append(tree.root.smiles) - ms = [Chem.MolFromSmiles(sm) for sm in [smi, tree.root.smiles]] - fps = [Chem.RDKFingerprint(x) for x in ms] - similaritys.append(DataStructs.FingerprintSimilarity(fps[0],fps[1])) - trees.append(tree) - - print('Saving ......') - save_path = '../results/' + args.rxn_template + '_' + args.featurize + '/' - if not os.path.exists(save_path): - os.makedirs(save_path) - df = pd.DataFrame({'query SMILES': query_smis, 'decode SMILES': output_smis, 'similarity': similaritys}) - print("mean similarities", df['similarity'].mean(), df['similarity'].std()) - print("NAs", df.isna().sum()) - df.to_csv(f'{save_path}decode_result_{args.data}_bw_{args.beam_width}.csv.gz', - compression='gzip', - index=False) - - synthetic_tree_set = SyntheticTreeSet(sts=trees) - synthetic_tree_set.save(f'{save_path}decoded_st_bw_{args.beam_width}_{args.data}.json.gz') - - print('Finish!') diff --git a/scripts/predict-beam-reactantOnly.py b/scripts/predict-beam-reactantOnly.py deleted file mode 100644 index daf9b3c0..00000000 --- a/scripts/predict-beam-reactantOnly.py +++ /dev/null @@ -1,183 +0,0 @@ -""" -This file contains the code to decode synthetic trees using beam search at the -first reactant sampling step (after the action network). -""" -import os -import pandas as pd -import numpy as np -from tqdm import tqdm -from rdkit import Chem -from rdkit import DataStructs -from syn_net.utils.data_utils import ReactionSet, SyntheticTreeSet -from dgl.nn.pytorch.glob import AvgPooling -from dgllife.model import load_pretrained -from syn_net.models.mlp import MLP -from syn_net.utils.predict_utils import mol_fp, get_mol_embedding -from syn_net.utils.predict_beam_utils import synthetic_tree_decoder, load_modules_from_checkpoint - - -if __name__ == '__main__': - - import argparse - parser = argparse.ArgumentParser() - parser.add_argument("-f", "--featurize", type=str, default='fp', - help="Choose from ['fp', 'gin']") - parser.add_argument("-r", "--rxn_template", type=str, default='hb', - help="Choose from ['hb', 'pis']") - parser.add_argument("-v", "--version", type=int, default=1, - help="Version") - parser.add_argument("--radius", type=int, default=2, - help="Radius for Morgan fingerprint.") - parser.add_argument("--nbits", type=int, default=1024, - help="Number of Bits for Morgan fingerprint.") - parser.add_argument("--out_dim", type=int, default=300, - help="Output dimension.") - parser.add_argument("--ncpu", type=int, default=16, - help="Number of cpus") - parser.add_argument("--batch_size", type=int, default=64, - help="Batch size") - parser.add_argument("--beam_width", type=int, default=5, - help="Beam width to use for Reactant1 search") - parser.add_argument("-n", "--num", type=int, default=-1, - help="Number of molecules to decode.") - parser.add_argument("-d", "--data", type=str, default='test', - help="Choose from ['train', 'valid', 'test']") - args = parser.parse_args() - - # define model to use for molecular embedding - readout = AvgPooling() - model_type = 'gin_supervised_contextpred' - device = 'cuda:0' - mol_embedder = load_pretrained(model_type).to(device) - mol_embedder.eval() - - # load the purchasable building block embeddings - bb_emb = np.load('/pool001/whgao/data/synth_net/st_' + args.rxn_template + '/enamine_us_emb.npy') - - # define path to the reaction templates and purchasable building blocks - path_to_reaction_file = ('/pool001/whgao/data/synth_net/st_' + args.rxn_template - + '/reactions_' + args.rxn_template + '.json.gz') - path_to_building_blocks = ('/pool001/whgao/data/synth_net/st_' + args.rxn_template - + '/enamine_us_matched.csv.gz') - - # define paths to pretrained modules - param_path = (f"/home/rociomer/SynthNet/pre-trained-models/{args.rxn_template}" - f"_{args.featurize}_{args.radius}_{args.nbits}_v{args.version}/") - path_to_act = param_path + 'act.ckpt' - path_to_rt1 = param_path + 'rt1.ckpt' - path_to_rxn = param_path + 'rxn.ckpt' - path_to_rt2 = param_path + 'rt2.ckpt' - - np.random.seed(6) - - # load the purchasable building block SMILES to a dictionary - building_blocks = pd.read_csv(path_to_building_blocks, compression='gzip')['SMILES'].tolist() - bb_dict = {building_blocks[i]: i for i in range(len(building_blocks))} - - # load the reaction templates as a ReactionSet object - rxn_set = ReactionSet() - rxn_set.load(path_to_reaction_file) - rxns = rxn_set.rxns - - # load the pre-trained modules - act_net, rt1_net, rxn_net, rt2_net = load_modules_from_checkpoint( - path_to_act=path_to_act, - path_to_rt1=path_to_rt1, - path_to_rxn=path_to_rxn, - path_to_rt2=path_to_rt2, - featurize=args.featurize, - rxn_template=args.rxn_template, - out_dim=args.out_dim, - nbits=args.nbits, - ncpu=args.ncpu, - ) - - def decode_one_molecule(query_smi): - """ - Generate a synthetic tree from a given query SMILES. - - Args: - query_smi (str): SMILES for molecule to decode. - - Returns: - tree (SyntheticTree): The final synthetic tree - act (int): The final action (to know if the tree was "properly" terminated) - """ - if args.featurize == 'fp': - z_target = mol_fp(query_smi, args.radius, args.nbits) - elif args.featurize == 'gin': - z_target = get_mol_embedding(query_smi) - tree, action = synthetic_tree_decoder(z_target=z_target, - building_blocks=building_blocks, - bb_dict=bb_dict, - reaction_templates=rxns, - mol_embedder=mol_embedder, - action_net=act_net, - reactant1_net=rt1_net, - rxn_net=rxn_net, - reactant2_net=rt2_net, - bb_emb=bb_emb, - beam_width=args.beam_width, - rxn_template=args.rxn_template, - n_bits=args.nbits, - max_step=15) - return tree, action - - path_to_data = f'/pool001/whgao/data/synth_net/st_{args.rxn_template}/st_{args.data}.json.gz' - print('Reading data from ', path_to_data) - sts = SyntheticTreeSet() - sts.load(path_to_data) - query_smis = [st.root.smiles for st in sts.sts] - if args.num == -1: - pass - else: - query_smis = query_smis[:args.num] - - output_smis = [] - similaritys = [] - trees = [] - num_finish = 0 - num_unfinish = 0 - - print('Start to decode!') - for smi in tqdm(query_smis): - - try: - tree, action = decode_one_molecule(smi) - except Exception as e: - print(e) - action = 1 - tree = None - - if action != 3: - num_unfinish += 1 - output_smis.append(None) - similaritys.append(None) - trees.append(None) - else: - num_finish += 1 - output_smis.append(tree.root.smiles) - ms = [Chem.MolFromSmiles(sm) for sm in [smi, tree.root.smiles]] - fps = [Chem.RDKFingerprint(x) for x in ms] - similaritys.append(DataStructs.FingerprintSimilarity(fps[0],fps[1])) - trees.append(tree) - - print('Saving ......') - save_path = f'../results/{args.rxn_template}_{args.featurize}/' - if not os.path.exists(save_path): - os.makedirs(save_path) - df = pd.DataFrame( - {'query SMILES' : query_smis, - 'decode SMILES': output_smis, - 'similarity' : similaritys} - ) - print("mean similarities", df['similarity'].mean(), df['similarity'].std()) - print("NAs", df.isna().sum()) - df.to_csv(f'{save_path}decode_result_{args.data}_robw_{str(args.beam_width)}.csv.gz', - compression='gzip', - index=False) - - synthetic_tree_set = SyntheticTreeSet(sts=trees) - synthetic_tree_set.save(f'{save_path}decoded_st_robw_{str(args.beam_width)}_{args.data}.json.gz') - - print('Finish!') diff --git a/scripts/predict.py b/scripts/predict.py deleted file mode 100644 index 2c40c83e..00000000 --- a/scripts/predict.py +++ /dev/null @@ -1,174 +0,0 @@ -""" -This file contains the code to decode synthetic trees using a greedy search at -every sampling step. -""" -import os -import pandas as pd -import numpy as np -from tqdm import tqdm -from rdkit import Chem -from rdkit import DataStructs -from syn_net.utils.data_utils import ReactionSet, SyntheticTreeSet -from dgl.nn.pytorch.glob import AvgPooling -from dgllife.model import load_pretrained -from syn_net.utils.predict_utils import mol_fp, get_mol_embedding -from syn_net.utils.predict_utils import synthetic_tree_decoder, load_modules_from_checkpoint - -if __name__ == '__main__': - - import argparse - parser = argparse.ArgumentParser() - parser.add_argument("-f", "--featurize", type=str, default='fp', - help="Choose from ['fp', 'gin']") - parser.add_argument("-r", "--rxn_template", type=str, default='hb', - help="Choose from ['hb', 'pis']") - parser.add_argument("-v", "--version", type=int, default=0, - help="Version") - parser.add_argument("--param", type=str, default='hb_fp_2_4096', - help="Name of directory with parameters in it.") - parser.add_argument("--radius", type=int, default=2, - help="Radius for Morgan fingerprint.") - parser.add_argument("--nbits", type=int, default=4096, - help="Number of Bits for Morgan fingerprint.") - parser.add_argument("--out_dim", type=int, default=300, - help="Output dimension.") - parser.add_argument("--ncpu", type=int, default=16, - help="Number of cpus") - parser.add_argument("--batch_size", type=int, default=64, - help="Batch size") - parser.add_argument("-n", "--num", type=int, default=-1, - help="Number of molecules to decode.") - parser.add_argument("-d", "--data", type=str, default='test', - help="Choose from ['train', 'valid', 'test']") - args = parser.parse_args() - - # define model to use for molecular embedding - readout = AvgPooling() - model_type = 'gin_supervised_contextpred' - device = 'cuda:0' - mol_embedder = load_pretrained(model_type).to(device) - mol_embedder.eval() - - # load the purchasable building block embeddings - bb_emb = np.load('/pool001/whgao/data/synth_net/st_' + args.rxn_template + '/enamine_us_emb.npy') - - # define path to the reaction templates and purchasable building blocks - path_to_reaction_file = ('/pool001/whgao/data/synth_net/st_' + args.rxn_template - + '/reactions_' + args.rxn_template + '.json.gz') - path_to_building_blocks = ('/pool001/whgao/data/synth_net/st_' + args.rxn_template - + '/enamine_us_matched.csv.gz') - - # define paths to pretrained modules - param_path = '/home/whgao/scGen/synth_net/synth_net/params/' + args.param + '/' - path_to_act = param_path + 'act.ckpt' - path_to_rt1 = param_path + 'rt1.ckpt' - path_to_rxn = param_path + 'rxn.ckpt' - path_to_rt2 = param_path + 'rt2.ckpt' - - np.random.seed(6) - - # load the purchasable building block SMILES to a dictionary - building_blocks = pd.read_csv(path_to_building_blocks, compression='gzip')['SMILES'].tolist() - bb_dict = {building_blocks[i]: i for i in range(len(building_blocks))} - - # load the reaction templates as a ReactionSet object - rxn_set = ReactionSet() - rxn_set.load(path_to_reaction_file) - rxns = rxn_set.rxns - - # load the pre-trained modules - act_net, rt1_net, rxn_net, rt2_net = load_modules_from_checkpoint( - path_to_act=path_to_act, - path_to_rt1=path_to_rt1, - path_to_rxn=path_to_rxn, - path_to_rt2=path_to_rt2, - featurize=args.featurize, - rxn_template=args.rxn_template, - out_dim=args.out_dim, - nbits=args.nbits, - ncpu=args.ncpu, - ) - - def decode_one_molecule(query_smi): - """ - Generate a synthetic tree from a given query SMILES. - - Args: - query_smi (str): SMILES for molecule to decode. - - Returns: - tree (SyntheticTree): The final synthetic tree - act (int): The final action (to know if the tree was "properly" terminated) - """ - if args.featurize == 'fp': - z_target = mol_fp(query_smi, args.radius, args.nbits) - elif args.featurize == 'gin': - z_target = get_mol_embedding(query_smi) - tree, action = synthetic_tree_decoder(z_target, - building_blocks, - bb_dict, - rxns, - mol_embedder, - act_net, - rt1_net, - rxn_net, - rt2_net, - bb_emb=bb_emb, - rxn_template=args.rxn_template, - n_bits=args.nbits, - max_step=15) - return tree, action - - - path_to_data = '/pool001/whgao/data/synth_net/st_' + args.rxn_template + '/st_' + args.data +'.json.gz' - print('Reading data from ', path_to_data) - sts = SyntheticTreeSet() - sts.load(path_to_data) - query_smis = [st.root.smiles for st in sts.sts] - if args.num == -1: - pass - else: - query_smis = query_smis[:args.num] - - output_smis = [] - similaritys = [] - trees = [] - num_finish = 0 - num_unfinish = 0 - - print('Start to decode!') - for smi in tqdm(query_smis): - - try: - tree, action = decode_one_molecule(smi) - except Exception as e: - print(e) - action = -1 - tree = None - - if action != 3: - num_unfinish += 1 - output_smis.append(None) - similaritys.append(None) - trees.append(None) - else: - num_finish += 1 - output_smis.append(tree.root.smiles) - ms = [Chem.MolFromSmiles(sm) for sm in [smi, tree.root.smiles]] - fps = [Chem.RDKFingerprint(x) for x in ms] - similaritys.append(DataStructs.FingerprintSimilarity(fps[0],fps[1])) - trees.append(tree) - - print('Saving ......') - save_path = '../results/' + args.rxn_template + '_' + args.featurize + '/' - if not os.path.exists(save_path): - os.makedirs(save_path) - df = pd.DataFrame({'query SMILES': query_smis, 'decode SMILES': output_smis, 'similarity': similaritys}) - print("mean similarities", df['similarity'].mean(), df['similarity'].std()) - print("NAs", df.isna().sum()) - df.to_csv(save_path + 'decode_result_' + args.data + '.csv.gz', compression='gzip', index=False) - - synthetic_tree_set = SyntheticTreeSet(sts=trees) - synthetic_tree_set.save(save_path + 'decoded_st_' + args.data + '.json.gz') - - print('Finish!') diff --git a/scripts/predict_mp.py b/scripts/predict_mp.py deleted file mode 100644 index fe521d48..00000000 --- a/scripts/predict_mp.py +++ /dev/null @@ -1,60 +0,0 @@ -""" -Generate synthetic trees for a set of specified query molecules. Multiprocessing. -""" -import multiprocessing as mp -import numpy as np -import pandas as pd -import scripts._mp_predict as predict -from syn_net.utils.data_utils import SyntheticTreeSet - - -if __name__ == '__main__': - - import argparse - parser = argparse.ArgumentParser() - parser.add_argument("-f", "--featurize", type=str, default='fp', - help="Choose from ['fp', 'gin']") - parser.add_argument("-r", "--rxn_template", type=str, default='hb', - help="Choose from ['hb', 'pis']") - parser.add_argument("--ncpu", type=int, default=16, - help="Number of cpus") - parser.add_argument("-n", "--num", type=int, default=-1, - help="Number of molecules to predict.") - parser.add_argument("-d", "--data", type=str, default='test', - help="Choose from ['train', 'valid', 'test']") - args = parser.parse_args() - - # load the query molecules (i.e. molecules to decode) - path_to_data = '/pool001/whgao/data/synth_net/st_' + args.rxn_template + '/st_' + args.data +'.json.gz' - print('Reading data from ', path_to_data) - sts = SyntheticTreeSet() - sts.load(path_to_data) - smis_query = [st.root.smiles for st in sts.sts] - if args.num == -1: - pass - else: - smis_query = smis_query[:args.num] - - print('Start to decode!') - with mp.Pool(processes=args.ncpu) as pool: - results = pool.map(predict.func, smis_query) - - smis_decoded = [r[0] for r in results] - similaritys = [r[1] for r in results] - trees = [r[2] for r in results] - - print("Finish decoding") - print(f"Recovery rate {args.data}: {np.sum(np.array(similaritys) == 1.0) / len(similaritys)}") - print(f"Average similarity {args.data}: {np.mean(np.array(similaritys))}") - - print('Saving ......') - save_path = '../results/' + args.rxn_template + '_' + args.featurize + '/' - df = pd.DataFrame({'query SMILES': smis_query, 'decode SMILES': smis_decoded, 'similarity': similaritys}) - df.to_csv(save_path + 'decode_result_' + args.data + '.csv.gz', compression='gzip', index=False) - - synthetic_tree_set = SyntheticTreeSet(sts=trees) - synthetic_tree_set.save(save_path + 'decoded_st_' + args.data + '.json.gz') - - print('Finish!') - - diff --git a/scripts/predict_multireactant_mp.py b/scripts/predict_multireactant_mp.py deleted file mode 100644 index f8ee0da9..00000000 --- a/scripts/predict_multireactant_mp.py +++ /dev/null @@ -1,66 +0,0 @@ -""" -Generate synthetic trees for a set of specified query molecules. Multiprocessing. -""" -import multiprocessing as mp -import numpy as np -import pandas as pd -import _mp_predict_multireactant as predict -from syn_net.utils.data_utils import SyntheticTreeSet - - -if __name__ == '__main__': - - import argparse - parser = argparse.ArgumentParser() - parser.add_argument("-f", "--featurize", type=str, default='fp', - help="Choose from ['fp', 'gin']") - parser.add_argument("-r", "--rxn_template", type=str, default='hb', - help="Choose from ['hb', 'pis']") - parser.add_argument("--ncpu", type=int, default=16, - help="Number of cpus") - parser.add_argument("-n", "--num", type=int, default=-1, - help="Number of molecules to predict.") - parser.add_argument("-d", "--data", type=str, default='test', - help="Choose from ['train', 'valid', 'test', 'chembl']") - args = parser.parse_args() - - # load the query molecules (i.e. molecules to decode) - if args.data != 'chembl': - path_to_data = f'/pool001/whgao/data/synth_net/st_{args.rxn_template}/st_{args.data}.json.gz' - print('Reading data from ', path_to_data) - sts = SyntheticTreeSet() - sts.load(path_to_data) - smis_query = [st.root.smiles for st in sts.sts] - if args.num == -1: - pass - else: - smis_query = smis_query[:args.num] - else: - df = pd.read_csv('/home/whgao/synth_net/chembl_20k.csv') - smis_query = df.smiles.to_list() - - print('Start to decode!') - with mp.Pool(processes=args.ncpu) as pool: - results = pool.map(predict.func, smis_query) - - smis_decoded = [r[0] for r in results] - similarities = [r[1] for r in results] - trees = [r[2] for r in results] - - print('Finish decoding') - print(f'Recovery rate {args.data}: {np.sum(np.array(similarities) == 1.0) / len(similarities)}') - print(f'Average similarity {args.data}: {np.mean(np.array(similarities))}') - - print('Saving ......') - save_path = '../results/' - df = pd.DataFrame({'query SMILES' : smis_query, - 'decode SMILES': smis_decoded, - 'similarity' : similarities}) - df.to_csv(f'{save_path}decode_result_{args.data}.csv.gz', - compression='gzip', - index=False) - - synthetic_tree_set = SyntheticTreeSet(sts=trees) - synthetic_tree_set.save(f'{save_path}decoded_st_{args.data}.json.gz') - - print('Finish!') diff --git a/scripts/read_st_data.py b/scripts/read_st_data.py deleted file mode 100644 index 95807fbd..00000000 --- a/scripts/read_st_data.py +++ /dev/null @@ -1,20 +0,0 @@ -""" -Reads synthetic tree data and prints the first five trees. -""" -from syn_net.utils.data_utils import * - - -if __name__ == "__main__": - - st_set = SyntheticTreeSet() - path_to_data = '/pool001/whgao/data/synth_net/st_pis/st_data.json.gz' - - print('Reading data from ', path_to_data) - st_set.load(path_to_data) - data = st_set.sts - - for t in data[:5]: - t._print() - - print(len(data)) - print("Finish!") diff --git a/scripts/sample_from_original.py b/scripts/sample_from_original.py deleted file mode 100644 index 01f79e78..00000000 --- a/scripts/sample_from_original.py +++ /dev/null @@ -1,72 +0,0 @@ -""" -Filters the synthetic trees by the QEDs of the root molecules. -""" -from tdc import Oracle -qed = Oracle(name='qed') -import numpy as np -import pandas as pd -from syn_net.utils.data_utils import * - -def is_valid(smi): - """ - Checks if a SMILES string is valid. - - Args: - smi (str): Molecular SMILES string. - - Returns: - False or str: False if the SMILES is not valid, or the reconverted - SMILES string. - """ - mol = Chem.MolFromSmiles(smi) - if mol is None: - return False - else: - return Chem.MolToSmiles(mol, isomericSmiles=False) - -if __name__ == '__main__': - - data_path = '/pool001/whgao/data/synth_net/st_pis/st_data.json.gz' - st_set = SyntheticTreeSet() - st_set.load(data_path) - data = st_set.sts - print(f'Finish reading, in total {len(data)} synthetic trees.') - - filtered_data = [] - original_qed = [] - qeds = [] - generated_smiles = [] - - threshold = 0.5 - - for t in tqdm(data): - try: - valid_smiles = is_valid(t.root.smiles) - if valid_smiles: - if valid_smiles in generated_smiles: - pass - else: - qed_value = qed(valid_smiles) - original_qed.append(qed_value) - - # filter the trees based on their QEDs - if qed_value > threshold or np.random.random() < (qed_value/threshold): - generated_smiles.append(valid_smiles) - filtered_data.append(t) - qeds.append(qed_value) - else: - pass - else: - pass - except: - pass - - print(f'Finish sampling, remaining {len(filtered_data)} synthetic trees.') - - st_set = SyntheticTreeSet(filtered_data) - st_set.save('/pool001/whgao/data/synth_net/st_pis/st_data_filtered.json.gz') - - df = pd.DataFrame({'SMILES': generated_smiles, 'qed': qeds}) - df.to_csv('/pool001/whgao/data/synth_net/st_pis/filtered_smiles.csv.gz', compression='gzip', index=False) - - print('Finish!') diff --git a/scripts/search_similar.py b/scripts/search_similar.py deleted file mode 100644 index ec0f2c02..00000000 --- a/scripts/search_similar.py +++ /dev/null @@ -1,53 +0,0 @@ -""" -Computes the fingerprint similarity of molecules in the validation and test set to -molecules in the training set. -""" -import numpy as np -import pandas as pd -from syn_net.utils.data_utils import * -from rdkit import Chem -from rdkit.Chem import AllChem -import multiprocessing as mp -from scripts._mp_search_similar import func - - -if __name__ == '__main__': - - ncpu = 64 - - data_path = '/pool001/whgao/data/synth_net/st_hb/st_train.json.gz' - st_set = SyntheticTreeSet() - st_set.load(data_path) - data = st_set.sts - data_train = [t.root.smiles for t in data] - - data_path = '/pool001/whgao/data/synth_net/st_hb/st_test.json.gz' - st_set = SyntheticTreeSet() - st_set.load(data_path) - data = st_set.sts - data_test = [t.root.smiles for t in data] - - data_path = '/pool001/whgao/data/synth_net/st_hb/st_valid.json.gz' - st_set = SyntheticTreeSet() - st_set.load(data_path) - data = st_set.sts - data_valid = [t.root.smiles for t in data] - - fps_valid = [AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(smi), 2, nBits=1024) for smi in data_valid] - fps_test = [AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(smi), 2, nBits=1024) for smi in data_test] - - with mp.Pool(processes=ncpu) as pool: - results = pool.map(func, fps_valid) - similaritys = [r[0] for r in results] - indices = [data_train[r[1]] for r in results] - df1 = pd.DataFrame({'smiles': data_valid, 'split': 'valid', 'most similar': indices, 'similarity': similaritys}) - - with mp.Pool(processes=ncpu) as pool: - results = pool.map(func, fps_test) - similaritys = [r[0] for r in results] - indices = [data_train[r[1]] for r in results] - df2 = pd.DataFrame({'smiles': data_test, 'split': 'test', 'most similar': indices, 'similarity': similaritys}) - - df = pd.concat([df1, df2], axis=0, ignore_index=True) - df.to_csv('data_similarity.csv', index=False) - print('Finish!') diff --git a/scripts/sketch-synthetic-trees.py b/scripts/sketch-synthetic-trees.py deleted file mode 100644 index 138aaaaa..00000000 --- a/scripts/sketch-synthetic-trees.py +++ /dev/null @@ -1,243 +0,0 @@ -""" -Sketches the synthetic trees in a specified file. -""" -from syn_net.utils.data_utils import * -import argparse -from typing import Tuple -from rdkit.Chem import MolFromSmiles -from rdkit.Chem.Draw import MolToImage -import networkx as nx -import matplotlib.pyplot as plt -from matplotlib.patches import Rectangle - - -# define some color maps for plotting -edges_cmap = { - 0 : "tab:brown", # Add - 1 : "tab:pink", # Expand - 2 : "tab:gray", # Merge - #3 : "tab:olive", # End # not currently plotting -} -nodes_cmap = { - 0 : "tab:blue", # most recent mol - 1 : "tab:orange", # other root mol - 2 : "tab:green", # product -} - - -def get_states_and_steps(synthetic_tree : "SyntheticTree") -> Tuple[list, list]: - """ - Gets the different nodes of the input synthetic tree, and the "action type" - that was used to get to those nodes. - - Args: - synthetic_tree (SyntheticTree): - - Returns: - Tuple[list, list]: Contains lists of the states and steps (actions) from - the Synthetic Tree. - """ - states = [] - steps = [] - - target = synthetic_tree.root.smiles - most_recent_mol = None - other_root_mol = None - - for i, action in enumerate(st.actions): - - # Action: (Add: 0, Expand: 1, Merge: 2, End: 3) - if action != 3: - r = synthetic_tree.reactions[i] - mol1 = r.child[0] - if len(r.child) == 2: - mol2 = r.child[1] - else: - mol2 = None - state = [mol1, mol2, r.parent] - else: - state = [most_recent_mol, other_root_mol, target] - - if action == 2: - most_recent_mol = r.parent - other_root_mol = None - - elif action == 1: - most_recent_mol = r.parent - - elif action == 0: - other_root_mol = most_recent_mol - most_recent_mol = r.parent - - states.append(state) - steps.append(action) - - return states, steps - -def draw_tree(states : list, steps : list, tree_name : str) -> None: - """ - Draws the synthetic tree based on the input list of states (reactant/product - nodes) and steps (actions). - - Args: - states (list): Molecular nodes (i.e. reactants and products). - steps (list): Action types (e.g. "Add" and "Merge"). - tree_name (str): Name of tree to use for file saving purposes. - """ - G = nx.Graph() - pos_dict = {} # sets the position of the nodes, for plotting below - edge_color_dict = {} # sets the color of the edges based on the action - node_color_dict = {} # sets the color of the box around the node during plotting - - node_idx =0 - prev_target_idx = None - merge_correction = 0.0 - for state_idx, state in enumerate(states): - - # Action: (Add: 0, Expand: 1, Merge: 2, End: 3) - step = steps[state_idx] - if step == 3: - break - - skip_mrm = False - skip_orm = False - for smiles_idx, smiles in enumerate(state): - - if smiles is None and smiles_idx == 0: - skip_mrm = True # mrm == 'most recent mol' - continue - elif smiles is None and smiles_idx == 1: - skip_orm = True # orm == 'other root molecule' - continue - elif smiles is None and smiles_idx == 2: - continue - elif step == 1 and smiles_idx == 0: - merge_correction -= 0.5 - skip_mrm = True # mrm == 'most recent mol' - continue - - # draw the molecules (creates a PIL image) - img = MolToImage(mol=MolFromSmiles(smiles), fitImage=False) - G.add_node(str(node_idx), image=img) - node_color_dict[str(node_idx)] = nodes_cmap[smiles_idx] - if smiles_idx != 2: - pos_dict[str(node_idx)] = [state_idx + merge_correction, smiles_idx/2 + 0.01] - else: - pos_dict[str(node_idx)] = [state_idx + 0.5 + merge_correction, 0.01] # 0.01 important to not plot edge under axis label, even if later axis label is turned off (weird behavior) - if smiles_idx == 2: - if not skip_mrm: - G.add_edge(str(node_idx - 2 + int(skip_orm)), str(node_idx)) # connect most recent mol to target - edge_color_dict[(str(node_idx - 2 + int(skip_orm)), str(node_idx))] = edges_cmap[step] - if not skip_orm: - G.add_edge(str(node_idx - 1), str(node_idx)) # connect other root mol to target - edge_color_dict[(str(node_idx - 1), str(node_idx))] = edges_cmap[step] - node_idx += 1 - - if prev_target_idx and not step == 1: - mrm_idx = node_idx - 3 + int(skip_orm) - G.add_edge(str(prev_target_idx), str(mrm_idx)) # connect the previous target to the current most recent mol - edge_color_dict[(str(prev_target_idx), str(mrm_idx))] = edges_cmap[step] - elif prev_target_idx and step == 1: - new_target_idx = node_idx - 1 - G.add_edge(str(prev_target_idx), str(new_target_idx)) # connect the previous target to the current most recent mol - edge_color_dict[(str(prev_target_idx), str(new_target_idx))] = edges_cmap[step] - prev_target_idx = node_idx - 1 - - # sketch the tree - fig, ax = plt.subplots() - - nx.draw_networkx_edges( - G, - pos=pos_dict, - ax=ax, - arrows=True, - edgelist=[edge for edge in G.edges], - edge_color=[edge_color_dict[edge] for edge in G.edges], - arrowstyle="-", # suppresses arrowheads - width=2.0, - alpha=0.9, - min_source_margin=15, - min_target_margin=15, - ) - - # Transform from data coordinates (scaled between xlim and ylim) to display coordinates - tr_figure = ax.transData.transform - # Transform from display to figure coordinates - tr_axes = fig.transFigure.inverted().transform - - # Select the size of the image (relative to the X axis) - x = 0 - for positions in pos_dict.values(): - if positions[0] > x: - x = positions[0] - - _, _ = ax.set_xlim(0, x) - _, _ = ax.set_ylim(0, 0.6) - icon_size = 0.2 - icon_center = icon_size / 2.0 - - # add a legend for the edge colors - markers_edges = [plt.Line2D([0,0],[0,0],color=color, linewidth=4, marker='_', linestyle='') for color in edges_cmap.values()] - markers_nodes = [plt.Line2D([0,0],[0,0],color=color, linewidth=2, marker='s', linestyle='') for color in nodes_cmap.values()] - markers_labels = ["Add", "Reactant 1", "Expand", "Reactant 2", "Merge", "Product"] - markers =[markers_edges[0], markers_nodes[0], markers_edges[1], markers_nodes[1], markers_edges[2], markers_nodes[2]] - plt.legend(markers, markers_labels, loc='upper center', - bbox_to_anchor=(0.5, 1.15), ncol=3, fancybox=True, shadow=True) - - # Add the respective image to each node - for n in G.nodes: - xf, yf = tr_figure(pos_dict[n]) - xa, ya = tr_axes((xf, yf)) - # get overlapped axes and plot icon - a = plt.axes([xa - icon_center, ya - icon_center, icon_size, icon_size]) - a.imshow(G.nodes[n]["image"]) - # add colored boxes around each node: - plt.gca().add_patch(Rectangle((0,0),295,295, linewidth=2, edgecolor=node_color_dict[n], facecolor="none")) - a.axis("off") - - ax.axis("off") - - # save the figure - plt.savefig(f"{tree_name}.png", dpi=100) - print(f"-- Tree saved in {tree_name}.png", flush=True) - - -if __name__ == '__main__': - - parser = argparse.ArgumentParser() - parser.add_argument("--file", type=str, default='/pool001/rociomer/test-data/synth_net/st_hb_test-plot-tests.json.gz', - help="Path/filename to the synthetic trees.") - parser.add_argument("--saveto", type=str, default='/pool001/rociomer/test-data/synth_net/images/', - help="Path to save the sketched synthetic trees.") - parser.add_argument("--nsketches", type=int, default=-1, - help="How many trees to sketch. Default -1 means to sketch all.") - parser.add_argument("--actions", type=int, default=-1, - help="How many actions the tree must have in order to sketch it (useful for testing).") - args = parser.parse_args() - - st_set = SyntheticTreeSet() - st_set.load(args.file) - data = st_set.sts - - trees_sketched = 0 - for st_idx, st in enumerate(data): - if len(st.actions) <= args.actions: - # don't sketch trees with fewer than n = `args.actions` actions - continue - try: - print("* Getting states and steps...") - states, steps = get_states_and_steps(synthetic_tree=st) - - print("* Sketching tree...") - draw_tree(states=states, steps=steps, tree_name=f"{args.saveto}tree{st_idx}") - - trees_sketched += 1 - - except Exception as e: - print(e) - continue - - if not (args.nsketches == -1) and trees_sketched > args.nsketches: - break - - print("Done!") diff --git a/scripts/st2steps.py b/scripts/st2steps.py deleted file mode 100644 index 892ae30e..00000000 --- a/scripts/st2steps.py +++ /dev/null @@ -1,86 +0,0 @@ -""" -Splits a synthetic tree into states and steps. -""" -import os -from tqdm import tqdm -from scipy import sparse -from syn_net.utils.data_utils import * -from syn_net.utils.prep_utils import organize - - - -if __name__ == '__main__': - - import argparse - parser = argparse.ArgumentParser() - parser.add_argument("-n", "--numbersave", type=int, default=999999999999, - help="Save number") - parser.add_argument("-v", "--verbose", action="store_true", default=False, - help="Increase output verbosity") - parser.add_argument("-e", "--targetembedding", type=str, default='fp', - help="Choose from ['fp', 'gin']") - parser.add_argument("-o", "--outputembedding", type=str, default='gin', - help="Choose from ['fp_4096', 'fp_256', 'gin', 'rdkit2d']") - parser.add_argument("-r", "--radius", type=int, default=2, - help="Radius for Morgan Fingerprint") - parser.add_argument("-b", "--nbits", type=int, default=4096, - help="Number of Bits for Morgan Fingerprint") - parser.add_argument("-d", "--datasettype", type=str, default='train', - help="Choose from ['train', 'valid', 'test']") - parser.add_argument("-r", "--rxn_template", type=str, default='hb', - help="Choose from ['hb', 'pis']") - args = parser.parse_args() - - dataset_type = args.datasettype - embedding = args.targetembedding - path_st = '/pool001/whgao/data/synth_net/st_hb/st_' + dataset_type + '.json.gz' - save_dir = '/pool001/whgao/data/synth_net/hb_' + embedding + '_' + str(args.radius) + '_' + str(args.nbits) + '_' + str(args.outputembedding) + '/' - - st_set = SyntheticTreeSet() - st_set.load(path_st) - print('Original length: ', len(st_set.sts)) - data = st_set.sts - del st_set - print('Working length: ', len(data)) - - states = [] - steps = [] - - num_save = args.numbersave - idx = 0 - save_idx = 0 - for st in tqdm(data): - try: - state, step = organize(st, target_embedding=embedding, radius=args.radius, nBits=args.nbits, output_embedding=args.outputembedding) - except Exception as e: - print(e) - continue - states.append(state) - steps.append(step) - idx += 1 - if idx % num_save == 0: - print('Saving......') - states = sparse.vstack(states) - steps = sparse.vstack(steps) - sparse.save_npz(save_dir + 'states_' + str(save_idx) + '_' + dataset_type + '.npz', states) - sparse.save_npz(save_dir + 'steps_' + str(save_idx) + '_' + dataset_type + '.npz', steps) - save_idx += 1 - del states - del steps - states = [] - steps = [] - - del data - - if len(steps) != 0: - states = sparse.vstack(states) - steps = sparse.vstack(steps) - - print('Saving......') - if not os.path.exists(save_dir): - os.makedirs(save_dir) - - sparse.save_npz(save_dir + 'states_' + str(save_idx) + '_' + dataset_type + '.npz', states) - sparse.save_npz(save_dir + 'steps_' + str(save_idx) + '_' + dataset_type + '.npz', steps) - - print('Finish!') diff --git a/scripts/st_split.py b/scripts/st_split.py deleted file mode 100644 index 60497304..00000000 --- a/scripts/st_split.py +++ /dev/null @@ -1,40 +0,0 @@ -""" -Reads synthetic tree data and splits it into training, validation and testing sets. -""" -from syn_net.utils.data_utils import * - - -if __name__ == "__main__": - - st_set = SyntheticTreeSet() - path_to_data = '/pool001/whgao/data/synth_net/st_pis/st_data_filtered.json.gz' - print('Reading data from ', path_to_data) - st_set.load(path_to_data) - data = st_set.sts - del st_set - num_total = len(data) - print("In total we have: ", num_total, "paths.") - - split_ratio = [0.6, 0.2, 0.2] - - num_train = int(split_ratio[0] * num_total) - num_valid = int(split_ratio[1] * num_total) - num_test = num_total - num_train - num_valid - - data_train = data[:num_train] - data_valid = data[num_train: num_train + num_valid] - data_test = data[num_train + num_valid: ] - - print("Saving training dataset: ", len(data_train)) - tree_set = SyntheticTreeSet(data_train) - tree_set.save('/pool001/whgao/data/synth_net/st_pis/st_train.json.gz') - - print("Saving validation dataset: ", len(data_valid)) - tree_set = SyntheticTreeSet(data_valid) - tree_set.save('/pool001/whgao/data/synth_net/st_pis/st_valid.json.gz') - - print("Saving testing dataset: ", len(data_test)) - tree_set = SyntheticTreeSet(data_test) - tree_set.save('/pool001/whgao/data/synth_net/st_pis/st_test.json.gz') - - print("Finish!") diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..f8480447 --- /dev/null +++ b/setup.py @@ -0,0 +1,20 @@ +import setuptools + +with open("README.md", "r", encoding="utf-8") as fh: + long_description = fh.read() + +setuptools.setup( + name="synnet", + version="0.1.0", + description="Synthetic tree generation using neural networks.", + long_description=long_description, + long_description_content_type="text/markdown", + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + ], + package_dir={"": "src"}, + packages=setuptools.find_packages(where="src",exclude=["src/synnet/encoding/gins.py"]), + python_requires=">=3.9", +) \ No newline at end of file diff --git a/src/synnet/MolEmbedder.py b/src/synnet/MolEmbedder.py new file mode 100644 index 00000000..105962c3 --- /dev/null +++ b/src/synnet/MolEmbedder.py @@ -0,0 +1,87 @@ +import logging +from pathlib import Path +from typing import Callable, Union + +import numpy as np +from sklearn.neighbors import BallTree + +from synnet.config import MAX_PROCESSES + +logger = logging.getLogger(__name__) + + +class MolEmbedder: + def __init__(self, processes: int = MAX_PROCESSES) -> None: + self.processes = processes + self.func: Callable + self.building_blocks: Union[list[str], np.ndarray] + self.embeddings: np.ndarray + self.kdtree: BallTree + self.kdtree_metric: str + + def get_embeddings(self) -> np.ndarray: + """Returns `self.embeddings` as 2d-array.""" + return np.atleast_2d(self.embeddings) + + def _compute_mp(self, data): + from pathos import multiprocessing as mp + + with mp.Pool(processes=self.processes) as pool: + embeddings = pool.map(self.func, data) + return embeddings + + def compute_embeddings(self, func: Callable, building_blocks: list[str]): + logger.info(f"Will compute embedding with {self.processes} processes.") + self.func = func + if self.processes == 1: + embeddings = list(map(self.func, building_blocks)) + else: + embeddings = self._compute_mp(building_blocks) + logger.info(f"Computed embeddings.") + self.embeddings = embeddings + return self + + def _save_npy(self, file: str): + if self.embeddings is None: + raise ValueError("Must have computed embeddings to save.") + + embeddings = np.asarray(self.embeddings) # assume at least 2d + np.save(file, embeddings) + logger.info(f"Successfully saved data (shape={embeddings.shape}) to {file} .") + return self + + def save_precomputed(self, file: str): + """Saves pre-computed molecule embeddings to `*.npy`""" + file = Path(file) + file.parent.mkdir(parents=True, exist_ok=True) + if file.suffixes == [".npy"]: + self._save_npy(file) + else: + raise NotImplementedError(f"File must have 'npy' extension, not {file.suffixes}") + return self + + def _load_npy(self, file: Path): + return np.load(file) + + def load_precomputed(self, file: str): + """Loads a pre-computed molecule embeddings from `*.npy`""" + file = Path(file) + if file.suffixes == [".npy"]: + self.embeddings = self._load_npy(file) + self.kdtree = None + else: + raise NotImplementedError + return self + + def init_balltree(self, metric: Union[Callable, str]): + """Initializes a `BallTree`. + + Note: + Can take a couple of minutes.""" + if self.embeddings is None: + raise ValueError("Need emebddings to compute kdtree.") + X = self.embeddings + self.kdtree_metric = metric.__name__ if not isinstance(metric, str) else metric + self.kdtree = BallTree(X, metric=metric) + + return self diff --git a/src/synnet/__init__.py b/src/synnet/__init__.py new file mode 100644 index 00000000..2d371afe --- /dev/null +++ b/src/synnet/__init__.py @@ -0,0 +1,11 @@ +import logging + +logging.basicConfig( + format="%(asctime)s %(name)s %(levelname)s: %(message)s", + datefmt="%H:%M:%S", + handlers=[logging.StreamHandler()], + # handlers=[logging.FileHandler(".log"),logging.StreamHandler()], + level=logging.INFO, +) + +logger = logging.getLogger(__name__) diff --git a/src/synnet/config.py b/src/synnet/config.py new file mode 100644 index 00000000..977f72b0 --- /dev/null +++ b/src/synnet/config.py @@ -0,0 +1,19 @@ +"""Central place for all configuration, paths, and parameter.""" +import multiprocessing + +# Multiprocessing +MAX_PROCESSES = min(32, multiprocessing.cpu_count()) - 1 + +# TODO: Remove these paths bit by bit + +# Pre-processed data +DATA_PREPROCESS_DIR = "data/pre-process" + +# Prepared data +DATA_FEATURIZED_DIR = "data/featurized" + +# Results +DATA_RESULT_DIR = "results" + +# Checkpoints (& pre-trained weights) +CHECKPOINTS_DIR = "checkpoints" diff --git a/syn_net/__init__.py b/src/synnet/data_generation/__init__.py similarity index 100% rename from syn_net/__init__.py rename to src/synnet/data_generation/__init__.py diff --git a/syn_net/data_generation/check_all_template.py b/src/synnet/data_generation/check_all_template.py similarity index 77% rename from syn_net/data_generation/check_all_template.py rename to src/synnet/data_generation/check_all_template.py index 03d21e37..5542701a 100644 --- a/syn_net/data_generation/check_all_template.py +++ b/src/synnet/data_generation/check_all_template.py @@ -3,13 +3,12 @@ templates. Originally written by Jake. Wenhao edited. """ import rdkit.Chem as Chem -from rdkit.Chem import AllChem -from rdkit.Chem import rdChemReactions from rdkit import RDLogger +from rdkit.Chem import AllChem, rdChemReactions def split_rxn_parts(rxn): - ''' + """ Given SMILES reaction, splits into reactants, agents, and products Args: @@ -17,11 +16,11 @@ def split_rxn_parts(rxn): Returns: list: Contains sets of reactants, agents, and products as RDKit molecules. - ''' - rxn_parts = rxn.strip().split('>') - rxn_reactants = set(rxn_parts[0].split('.')) - rxn_agents = None if not rxn_parts[1] else set(rxn_parts[1].split('.')) - rxn_products = set(rxn_parts[2].split('.')) + """ + rxn_parts = rxn.strip().split(">") + rxn_reactants = set(rxn_parts[0].split(".")) + rxn_agents = None if not rxn_parts[1] else set(rxn_parts[1].split(".")) + rxn_products = set(rxn_parts[2].split(".")) reactants, agents, products = set(), set(), set() @@ -42,7 +41,7 @@ def split_rxn_parts(rxn): def rxn_template(rxn_smiles, templates): - ''' + """ Given a reaction, checks whether it matches any templates. Args: @@ -51,7 +50,7 @@ def rxn_template(rxn_smiles, templates): Returns: str: Matching template name. If no templates matched, returns None. - ''' + """ rxn_parts = split_rxn_parts(rxn_smiles) reactants, agents, products = rxn_parts[0], rxn_parts[1], rxn_parts[2] temp_match = None @@ -92,7 +91,7 @@ def rxn_template(rxn_smiles, templates): def route_templates(route, templates): - ''' + """ Given synthesis route, checks whether all reaction steps are in template list Args: @@ -102,7 +101,7 @@ def route_templates(route, templates): Returns: List of matching template names (as strings). If no templates matched, returns empty list. - ''' + """ synth_route = [] tree_match = True for rxn_step in route: @@ -116,32 +115,33 @@ def route_templates(route, templates): return synth_route -if __name__ == '__main__': + +if __name__ == "__main__": disable_RDLogger = True # disables RDKit warnings if disable_RDLogger: - RDLogger.DisableLog('rdApp.*') + RDLogger.DisableLog("rdApp.*") - rxn_set_path = '/path/to/rxn_set.txt' + rxn_set_path = "/path/to/rxn_set.txt" - rxn_set = open(rxn_set_path, 'r') + rxn_set = open(rxn_set_path, "r") templates = {} for rxn in rxn_set: - rxn_name = rxn.split('|')[0] - template = rxn.split('|')[1].strip() + rxn_name = rxn.split("|")[0] + template = rxn.split("|")[1].strip() rdkit_rxn = AllChem.ReactionFromSmarts(template) rdChemReactions.ChemicalReaction.Initialize(rdkit_rxn) templates[rdkit_rxn] = rxn_name - rxn_smiles = 'ClCC1CO1.NC(=O)Cc1ccc(O)cc1>>NC(=O)Cc1ccc(OCC2CO2)cc1' + rxn_smiles = "ClCC1CO1.NC(=O)Cc1ccc(O)cc1>>NC(=O)Cc1ccc(OCC2CO2)cc1" print(rxn_smiles) print(rxn_template(rxn_smiles, templates)) - print('------------------------------------------------------') + print("------------------------------------------------------") synthesis_route = [ - 'C(CCc1ccccc1)N(Cc1ccccc1)CC(O)c1ccc(O)c(C(N)=O)c1>>CC(CCc1ccccc1)NCC(O)c1ccc(O)c(C(N)=O)c1', - 'CC(CCc1ccccc1)N(CC(=O)c1ccc(O)c(C(N)=O)c1)Cc1ccccc1>>CC(CCc1ccccc1)N(Cc1ccccc1)CC(O)c1ccc(O)c(C(N)=O)c1', - 'CC(CCc1ccccc1)NCc1ccccc1.NC(=O)c1cc(C(=O)CBr)ccc1O>>CC(CCc1ccccc1)N(CC(=O)c1ccc(O)c(C(N)=O)c1)Cc1ccccc1' + "C(CCc1ccccc1)N(Cc1ccccc1)CC(O)c1ccc(O)c(C(N)=O)c1>>CC(CCc1ccccc1)NCC(O)c1ccc(O)c(C(N)=O)c1", + "CC(CCc1ccccc1)N(CC(=O)c1ccc(O)c(C(N)=O)c1)Cc1ccccc1>>CC(CCc1ccccc1)N(Cc1ccccc1)CC(O)c1ccc(O)c(C(N)=O)c1", + "CC(CCc1ccccc1)NCc1ccccc1.NC(=O)c1cc(C(=O)CBr)ccc1O>>CC(CCc1ccccc1)N(CC(=O)c1ccc(O)c(C(N)=O)c1)Cc1ccccc1", ] print(synthesis_route) print(route_templates(synthesis_route, templates)) diff --git a/src/synnet/data_generation/preprocessing.py b/src/synnet/data_generation/preprocessing.py new file mode 100644 index 00000000..e800a749 --- /dev/null +++ b/src/synnet/data_generation/preprocessing.py @@ -0,0 +1,136 @@ +from pathlib import Path + +from tqdm import tqdm + +from synnet.config import MAX_PROCESSES +from synnet.utils.data_utils import Reaction + + +class BuildingBlockFilter: + """Filter building blocks.""" + + building_blocks: list[str] + building_blocks_filtered: list[str] + rxn_templates: list[str] + rxns: list[Reaction] + rxns_initialised: bool + + def __init__( + self, + *, + building_blocks: list[str], + rxn_templates: list[str], + processes: int = MAX_PROCESSES, + verbose: bool = False + ) -> None: + self.building_blocks = building_blocks + self.rxn_templates = rxn_templates + + # Init reactions + self.rxns = [Reaction(template=template) for template in self.rxn_templates] + # Init other stuff + self.processes = processes + self.verbose = verbose + self.rxns_initialised = False + + def _match_mp(self): + from functools import partial + + from pathos import multiprocessing as mp + + def __match(bblocks: list[str], _rxn: Reaction): + return _rxn.set_available_reactants(bblocks) + + func = partial(__match, self.building_blocks) + with mp.Pool(processes=self.processes) as pool: + self.rxns = pool.map(func, self.rxns) + return self + + def _init_rxns_with_reactants(self): + """Initializes a `Reaction` with a list of possible reactants. + + Info: This can take a while for lots of possible reactants.""" + self.rxns = tqdm(self.rxns) if self.verbose else self.rxns + if self.processes == 1: + self.rxns = [rxn.set_available_reactants(self.building_blocks) for rxn in self.rxns] + else: + self._match_mp() + + self.rxns_initialised = True + return self + + def filter(self): + """Filters out building blocks which do not match a reaction template.""" + if not self.rxns_initialised: + self = self._init_rxns_with_reactants() + matched_bblocks = {x for rxn in self.rxns for x in rxn.get_available_reactants} + self.building_blocks_filtered = list(matched_bblocks) + return self + + +class BuildingBlockFileHandler: + def _load_csv(self, file: str) -> list[str]: + """Load building blocks as smiles from `*.csv` or `*.csv.gz`.""" + import pandas as pd + + return pd.read_csv(file)["SMILES"].to_list() + + def load(self, file: str) -> list[str]: + """Load building blocks from file.""" + file = Path(file) + if ".csv" in file.suffixes: + return self._load_csv(file) + else: + raise NotImplementedError + + def _save_csv(self, file: Path, building_blocks: list[str]): + """Save building blocks to `*.csv.gz`""" + import pandas as pd + + # remove possible 1 or more extensions, i.e. + # .csv OR .csv.gz --> + file_no_ext = file.parent / file.stem.split(".")[0] + file = (file_no_ext).with_suffix(".csv.gz") + # Save + df = pd.DataFrame({"SMILES": building_blocks}) + df.to_csv(file, compression="gzip") + return None + + def save(self, file: str, building_blocks: list[str]): + """Save building blocks to file.""" + file = Path(file) + file.parent.mkdir(parents=True, exist_ok=True) + if ".csv" in file.suffixes: + self._save_csv(file, building_blocks) + else: + raise NotImplementedError + + +class ReactionTemplateFileHandler: + def load(self, file: str) -> list[str]: + """Load reaction templates from file.""" + with open(file, "rt") as f: + rxn_templates = f.readlines() + + rxn_templates = [tmplt.strip() for tmplt in rxn_templates] + + if not all([self._validate(t)] for t in rxn_templates): + raise ValueError("Not all reaction templates are valid.") + + return rxn_templates + + def _validate(self, rxn_template: str) -> bool: + """Validate reaction templates. + + Checks if: + - reaction is uni- or bimolecular + - has only a single product + + Note: + - only uses std-lib functions, very basic validation only + """ + reactants, agents, products = rxn_template.split(">") + is_uni_or_bimolecular = len(reactants) == 1 or len(reactants) == 2 + has_single_product = len(products) == 1 + + return is_uni_or_bimolecular and has_single_product diff --git a/src/synnet/data_generation/syntrees.py b/src/synnet/data_generation/syntrees.py new file mode 100644 index 00000000..5e071aea --- /dev/null +++ b/src/synnet/data_generation/syntrees.py @@ -0,0 +1,487 @@ +"""syntrees +""" +import logging +from typing import Tuple, Union + +import numpy as np +from rdkit import Chem +from scipy import sparse +from tqdm import tqdm + +from synnet.config import MAX_PROCESSES + +logger = logging.getLogger(__name__) + +from synnet.utils.data_utils import Reaction, SyntheticTree + + +class NoReactantAvailableError(Exception): + """No second reactant available for the bimolecular reaction.""" + + def __init__(self, message): + super().__init__(message) + + +class NoReactionAvailableError(Exception): + """Reactant does not match any reaction template.""" + + def __init__(self, message): + super().__init__(message) + + +class NoBiReactionAvailableError(Exception): + """Reactants do not match any reaction template.""" + + def __init__(self, message): + super().__init__(message) + + +class NoReactionPossibleError(Exception): + """`rdkit` can not yield a valid reaction product.""" + + def __init__(self, message): + super().__init__(message) + + +class MaxDepthError(Exception): + """Synthetic Tree has exceeded its maximum depth.""" + + def __init__(self, message): + super().__init__(message) + + +class SynTreeGenerator: + + building_blocks: list[str] + rxn_templates: list[Reaction] + rxns: list[Reaction] + IDX_RXNS: np.ndarray # (nReactions,) + ACTIONS: dict[int, str] = {i: action for i, action in enumerate("add expand merge end".split())} + verbose: bool + + def __init__( + self, + *, + building_blocks: list[str], + rxn_templates: list[str], + rng=np.random.default_rng(), # TODO: Think about this... + processes: int = MAX_PROCESSES, + verbose: bool = False, + ) -> None: + self.building_blocks = building_blocks + self.rxn_templates = rxn_templates + self.rxns = [Reaction(template=tmplt) for tmplt in rxn_templates] + self.rng = rng + self.IDX_RXNS = np.arange(len(self.rxns)) + self.processes = processes + self.verbose = verbose + if not verbose: + logger.setLevel("CRITICAL") # dont show error msgs + + # Time intensive tasks + self._init_rxns_with_reactants() + + def __match_mp(self): + # TODO: refactor / merge with `BuildingBlockFilter` + # TODO: Rename `ReactionSet` -> `ReactionCollection` (same for `SyntheticTreeSet`) + # `Reaction` as "datacls", `*Collection` as cls that encompasses operations on "data"? + # Third class simpyl for file I/O or include somewhere? + from functools import partial + + from pathos import multiprocessing as mp + + def __match(bblocks: list[str], _rxn: Reaction): + return _rxn.set_available_reactants(bblocks) + + func = partial(__match, self.building_blocks) + with mp.Pool(processes=self.processes) as pool: + rxns = pool.map(func, self.rxns) + + self.rxns = rxns + return self + + def _init_rxns_with_reactants(self): + """Initializes a `Reaction` with a list of possible reactants. + + Info: This can take a while for lots of possible reactants.""" + self.rxns = tqdm(self.rxns) if self.verbose else self.rxns + if self.processes == 1: + self.rxns = [rxn.set_available_reactants(self.building_blocks) for rxn in self.rxns] + else: + self.__match_mp() + + self.rxns_initialised = True + return self + + def _sample_molecule(self) -> str: + """Sample a molecule.""" + idx = self.rng.choice(len(self.building_blocks)) + smiles = self.building_blocks[idx] + logger.debug(f" Sampled molecule: {smiles}") + return smiles + + def _base_case(self) -> str: + return self._sample_molecule() + + def _find_rxn_candidates(self, smiles: str, raise_exc: bool = True) -> list[bool]: + """Tests which reactions have `mol` as reactant.""" + mol = Chem.MolFromSmiles(smiles) + rxn_mask = [rxn.is_reactant(mol) for rxn in self.rxns] + if raise_exc and not any(rxn_mask): # Do not raise exc when checking if two mols can react + raise NoReactionAvailableError(f"No reaction available for: {smiles}.") + return rxn_mask + + def _sample_rxn(self, mask: np.ndarray = None) -> Tuple[Reaction, int]: + """Sample a reaction by index.""" + if mask is None: + irxn_mask = self.IDX_RXNS # All reactions are possible + else: + mask = np.asarray(mask) + irxn_mask = self.IDX_RXNS[mask] + idx = self.rng.choice(irxn_mask) + logger.debug( + f"Sampled reaction with index: {idx} (nreactants: {self.rxns[idx].num_reactant})" + ) + return self.rxns[idx], idx + + def _expand(self, reactant_1: str) -> Tuple[str, str, str, np.int64]: + """Expand a sub-tree from one molecule. + This can result in uni- or bimolecular reaction.""" + + # Identify applicable reactions + rxn_mask = self._find_rxn_candidates(reactant_1) + + # Sample reaction (by index) + rxn, idx_rxn = self._sample_rxn(mask=rxn_mask) + + # Sample 2nd reactant + if rxn.num_reactant == 1: + reactant_2 = None + else: + # Sample a molecule from the available reactants of this reaction + # That is, for a reaction A + B -> C, + # - determine if we have "A" or "B" + # - then sample "B" (or "A") + idx = 1 if rxn.is_reactant_first(reactant_1) else 0 + available_reactants = rxn.available_reactants[idx] + nPossibleReactants = len(available_reactants) + if nPossibleReactants == 0: + raise NoReactantAvailableError( + f"Unable to find reactant {idx+1} for bimolecular reaction (ID: {idx_rxn}) and reactant {reactant_1}." + ) + # TODO: 2 bi-molecular rxn templates have no matching bblock + # TODO: use numpy array to avoid type conversion or stick to sampling idx? + idx = self.rng.choice(nPossibleReactants) + reactant_2 = available_reactants[idx] + + # Run reaction + reactants = (reactant_1, reactant_2) + product = rxn.run_reaction(reactants) + return *reactants, product, idx_rxn + + def _get_action_mask(self, syntree: SyntheticTree): + """Get a mask of possible action for a SyntheticTree""" + # Recall: (Add, Expand, Merge, and End) + canAdd = False + canMerge = False + canExpand = False + canEnd = False + + state = syntree.get_state() + nTrees = len(state) + if nTrees == 0: + canAdd = True + elif nTrees == 1: + canAdd = True + canExpand = True + canEnd = True # TODO: When syntree has reached max depth, only allow to end it. + elif nTrees == 2: + canExpand = True + canMerge = any(self._get_rxn_mask(tuple(state))) + else: + raise ValueError + + return np.array((canAdd, canExpand, canMerge, canEnd), dtype=bool) + + def _get_rxn_mask(self, reactants: tuple[str, str]) -> list[bool]: + """Get a mask of possible reactions for the two reactants.""" + masks = [self._find_rxn_candidates(r, raise_exc=False) for r in reactants] + # TODO: We do not check if the two reactants are 1st and 2nd reactants in a given reaction. + # It is possible that both are only applicable as 1st reactant, + # and then the reaction is not possible, although the mask returns true. + # Alternative: Run the reaction and check if the product is valid. + mask = [rxn1 and rxn2 for rxn1, rxn2 in zip(*masks)] + if not any(mask): + raise NoBiReactionAvailableError(f"No reaction available for {reactants}.") + return mask + + def generate(self, max_depth: int = 15, retries: int = 3): + """Generate a syntree by random sampling.""" + + # Init + logger.debug(f"Starting synthetic tree generation with {max_depth=} ") + syntree = SyntheticTree() + recent_mol = self._sample_molecule() # root of the current tree + + for i in range(max_depth): + logger.debug(f"Iteration {i}") + + # State of syntree + state = syntree.get_state() + + # Sample action + p_action = self.rng.random((1, 4)) # (1,4) + action_mask = self._get_action_mask(syntree) # (1,4) + act = np.argmax(p_action * action_mask) # (1,) + action = self.ACTIONS[act] + logger.debug(f" Sampled action: {action}") + + if action == "end": + r1, r2, p, idx_rxn = None, None, None, -1 + elif action == "expand": + for j in range(retries): + logger.debug(f" Try {j}") + r1, r2, p, idx_rxn = self._expand(recent_mol) + if p is not None: + break + if p is None: + # TODO: move to rxn.run_reaction? + raise NoReactionPossibleError( + f"Reaction (ID: {idx_rxn}) not possible with: {r1} + {r2}." + ) + elif action == "add": + mol = self._sample_molecule() + r1, r2, p, idx_rxn = self._expand(mol) + # Expand this subtree: reactant, reaction, reactant2 + elif action == "merge": + # merge two subtrees: sample reaction, run it. + + # Identify suitable rxn + r1, r2 = syntree.get_state() + rxn_mask = self._get_rxn_mask(tuple((r1, r2))) + # Sample reaction + rxn, idx_rxn = self._sample_rxn(mask=rxn_mask) + # Run reaction + p = rxn.run_reaction((r1, r2)) + if p is None: + # TODO: move to rxn.run_reaction? + raise NoReactionPossibleError( + f"Reaction (ID: {idx_rxn}) not possible with: {r1} + {r2}." + ) + else: + raise ValueError(f"Invalid action {action}") + + # Prepare next iteration + logger.debug(f" Ran reaction {r1} + {r2} -> {p}") + + recent_mol = p + + # Update tree + syntree.update(act, rxn_id=int(idx_rxn), mol1=r1, mol2=r2, mol_product=p) + logger.debug(f"SynTree updated.") + if action == "end": + break + + if i == max_depth - 1 and not action == "end": + raise MaxDepthError("Maximum depth {max_depth} exceeded.") + logger.debug(f"🙌 SynTree completed.") + return syntree + + +def wraps_syntreegenerator_generate( + stgen: SynTreeGenerator, +) -> Tuple[Union[SyntheticTree, None], Union[Exception, None]]: + """Wrapper for `SynTreeGenerator().generate` that catches all Exceptions.""" + try: + st = stgen.generate() + except NoReactantAvailableError as e: + logger.error(e) + return None, e + except NoReactionAvailableError as e: + logger.error(e) + return None, e + except NoBiReactionAvailableError as e: + logger.error(e) + return None, e + except NoReactionPossibleError as e: + logger.error(e) + return None, e + except TypeError as e: + # When converting an invalid molecule from SMILES to rdkit Molecule. + # This happens if the reaction template/rdkit produces an invalid product. + logger.error(e) + return None, e + except Exception as e: + logger.error(e, exc_info=e, stack_info=False) + return None, e + else: + return st, None + + +def load_syntreegenerator(file: str) -> SynTreeGenerator: + import pickle + + with open(file, "rb") as f: + syntreegenerator = pickle.load(f) + return syntreegenerator + + +def save_syntreegenerator(syntreegenerator: SynTreeGenerator, file: str) -> None: + import pickle + + with open(file, "wb") as f: + pickle.dump(syntreegenerator, f) + + +# TODO: Move all these encoders to "from syn_net.encoding/" +# TODO: Evaluate if One-Hot-Encoder can be replaced with encoder from sklearn + +from abc import ABC, abstractmethod + + +class Encoder(ABC): + @abstractmethod + def encode(self, *args, **kwargs): + ... + + def __repr__(self) -> str: + return f"'{self.__class__.__name__}': {self.__dict__}" + + +class OneHotEncoder(Encoder): + def __init__(self, d: int) -> None: + self.d = d + + def encode(self, ind: int, datatype: np.dtype = np.float64) -> np.ndarray: + """Returns a (1,d)-array with zeros and a 1 at index `ind`.""" + onehot = np.zeros((1, self.d), dtype=datatype) # (1,d) + onehot[0, ind] = 1.0 + return onehot # (1,d) + + +class MorganFingerprintEncoder(Encoder): + def __init__(self, radius: int, nbits: int) -> None: + self.radius = radius + self.nbits = nbits + + def encode(self, smi: str) -> np.ndarray: + if smi is None: + fp = np.zeros((1, self.nbits)) # (1,d) + else: + mol = Chem.MolFromSmiles(smi) # TODO: sanity check mol here or use datmol? + bv = Chem.AllChem.GetMorganFingerprintAsBitVect(mol, self.radius, self.nbits) + fp = np.empty(self.nbits) + Chem.DataStructs.ConvertToNumpyArray(bv, fp) + fp = fp[None, :] + return fp # (1,d) + + +class IdentityIntEncoder(Encoder): + def __init__(self) -> None: + pass + + def encode(self, number: int): + return np.atleast_2d(number) + + +class SynTreeFeaturizer: + def __init__( + self, + *, + reactant_embedder: Encoder, + mol_embedder: Encoder, + rxn_embedder: Encoder, + action_embedder: Encoder, + ) -> None: + # Embedders + self.reactant_embedder = reactant_embedder + self.mol_embedder = mol_embedder + self.rxn_embedder = rxn_embedder + self.action_embedder = action_embedder + + def __repr__(self) -> str: + return f"{self.__dict__}" + + def featurize(self, syntree: SyntheticTree): + """Featurize a synthetic tree at every state. + + Note: + - At each iteration of the syntree growth, an action is chosen + - Every action (except "end") comes with a reaction. + - For every action, we compute: + - a "state" + - a "step", a vector that encompasses all info we need for training the neural nets. + This step is: [action, z_rt1, reaction_id, z_rt2, z_root_mol_1] + """ + + states, steps = [], [] + + target_mol = syntree.root.smiles + z_target_mol = self.mol_embedder.encode(target_mol) + + # Recall: We can have at most 2 sub-trees, each with a root node. + root_mol_1 = None + root_mol_2 = None + for i, action in enumerate(syntree.actions): + + # 1. Encode "state" + z_root_mol_1 = self.mol_embedder.encode(root_mol_1) + z_root_mol_2 = self.mol_embedder.encode(root_mol_2) + state = np.concatenate((z_root_mol_1, z_root_mol_2, z_target_mol), axis=1) # (1,3d) + + # 2. Encode "super"-step + if action == 3: # end + step = np.concatenate( + ( + self.action_embedder.encode(action), + self.reactant_embedder.encode(mol1), + self.rxn_embedder.encode(rxn_node.rxn_id), + self.reactant_embedder.encode(mol2), + self.mol_embedder.encode(mol1), + ), + axis=1, + ) + else: + rxn_node = syntree.reactions[i] + + if len(rxn_node.child) == 1: + mol1 = rxn_node.child[0] + mol2 = None + elif len(rxn_node.child) == 2: + mol1 = rxn_node.child[0] + mol2 = rxn_node.child[1] + else: # TODO: Change `child` is stored in reaction node so we can just unpack via * + raise ValueError() + + step = np.concatenate( + ( + self.action_embedder.encode(action), + self.reactant_embedder.encode(mol1), + self.rxn_embedder.encode(rxn_node.rxn_id), + self.reactant_embedder.encode(mol2), + self.mol_embedder.encode(mol1), + ), + axis=1, + ) + + # 3. Prepare next iteration + if action == 2: # merge + root_mol_1 = rxn_node.parent + root_mol_2 = None + + elif action == 1: # expand + root_mol_1 = rxn_node.parent + + elif action == 0: # add + root_mol_2 = root_mol_1 + root_mol_1 = rxn_node.parent + + # 4. Keep track of data + states.append(state) + steps.append(step) + + # Some housekeeping on dimensions + states = np.atleast_2d(np.asarray(states).squeeze()) + steps = np.atleast_2d(np.asarray(steps).squeeze()) + + return sparse.csc_matrix(states), sparse.csc_matrix(steps) diff --git a/src/synnet/encoding/distances.py b/src/synnet/encoding/distances.py new file mode 100644 index 00000000..fd5b5e92 --- /dev/null +++ b/src/synnet/encoding/distances.py @@ -0,0 +1,59 @@ +import numpy as np + +from synnet.encoding.fingerprints import mol_fp + + +def cosine_distance(v1, v2): + """Compute the cosine distance between two 1d-vectors. + + Note: + cosine_similarity = x'y / (||x|| ||y||) in [-1,1] + cosine_distance = 1 - cosine_similarity in [0,2] + """ + return max(0, min(1 - np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2)), 2)) + + +def ce_distance(y, y_pred, eps=1e-15): + """Computes the cross-entropy between two vectors. + + Args: + y (np.ndarray): First vector. + y_pred (np.ndarray): Second vector. + eps (float, optional): Small value, for numerical stability. Defaults + to 1e-15. + + Returns: + float: The cross-entropy. + """ + y_pred = np.clip(y_pred, eps, 1 - eps) + return -np.sum((y * np.log(y_pred) + (1 - y) * np.log(1 - y_pred))) + + +def _tanimoto_similarity(fp1: np.ndarray, fp2: np.ndarray): + """ + Returns the Tanimoto similarity between two molecular fingerprints. + + Args: + fp1 (np.ndarray): Molecular fingerprint 1. + fp2 (np.ndarray): Molecular fingerprint 2. + + Returns: + np.float: Tanimoto similarity. + """ + return np.sum(fp1 * fp2) / (np.sum(fp1) + np.sum(fp2) - np.sum(fp1 * fp2)) + + +def tanimoto_similarity(target_fp: np.ndarray, smis: list[str]): + """ + Returns the Tanimoto similarities between a target fingerprint and molecules + in an input list of SMILES. + + Args: + target_fp (np.ndarray): Contains the reference (target) fingerprint. + smis (list of str): Contains SMILES to compute similarity to. + + Returns: + list of np.ndarray: Contains Tanimoto similarities. + """ + fps = [mol_fp(smi, 2, 4096) for smi in smis] + return [_tanimoto_similarity(target_fp, fp) for fp in fps] diff --git a/src/synnet/encoding/fingerprints.py b/src/synnet/encoding/fingerprints.py new file mode 100644 index 00000000..66501af4 --- /dev/null +++ b/src/synnet/encoding/fingerprints.py @@ -0,0 +1,69 @@ +import numpy as np +from rdkit import Chem +from rdkit.Chem import AllChem, DataStructs + + +## Morgan fingerprints +def mol_fp(smi, _radius=2, _nBits=4096) -> np.ndarray: # dtype=int64 + """ + Computes the Morgan fingerprint for the input SMILES. + + Args: + smi (str): SMILES for molecule to compute fingerprint for. + _radius (int, optional): Fingerprint radius to use. Defaults to 2. + _nBits (int, optional): Length of fingerprint. Defaults to 1024. + + Returns: + features (np.ndarray): For valid SMILES, this is the fingerprint. + Otherwise, if the input SMILES is bad, this will be a zero vector. + """ + if smi is None: + return np.zeros(_nBits) + else: + mol = Chem.MolFromSmiles(smi) + features_vec = Chem.AllChem.GetMorganFingerprintAsBitVect(mol, _radius, _nBits) + return np.array( + features_vec + ) # TODO: much slower compared to `DataStructs.ConvertToNumpyArray` (20x?) so deprecates + + +def fp_embedding(smi, _radius=2, _nBits=4096) -> list[float]: + """ + General function for building variable-size & -radius Morgan fingerprints. + + Args: + smi (str): The SMILES to encode. + _radius (int, optional): Morgan fingerprint radius. Defaults to 2. + _nBits (int, optional): Morgan fingerprint length. Defaults to 4096. + + Returns: + np.ndarray: A Morgan fingerprint generated using the specified parameters. + """ + if smi is None: + return np.zeros(_nBits).reshape((-1,)).tolist() + else: + mol = Chem.MolFromSmiles(smi) + features_vec = AllChem.GetMorganFingerprintAsBitVect(mol, _radius, _nBits) + features = np.zeros((1,)) + DataStructs.ConvertToNumpyArray(features_vec, features) + return features.reshape((-1,)).tolist() + + +def fp_4096(smi): + return fp_embedding(smi, _radius=2, _nBits=4096) + + +def fp_2048(smi): + return fp_embedding(smi, _radius=2, _nBits=2048) + + +def fp_1024(smi): + return fp_embedding(smi, _radius=2, _nBits=1024) + + +def fp_512(smi): + return fp_embedding(smi, _radius=2, _nBits=512) + + +def fp_256(smi): + return fp_embedding(smi, _radius=2, _nBits=256) diff --git a/src/synnet/encoding/gins.py b/src/synnet/encoding/gins.py new file mode 100644 index 00000000..a95a1d1e --- /dev/null +++ b/src/synnet/encoding/gins.py @@ -0,0 +1,169 @@ +import functools + +import numpy as np +import torch +import tqdm +from dgl.nn.pytorch.glob import AvgPooling +from dgllife.model import load_pretrained +from dgllife.utils import PretrainAtomFeaturizer, PretrainBondFeaturizer, mol_to_bigraph +from rdkit import Chem + + +@functools.lru_cache(1) +def _fetch_gin_pretrained_model(model_name: str): + """Get a GIN pretrained model to use for creating molecular embeddings""" + device = "cpu" + model = load_pretrained(model_name).to(device) # used to learn embedding + model.eval() + return model + + +def graph_construction_and_featurization(smiles): + """ + Constructs graphs from SMILES and featurizes them. + + Args: + smiles (list of str): Contains SMILES of molecules to embed. + + Returns: + graphs (list of DGLGraph): List of graphs constructed and featurized. + success (list of bool): Indicators for whether the SMILES string can be + parsed by RDKit. + """ + graphs = [] + success = [] + for smi in tqdm(smiles): + try: + mol = Chem.MolFromSmiles(smi) + if mol is None: + success.append(False) + continue + g = mol_to_bigraph( + mol, + add_self_loop=True, + node_featurizer=PretrainAtomFeaturizer(), + edge_featurizer=PretrainBondFeaturizer(), + canonical_atom_order=False, + ) + graphs.append(g) + success.append(True) + except: + success.append(False) + + return graphs, success + + +def mol_embedding(smi, device="cpu", readout=AvgPooling()): + """ + Constructs a graph embedding using the GIN network for an input SMILES. + + Args: + smi (str): A SMILES string. + device (str): Indicates the device to run on ('cpu' or 'cuda:0'). Default 'cpu'. + + Returns: + np.ndarray: Either a zeros array or the graph embedding. + """ + name = "gin_supervised_contextpred" + gin_pretrained_model = _fetch_gin_pretrained_model(name) + + # get the embedding + if smi is None: + return np.zeros(300) + else: + mol = Chem.MolFromSmiles(smi) + # convert RDKit.Mol into featurized bi-directed DGLGraph + g = mol_to_bigraph( + mol, + add_self_loop=True, + node_featurizer=PretrainAtomFeaturizer(), + edge_featurizer=PretrainBondFeaturizer(), + canonical_atom_order=False, + ) + bg = g.to(device) + nfeats = [ + bg.ndata.pop("atomic_number").to(device), + bg.ndata.pop("chirality_type").to(device), + ] + efeats = [ + bg.edata.pop("bond_type").to(device), + bg.edata.pop("bond_direction_type").to(device), + ] + with torch.no_grad(): + node_repr = gin_pretrained_model(bg, nfeats, efeats) + return ( + readout(bg, node_repr) + .detach() + .cpu() + .numpy() + .reshape( + -1, + ) + .tolist() + ) + + +def get_mol_embedding(smi, model, device="cpu", readout=AvgPooling()): + """ + Computes the molecular graph embedding for the input SMILES. + + Args: + smi (str): SMILES of molecule to embed. + model (dgllife.model, optional): Pre-trained NN model to use for + computing the embedding. + device (str, optional): Indicates the device to run on. Defaults to 'cpu'. + readout (dgl.nn.pytorch.glob, optional): Readout function to use for + computing the graph embedding. Defaults to readout. + + Returns: + torch.Tensor: Learned embedding for the input molecule. + """ + mol = Chem.MolFromSmiles(smi) + g = mol_to_bigraph( + mol, + add_self_loop=True, + node_featurizer=PretrainAtomFeaturizer(), + edge_featurizer=PretrainBondFeaturizer(), + canonical_atom_order=False, + ) + bg = g.to(device) + nfeats = [bg.ndata.pop("atomic_number").to(device), bg.ndata.pop("chirality_type").to(device)] + efeats = [bg.edata.pop("bond_type").to(device), bg.edata.pop("bond_direction_type").to(device)] + with torch.no_grad(): + node_repr = model(bg, nfeats, efeats) + return readout(bg, node_repr).detach().cpu().numpy()[0] + + +def graph_construction_and_featurization(smiles): + """ + Constructs graphs from SMILES and featurizes them. + + Args: + smiles (list of str): SMILES of molecules, for embedding computation. + + Returns: + graphs (list of DGLGraph): List of graphs constructed and featurized. + success (list of bool): Indicators for whether the SMILES string can be + parsed by RDKit. + """ + graphs = [] + success = [] + for smi in tqdm(smiles): + try: + mol = Chem.MolFromSmiles(smi) + if mol is None: + success.append(False) + continue + g = mol_to_bigraph( + mol, + add_self_loop=True, + node_featurizer=PretrainAtomFeaturizer(), + edge_featurizer=PretrainBondFeaturizer(), + canonical_atom_order=False, + ) + graphs.append(g) + success.append(True) + except: + success.append(False) + + return graphs, success diff --git a/src/synnet/encoding/utils.py b/src/synnet/encoding/utils.py new file mode 100644 index 00000000..e5ffc995 --- /dev/null +++ b/src/synnet/encoding/utils.py @@ -0,0 +1,18 @@ +import numpy as np + + +def one_hot_encoder(dim, space): + """ + Create a one-hot encoded vector of length=`space`, with a non-zero element + at the index given by `dim`. + + Args: + dim (int): Non-zero bit in one-hot vector. + space (int): Length of one-hot encoded vector. + + Returns: + vec (np.ndarray): One-hot encoded vector. + """ + vec = np.zeros((1, space)) + vec[0, dim] = 1 + return vec diff --git a/src/synnet/models/act.py b/src/synnet/models/act.py new file mode 100644 index 00000000..fbf6da3d --- /dev/null +++ b/src/synnet/models/act.py @@ -0,0 +1,104 @@ +"""Action network. +""" +import json +import logging +from pathlib import Path + +import pytorch_lightning as pl +from pytorch_lightning import loggers as pl_loggers +from pytorch_lightning.callbacks.early_stopping import EarlyStopping +from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint +from pytorch_lightning.callbacks.progress import TQDMProgressBar + +from synnet.models.common import get_args, xy_to_dataloader +from synnet.models.mlp import MLP + +logger = logging.getLogger(__name__) +MODEL_ID = Path(__file__).stem + +if __name__ == "__main__": + logger.info("Start.") + + # Parse input args + args = get_args() + logger.info(f"Arguments: {json.dumps(vars(args),indent=2)}") + + pl.seed_everything(0) + + # Set up dataloaders + dataset = "train" + train_dataloader = xy_to_dataloader( + X_file=Path(args.data_dir) / f"X_{MODEL_ID}_{dataset}.npz", + y_file=Path(args.data_dir) / f"y_{MODEL_ID}_{dataset}.npz", + n=None if not args.debug else 128, + task="classification", + batch_size=args.batch_size, + num_workers=args.ncpu, + shuffle=True if dataset == "train" else False, + ) + + dataset = "valid" + valid_dataloader = xy_to_dataloader( + X_file=Path(args.data_dir) / f"X_{MODEL_ID}_{dataset}.npz", + y_file=Path(args.data_dir) / f"y_{MODEL_ID}_{dataset}.npz", + n=None if not args.debug else 128, + task="classification", + batch_size=args.batch_size, + num_workers=args.ncpu, + shuffle=True if dataset == "train" else False, + ) + logger.info(f"Set up dataloaders.") + + INPUT_DIMS = { + "fp": int(3 * args.nbits), + "gin": int(2 * args.nbits + args.out_dim), + } # somewhat constant... + + input_dims = INPUT_DIMS[args.featurize] + + mlp = MLP( + input_dim=input_dims, + output_dim=4, + hidden_dim=1000, + num_layers=5, + dropout=0.5, + num_dropout_layers=1, + task="classification", + loss="cross_entropy", + valid_loss="accuracy", + optimizer="adam", + learning_rate=3e-4, + val_freq=10, + ncpu=args.ncpu, + ) + + # Set up Trainer + save_dir = Path("results/logs/") / MODEL_ID + save_dir.mkdir(exist_ok=True, parents=True) + + tb_logger = pl_loggers.TensorBoardLogger(save_dir, name="") + csv_logger = pl_loggers.CSVLogger(tb_logger.log_dir, name="", version="") + logger.info(f"Log dir set to: {tb_logger.log_dir}") + + checkpoint_callback = ModelCheckpoint( + monitor="val_loss", + dirpath=tb_logger.log_dir, + filename="ckpts.{epoch}-{val_loss:.2f}", + save_weights_only=False, + ) + earlystop_callback = EarlyStopping(monitor="val_loss", patience=3) + tqdm_callback = TQDMProgressBar(refresh_rate=int(len(train_dataloader) * 0.05)) + + max_epochs = args.epoch if not args.debug else 100 + # Create trainer + trainer = pl.Trainer( + gpus=[0], + max_epochs=max_epochs, + callbacks=[checkpoint_callback, tqdm_callback], + logger=[tb_logger, csv_logger], + fast_dev_run=args.fast_dev_run, + ) + + logger.info(f"Start training") + trainer.fit(mlp, train_dataloader, valid_dataloader) + logger.info(f"Training completed.") diff --git a/src/synnet/models/common.py b/src/synnet/models/common.py new file mode 100644 index 00000000..301be528 --- /dev/null +++ b/src/synnet/models/common.py @@ -0,0 +1,161 @@ +"""Common methods and params shared by all models. +""" + +from pathlib import Path +from typing import Union + +import numpy as np +import torch +from scipy import sparse + +from synnet.models.mlp import MLP + + +def get_args(): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--data-dir", type=str, default="data/featurized/Xy", help="Directory with X,y data." + ) + parser.add_argument( + "-f", "--featurize", type=str, default="fp", help="Choose from ['fp', 'gin']" + ) + parser.add_argument( + "-r", "--rxn_template", type=str, default="hb", help="Choose from ['hb', 'pis']" + ) + parser.add_argument("--radius", type=int, default=2, help="Radius for Morgan fingerprint.") + parser.add_argument( + "--nbits", type=int, default=4096, help="Number of Bits for Morgan fingerprint." + ) + parser.add_argument("--out_dim", type=int, default=256, help="Output dimension.") + parser.add_argument("--ncpu", type=int, default=16, help="Number of cpus") + parser.add_argument("--batch_size", type=int, default=64, help="Batch size") + parser.add_argument("--epoch", type=int, default=2000, help="Maximum number of epoches.") + parser.add_argument( + "--ckpt-file", + type=str, + default=None, + help="Checkpoint file. If provided, load and resume training.", + ) + parser.add_argument("-v", "--version", type=int, default=1, help="Version") + parser.add_argument("--debug", default=False, action="store_true") + parser.add_argument("--fast-dev-run", default=False, action="store_true") + return parser.parse_args() + + +def xy_to_dataloader( + X_file: str, y_file: str, task: str = "regression", n: Union[int, float] = 1.0, **kwargs +): + """Loads featurized X,y `*.npz`-data into a `DataLoader`""" + X = sparse.load_npz(X_file) + y = sparse.load_npz(y_file) + # Filer? + if isinstance(n, int): + n = min(n, X.shape[0]) # ensure n does not exceed size of dataset + X = X[:n] + y = y[:n] + elif isinstance(n, float) and n < 1.0: + xn = X.shape[0] * n + yn = X.shape[0] * n + X = X[:xn] + y = y[:yn] + else: + pass # + X = np.atleast_2d(np.asarray(X.todense())) + y = ( + np.atleast_2d(np.asarray(y.todense())) + if task == "regression" + else np.asarray(y.todense()).squeeze() + ) + dataset = torch.utils.data.TensorDataset( + torch.Tensor(X), + torch.Tensor(y), + ) + return torch.utils.data.DataLoader(dataset, **kwargs) + + +def load_mlp_from_ckpt(ckpt_file: str): + """Load a model from a checkpoint for inference.""" + try: + model = MLP.load_from_checkpoint(ckpt_file) + except TypeError: + model = _load_mlp_from_iclr_ckpt(ckpt_file) + return model.eval() + + +def find_best_model_ckpt(path: str) -> Union[Path, None]: + """Find checkpoint with lowest val_loss. + + Poor man's regex: + somepath/act/ckpts.epoch=70-val_loss=0.03.ckpt + ^^^^--extract this as float + """ + ckpts = Path(path).rglob("*.ckpt") + best_model_ckpt = None + lowest_loss = 10_000 # ~ math.inf + for file in ckpts: + stem = file.stem + val_loss = float(stem.split("val_loss=")[-1]) + if val_loss < lowest_loss: + best_model_ckpt = file + lowest_loss = val_loss + return best_model_ckpt + + +def _load_mlp_from_iclr_ckpt(ckpt_file: str): + """Load a model from a checkpoint for inference. + Info: hparams were not saved, so we specify the ones needed for inference again.""" + model = Path(ckpt_file).parent.name # assume "//.ckpt" + if model == "act": + model = MLP.load_from_checkpoint( + ckpt_file, + input_dim=3 * 4096, + output_dim=4, + hidden_dim=1000, + num_layers=5, + task="classification", + dropout=0.5, + ) + elif model == "rt1": + model = MLP.load_from_checkpoint( + ckpt_file, + input_dim=3 * 4096, + output_dim=256, + hidden_dim=1200, + num_layers=5, + task="regression", + dropout=0.5, + ) + elif model == "rxn": + model = MLP.load_from_checkpoint( + ckpt_file, + input_dim=4 * 4096, + output_dim=91, + hidden_dim=3000, + num_layers=5, + task="classification", + dropout=0.5, + ) + elif model == "rt2": + model = MLP.load_from_checkpoint( + ckpt_file, + input_dim=4 * 4096 + 91, + output_dim=256, + hidden_dim=3000, + num_layers=5, + task="regression", + dropout=0.5, + ) + + else: + raise ValueError + return model.eval() + + +if __name__ == "__main__": + import json + + args = get_args() + print("Default Arguments are:") + print(json.dumps(args.__dict__, indent=2)) diff --git a/src/synnet/models/mlp.py b/src/synnet/models/mlp.py new file mode 100644 index 00000000..a946e2a9 --- /dev/null +++ b/src/synnet/models/mlp.py @@ -0,0 +1,137 @@ +""" +Multi-layer perceptron (MLP) class. +""" +import logging +from pathlib import Path + +import numpy as np +import pytorch_lightning as pl +import torch +import torch.nn.functional as F +from torch import nn + +from synnet.MolEmbedder import MolEmbedder + +logger = logging.getLogger(__name__) + + +class MLP(pl.LightningModule): + def __init__( + self, + input_dim: int, + output_dim: int, + hidden_dim: int, + num_layers: int, + dropout: float, + num_dropout_layers: int = 1, + task: str = "classification", + loss: str = "cross_entropy", + valid_loss: str = "accuracy", + optimizer: str = "adam", + learning_rate: float = 1e-4, + val_freq: int = 10, + ncpu: int = 16, + molembedder: MolEmbedder = None, + ): + super().__init__() + self.save_hyperparameters(ignore="molembedder") + self.loss = loss + self.valid_loss = valid_loss + self.optimizer = optimizer + self.learning_rate = learning_rate + self.ncpu = ncpu + self.val_freq = val_freq + self.molembedder = molembedder + + modules = [] + modules.append(nn.Linear(input_dim, hidden_dim)) + modules.append(nn.BatchNorm1d(hidden_dim)) + modules.append(nn.ReLU()) + + for i in range(num_layers - 2): + modules.append(nn.Linear(hidden_dim, hidden_dim)) + modules.append(nn.BatchNorm1d(hidden_dim)) + modules.append(nn.ReLU()) + if i > num_layers - 3 - num_dropout_layers: + modules.append(nn.Dropout(dropout)) + + modules.append(nn.Linear(hidden_dim, output_dim)) + + self.layers = nn.Sequential(*modules) + + def forward(self, x): + """Forward step for inference only.""" + y_hat = self.layers(x) + if ( + self.hparams.task == "classification" + ): # during training, `cross_entropy` loss expects raw logits + y_hat = F.softmax(y_hat, dim=-1) + return y_hat + + def training_step(self, batch, batch_idx): + """The complete training loop.""" + x, y = batch + y_hat = self.layers(x) + if self.loss == "cross_entropy": + loss = F.cross_entropy(y_hat, y.long()) + elif self.loss == "mse": + loss = F.mse_loss(y_hat, y) + elif self.loss == "l1": + loss = F.l1_loss(y_hat, y) + elif self.loss == "huber": + loss = F.huber_loss(y_hat, y) + else: + raise ValueError("Unsupported loss function '%s'" % self.loss) + self.log(f"train_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) + return loss + + def validation_step(self, batch, batch_idx): + """The complete validation loop.""" + if self.trainer.current_epoch % self.val_freq != 0: + return None + + x, y = batch + y_hat = self.layers(x) + if self.valid_loss == "cross_entropy": + loss = F.cross_entropy(y_hat, y.long()) + elif self.valid_loss == "accuracy": + y_hat = torch.argmax(y_hat, axis=1) + accuracy = (y_hat == y).sum() / len(y) + loss = 1 - accuracy + elif self.valid_loss[:11] == "nn_accuracy": + # NOTE: Very slow! + # Performing the knn-search can easily take a couple of minutes, + # even for small datasets. + kdtree = self.molembedder.kdtree + y = nn_search_list(y.detach().cpu().numpy(), kdtree) + y_hat = nn_search_list(y_hat.detach().cpu().numpy(), kdtree) + + accuracy = (y_hat == y).sum() / len(y) + loss = 1 - accuracy + elif self.valid_loss == "mse": + loss = F.mse_loss(y_hat, y) + elif self.valid_loss == "l1": + loss = F.l1_loss(y_hat, y) + elif self.valid_loss == "huber": + loss = F.huber_loss(y_hat, y) + else: + raise ValueError("Unsupported loss function '%s'" % self.valid_loss) + self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) + + def configure_optimizers(self): + """Define Optimerzers and LR schedulers.""" + if self.optimizer == "adam": + optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) + elif self.optimizer == "sgd": + optimizer = torch.optim.SGD(self.parameters(), lr=self.learning_rate) + return optimizer + + +def nn_search_list(y, kdtree): + y = np.atleast_2d(y) # (n_samples, n_features) + ind = kdtree.query(y, k=1, return_distance=False) # (n_samples, 1) + return ind + + +if __name__ == "__main__": + pass diff --git a/src/synnet/models/rt1.py b/src/synnet/models/rt1.py new file mode 100644 index 00000000..8bb3e9b9 --- /dev/null +++ b/src/synnet/models/rt1.py @@ -0,0 +1,118 @@ +"""Reactant1 network (for predicting 1st reactant). +""" +import json +import logging +from pathlib import Path + +import pytorch_lightning as pl +from pytorch_lightning import loggers as pl_loggers +from pytorch_lightning.callbacks.early_stopping import EarlyStopping +from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint +from pytorch_lightning.callbacks.progress import TQDMProgressBar + +from synnet.encoding.distances import cosine_distance +from synnet.models.common import get_args, xy_to_dataloader +from synnet.models.mlp import MLP +from synnet.MolEmbedder import MolEmbedder + +logger = logging.getLogger(__name__) +MODEL_ID = Path(__file__).stem + + +def _fetch_molembedder(): + file = args.mol_embedder_file + logger.info(f"Try to load precomputed MolEmbedder from {file}.") + molembedder = MolEmbedder().load_precomputed(file).init_balltree(metric=cosine_distance) + logger.info(f"Loaded MolEmbedder from {file}.") + return molembedder + + +if __name__ == "__main__": + logger.info("Start.") + + # Parse input args + args = get_args() + logger.info(f"Arguments: {json.dumps(vars(args),indent=2)}") + + pl.seed_everything(0) + + # Set up dataloaders + dataset = "train" + train_dataloader = xy_to_dataloader( + X_file=Path(args.data_dir) / f"X_{MODEL_ID}_{dataset}.npz", + y_file=Path(args.data_dir) / f"y_{MODEL_ID}_{dataset}.npz", + n=None if not args.debug else 128, + batch_size=args.batch_size, + num_workers=args.ncpu, + shuffle=True if dataset == "train" else False, + ) + + dataset = "valid" + valid_dataloader = xy_to_dataloader( + X_file=Path(args.data_dir) / f"X_{MODEL_ID}_{dataset}.npz", + y_file=Path(args.data_dir) / f"y_{MODEL_ID}_{dataset}.npz", + n=None if not args.debug else 128, + batch_size=args.batch_size, + num_workers=args.ncpu, + shuffle=True if dataset == "train" else False, + ) + + logger.info(f"Set up dataloaders.") + + # Fetch Molembedder and init BallTree + molembedder = None # _fetch_molembedder() + + INPUT_DIMS = { + "fp": int(3 * args.nbits), + "gin": int(2 * args.nbits + args.out_dim), + } # somewhat constant... + + input_dims = INPUT_DIMS[args.featurize] + + mlp = MLP( + input_dim=input_dims, + output_dim=args.out_dim, + hidden_dim=1200, + num_layers=5, + dropout=0.5, + num_dropout_layers=1, + task="regression", + loss="mse", + valid_loss="mse", + optimizer="adam", + learning_rate=3e-4, + val_freq=10, + molembedder=molembedder, + ncpu=args.ncpu, + ) + + # Set up Trainer + save_dir = Path("results/logs/") / MODEL_ID + save_dir.mkdir(exist_ok=True, parents=True) + + tb_logger = pl_loggers.TensorBoardLogger(save_dir, name="") + csv_logger = pl_loggers.CSVLogger(tb_logger.log_dir, name="", version="") + logger.info(f"Log dir set to: {tb_logger.log_dir}") + + checkpoint_callback = ModelCheckpoint( + monitor="val_loss", + dirpath=tb_logger.log_dir, + filename="ckpts.{epoch}-{val_loss:.2f}", + save_weights_only=False, + ) + earlystop_callback = EarlyStopping(monitor="val_loss", patience=3) + tqdm_callback = TQDMProgressBar(refresh_rate=int(len(train_dataloader) * 0.05)) + + max_epochs = args.epoch if not args.debug else 100 + # Create trainer + trainer = pl.Trainer( + gpus=[0], + max_epochs=max_epochs, + callbacks=[checkpoint_callback, tqdm_callback], + logger=[tb_logger, csv_logger], + fast_dev_run=args.fast_dev_run, + ) + + logger.info(f"Start training") + trainer.fit(mlp, train_dataloader, valid_dataloader) + logger.info(f"Training completed.") diff --git a/src/synnet/models/rt2.py b/src/synnet/models/rt2.py new file mode 100644 index 00000000..2ea69453 --- /dev/null +++ b/src/synnet/models/rt2.py @@ -0,0 +1,123 @@ +"""Reactant2 network (for predicting 2nd reactant). +""" +import json +import logging +from pathlib import Path + +import pytorch_lightning as pl +from pytorch_lightning import loggers as pl_loggers +from pytorch_lightning.callbacks.early_stopping import EarlyStopping +from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint +from pytorch_lightning.callbacks.progress import TQDMProgressBar + +from synnet.encoding.distances import cosine_distance +from synnet.models.common import get_args, xy_to_dataloader +from synnet.models.mlp import MLP +from synnet.MolEmbedder import MolEmbedder + +logger = logging.getLogger(__name__) +MODEL_ID = Path(__file__).stem + + +def _fetch_molembedder(): + file = args.mol_embedder_file + logger.info(f"Try to load precomputed MolEmbedder from {file}.") + molembedder = MolEmbedder().load_precomputed(file).init_balltree(metric=cosine_distance) + logger.info(f"Loaded MolEmbedder from {file}.") + return molembedder + + +if __name__ == "__main__": + logger.info("Start.") + + # Parse input args + args = get_args() + logger.info(f"Arguments: {json.dumps(vars(args),indent=2)}") + + pl.seed_everything(0) + + # Set up dataloaders + dataset = "train" + train_dataloader = xy_to_dataloader( + X_file=Path(args.data_dir) / f"X_{MODEL_ID}_{dataset}.npz", + y_file=Path(args.data_dir) / f"y_{MODEL_ID}_{dataset}.npz", + n=None if not args.debug else 128, + batch_size=args.batch_size, + num_workers=args.ncpu, + shuffle=True if dataset == "train" else False, + ) + + dataset = "valid" + valid_dataloader = xy_to_dataloader( + X_file=Path(args.data_dir) / f"X_{MODEL_ID}_{dataset}.npz", + y_file=Path(args.data_dir) / f"y_{MODEL_ID}_{dataset}.npz", + n=None if not args.debug else 128, + batch_size=args.batch_size, + num_workers=args.ncpu, + shuffle=True if dataset == "train" else False, + ) + + logger.info(f"Set up dataloaders.") + + # Fetch Molembedder and init BallTree + molembedder = None # _fetch_molembedder() + + INPUT_DIMS = { + "fp": { + "hb": int(4 * args.nbits + 91), + "gin": int(4 * args.nbits + 4700), + }, + "gin": { + "hb": int(3 * args.nbits + args.out_dim + 91), + "gin": int(3 * args.nbits + args.out_dim + 4700), + }, + } # somewhat constant... + input_dims = INPUT_DIMS[args.featurize][args.rxn_template] + + mlp = MLP( + input_dim=input_dims, + output_dim=args.out_dim, + hidden_dim=3000, + num_layers=5, + dropout=0.5, + num_dropout_layers=1, + task="regression", + loss="mse", + valid_loss="mse", + optimizer="adam", + learning_rate=3e-4, + val_freq=10, + molembedder=molembedder, + ncpu=args.ncpu, + ) + + # Set up Trainer + save_dir = Path("results/logs/") / MODEL_ID + save_dir.mkdir(exist_ok=True, parents=True) + + tb_logger = pl_loggers.TensorBoardLogger(save_dir, name="") + csv_logger = pl_loggers.CSVLogger(tb_logger.log_dir, name="", version="") + logger.info(f"Log dir set to: {tb_logger.log_dir}") + + checkpoint_callback = ModelCheckpoint( + monitor="val_loss", + dirpath=tb_logger.log_dir, + filename="ckpts.{epoch}-{val_loss:.2f}", + save_weights_only=False, + ) + earlystop_callback = EarlyStopping(monitor="val_loss", patience=3) + tqdm_callback = TQDMProgressBar(refresh_rate=int(len(train_dataloader) * 0.05)) + + max_epochs = args.epoch if not args.debug else 100 + # Create trainer + trainer = pl.Trainer( + gpus=[0], + max_epochs=max_epochs, + callbacks=[checkpoint_callback, tqdm_callback], + logger=[tb_logger, csv_logger], + fast_dev_run=args.fast_dev_run, + ) + + logger.info(f"Start training") + trainer.fit(mlp, train_dataloader, valid_dataloader) + logger.info(f"Training completed.") diff --git a/src/synnet/models/rxn.py b/src/synnet/models/rxn.py new file mode 100644 index 00000000..d4ded03c --- /dev/null +++ b/src/synnet/models/rxn.py @@ -0,0 +1,132 @@ +""" +Reaction network. +""" +import json +import logging +from pathlib import Path + +import pytorch_lightning as pl +from pytorch_lightning import loggers as pl_loggers +from pytorch_lightning.callbacks.early_stopping import EarlyStopping +from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint +from pytorch_lightning.callbacks.progress import TQDMProgressBar + +from synnet.config import CHECKPOINTS_DIR +from synnet.models.common import get_args, xy_to_dataloader +from synnet.models.mlp import MLP + +logger = logging.getLogger(__name__) +MODEL_ID = Path(__file__).stem + +if __name__ == "__main__": + logger.info("Start.") + + # Parse input args + args = get_args() + logger.info(f"Arguments: {json.dumps(vars(args),indent=2)}") + + pl.seed_everything(0) + + # Set up dataloaders + dataset = "train" + train_dataloader = xy_to_dataloader( + X_file=Path(args.data_dir) / f"X_{MODEL_ID}_{dataset}.npz", + y_file=Path(args.data_dir) / f"y_{MODEL_ID}_{dataset}.npz", + n=None if not args.debug else 128, + task="classification", + batch_size=args.batch_size, + num_workers=args.ncpu, + shuffle=True if dataset == "train" else False, + ) + + dataset = "valid" + valid_dataloader = xy_to_dataloader( + X_file=Path(args.data_dir) / f"X_{MODEL_ID}_{dataset}.npz", + y_file=Path(args.data_dir) / f"y_{MODEL_ID}_{dataset}.npz", + n=None if not args.debug else 128, + task="classification", + batch_size=args.batch_size, + num_workers=args.ncpu, + shuffle=True if dataset == "train" else False, + ) + logger.info(f"Set up dataloaders.") + + INPUT_DIMS = { + "fp": { + "hb": int(4 * args.nbits), + "gin": int(4 * args.nbits), + }, + "gin": { + "hb": int(3 * args.nbits + args.out_dim), + "gin": int(3 * args.nbits + args.out_dim), + }, + } # somewhat constant... + input_dim = INPUT_DIMS[args.featurize][args.rxn_template] + + HIDDEN_DIMS = { + "fp": { + "hb": 3000, + "gin": 4500, + }, + "gin": { + "hb": 3000, + "gin": 3000, + }, + } + hidden_dim = HIDDEN_DIMS[args.featurize][args.rxn_template] + + OUTPUT_DIMS = { + "hb": 91, + "gin": 4700, + } + output_dim = OUTPUT_DIMS[args.rxn_template] + + ckpt_path = args.ckpt_file # TODO: Unify for all networks + mlp = MLP( + input_dim=input_dim, + output_dim=output_dim, + hidden_dim=hidden_dim, + num_layers=5, + dropout=0.5, + num_dropout_layers=1, + task="classification", + loss="cross_entropy", + valid_loss="accuracy", + optimizer="adam", + learning_rate=3e-4, + val_freq=10, + ncpu=args.ncpu, + ) + + # Set up Trainer + save_dir = Path("results/logs/") / MODEL_ID + save_dir.mkdir(exist_ok=True, parents=True) + + tb_logger = pl_loggers.TensorBoardLogger(save_dir, name="") + csv_logger = pl_loggers.CSVLogger(tb_logger.log_dir, name="", version="") + logger.info(f"Log dir set to: {tb_logger.log_dir}") + + tb_logger = pl_loggers.TensorBoardLogger(save_dir, name="") + + checkpoint_callback = ModelCheckpoint( + monitor="val_loss", + dirpath=tb_logger.log_dir, + filename="ckpts.{epoch}-{val_loss:.2f}", + save_weights_only=False, + ) + earlystop_callback = EarlyStopping(monitor="val_loss", patience=3) + tqdm_callback = TQDMProgressBar(refresh_rate=int(len(train_dataloader) * 0.05)) + + max_epochs = args.epoch if not args.debug else 100 + # Create trainer + trainer = pl.Trainer( + gpus=[0], + max_epochs=max_epochs, + callbacks=[checkpoint_callback, tqdm_callback], + logger=[tb_logger, csv_logger], + fast_dev_run=args.fast_dev_run, + ) + + logger.info(f"Start training") + trainer.fit(mlp, train_dataloader, valid_dataloader, ckpt_path=ckpt_path) + logger.info(f"Training completed.") diff --git a/syn_net/data_generation/__init__.py b/src/synnet/utils/__init__.py similarity index 100% rename from syn_net/data_generation/__init__.py rename to src/synnet/utils/__init__.py diff --git a/src/synnet/utils/data_utils.py b/src/synnet/utils/data_utils.py new file mode 100644 index 00000000..bc043653 --- /dev/null +++ b/src/synnet/utils/data_utils.py @@ -0,0 +1,743 @@ +""" +Here we define the following classes for working with synthetic tree data: +* `Reaction` +* `ReactionSet` +* `NodeChemical` +* `NodeRxn` +* `SyntheticTree` +* `SyntheticTreeSet` +""" +import functools +import gzip +import itertools +import json +from typing import Any, Optional, Set, Tuple, Union + +from rdkit import Chem +from rdkit.Chem import AllChem, Draw, rdChemReactions +from tqdm import tqdm + + +# the definition of reaction classes below +class Reaction: + """ + This class models a chemical reaction based on a SMARTS transformation. + + Args: + template (str): SMARTS string representing a chemical reaction. + rxnname (str): The name of the reaction for downstream analysis. + smiles: (str): A reaction SMILES string that macthes the SMARTS pattern. + reference (str): Reference information for the reaction. + """ + + smirks: str # SMARTS pattern + rxn: Chem.rdChemReactions.ChemicalReaction + num_reactant: int + num_agent: int + num_product: int + reactant_template: Tuple[str, str] + product_template: str + agent_template: str + available_reactants: Tuple[list[str], Optional[list[str]]] + rxnname: str + smiles: Any + reference: Any + + def __init__(self, template=None, rxnname=None, smiles=None, reference=None): + + if template is not None: + # define a few attributes based on the input + self.smirks = template.strip() + self.rxnname = rxnname + self.smiles = smiles + self.reference = reference + + # compute a few additional attributes + self.rxn = self.__init_reaction(self.smirks) + + # Extract number of ... + self.num_reactant = self.rxn.GetNumReactantTemplates() + if self.num_reactant not in (1, 2): + raise ValueError("Reaction is neither uni- nor bi-molecular.") + self.num_agent = self.rxn.GetNumAgentTemplates() + self.num_product = self.rxn.GetNumProductTemplates() + + # Extract reactants, agents, products + reactants, agents, products = self.smirks.split(">") + + if self.num_reactant == 1: + self.reactant_template = list((reactants,)) + else: + self.reactant_template = list(reactants.split(".")) + self.product_template = products + self.agent_template = agents + else: + self.smirks = None + + def __init_reaction(self, smirks: str) -> Chem.rdChemReactions.ChemicalReaction: + """Initializes a reaction by converting the SMARTS-pattern to an `rdkit` object.""" + rxn = AllChem.ReactionFromSmarts(smirks) + rdChemReactions.ChemicalReaction.Initialize(rxn) + return rxn + + def load( + self, + smirks, + num_reactant, + num_agent, + num_product, + reactant_template, + product_template, + agent_template, + available_reactants, + rxnname, + smiles, + reference, + ): + """ + This function loads a set of elements and reconstructs a `Reaction` object. + """ + self.smirks = smirks + self.num_reactant = num_reactant + self.num_agent = num_agent + self.num_product = num_product + self.reactant_template = list(reactant_template) + self.product_template = product_template + self.agent_template = agent_template + self.available_reactants = list(available_reactants) # TODO: use Tuple[list,list] here + self.rxnname = rxnname + self.smiles = smiles + self.reference = reference + self.rxn = self.__init_reaction(self.smirks) + return self + + @functools.lru_cache(maxsize=20) + def get_mol(self, smi: Union[str, Chem.Mol]) -> Chem.Mol: + """ + A internal function that returns an `RDKit.Chem.Mol` object. + + Args: + smi (str or RDKit.Chem.Mol): The query molecule, as either a SMILES + string or an `RDKit.Chem.Mol` object. + + Returns: + RDKit.Chem.Mol + """ + if isinstance(smi, str): + return Chem.MolFromSmiles(smi) + elif isinstance(smi, Chem.Mol): + return smi + else: + raise TypeError(f"{type(smi)} not supported, only `str` or `rdkit.Chem.Mol`") + + def visualize(self, name="./reaction1_highlight.o.png"): + """ + A function that plots the chemical translation into a PNG figure. + One can use "from IPython.display import Image ; Image(name)" to see it + in a Python notebook. + + Args: + name (str): The path to the figure. + + Returns: + name (str): The path to the figure. + """ + rxn = AllChem.ReactionFromSmarts(self.smirks) + d2d = Draw.MolDraw2DCairo(800, 300) + d2d.DrawReaction(rxn, highlightByReactant=True) + png = d2d.GetDrawingText() + open(name, "wb+").write(png) + del rxn + return name + + def is_reactant(self, smi: Union[str, Chem.Mol]) -> bool: + """Checks if `smi` is a reactant of this reaction.""" + smi = self.get_mol(smi) + return self.rxn.IsMoleculeReactant(smi) + + def is_agent(self, smi: Union[str, Chem.Mol]) -> bool: + """Checks if `smi` is an agent of this reaction.""" + smi = self.get_mol(smi) + return self.rxn.IsMoleculeAgent(smi) + + def is_product(self, smi): + """Checks if `smi` is a product of this reaction.""" + smi = self.get_mol(smi) + return self.rxn.IsMoleculeProduct(smi) + + def is_reactant_first(self, smi: Union[str, Chem.Mol]) -> bool: + """Check if `smi` is the first reactant in this reaction""" + mol = self.get_mol(smi) + pattern = Chem.MolFromSmarts(self.reactant_template[0]) + return mol.HasSubstructMatch(pattern) + + def is_reactant_second(self, smi: Union[str, Chem.Mol]) -> bool: + """Check if `smi` the second reactant in this reaction""" + mol = self.get_mol(smi) + pattern = Chem.MolFromSmarts(self.reactant_template[1]) + return mol.HasSubstructMatch(pattern) + + def run_reaction( + self, reactants: Tuple[Union[str, Chem.Mol, None]], keep_main: bool = True + ) -> Union[str, None]: + """Run this reactions with reactants and return corresponding product. + + Args: + reactants (tuple): Contains SMILES strings for the reactants. + keep_main (bool): Return main product only or all possibel products. Defaults to True. + + Returns: + uniqps: SMILES string representing the product or `None` if not reaction possible + """ + # Input validation. + if not isinstance(reactants, tuple): + raise TypeError(f"Unsupported type '{type(reactants)}' for `reactants`.") + if not len(reactants) in (1, 2): + raise ValueError(f"Can only run reactions with 1 or 2 reactants, not {len(reactants)}.") + + rxn = self.rxn # TODO: investigate if this is necessary (if not, delete "delete rxn below") + + # Convert all reactants to `Chem.Mol` + r: Tuple = tuple(self.get_mol(smiles) for smiles in reactants if smiles is not None) + + if self.num_reactant == 1: + if len(r) == 2: # Provided two reactants for unimolecular reaction -> no rxn possible + return None + if not self.is_reactant(r[0]): + return None + elif self.num_reactant == 2: + # Match reactant order with reaction template + if self.is_reactant_first(r[0]) and self.is_reactant_second(r[1]): + pass + elif self.is_reactant_first(r[1]) and self.is_reactant_second(r[0]): + r = tuple(reversed(r)) + else: # No reaction possible + return None + else: + raise ValueError("This reaction is neither uni- nor bi-molecular.") + + # Run reaction with rdkit magic + ps = rxn.RunReactants(r) + + # Filter for unique products (less magic) + # Note: Use chain() to flatten the tuple of tuples + uniqps = list({Chem.MolToSmiles(p) for p in itertools.chain(*ps)}) + + # Sanity check + if not len(uniqps) >= 1: + # TODO: Raise (custom) exception? + raise ValueError("Reaction did not yield any products.") + + del rxn + + if keep_main: + uniqps = uniqps[:1] + # >>> TODO: Always return list[str] (currently depends on "keep_main") + uniqps = uniqps[0] + # <<< ^ delete this line if resolved. + return uniqps + + def _filter_reactants( + self, smiles: list[str], verbose: bool = False + ) -> Tuple[list[str], list[str]]: + """ + Filters reactants which do not match the reaction. + + Args: + smiles: Possible reactants for this reaction. + + Returns: + :lists of SMILES which match either the first + reactant, or, if applicable, the second reactant. + + Raises: + ValueError: If `self` is not a uni- or bi-molecular reaction. + """ + smiles = tqdm(smiles) if verbose else smiles + + if self.num_reactant == 1: # uni-molecular reaction + reactants_1 = [smi for smi in smiles if self.is_reactant_first(smi)] + return (reactants_1,) + + elif self.num_reactant == 2: # bi-molecular reaction + reactants_1 = [smi for smi in smiles if self.is_reactant_first(smi)] + reactants_2 = [smi for smi in smiles if self.is_reactant_second(smi)] + + return (reactants_1, reactants_2) + else: + raise ValueError("This reaction is neither uni- nor bi-molecular.") + + def set_available_reactants(self, building_blocks: list[str], verbose: bool = False): + """ + Finds applicable reactants from a list of building blocks. + Sets `self.available_reactants`. + + Args: + building_blocks: Building blocks as SMILES strings. + """ + self.available_reactants = self._filter_reactants(building_blocks, verbose=verbose) + return self + + @property + def get_available_reactants(self) -> Set[str]: + return {x for reactants in self.available_reactants for x in reactants} + + def asdict(self) -> dict(): + """Returns serializable fields as new dictionary mapping. + *Excludes* Not-easily-serializable `self.rxn: rdkit.Chem.ChemicalReaction`.""" + import copy + + out = copy.deepcopy(self.__dict__) # TODO: + _ = out.pop("rxn") + return out + + +class ReactionSet: + """Represents a collection of reactions, for saving and loading purposes.""" + + def __init__(self, rxns: Optional[list[Reaction]] = None): + self.rxns = rxns if rxns is not None else [] + + def load(self, file: str): + """Load a collection of reactions from a `*.json.gz` file.""" + assert str(file).endswith(".json.gz"), f"Incompatible file extension for file {file}" + with gzip.open(file, "r") as f: + data = json.loads(f.read().decode("utf-8")) + + for r in data["reactions"]: + rxn = Reaction().load( + **r + ) # TODO: `load()` relies on postional args, hence we cannot load a reaction that has no `available_reactants` for extample (or no template) + self.rxns.append(rxn) + return self + + def save(self, file: str) -> None: + """Save a collection of reactions to a `*.json.gz` file.""" + + assert str(file).endswith(".json.gz"), f"Incompatible file extension for file {file}" + + r_list = {"reactions": [r.asdict() for r in self.rxns]} + with gzip.open(file, "w") as f: + f.write(json.dumps(r_list).encode("utf-8")) + + def __len__(self): + return len(self.rxns) + + def _print(self, x=3): + # For debugging + for i, r in enumerate(self.rxns): + if i >= x: + break + print(json.dumps(r.asdict(), indent=2)) + + +# the definition of classes for defining synthetic trees below +class NodeChemical: + """Represents a chemical node in a synthetic tree. + + Args: + smiles: Molecule represented as SMILES string. + parent: Parent molecule represented as SMILES string (i.e. the result of a reaction) + child: Index of the reaction this object participates in. + is_leaf: Is this a leaf node in a synthetic tree? + is_root: Is this a root node in a synthetic tree? + depth: Depth this node is in tree (+1 for an action, +.5 for a reaction) + index: Incremental index for all chemical nodes in the tree. + """ + + def __init__( + self, + smiles: Union[str, None] = None, + parent: Union[int, None] = None, + child: Union[int, None] = None, + is_leaf: bool = False, + is_root: bool = False, + depth: float = 0, + index: int = 0, + ): + self.smiles = smiles + self.parent = parent + self.child = child + self.is_leaf = is_leaf + self.is_root = is_root + self.depth = depth + self.index = index + + +class NodeRxn: + """Represents a chemical reaction in a synthetic tree. + + + Args: + rxn_id (None or int): Index corresponding to reaction in a one-hot vector + of reaction templates. + rtype (None or int): Indicates if uni- (1) or bi-molecular (2) reaction. + parent (None or list): + child (None or list): Contains SMILES strings of reactants which lead to + the specified reaction. + depth (float): + index (int): Indicates the order of this reaction node in the tree. + """ + + def __init__( + self, + rxn_id: Union[int, None] = None, + rtype: Union[int, None] = None, + parent: Union[list, None] = [], + child: Union[list, None] = None, + depth: float = 0, + index: int = 0, + ): + self.rxn_id = rxn_id + self.rtype = rtype + self.parent = parent + self.child = child + self.depth = depth + self.index = index + + +class SyntheticTree: + """ + A class representing a synthetic tree. + + Args: + chemicals (list): A list of chemical nodes, in order of addition. + reactions (list): A list of reaction nodes, in order of addition. + actions (list): A list of actions, in order of addition. + root (NodeChemical): The root node. + depth (int): The depth of the tree. + rxn_id2type (dict): A dictionary that maps reaction indices to reaction + type (uni- or bi-molecular). + """ + + def __init__(self, tree=None): + self.chemicals: list[NodeChemical] = [] + self.reactions: list[NodeRxn] = [] + self.root = None + self.depth: float = 0 + self.actions = [] + self.rxn_id2type = None + + if tree is not None: + self.read(tree) + + def read(self, data): + """ + A function that loads a dictionary from synthetic tree data. + + Args: + data (dict): A dictionary representing a synthetic tree. + """ + self.root = NodeChemical(**data["root"]) + self.depth = data["depth"] + self.actions = data["actions"] + self.rxn_id2type = data["rxn_id2type"] + + for r_dict in data["reactions"]: + r = NodeRxn(**r_dict) + self.reactions.append(r) + + for m_dict in data["chemicals"]: + r = NodeChemical(**m_dict) + self.chemicals.append(r) + + def output_dict(self): + """ + A function that exports dictionary-formatted synthetic tree data. + + Returns: + data (dict): A dictionary representing a synthetic tree. + """ + return { + "reactions": [r.__dict__ for r in self.reactions], + "chemicals": [m.__dict__ for m in self.chemicals], + "root": self.root.__dict__, + "depth": self.depth, + "actions": self.actions, + "rxn_id2type": self.rxn_id2type, + } + + def _print(self): + """ + A function that prints the contents of the synthetic tree. + """ + print("===============Stored Molecules===============") + for node in self.chemicals: + print(node.smiles, node.is_root) + print("===============Stored Reactions===============") + for node in self.reactions: + print(node.rxn_id, node.rtype) + print("===============Followed Actions===============") + print(self.actions) + + def get_node_index(self, smi): + """ + Returns the index of the node matching the input SMILES. + + Args: + smi (str): A SMILES string that represents the query molecule. + + Returns: + index (int): Index of chemical node corresponding to the query + molecule. If the query moleucle is not in the tree, return None. + """ + for node in self.chemicals: + if smi == node.smiles: + return node.index + return None + + def get_state(self) -> list[str]: + """Get the state of this synthetic tree. + The most recent root node has 0 as its index. + + Returns: + state (list): A list contains all root node molecules. + """ + state = [node.smiles for node in self.chemicals if node.is_root] + return state[::-1] + + def update(self, action: int, rxn_id: int, mol1: str, mol2: str, mol_product: str): + """Update this synthetic tree by adding a reaction step. + + Args: + action (int): Action index, where the indices (0, 1, 2, 3) represent + (Add, Expand, Merge, and End), respectively. + rxn_id (int): Index of the reaction occured, where the index can be + anything in the range [0, len(template_list)-1]. + mol1 (str): SMILES string representing the first reactant. + mol2 (str): SMILES string representing the second reactant. + mol_product (str): SMILES string representing the product. + """ + self.actions.append(int(action)) + + if action == 3: # End + self.root = self.chemicals[-1] + self.depth = self.root.depth + + elif action == 2: # Merge (with bi-mol rxn) + node_mol1 = self.chemicals[self.get_node_index(mol1)] + node_mol2 = self.chemicals[self.get_node_index(mol2)] + node_rxn = NodeRxn( + rxn_id=rxn_id, + rtype=2, + parent=None, + child=[node_mol1.smiles, node_mol2.smiles], + depth=max(node_mol1.depth, node_mol2.depth) + 0.5, + index=len(self.reactions), + ) + node_product = NodeChemical( + smiles=mol_product, + parent=None, + child=node_rxn.rxn_id, + is_leaf=False, + is_root=True, + depth=node_rxn.depth + 0.5, + index=len(self.chemicals), + ) + + node_rxn.parent = node_product.smiles + node_mol1.parent = node_rxn.rxn_id + node_mol2.parent = node_rxn.rxn_id + node_mol1.is_root = False + node_mol2.is_root = False + + self.chemicals.append(node_product) + self.reactions.append(node_rxn) + + elif action == 1 and mol2 is None: # Expand with uni-mol rxn + node_mol1 = self.chemicals[self.get_node_index(mol1)] + node_rxn = NodeRxn( + rxn_id=rxn_id, + rtype=1, + parent=None, + child=[node_mol1.smiles], + depth=node_mol1.depth + 0.5, + index=len(self.reactions), + ) + node_product = NodeChemical( + smiles=mol_product, + parent=None, + child=node_rxn.rxn_id, + is_leaf=False, + is_root=True, + depth=node_rxn.depth + 0.5, + index=len(self.chemicals), + ) + + node_rxn.parent = node_product.smiles + node_mol1.parent = node_rxn.rxn_id + node_mol1.is_root = False + + self.chemicals.append(node_product) + self.reactions.append(node_rxn) + + elif action == 1 and mol2 is not None: # Expand with bi-mol rxn + node_mol1 = self.chemicals[self.get_node_index(mol1)] + node_mol2 = NodeChemical( + smiles=mol2, + parent=None, + child=None, + is_leaf=True, + is_root=False, + depth=0, + index=len(self.chemicals), + ) + node_rxn = NodeRxn( + rxn_id=rxn_id, + rtype=2, + parent=None, + child=[node_mol1.smiles, node_mol2.smiles], + depth=max(node_mol1.depth, node_mol2.depth) + 0.5, + index=len(self.reactions), + ) + node_product = NodeChemical( + smiles=mol_product, + parent=None, + child=node_rxn.rxn_id, + is_leaf=False, + is_root=True, + depth=node_rxn.depth + 0.5, + index=len(self.chemicals) + 1, + ) + + node_rxn.parent = node_product.smiles + node_mol1.parent = node_rxn.rxn_id + node_mol2.parent = node_rxn.rxn_id + node_mol1.is_root = False + + self.chemicals.append(node_mol2) + self.chemicals.append(node_product) + self.reactions.append(node_rxn) + + elif action == 0 and mol2 is None: # Add with uni-mol rxn + node_mol1 = NodeChemical( + smiles=mol1, + parent=None, + child=None, + is_leaf=True, + is_root=False, + depth=0, + index=len(self.chemicals), + ) + node_rxn = NodeRxn( + rxn_id=rxn_id, + rtype=1, + parent=None, + child=[node_mol1.smiles], + depth=0.5, + index=len(self.reactions), + ) + node_product = NodeChemical( + smiles=mol_product, + parent=None, + child=node_rxn.rxn_id, + is_leaf=False, + is_root=True, + depth=1, + index=len(self.chemicals) + 1, + ) + + node_rxn.parent = node_product.smiles + node_mol1.parent = node_rxn.rxn_id + + self.chemicals.append(node_mol1) + self.chemicals.append(node_product) + self.reactions.append(node_rxn) + + elif action == 0 and mol2 is not None: # Add with bi-mol rxn + node_mol1 = NodeChemical( + smiles=mol1, + parent=None, + child=None, + is_leaf=True, + is_root=False, + depth=0, + index=len(self.chemicals), + ) + node_mol2 = NodeChemical( + smiles=mol2, + parent=None, + child=None, + is_leaf=True, + is_root=False, + depth=0, + index=len(self.chemicals) + 1, + ) + node_rxn = NodeRxn( + rxn_id=rxn_id, + rtype=2, + parent=None, + child=[node_mol1.smiles, node_mol2.smiles], + depth=0.5, + index=len(self.reactions), + ) + node_product = NodeChemical( + smiles=mol_product, + parent=None, + child=node_rxn.rxn_id, + is_leaf=False, + is_root=True, + depth=1, + index=len(self.chemicals) + 2, + ) + + node_rxn.parent = node_product.smiles + node_mol1.parent = node_rxn.rxn_id + node_mol2.parent = node_rxn.rxn_id + + self.chemicals.append(node_mol1) + self.chemicals.append(node_mol2) + self.chemicals.append(node_product) + self.reactions.append(node_rxn) + + else: + raise ValueError("Check input") + + return None + + +class SyntheticTreeSet: + """Represents a collection of synthetic trees, for saving and loading purposes.""" + + def __init__(self, sts: Optional[list[SyntheticTree]] = None): + self.sts = sts if sts is not None else [] + + def __len__(self): + return len(self.sts) + + def __getitem__(self, index): + if self.sts is None: + raise IndexError("No Synthetic Trees.") + return self.sts[index] + + def load(self, file: str): + """Load a collection of synthetic trees from a `*.json.gz` file.""" + assert str(file).endswith(".json.gz"), f"Incompatible file extension for file {file}" + + with gzip.open(file, "rt") as f: + data = json.loads(f.read()) + + for st in data["trees"]: + st = SyntheticTree(st) if st is not None else None + self.sts.append(st) + + return self + + def save(self, file: str) -> None: + """Save a collection of synthetic trees to a `*.json.gz` file.""" + assert str(file).endswith(".json.gz"), f"Incompatible file extension for file {file}" + + st_list = {"trees": [st.output_dict() for st in self.sts if st is not None]} + with gzip.open(file, "wt") as f: + f.write(json.dumps(st_list)) + + def _print(self, x=3): + """Helper function for debugging.""" + for i, r in enumerate(self.sts): + if i >= x: + break + print(r.output_dict()) + + +if __name__ == "__main__": + pass diff --git a/syn_net/utils/ga_utils.py b/src/synnet/utils/ga_utils.py similarity index 53% rename from syn_net/utils/ga_utils.py rename to src/synnet/utils/ga_utils.py index 7e6cc11d..9ee2f4c4 100644 --- a/syn_net/utils/ga_utils.py +++ b/src/synnet/utils/ga_utils.py @@ -5,7 +5,7 @@ import scipy -def crossover(parents, offspring_size, distribution='even'): +def crossover(parents, offspring_size, distribution="even"): """ A function that samples an offspring set through a crossover from a mating pool. @@ -24,54 +24,57 @@ def crossover(parents, offspring_size, distribution='even'): Returns: offspring (numpy.ndarray): An array which represents the offspring pool. """ - fp_length = parents.shape[1] - offspring = np.zeros((offspring_size, fp_length)) + fp_length = parents.shape[1] + offspring = np.zeros((offspring_size, fp_length)) inherit_num = np.ceil( - np.random.normal(loc=fp_length/2, scale=fp_length/10, size=(offspring_size, )) + np.random.normal(loc=fp_length / 2, scale=fp_length / 10, size=(offspring_size,)) ) inherit_num = np.where( - inherit_num >= int(fp_length/5) * np.ones((offspring_size, )), - inherit_num, int(fp_length/5) * np.ones((offspring_size, )) + inherit_num >= int(fp_length / 5) * np.ones((offspring_size,)), + inherit_num, + int(fp_length / 5) * np.ones((offspring_size,)), ) inherit_num = np.where( - int(fp_length*4/5) * np.ones((offspring_size, )) <= inherit_num, - int(fp_length*4/5) * np.ones((offspring_size, )), - inherit_num + int(fp_length * 4 / 5) * np.ones((offspring_size,)) <= inherit_num, + int(fp_length * 4 / 5) * np.ones((offspring_size,)), + inherit_num, ) for k in range(offspring_size): - parent1_idx = list(set(np.random.choice(fp_length, size=int(inherit_num[k]), replace=False))) + parent1_idx = list( + set(np.random.choice(fp_length, size=int(inherit_num[k]), replace=False)) + ) parent2_idx = list(set(range(fp_length)).difference(set(parent1_idx))) - if distribution == 'even': - parent_set = parents[np.random.choice(parents.shape[0], - size=2, - replace=False)] - elif distribution == 'linear': - p_ = np.arange(parents.shape[0])[::-1] + 10 - parent_set = parents[np.random.choice(parents.shape[0], - size=2, - replace=False, - p=p_/np.sum(p_))] - elif distribution == 'softmax_linear': - p_ = np.arange(parents.shape[0])[::-1] + 10 - parent_set = parents[np.random.choice(parents.shape[0], - size=2, - replace=False, - p=scipy.special.softmax(p_))] + if distribution == "even": + parent_set = parents[np.random.choice(parents.shape[0], size=2, replace=False)] + elif distribution == "linear": + p_ = np.arange(parents.shape[0])[::-1] + 10 + parent_set = parents[ + np.random.choice(parents.shape[0], size=2, replace=False, p=p_ / np.sum(p_)) + ] + elif distribution == "softmax_linear": + p_ = np.arange(parents.shape[0])[::-1] + 10 + parent_set = parents[ + np.random.choice( + parents.shape[0], size=2, replace=False, p=scipy.special.softmax(p_) + ) + ] offspring[k, parent1_idx] = parent_set[0][parent1_idx] offspring[k, parent2_idx] = parent_set[1][parent2_idx] return offspring + def fitness_sum(element): """ Test fitness function. """ return np.sum(element) + def mutation(offspring_crossover, num_mut_per_ele=1, mut_probability=0.5): """ A function that samples an offspring set through a crossover from a mating @@ -87,44 +90,43 @@ def mutation(offspring_crossover, num_mut_per_ele=1, mut_probability=0.5): offspring_crossover (numpy.ndarray): An array represents the offspring pool after mutation. """ - b_dict = {1:0, 0:1} + b_dict = {1: 0, 0: 1} fp_length = offspring_crossover.shape[1] mut_proba = np.random.random(offspring_crossover.shape[0]) for idx in range(offspring_crossover.shape[0]): # The random value to be added to the gene. if mut_proba[idx] <= mut_probability: - position = np.random.choice(fp_length, - size=int(num_mut_per_ele), - replace=False) - tmp = np.array([b_dict[int(_)] for _ in offspring_crossover[idx, position]]) + position = np.random.choice(fp_length, size=int(num_mut_per_ele), replace=False) + tmp = np.array([b_dict[int(_)] for _ in offspring_crossover[idx, position]]) offspring_crossover[idx, position] = tmp else: pass return offspring_crossover -if __name__ == '__main__': - num_parents = 10 - fp_size = 128 +if __name__ == "__main__": + + num_parents = 10 + fp_size = 128 offspring_size = 30 - ngen = 100 - population = np.ceil(np.random.random(size=(num_parents, fp_size)) * 2 - 1) + ngen = 100 + population = np.ceil(np.random.random(size=(num_parents, fp_size)) * 2 - 1) - print(f'Starting with {num_parents} fps with {fp_size} bits') + print(f"Starting with {num_parents} fps with {fp_size} bits") scores = np.array([fitness_sum(_) for _ in population]) - print(f'Initial: {scores.mean():.3f} +/- {scores.std():.3f}') - print(f'Scores: {scores}') + print(f"Initial: {scores.mean():.3f} +/- {scores.std():.3f}") + print(f"Scores: {scores}") for n in range(ngen): - offspring = crossover(population, offspring_size) - offspring = mutation(offspring, num_mut_per_ele=4, mut_probability=0.5) + offspring = crossover(population, offspring_size) + offspring = mutation(offspring, num_mut_per_ele=4, mut_probability=0.5) new_population = np.concatenate([population, offspring], axis=0) - new_scores = np.array(scores.tolist() + [fitness_sum(_) for _ in offspring]) - scores = [] + new_scores = np.array(scores.tolist() + [fitness_sum(_) for _ in offspring]) + scores = [] for parent_idx in range(num_parents): max_score_idx = np.where(new_scores == np.max(new_scores))[0][0] @@ -133,5 +135,5 @@ def mutation(offspring_crossover, num_mut_per_ele=1, mut_probability=0.5): new_scores[max_score_idx] = -999999 scores = np.array(scores) - print(f'Generation {ngen}: {scores.mean()} +/- {scores.std()}') - print(f'Scores: {scores}') + print(f"Generation {ngen}: {scores.mean()} +/- {scores.std()}") + print(f"Scores: {scores}") diff --git a/src/synnet/utils/predict_utils.py b/src/synnet/utils/predict_utils.py new file mode 100644 index 00000000..147b10cb --- /dev/null +++ b/src/synnet/utils/predict_utils.py @@ -0,0 +1,389 @@ +""" +This file contains various utils for creating molecular embeddings and for +decoding synthetic trees. +""" +from typing import Callable, Tuple + +import numpy as np +import pytorch_lightning as pl +import rdkit +import torch +from rdkit import Chem +from sklearn.neighbors import BallTree + +from synnet.encoding.distances import cosine_distance, tanimoto_similarity +from synnet.encoding.fingerprints import mol_fp +from synnet.encoding.utils import one_hot_encoder +from synnet.utils.data_utils import Reaction, SyntheticTree + +# create a random seed for NumPy +np.random.seed(6) + +# general functions +def can_react(state, rxns: list[Reaction]) -> Tuple[int, list[bool]]: + """ + Determines if two molecules can react using any of the input reactions. + + Args: + state (np.ndarray): The current state in the synthetic tree. + rxns (list of Reaction objects): Contains available reaction templates. + + Returns: + np.ndarray: The sum of the reaction mask tells us how many reactions are + viable for the two molecules. + np.ndarray: The reaction mask, which masks out reactions which are not + viable for the two molecules. + """ + mol1 = state.pop() + mol2 = state.pop() + reaction_mask = [int(rxn.run_reaction((mol1, mol2)) is not None) for rxn in rxns] + return sum(reaction_mask), reaction_mask + + +def get_action_mask(state: list, rxns: list[Reaction]) -> np.ndarray: + """ + Determines which actions can apply to a given state in the synthetic tree + and returns a mask for which actions can apply. + + Args: + state (np.ndarray): The current state in the synthetic tree. + rxns (list of Reaction objects): Contains available reaction templates. + + Raises: + ValueError: There is an issue with the input state. + + Returns: + np.ndarray: The action mask. Masks out unviable actions from the current + state using 0s, with 1s at the positions corresponding to viable + actions. + """ + # Action: (Add: 0, Expand: 1, Merge: 2, End: 3) + if len(state) == 0: + mask = [1, 0, 0, 0] + elif len(state) == 1: + mask = [1, 1, 0, 1] + elif len(state) == 2: + can_react_, _ = can_react(state, rxns) + if can_react_: + mask = [0, 1, 1, 0] + else: + mask = [0, 1, 0, 0] + else: + raise ValueError("Problem with state.") + return np.asarray(mask, dtype=bool) + + +def get_reaction_mask(smi: str, rxns: list[Reaction]): + """ + Determines which reaction templates can apply to the input molecule. + + Args: + smi (str): The SMILES string corresponding to the molecule in question. + rxns (list of Reaction objects): Contains available reaction templates. + + Raises: + ValueError: There is an issue with the reactants in the reaction. + + Returns: + reaction_mask (list of ints, or None): The reaction template mask. Masks + out reaction templates which are not viable for the input molecule. + If there are no viable reaction templates identified, is simply None. + available_list (list of lists, or None): Contains available reactants if + at least one viable reaction template is identified. Else is simply + None. + """ + # Return all available reaction templates + # List of available building blocks if 2 + # Exclude the case of len(available_list) == 0 + reaction_mask = [int(rxn.is_reactant(smi)) for rxn in rxns] + + if sum(reaction_mask) == 0: + return None, None + + available_list = [] + mol = rdkit.Chem.MolFromSmiles(smi) + for i, rxn in enumerate(rxns): + if reaction_mask[i] and rxn.num_reactant == 2: + + if rxn.is_reactant_first(mol): + available_list.append(rxn.available_reactants[1]) + elif rxn.is_reactant_second(mol): + available_list.append(rxn.available_reactants[0]) + else: + raise ValueError("Check the reactants") + + if len(available_list[-1]) == 0: + reaction_mask[i] = 0 + + else: + available_list.append([]) + + return reaction_mask, available_list + + +def nn_search( + _e: np.ndarray, _tree: BallTree, _k: int = 1 +) -> Tuple[float, float]: # TODO: merge w `nn_search_rt1` + """ + Conducts a nearest neighbor search to find the molecule from the tree most + simimilar to the input embedding. + + Args: + _e (np.ndarray): A specific point in the dataset. + _tree (sklearn.neighbors._kd_tree.KDTree, optional): A k-d tree. + _k (int, optional): Indicates how many nearest neighbors to get. + Defaults to 1. + + Returns: + float: The distance to the nearest neighbor. + int: The indices of the nearest neighbor. + """ + dist, ind = _tree.query(_e, k=_k) + return dist[0][0], ind[0][0] + + +def nn_search_rt1(_e: np.ndarray, _tree: BallTree, _k: int = 1) -> Tuple[np.ndarray, np.ndarray]: + dist, ind = _tree.query(_e, k=_k) + return dist[0], ind[0] + + +def set_embedding( + z_target: np.ndarray, state: list[str], nbits: int, _mol_embedding: Callable +) -> np.ndarray: + """ + Computes embeddings for all molecules in the input space. + Embedding = [z_mol1, z_mol2, z_target] + + Args: + z_target (np.ndarray): Molecular embedding of the target molecule. + state (list): State of the synthetic tree, i.e. list of root molecules. + nbits (int): Length of fingerprint. + _mol_embedding (Callable): Computes the embeddings of molecules in the state. + + Returns: + embedding (np.ndarray): shape (1,d+2*nbits) + """ + z_target = np.atleast_2d(z_target) # (1,d) + if len(state) == 0: + z_mol1 = np.zeros((1, nbits)) + z_mol2 = np.zeros((1, nbits)) + elif len(state) == 1: + z_mol1 = np.atleast_2d(_mol_embedding(state[0])) + z_mol2 = np.zeros((1, nbits)) + elif len(state) == 2: + z_mol1 = np.atleast_2d(_mol_embedding(state[0])) + z_mol2 = np.atleast_2d(_mol_embedding(state[1])) + else: + raise ValueError + embedding = np.concatenate([z_mol1, z_mol2, z_target], axis=1) + return embedding # (1,d+2*nbits) + + +def synthetic_tree_decoder( + z_target: np.ndarray, + building_blocks: list[str], + bb_dict: dict[str, int], + reaction_templates: list[Reaction], + mol_embedder, + action_net: pl.LightningModule, + reactant1_net: pl.LightningModule, + rxn_net: pl.LightningModule, + reactant2_net: pl.LightningModule, + bb_emb: np.ndarray, + rxn_template: str, + n_bits: int, + max_step: int = 15, + k_reactant1: int = 1, +) -> Tuple[SyntheticTree, int]: + """ + Computes a synthetic tree given an input molecule embedding. + Uses the Action, Reaction, Reactant1, and Reactant2 networks and a greedy search. + + Args: + z_target (np.ndarray): Embedding for the target molecule + building_blocks (list of str): Contains available building blocks + bb_dict (dict): Building block dictionary + reaction_templates (list of Reactions): Contains reaction templates + mol_embedder (dgllife.model.gnn.gin.GIN): GNN to use for obtaining + molecular embeddings + action_net (synth_net.models.mlp.MLP): The action network + reactant1_net (synth_net.models.mlp.MLP): The reactant1 network + rxn_net (synth_net.models.mlp.MLP): The reaction network + reactant2_net (synth_net.models.mlp.MLP): The reactant2 network + bb_emb (list): Contains purchasable building block embeddings. + rxn_template (str): Specifies the set of reaction templates to use. + n_bits (int): Length of fingerprint. + max_step (int, optional): Maximum number of steps to include in the + synthetic tree + + Returns: + tree (SyntheticTree): The final synthetic tree. + act (int): The final action (to know if the tree was "properly" + terminated). + """ + # Initialization + tree = SyntheticTree() + mol_recent = None + kdtree = mol_embedder # TODO: dont mis-use this arg + + # Start iteration + # TODO: tree decoder can exceed this an still return a tree, but action is not equal to 3 + # Raise error instead like in syntree generation? + for i in range(max_step): + # Encode current state + state = tree.get_state() # a list + z_state = set_embedding(z_target, state, nbits=n_bits, _mol_embedding=mol_fp) + + # Predict action type, masked selection + # Action: (Add: 0, Expand: 1, Merge: 2, End: 3) + action_proba = action_net(torch.Tensor(z_state)) # (1,4) + action_proba = action_proba.squeeze().detach().numpy() + 1e-10 + action_mask = get_action_mask(tree.get_state(), reaction_templates) + act = np.argmax(action_proba * action_mask) + + # Continue growing tree? + if act == 3: # End + break + + z_mol1 = reactant1_net(torch.Tensor(z_state)) + z_mol1 = z_mol1.detach().numpy() # (1,dimension_output_embedding), default: (1,256) + + # Select first molecule + if act == 0: + # Select `k` for kNN search of 1st reactant + # Use k>1 for the first action, and k==1 for all others. + # Idea: Increase the chances of generating a better tree. + k = k_reactant1 if mol_recent is None else 1 + + _, idxs = kdtree.query(z_mol1, k=k) # idxs.shape = (1,k) + mol1 = building_blocks[idxs[0][k - 1]] + elif act == 1 or act == 2: + # Expand or Merge + mol1 = mol_recent + else: + raise ValueError(f"Unexpected action {act}.") + + z_mol1 = mol_fp(mol1) + z_mol1 = np.atleast_2d(z_mol1) # (1,4096) + + # Select reaction + z = np.concatenate([z_state, z_mol1], axis=1) + reaction_proba = rxn_net(torch.Tensor(z)) + reaction_proba = reaction_proba.squeeze().detach().numpy() + 1e-10 # (nReactionTemplate,) + + if act == 0 or act == 1: # add or expand + reaction_mask, available_list = get_reaction_mask(mol1, reaction_templates) + else: # merge + _, reaction_mask = can_react(tree.get_state(), reaction_templates) + available_list = [ + [] for rxn in reaction_templates + ] # TODO: if act=merge, this is not used at all + + # If we ended up in a state where no reaction is possible, end this iteration. + if reaction_mask is None: + if len(state) == 1: # only a single root mol, so this syntree is valid + act = 3 + break + else: + break # action != 3, so in our analysis we will see this tree as "invalid" + + # Select reaction template + rxn_id = np.argmax(reaction_proba * reaction_mask) + rxn = reaction_templates[rxn_id] + + NUMBER_OF_REACTION_TEMPLATES = { + "hb": 91, + "pis": 4700, + "unittest": 3, + } # TODO: Refactor / use class + + # Select 2nd reactant + if rxn.num_reactant == 2: + if act == 2: # Merge + temp = set(state) - set([mol1]) + mol2 = temp.pop() + else: # Add or Expand + x_rxn = one_hot_encoder(rxn_id, NUMBER_OF_REACTION_TEMPLATES[rxn_template]) + x_rct2 = np.concatenate([z_state, z_mol1, x_rxn], axis=1) + z_mol2 = reactant2_net(torch.Tensor(x_rct2)) + z_mol2 = z_mol2.detach().numpy() + available = available_list[rxn_id] # list[str], list of reactants for this rxn + available = [bb_dict[available[i]] for i in range(len(available))] # list[int] + temp_emb = bb_emb[available] + available_tree = BallTree( + temp_emb, metric=cosine_distance + ) # TODO: evaluate if distance matrix is faster/feasible as this BallTree is discarded immediately. + dist, ind = nn_search(z_mol2, _tree=available_tree) + mol2 = building_blocks[available[ind]] + else: + mol2 = None + + # Run reaction + mol_product = rxn.run_reaction((mol1, mol2)) + if mol_product is None or Chem.MolFromSmiles(mol_product) is None: + if len(state) == 1: # only a single root mol, so this syntree is valid + act = 3 + break + else: + break # action != 3, so in our analysis we will see this tree as "invalid" + + # Update + tree.update(act, int(rxn_id), mol1, mol2, mol_product) + mol_recent = mol_product + + if act != 3: + tree = tree + else: + tree.update(act, None, None, None, None) + + return tree, act + + +def synthetic_tree_decoder_greedy_search( + beam_width: int = 3, **kwargs +) -> Tuple[str, float, SyntheticTree, int]: + """ + Wrapper around `synthetic_tree_decoder_rt1` with variable `k` for kNN search of 1st reactant. + Will keep the syntree that comprises of a molecule most similar to the target mol. + + Args: + beam_width (int): The beam width to use for Reactant 1 search. Defaults to 3. + kwargs: Identical to wrapped function. + + Returns: + tree (SyntheticTree): The final synthetic tree + act (int): The final action (to know if the tree was "properly" terminated) + """ + z_target = kwargs["z_target"] + trees: list[SyntheticTree] = [] + smiles: list[str] = [] + similarities: list[float] = [] + acts: list[int] = [] + + for i in range(beam_width): + tree, act = synthetic_tree_decoder(k_reactant1=i + 1, **kwargs) + + # Find the chemical in this tree that is most similar to the target. + # Note: This does not have to be the final root mol, but any, as we can truncate tree to our liking. + similarities_in_tree = np.array( + tanimoto_similarity(z_target, [node.smiles for node in tree.chemicals]) + ) + max_similar_idx = np.argmax(similarities_in_tree) + max_similarity = similarities_in_tree[max_similar_idx] + + # Keep track of max similarities (across syntrees) + similarities.append(max_similarity) + + # Keep track of generated syntrees + smiles.append(tree.chemicals[max_similar_idx].smiles) + trees.append(tree) + acts.append(act) + + # Identify most similar among all trees + max_similar_idx = np.argmax(similarities) + similarity = similarities[max_similar_idx] + tree = trees[max_similar_idx] + smi = smiles[max_similar_idx] + act = acts[max_similar_idx] + + return smi, similarity, tree, act diff --git a/src/synnet/utils/prep_utils.py b/src/synnet/utils/prep_utils.py new file mode 100644 index 00000000..d9298b87 --- /dev/null +++ b/src/synnet/utils/prep_utils.py @@ -0,0 +1,171 @@ +""" +This file contains various utils for data preparation and preprocessing. +""" +import logging +from pathlib import Path +from typing import Iterator, Union + +import numpy as np +from rdkit import Chem +from scipy import sparse +from sklearn.preprocessing import OneHotEncoder + +logger = logging.getLogger(__name__) + + +def rdkit2d_embedding(smi): + """ + Computes an embedding using RDKit 2D descriptors. + + Args: + smi (str): SMILES string. + + Returns: + np.ndarray: A molecular embedding corresponding to the input molecule. + """ + from tdc.chem_utils import MolConvert + + if smi is None: + return np.zeros(200).reshape((-1,)) + else: + # define the RDKit 2D descriptor + rdkit2d = MolConvert(src="SMILES", dst="RDKit2D") + return rdkit2d(smi).reshape( + -1, + ) + + +import functools + + +@functools.lru_cache(maxsize=1) +def _fetch_gin_pretrained_model(model_name: str): + from dgllife.model import load_pretrained + + """Get a GIN pretrained model to use for creating molecular embeddings""" + device = "cpu" + model = load_pretrained(model_name).to(device) + model.eval() + return model + + +def split_data_into_Xy( + dataset_type: str, + steps_file: str, + states_file: str, + output_dir: Path, + num_rxn: int, + out_dim: int, +) -> None: + """Split the featurized data into X,y-chunks for the {act,rt1,rxn,rt2}-networks. + + Args: + num_rxn (int): Number of reactions in the dataset. + out_dim (int): Size of the output feature vectors (used in kNN-search for rt1,rt2) + """ + output_dir = Path(output_dir) + output_dir.mkdir(exist_ok=True, parents=True) + + # Load data # TODO: separate functionality? + states = sparse.load_npz(states_file) # (n,3*4096) + steps = sparse.load_npz(steps_file) # (n,1+256+91+256+4096) + + # Extract data for each network... + + # ... action data + # X: [z_state] + # y: [action id] (int) + X = states + y = steps[:, 0] + sparse.save_npz(output_dir / f"X_act_{dataset_type}.npz", X) + sparse.save_npz(output_dir / f"y_act_{dataset_type}.npz", y) + logger.info(f' saved data for "Action" to {output_dir}') + + # ... reaction data + # X: [state, z_reactant_1] + # y: [reaction_id] (int) + # but: delete all steps where we *end* syntrees, as that will not be followed by a reaction + actions = steps[:, 0].A # (n,1) as array to allow boolean + isActionEnd = (actions == 3).squeeze() # (n,) + states = states[~isActionEnd] + steps = steps[~isActionEnd] + X = sparse.hstack([states, steps[:, (2 * out_dim + 2) :]]) # (n,4*4096) + y = steps[:, out_dim + 1] # (n,1) + sparse.save_npz(output_dir / f"X_rxn_{dataset_type}.npz", X) + sparse.save_npz(output_dir / f"y_rxn_{dataset_type}.npz", y) + logger.info(f' saved data for "Reaction" to {output_dir}') + + # ... reactant 2 data + # X: [state,z_mol1,OneHotEnc(rxn_id)] + # y: [z_mol2] + # but: delete all steps where we *merge* syntrees, as in that case we already have reactant1+2 + actions = steps[:, 0].A # (n',1) as array to allow boolean + isActionMerge = (actions == 2).squeeze() # (n',) + steps = steps[~isActionMerge] + states = states[~isActionMerge] + z_mol1 = steps[:, (2 * out_dim + 2) :] + rxn_ids = steps[:, (1 + out_dim)] + z_rxn_id = OneHotEncoder().fit(np.arange(num_rxn)[:, None]).transform(rxn_ids.A) + X = sparse.hstack((states, z_mol1, z_rxn_id)) # (n,3*4096+4096+91) + y = steps[:, (2 + out_dim) : (2 * out_dim + 2)] + sparse.save_npz(output_dir / f"X_rt2_{dataset_type}.npz", X) + sparse.save_npz(output_dir / f"y_rt2_{dataset_type}.npz", y) + logger.info(f' saved data for "Reactant 2" to {output_dir}') + + # ... reactant 1 data + # X: [z_state] + # y: [z'_reactant_1] + # but: delete all steps where we expand syntrees, as in that case we already have a reactant1 + actions = steps[:, 0].A # (n',1) as array to allow boolean + isActionExpand = (actions == 1).squeeze() # (n',) + steps = steps[~isActionExpand] + states = states[~isActionExpand] + zprime_mol1 = steps[:, 1 : (out_dim + 1)] + + X = states + y = zprime_mol1 + sparse.save_npz(output_dir / f"X_rt1_{dataset_type}.npz", X) + sparse.save_npz(output_dir / f"y_rt1_{dataset_type}.npz", y) + logger.info(f' saved data for "Reactant 1" to {output_dir}') + + return None + + +class Sdf2SmilesExtractor: + """Helper class for data generation.""" + + def __init__(self) -> None: + self.smiles: Iterator[str] + + def from_sdf(self, file: Union[str, Path]): + """Extract chemicals as SMILES from `*.sdf` file. + + See also: + https://www.rdkit.org/docs/GettingStartedInPython.html#reading-sets-of-molecules + """ + file = str(Path(file).resolve()) + suppl = Chem.SDMolSupplier(file) + self.smiles = (Chem.MolToSmiles(mol, canonical=True, isomericSmiles=False) for mol in suppl) + logger.info(f"Read data from {file}") + + return self + + def _to_csv_gz(self, file: Path) -> None: + import gzip + + with gzip.open(file, "wt") as f: + f.writelines("SMILES\n") + f.writelines((s + "\n" for s in self.smiles)) + + def _to_txt(self, file: Path) -> None: + with open(file, "wt") as f: + f.writelines("SMILES\n") + f.writelines((s + "\n" for s in self.smiles)) + + def to_file(self, file: Union[str, Path]) -> None: + + if Path(file).suffixes == [".csv", ".gz"]: + self._to_csv_gz(file) + else: + self._to_txt(file) + logger.info(f"Saved data to {file}") diff --git a/src/synnet/visualize/drawers.py b/src/synnet/visualize/drawers.py new file mode 100644 index 00000000..8b6c0511 --- /dev/null +++ b/src/synnet/visualize/drawers.py @@ -0,0 +1,56 @@ +import uuid +from pathlib import Path +from typing import Optional, Union + +import rdkit.Chem as Chem +from rdkit.Chem import Draw + + +class MolDrawer: + """Draws molecules as images.""" + + def __init__(self, path: Optional[str], subfolder: str = "assets"): + + # Init outfolder + if not (path is not None and Path(path).exists()): + raise NotADirectoryError(path) + self.outfolder = Path(path) / subfolder + self.outfolder.mkdir(exist_ok=1) + + # Placeholder + self.lookup: dict[str, str] = None + + def _hash(self, smiles: list[str]) -> dict[str, str]: + """Hashing for amateurs. + Goal: Get a short, valid, and hopefully unique filename for each molecule.""" + self.lookup = {smile: str(uuid.uuid4())[:8] for smile in smiles} + return self + + def get_path(self) -> str: + return self.path + + def get_molecule_filesnames(self): + return self.lookup + + def plot(self, smiles: Union[list[str], str]): + """Plot smiles as 2d molecules and save to `self.path/subfolder/*.svg`.""" + self._hash(smiles) + + for k, v in self.lookup.items(): + fname = self.outfolder / f"{v}.svg" + mol = Chem.MolFromSmiles(k) + # Plot + drawer = Draw.rdMolDraw2D.MolDraw2DSVG(300, 150) + opts = drawer.drawOptions() + drawer.DrawMolecule(mol) + drawer.FinishDrawing() + p = drawer.GetDrawingText() + + with open(fname, "w") as f: + f.write(p) + + return self + + +if __name__ == "__main__": + pass diff --git a/src/synnet/visualize/visualizer.py b/src/synnet/visualize/visualizer.py new file mode 100644 index 00000000..7d5e9593 --- /dev/null +++ b/src/synnet/visualize/visualizer.py @@ -0,0 +1,177 @@ +from pathlib import Path +from typing import Union + +from synnet.utils.data_utils import NodeChemical, NodeRxn, SyntheticTree +from synnet.visualize.drawers import MolDrawer +from synnet.visualize.writers import subgraph + + +class SynTreeVisualizer: + actions_taken: dict[int, str] + CHEMICALS: dict[str, NodeChemical] + outfolder: Union[str, Path] + version: int + + ACTIONS = { + 0: "Add", + 1: "Expand", + 2: "Merge", + 3: "End", + } + + def __init__(self, syntree: SyntheticTree, outfolder: str = "./syntree-viz/st"): + self.syntree = syntree + self.actions_taken = { + depth: self.ACTIONS[action] for depth, action in enumerate(syntree.actions) + } + self.CHEMICALS = {node.smiles: node for node in syntree.chemicals} + + # Placeholder for images for molecues. + self.drawer: Union[MolDrawer, None] + self.molecule_filesnames: Union[None, dict[str, str]] = None + + # Folders + outfolder = Path(outfolder) + self.version = self._get_next_version(outfolder) + self.path = outfolder.with_name(outfolder.name + f"_{self.version}") + return None + + def _get_next_version(self, dir: str) -> int: + root_dir = Path(dir).parent + name = Path(dir).name + + existing_versions = [] + for d in Path(root_dir).glob(f"{name}_*"): + d = str(d.resolve()) + existing_versions.append(int(d.split("_")[1])) + + if len(existing_versions) == 0: + return 0 + + return max(existing_versions) + 1 + + def with_drawings(self, drawer: MolDrawer): + """Init `MolDrawer` to plot molecules in the nodes.""" + self.path.mkdir(parents=True) + self.drawer = drawer(self.path) + return self + + def plot(self): + """Plots molecules via `self.drawer.plot()`.""" + if self.drawer is None: + raise ValueError("Must initialize drawer beforehand.") + self.drawer.plot(self.CHEMICALS) + self.molecule_filesnames = self.drawer.get_molecule_filesnames() + return self + + def _define_chemicals( + self, + chemicals: dict[str, NodeChemical] = None, + ) -> list[str]: + chemicals = self.CHEMICALS if chemicals is None else chemicals + + if self.drawer.outfolder is None or self.molecule_filesnames is None: + raise NotImplementedError("Must provide drawer via `_with_drawings()` before plotting.") + + out: list[str] = [] + + for node in chemicals.values(): + name = f'"node.smiles"' + name = f'' + classdef = self._map_node_type_to_classdef(node) + info = f"n{node.index}[{name}]:::{classdef}" + out += [info] + return out + + def _map_node_type_to_classdef(self, node: NodeChemical) -> str: + """Map a node to pre-defined mermaid class for styling.""" + if node.is_leaf: + classdef = "buildingblock" + elif node.is_root: + classdef = "final" + else: + classdef = "intermediate" + return classdef + + def _write_reaction_connectivity( + self, reactants: list[NodeChemical], product: NodeChemical + ) -> list[str]: + """Write the connectivity of the graph. + Unimolecular reactions have one edge, bimolecular two. + + Examples: + n1 --> n3 + n2 --> n3 + """ + NODE_PREFIX = "n" + r1, r2 = reactants + out = [f"{NODE_PREFIX}{r1.index} --> {NODE_PREFIX}{product.index}"] + if r2 is not None: + out += [f"{NODE_PREFIX}{r2.index} --> {NODE_PREFIX}{product.index}"] + return out + + def write(self) -> list[str]: + """Write markdown with mermaid block.""" + # 1. Plot images + self.plot() + # 2. Write markdown (with reference to image files.) + rxns: list[NodeRxn] = self.syntree.reactions + text = [] + + # Add node definitions + text.extend(self._define_chemicals(self.CHEMICALS)) + + # Add paragraphs (<=> actions taken) + for i, action in self.actions_taken.items(): + if action == "End": + continue + rxn = rxns[i] + product: str = rxn.parent + reactant1: str = rxn.child[0] + reactant2: str = rxn.child[1] if rxn.rtype == 2 else None + + @subgraph(f'"{i:>2d} : {action}"') + def __printer(): + return self._write_reaction_connectivity( + [self.CHEMICALS.get(reactant1), self.CHEMICALS.get(reactant2)], + self.CHEMICALS.get(product), + ) + + out = __printer() + text.extend(out) + return text + + +def demo(): + """Demo syntree visualisation""" + # 1. Load syntree + import json + + infile = "tests/assets/syntree-small.json" + with open(infile, "rt") as f: + data = json.load(f) + + st = SyntheticTree() + st.read(data) + + from synnet.visualize.drawers import MolDrawer + from synnet.visualize.visualizer import SynTreeVisualizer + from synnet.visualize.writers import SynTreeWriter + + outpath = Path("./figures/syntrees/generation/st") + outpath.mkdir(parents=True, exist_ok=True) + + # 2. Plot & Write mermaid markup diagram + stviz = SynTreeVisualizer(syntree=st, outfolder=outpath).with_drawings(drawer=MolDrawer) + mermaid_txt = stviz.write() + # 3. Write everything to a markdown doc + outfile = stviz.path / "syntree.md" + SynTreeWriter().write(mermaid_txt).to_file(outfile) + print(f"Generated markdown file.") + print(f" Input file:", infile) + print(f" Output file:", outfile) + return None + + +if __name__ == "__main__": + demo() diff --git a/src/synnet/visualize/writers.py b/src/synnet/visualize/writers.py new file mode 100644 index 00000000..1832ef54 --- /dev/null +++ b/src/synnet/visualize/writers.py @@ -0,0 +1,105 @@ +from functools import wraps +from typing import Callable + + +class PrefixWriter: + def __init__(self, file: str = None): + self.prefix = self._default_prefix() if file is None else self._load(file) + + def _default_prefix(self): + md = [ + "# Synthetic Tree Visualisation", + "", + "Legend", + "- :green_square: Building Block", + "- :orange_square: Intermediate", + "- :blue_square: Final Molecule", + "- :red_square: Target Molecule", + "", + ] + start = ["```mermaid"] + theming = [ + "%%{init: {", + " 'theme': 'base',", + " 'themeVariables': {", + " 'backgroud': '#ffffff',", + " 'primaryColor': '#ffffff',", + " 'clusterBkg': '#ffffff',", + " 'clusterBorder': '#000000',", + " 'edgeLabelBackground':'#dbe1e1',", + " 'fontSize': '20px'", + " }", + " }", + "}%%", + ] + diagram_id = ["graph BT"] + style = [ + "classDef buildingblock stroke:#00d26a,stroke-width:2px", + "classDef intermediate stroke:#ff6723,stroke-width:2px", + "classDef final stroke:#0074ba,stroke-width:2px", + "classDef target stroke:#f8312f,stroke-width:2px", + ] + return md + start + theming + diagram_id + style + + def _load(self, file): + with open(file, "rt") as f: + out = [l.removesuffix("\n") for l in f] + return out + + def write(self) -> list[str]: + return self.prefix + + +class PostfixWriter: + def write(self) -> list[str]: + return ["```"] + + +class SynTreeWriter: + def __init__(self, prefixer=PrefixWriter(), postfixer=PostfixWriter()): + self.prefixer = prefixer + self.postfixer = postfixer + self._text: list[str] = None + + def write(self, out) -> list[str]: + out = self.prefixer.write() + out + self.postfixer.write() + self._text = out + return self + + def to_file(self, file: str, text: list[str] = None): + if text is None: + text = self._text + + with open(file, "wt") as f: + f.writelines((l.rstrip() + "\n" for l in text)) + return None + + @property + def text(self) -> list[str]: + return self.text + + +def subgraph(argument: str = "") -> Callable: + """Decorator that writes a named mermaid subparagraph. + + Example output: + ``` + subparagraph argument + + end + ``` + """ + + def _subgraph(func) -> Callable: + @wraps(func) + def wrapper(*args, **kwargs) -> list[str]: + out = f"subgraph {argument}" + inner = func(*args, **kwargs) + # add a tab to inner + TAB_CHAR = " " * 4 + inner = [f"{TAB_CHAR}{line}" for line in inner] + return [out] + inner + ["end"] + + return wrapper + + return _subgraph diff --git a/syn_net/data_generation/_mp_make.py b/syn_net/data_generation/_mp_make.py deleted file mode 100644 index badd519c..00000000 --- a/syn_net/data_generation/_mp_make.py +++ /dev/null @@ -1,29 +0,0 @@ -""" -This file contains a function to generate a single synthetic tree, prepared for -multiprocessing. -""" -import pandas as pd -import numpy as np -# import dill as pickle -# import gzip - -from syn_net.data_generation.make_dataset import synthetic_tree_generator -from syn_net.utils.data_utils import ReactionSet - - -path_reaction_file = '/pool001/whgao/data/synth_net/st_pis/reactions_pis.json.gz' -path_to_building_blocks = '/pool001/whgao/data/synth_net/st_pis/enamine_us_matched.csv.gz' - -building_blocks = pd.read_csv(path_to_building_blocks, compression='gzip')['SMILES'].tolist() -r_set = ReactionSet() -r_set.load(path_reaction_file) -rxns = r_set.rxns -# with gzip.open(path_reaction_file, 'rb') as f: -# rxns = pickle.load(f) - -print('Finish reading the templates and building blocks list!') - -def func(_): - np.random.seed(_) - tree, action = synthetic_tree_generator(building_blocks, rxns, max_step=15) - return tree, action diff --git a/syn_net/data_generation/_mp_process.py b/syn_net/data_generation/_mp_process.py deleted file mode 100644 index 21947daf..00000000 --- a/syn_net/data_generation/_mp_process.py +++ /dev/null @@ -1,13 +0,0 @@ -""" -This file contains a function for search available building blocks -for a matching reaction template. Prepared for multiprocessing. -""" -import pandas as pd - -path_to_building_blocks = '/home/whgao/scGen/synth_net/data/enamine_us.csv.gz' -building_blocks = pd.read_csv(path_to_building_blocks, compression='gzip')['SMILES'].tolist() -print('Finish reading the building blocks list!') - -def func(rxn_): - rxn_.set_available_reactants(building_blocks) - return rxn_ diff --git a/syn_net/data_generation/filter_unmatch.py b/syn_net/data_generation/filter_unmatch.py deleted file mode 100644 index a81db3a6..00000000 --- a/syn_net/data_generation/filter_unmatch.py +++ /dev/null @@ -1,25 +0,0 @@ -""" -Filters out purchasable building blocks which don't match a single template. -""" -from syn_net.utils.data_utils import * -import pandas as pd -from tqdm import tqdm - - -if __name__ == '__main__': - r_path = '/pool001/whgao/data/synth_net/st_pis/reactions_pis.json.gz' - bb_path = '/home/whgao/scGen/synth_net/data/enamine_us.csv.gz' - r_set = ReactionSet() - r_set.load(r_path) - matched_mols = set() - for r in tqdm(r_set.rxns): - for a_list in r.available_reactants: - matched_mols = matched_mols | set(a_list) - - original_mols = pd.read_csv(bb_path, compression='gzip')['SMILES'].tolist() - - print('Total building blocks number:', len(original_mols)) - print('Matched building blocks number:', len(matched_mols)) - - df = pd.DataFrame({'SMILES': list(matched_mols)}) - df.to_csv('/pool001/whgao/data/synth_net/st_pis/enamine_us_matched.csv.gz', compression='gzip') diff --git a/syn_net/data_generation/make_dataset.py b/syn_net/data_generation/make_dataset.py deleted file mode 100644 index 05dc3a74..00000000 --- a/syn_net/data_generation/make_dataset.py +++ /dev/null @@ -1,50 +0,0 @@ -""" -This file generates synthetic tree data in a sequential fashion. -""" -import dill as pickle -import gzip -import pandas as pd -import numpy as np -from tqdm import tqdm -from syn_net.utils.data_utils import SyntheticTreeSet -from syn_net.utils.prep_utils import synthetic_tree_generator - - - -if __name__ == '__main__': - path_reaction_file = '/home/whgao/shared/Data/scGen/reactions_pis.pickle.gz' - path_to_building_blocks = '/home/whgao/shared/Data/scGen/enamine_building_blocks_nochiral_matched.csv.gz' - - np.random.seed(6) - - building_blocks = pd.read_csv(path_to_building_blocks, compression='gzip')['SMILES'].tolist() - with gzip.open(path_reaction_file, 'rb') as f: - rxns = pickle.load(f) - - Trial = 5 - num_finish = 0 - num_error = 0 - num_unfinish = 0 - - trees = [] - for _ in tqdm(range(Trial)): - tree, action = synthetic_tree_generator(building_blocks, rxns, max_step=15) - if action == 3: - trees.append(tree) - num_finish += 1 - elif action == -1: - num_error += 1 - else: - num_unfinish += 1 - - print('Total trial: ', Trial) - print('num of finished trees: ', num_finish) - print('num of unfinished tree: ', num_unfinish) - print('num of error processes: ', num_error) - - synthetic_tree_set = SyntheticTreeSet(sts=trees) - synthetic_tree_set.save('st_data.json.gz') - - # data_file = gzip.open('st_data.pickle.gz', 'wb') - # pickle.dump(trees, data_file) - # data_file.close() diff --git a/syn_net/data_generation/make_dataset_mp.py b/syn_net/data_generation/make_dataset_mp.py deleted file mode 100644 index b3d596b4..00000000 --- a/syn_net/data_generation/make_dataset_mp.py +++ /dev/null @@ -1,38 +0,0 @@ -""" -This file generates synthetic tree data in a multi-thread fashion. - -Usage: - python make_dataset_mp.py -""" -import numpy as np -import multiprocessing as mp -from time import time - -from syn_net.utils.data_utils import SyntheticTreeSet -import syn_net.data_generation._mp_make as make - - -if __name__ == '__main__': - - pool = mp.Pool(processes=100) - - NUM_TREES = 600000 - - t = time() - results = pool.map(make.func, np.arange(NUM_TREES).tolist()) - print('Time: ', time() - t, 's') - - trees = [r[0] for r in results if r[1] == 3] - actions = [r[1] for r in results] - - num_finish = actions.count(3) - num_error = actions.count(-1) - num_unfinish = NUM_TREES - num_finish - num_error - - print('Total trial: ', NUM_TREES) - print('num of finished trees: ', num_finish) - print('num of unfinished tree: ', num_unfinish) - print('num of error processes: ', num_error) - - tree_set = SyntheticTreeSet(trees) - tree_set.save('/pool001/whgao/data/synth_net/st_pis/st_data.json.gz') diff --git a/syn_net/data_generation/process_rxn_mp.py b/syn_net/data_generation/process_rxn_mp.py deleted file mode 100644 index a57faf5f..00000000 --- a/syn_net/data_generation/process_rxn_mp.py +++ /dev/null @@ -1,32 +0,0 @@ -""" -This file processes a set of reaction templates and finds applicable -reactants from a list of purchasable building blocks. - -Usage: - python process_rxn_mp.py -""" -import multiprocessing as mp -from time import time - -from syn_net.utils.data_utils import Reaction, ReactionSet -import syn_net.data_generation._mp_process as process -import shutup -shutup.please() - - -if __name__ == '__main__': - name = 'pis' - path_to_rxn_templates = '/home/whgao/scGen/synth_net/data/rxn_set_' + name + '.txt' - rxn_templates = [] - for line in open(path_to_rxn_templates, 'rt'): - rxn = Reaction(line.split('|')[1].strip()) - rxn_templates.append(rxn) - - pool = mp.Pool(processes=64) - - t = time() - rxns = pool.map(process.func, rxn_templates) - print('Time: ', time() - t, 's') - - r = ReactionSet(rxns) - r.save('/pool001/whgao/data/synth_net/st_pis/reactions_' + name + '.json.gz') diff --git a/syn_net/models/act.py b/syn_net/models/act.py deleted file mode 100644 index b2cd6c9f..00000000 --- a/syn_net/models/act.py +++ /dev/null @@ -1,97 +0,0 @@ -""" -Action network. -""" -import time -import torch -import pytorch_lightning as pl -from pytorch_lightning import loggers as pl_loggers -from syn_net.models.mlp import MLP, load_array -from scipy import sparse - - -if __name__ == '__main__': - - import argparse - parser = argparse.ArgumentParser() - parser.add_argument("-f", "--featurize", type=str, default='fp', - help="Choose from ['fp', 'gin']") - parser.add_argument("-r", "--rxn_template", type=str, default='hb', - help="Choose from ['hb', 'pis']") - parser.add_argument("--radius", type=int, default=2, - help="Radius for Morgan fingerprint.") - parser.add_argument("--nbits", type=int, default=4096, - help="Number of Bits for Morgan fingerprint.") - parser.add_argument("--out_dim", type=int, default=300, - help="Output dimension.") - parser.add_argument("--ncpu", type=int, default=8, - help="Number of cpus") - parser.add_argument("--batch_size", type=int, default=64, - help="Batch size") - parser.add_argument("--epoch", type=int, default=2000, - help="Maximum number of epoches.") - args = parser.parse_args() - - if args.out_dim == 300: - validation_option = 'nn_accuracy_gin' - elif args.out_dim == 4096: - validation_option = 'nn_accuracy_fp_4096' - elif args.out_dim == 256: - validation_option = 'nn_accuracy_fp_256' - elif args.out_dim == 200: - validation_option = 'nn_accuracy_rdkit2d' - else: - raise ValueError - - main_dir = f'/pool001/whgao/data/synth_net/{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_{validation_option[12:]}/' - batch_size = args.batch_size - ncpu = args.ncpu - - X = sparse.load_npz(main_dir + 'X_act_train.npz') - y = sparse.load_npz(main_dir + 'y_act_train.npz') - X = torch.Tensor(X.A) - y = torch.LongTensor(y.A.reshape(-1, )) - train_data_iter = load_array((X, y), batch_size, ncpu=ncpu, is_train=True) - - X = sparse.load_npz(main_dir + 'X_act_valid.npz') - y = sparse.load_npz(main_dir + 'y_act_valid.npz') - X = torch.Tensor(X.A) - y = torch.LongTensor(y.A.reshape(-1, )) - valid_data_iter = load_array((X, y), batch_size, ncpu=ncpu, is_train=False) - - pl.seed_everything(0) - if args.featurize == 'fp': - mlp = MLP(input_dim=int(3 * args.nbits), - output_dim=4, - hidden_dim=1000, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='classification', - loss='cross_entropy', - valid_loss='accuracy', - optimizer='adam', - learning_rate=1e-4, - val_freq=10, - ncpu=ncpu) - elif args.featurize == 'gin': - mlp = MLP(input_dim=int(2 * args.nbits + args.out_dim), - output_dim=4, - hidden_dim=1000, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='classification', - loss='cross_entropy', - valid_loss='accuracy', - optimizer='adam', - learning_rate=1e-4, - val_freq=10, - ncpu=ncpu) - - tb_logger = pl_loggers.TensorBoardLogger(f'act_{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_logs/') - trainer = pl.Trainer(gpus=[0], max_epochs=args.epoch, progress_bar_refresh_rate=20, logger=tb_logger) - t = time.time() - trainer.fit(mlp, train_data_iter, valid_data_iter) - print(time.time() - t, 's') - - print('Finish!') diff --git a/syn_net/models/mlp.py b/syn_net/models/mlp.py deleted file mode 100644 index 4db3574f..00000000 --- a/syn_net/models/mlp.py +++ /dev/null @@ -1,172 +0,0 @@ -""" -Multi-layer perceptron (MLP) class. -""" -import time -import torch -from torch import nn -import torch.nn.functional as F -import pytorch_lightning as pl -from pytorch_lightning import loggers as pl_loggers -from sklearn.neighbors import BallTree -import numpy as np - - -class MLP(pl.LightningModule): - - def __init__(self, input_dim=3072, - output_dim=4, - hidden_dim=1000, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='classification', - loss='cross_entropy', - valid_loss='accuracy', - optimizer='adam', - learning_rate=1e-4, - val_freq=10, - ncpu=16): - super().__init__() - - self.loss = loss - self.valid_loss = valid_loss - self.optimizer = optimizer - self.learning_rate = learning_rate - self.ncpu = ncpu - self.val_freq = val_freq - - modules = [] - modules.append(nn.Linear(input_dim, hidden_dim)) - modules.append(nn.BatchNorm1d(hidden_dim)) - modules.append(nn.ReLU()) - - for i in range(num_layers-2): - modules.append(nn.Linear(hidden_dim, hidden_dim)) - modules.append(nn.BatchNorm1d(hidden_dim)) - modules.append(nn.ReLU()) - if i > num_layers - 3 - num_dropout_layers: - modules.append(nn.Dropout(dropout)) - - modules.append(nn.Linear(hidden_dim, output_dim)) - if task == 'classification': - modules.append(nn.Softmax()) - - self.layers = nn.Sequential(*modules) - - def forward(self, x): - return self.layers(x) - - def training_step(self, batch, batch_idx): - x, y = batch - y_hat = self.layers(x) - if self.loss == 'cross_entropy': - loss = F.cross_entropy(y_hat, y) - elif self.loss == 'mse': - loss = F.mse_loss(y_hat, y) - elif self.loss == 'l1': - loss = F.l1_loss(y_hat, y) - elif self.loss == 'huber': - loss = F.huber_loss(y_hat, y) - else: - raise ValueError('Not specified loss function') - self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) - return loss - - def validation_step(self, batch, batch_idx): - if self.trainer.current_epoch % self.val_freq == 0: - out_feat = self.valid_loss[12:] - if out_feat == 'gin': - bb_emb_gin = np.load('/pool001/whgao/data/synth_net/st_hb/enamine_us_emb_gin.npy') - kdtree = BallTree(bb_emb_gin, metric='euclidean') - elif out_feat == 'fp_4096': - bb_emb_fp_4096 = np.load('/pool001/whgao/data/synth_net/st_hb/enamine_us_emb_fp_4096.npy') - kdtree = BallTree(bb_emb_fp_4096, metric='euclidean') - elif out_feat == 'fp_256': - bb_emb_fp_256 = np.load('/pool001/whgao/data/synth_net/st_hb/enamine_us_emb_fp_256.npy') - kdtree = BallTree(bb_emb_fp_256, metric=cosine_distance) - elif out_feat == 'rdkit2d': - bb_emb_rdkit2d = np.load('/pool001/whgao/data/synth_net/st_hb/enamine_us_emb_rdkit2d.npy') - kdtree = BallTree(bb_emb_rdkit2d, metric='euclidean') - x, y = batch - y_hat = self.layers(x) - if self.valid_loss == 'cross_entropy': - loss = F.cross_entropy(y_hat, y) - elif self.valid_loss == 'accuracy': - y_hat = torch.argmax(y_hat, axis=1) - loss = 1 - (sum(y_hat == y) / len(y)) - elif self.valid_loss[:11] == 'nn_accuracy': - y = nn_search_list(y.detach().cpu().numpy(), out_feat=out_feat, kdtree=kdtree) - y_hat = nn_search_list(y_hat.detach().cpu().numpy(), out_feat=out_feat, kdtree=kdtree) - loss = 1 - (sum(y_hat == y) / len(y)) - # import ipdb; ipdb.set_trace(context=11) - elif self.valid_loss == 'mse': - loss = F.mse_loss(y_hat, y) - elif self.valid_loss == 'l1': - loss = F.l1_loss(y_hat, y) - elif self.valid_loss == 'huber': - loss = F.huber_loss(y_hat, y) - else: - raise ValueError('Not specified validation loss function') - self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) - else: - pass - - def configure_optimizers(self): - if self.optimizer == 'adam': - optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) - elif self.optimizer == 'sgd': - optimizer = torch.optim.SGD(self.parameters(), lr=self.learning_rate) - return optimizer - -def load_array(data_arrays, batch_size, is_train=True, ncpu=-1): - dataset = torch.utils.data.TensorDataset(*data_arrays) - return torch.utils.data.DataLoader(dataset, batch_size, shuffle=is_train, num_workers=ncpu) - -def cosine_distance(v1, v2, eps=1e-15): - return 1 - np.dot(v1, v2) / (np.linalg.norm(v1, ord=2) * np.linalg.norm(v2, ord=2) + eps) - -def nn_search(_e, _tree, _k=1): - dist, ind = _tree.query(_e, k=_k) - return ind[0][0] - -def nn_search_list(y, out_feat, kdtree): - if out_feat == 'gin': - return np.array([nn_search(emb.reshape(1, -1), _tree=kdtree) for emb in y]) - elif out_feat == 'fp_4096': - return np.array([nn_search(emb.reshape(1, -1), _tree=kdtree) for emb in y]) - elif out_feat == 'fp_256': - return np.array([nn_search(emb.reshape(1, -1), _tree=kdtree) for emb in y]) - elif out_feat == 'rdkit2d': - return np.array([nn_search(emb.reshape(1, -1), _tree=kdtree) for emb in y]) - else: - raise ValueError - - -if __name__ == '__main__': - - states_list = [] - steps_list = [] - for i in range(1): - states_list.append(np.load('/home/rociomer/data/synth_net/pis_fp/states_' + str(i) + '_valid.npz', allow_pickle=True)) - steps_list.append(np.load('/home/rociomer/data/synth_net/pis_fp/steps_' + str(i) + '_valid.npz', allow_pickle=True)) - - states = np.concatenate(states_list, axis=0) - steps = np.concatenate(steps_list, axis=0) - - X = states - y = steps[:, 0] - - X_train = torch.Tensor(X) - y_train = torch.LongTensor(y) - - batch_size = 64 - train_data_iter = load_array((X_train, y_train), batch_size, is_train=True) - - pl.seed_everything(0) - mlp = MLP() - tb_logger = pl_loggers.TensorBoardLogger('temp_logs/') - - trainer = pl.Trainer(gpus=[0], max_epochs=30, progress_bar_refresh_rate=20, logger=tb_logger) - t = time.time() - trainer.fit(mlp, train_data_iter, train_data_iter) - print(time.time() - t, 's') diff --git a/syn_net/models/prepare_data.py b/syn_net/models/prepare_data.py deleted file mode 100644 index 04400c2d..00000000 --- a/syn_net/models/prepare_data.py +++ /dev/null @@ -1,46 +0,0 @@ -""" -Prepares the training, testing, and validation data by reading in the states -and steps for the reaction data and re-writing it as separate one-hot encoded -Action, Reactant 1, Reactant 2, and Reaction files. -""" -from syn_net.utils.prep_utils import prep_data - - -if __name__ == '__main__': - - import argparse - parser = argparse.ArgumentParser() - parser.add_argument("-f", "--featurize", type=str, default='fp', - help="Choose from ['fp', 'gin']") - parser.add_argument("-r", "--rxn_template", type=str, default='hb', - help="Choose from ['hb', 'pis']") - parser.add_argument("--radius", type=int, default=2, - help="Radius for Morgan fingerprint.") - parser.add_argument("--nbits", type=int, default=4096, - help="Number of Bits for Morgan fingerprint.") - parser.add_argument("--outputembedding", type=str, default='gin', - help="Choose from ['fp_4096', 'fp_256', 'gin', 'rdkit2d']") - args = parser.parse_args() - rxn_template = args.rxn_template - featurize = args.featurize - output_emb = args.outputembedding - - main_dir = '/pool001/whgao/data/synth_net/' + rxn_template + '_' + featurize + '_' + str(args.radius) + '_' + str(args.nbits) + '_' + str(args.outputembedding) + '/' - if rxn_template == 'hb': - num_rxn = 91 - elif rxn_template == 'pis': - num_rxn = 4700 - - if output_emb == 'gin': - out_dim = 300 - elif output_emb == 'rdkit2d': - out_dim = 200 - elif output_emb == 'fp_4096': - out_dim = 4096 - elif output_emb == 'fp_256': - out_dim = 256 - - prep_data(main_dir=main_dir, out_dim=out_dim) - - - print('Finish!') diff --git a/syn_net/models/rt1.py b/syn_net/models/rt1.py deleted file mode 100644 index 1fe8026c..00000000 --- a/syn_net/models/rt1.py +++ /dev/null @@ -1,100 +0,0 @@ -""" -Reactant1 network (for predicting 1st reactant). -""" -import time -import numpy as np -import torch -import pytorch_lightning as pl -from pytorch_lightning import loggers as pl_loggers -from syn_net.models.mlp import MLP, load_array -from scipy import sparse - - -if __name__ == '__main__': - - import argparse - parser = argparse.ArgumentParser() - parser.add_argument("-f", "--featurize", type=str, default='fp', - help="Choose from ['fp', 'gin']") - parser.add_argument("-r", "--rxn_template", type=str, default='hb', - help="Choose from ['hb', 'pis']") - parser.add_argument("--radius", type=int, default=2, - help="Radius for Morgan fingerprint.") - parser.add_argument("--nbits", type=int, default=4096, - help="Number of Bits for Morgan fingerprint.") - parser.add_argument("--out_dim", type=int, default=256, - help="Output dimension.") - parser.add_argument("--ncpu", type=int, default=8, - help="Number of cpus") - parser.add_argument("--batch_size", type=int, default=64, - help="Batch size") - parser.add_argument("--epoch", type=int, default=2000, - help="Maximum number of epoches.") - args = parser.parse_args() - - if args.out_dim == 300: - validation_option = 'nn_accuracy_gin' - elif args.out_dim == 4096: - validation_option = 'nn_accuracy_fp_4096' - elif args.out_dim == 256: - validation_option = 'nn_accuracy_fp_256' - elif args.out_dim == 200: - validation_option = 'nn_accuracy_rdkit2d' - else: - raise ValueError - - main_dir = f'/pool001/whgao/data/synth_net/{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_{validation_option[12:]}/' - batch_size = args.batch_size - ncpu = args.ncpu - - X = sparse.load_npz(main_dir + 'X_rt1_train.npz') - y = sparse.load_npz(main_dir + 'y_rt1_train.npz') - X = torch.Tensor(X.A) - y = torch.Tensor(y.A) - train_data_iter = load_array((X, y), batch_size, ncpu=ncpu, is_train=True) - - X = sparse.load_npz(main_dir + 'X_rt1_valid.npz') - y = sparse.load_npz(main_dir + 'y_rt1_valid.npz') - X = torch.Tensor(X.A) - y = torch.Tensor(y.A) - _idx = np.random.choice(list(range(X.shape[0])), size=int(X.shape[0]/10), replace=False) - valid_data_iter = load_array((X[_idx], y[_idx]), batch_size, ncpu=ncpu, is_train=False) - - pl.seed_everything(0) - if args.featurize == 'fp': - mlp = MLP(input_dim=int(3 * args.nbits), - output_dim=args.out_dim, - hidden_dim=1200, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='regression', - loss='mse', - valid_loss=validation_option, - optimizer='adam', - learning_rate=1e-4, - val_freq=10, - ncpu=ncpu) - elif args.featurize == 'gin': - mlp = MLP(input_dim=int(2 * args.nbits + args.out_dim), - output_dim=args.out_dim, - hidden_dim=1200, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='regression', - loss='mse', - valid_loss=validation_option, - optimizer='adam', - learning_rate=1e-4, - val_freq=10, - ncpu=ncpu) - tb_logger = pl_loggers.TensorBoardLogger( - f'rt1_{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_{validation_option[12:]}_logs/' - ) - - trainer = pl.Trainer(gpus=[0], max_epochs=args.epoch, progress_bar_refresh_rate=20, logger=tb_logger) - t = time.time() - trainer.fit(mlp, train_data_iter, valid_data_iter) - print(time.time() - t, 's') - print('Finish!') diff --git a/syn_net/models/rt2.py b/syn_net/models/rt2.py deleted file mode 100644 index 40ca5237..00000000 --- a/syn_net/models/rt2.py +++ /dev/null @@ -1,133 +0,0 @@ -""" -Reactant2 network (for predicting 2nd reactant). -""" -import time -import numpy as np -import torch -import pytorch_lightning as pl -from pytorch_lightning import loggers as pl_loggers -from syn_net.models.mlp import MLP, load_array -from scipy import sparse - - -if __name__ == '__main__': - - import argparse - parser = argparse.ArgumentParser() - parser.add_argument("-f", "--featurize", type=str, default='fp', - help="Choose from ['fp', 'gin']") - parser.add_argument("-r", "--rxn_template", type=str, default='hb', - help="Choose from ['hb', 'pis']") - parser.add_argument("--radius", type=int, default=2, - help="Radius for Morgan fingerprint.") - parser.add_argument("--nbits", type=int, default=4096, - help="Number of Bits for Morgan fingerprint.") - parser.add_argument("--out_dim", type=int, default=256, - help="Output dimension.") - parser.add_argument("--ncpu", type=int, default=8, - help="Number of cpus") - parser.add_argument("--batch_size", type=int, default=64, - help="Batch size") - parser.add_argument("--epoch", type=int, default=2000, - help="Maximum number of epoches.") - args = parser.parse_args() - - if args.out_dim == 300: - validation_option = 'nn_accuracy_gin' - elif args.out_dim == 4096: - validation_option = 'nn_accuracy_fp_4096' - elif args.out_dim == 256: - validation_option = 'nn_accuracy_fp_256' - elif args.out_dim == 200: - validation_option = 'nn_accuracy_rdkit2d' - else: - raise ValueError - - main_dir = f'/pool001/whgao/data/synth_net/{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_{validation_option[12:]}/' - batch_size = args.batch_size - ncpu = args.ncpu - - - X = sparse.load_npz(main_dir + 'X_rt2_train.npz') - y = sparse.load_npz(main_dir + 'y_rt2_train.npz') - X = torch.Tensor(X.A) - y = torch.Tensor(y.A) - train_data_iter = load_array((X, y), batch_size, ncpu=ncpu, is_train=True) - - X = sparse.load_npz(main_dir + 'X_rt2_valid.npz') - y = sparse.load_npz(main_dir + 'y_rt2_valid.npz') - X = torch.Tensor(X.A) - y = torch.Tensor(y.A) - _idx = np.random.choice(list(range(X.shape[0])), size=int(X.shape[0]/10), replace=False) - valid_data_iter = load_array((X[_idx], y[_idx]), batch_size, ncpu=ncpu, is_train=False) - - pl.seed_everything(0) - if args.featurize == 'fp': - if args.rxn_template == 'hb': - mlp = MLP(input_dim=int(4 * args.nbits + 91), - output_dim=args.out_dim, - hidden_dim=3000, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='regression', - loss='mse', - valid_loss=validation_option, - optimizer='adam', - learning_rate=1e-4, - val_freq=10, - ncpu=ncpu) - elif args.rxn_template == 'pis': - mlp = MLP(input_dim=int(4 * args.nbits + 4700), - output_dim=args.out_dim, - hidden_dim=3000, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='regression', - loss='mse', - valid_loss=validation_option, - optimizer='adam', - learning_rate=1e-4, - val_freq=10, - ncpu=ncpu) - elif args.featurize == 'gin': - if args.rxn_template == 'hb': - mlp = MLP(input_dim=int(3 * args.nbits + args.out_dim + 91), - output_dim=args.out_dim, - hidden_dim=3000, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='regression', - loss='mse', - valid_loss=validation_option, - optimizer='adam', - learning_rate=1e-4, - val_freq=10, - ncpu=ncpu) - elif args.rxn_template == 'pis': - mlp = MLP(input_dim=int(3 * args.nbits + args.out_dim + 4700), - output_dim=args.out_dim, - hidden_dim=3000, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='regression', - loss='mse', - valid_loss=validation_option, - optimizer='adam', - learning_rate=1e-4, - val_freq=10, - ncpu=ncpu) - - tb_logger = pl_loggers.TensorBoardLogger( - f'rt2_{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_{validation_option[12:]}_logs/' - ) - - trainer = pl.Trainer(gpus=[0], max_epochs=args.epoch, progress_bar_refresh_rate=20, logger=tb_logger) - t = time.time() - trainer.fit(mlp, train_data_iter, valid_data_iter) - print(time.time() - t, 's') - - print('Finish!') diff --git a/syn_net/models/rxn.py b/syn_net/models/rxn.py deleted file mode 100644 index 69f7ce87..00000000 --- a/syn_net/models/rxn.py +++ /dev/null @@ -1,170 +0,0 @@ -""" -Reaction network. -""" -import time -import torch -import pytorch_lightning as pl -from pytorch_lightning import loggers as pl_loggers -from syn_net.models.mlp import MLP, load_array -from scipy import sparse - - -if __name__ == '__main__': - - import argparse - parser = argparse.ArgumentParser() - parser.add_argument("-f", "--featurize", type=str, default='fp', - help="Choose from ['fp', 'gin']") - parser.add_argument("-r", "--rxn_template", type=str, default='hb', - help="Choose from ['hb', 'pis']") - parser.add_argument("--radius", type=int, default=2, - help="Radius for Morgan fingerprint.") - parser.add_argument("--nbits", type=int, default=4096, - help="Number of Bits for Morgan fingerprint.") - parser.add_argument("--out_dim", type=int, default=300, - help="Output dimension.") - parser.add_argument("--ncpu", type=int, default=8, - help="Number of cpus") - parser.add_argument("--batch_size", type=int, default=64, - help="Batch size") - parser.add_argument("--epoch", type=int, default=2000, - help="Maximum number of epochs.") - parser.add_argument("--restart", type=bool, default=False, - help="Indicates whether to restart training.") - parser.add_argument("-v", "--version", type=int, default=1, - help="Version") - args = parser.parse_args() - - if args.out_dim == 300: - validation_option = 'nn_accuracy_gin' - elif args.out_dim == 4096: - validation_option = 'nn_accuracy_fp_4096' - elif args.out_dim == 256: - validation_option = 'nn_accuracy_fp_256' - elif args.out_dim == 200: - validation_option = 'nn_accuracy_rdkit2d' - else: - raise ValueError - - main_dir = f'/pool001/whgao/data/synth_net/{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_{validation_option[12:]}/' - batch_size = args.batch_size - ncpu = args.ncpu - - X = sparse.load_npz(main_dir + 'X_rxn_train.npz') - y = sparse.load_npz(main_dir + 'y_rxn_train.npz') - X = torch.Tensor(X.A) - y = torch.LongTensor(y.A.reshape(-1, )) - train_data_iter = load_array((X, y), batch_size, ncpu=ncpu, is_train=True) - - X = sparse.load_npz(main_dir + 'X_rxn_valid.npz') - y = sparse.load_npz(main_dir + 'y_rxn_valid.npz') - X = torch.Tensor(X.A) - y = torch.LongTensor(y.A.reshape(-1, )) - valid_data_iter = load_array((X, y), batch_size, ncpu=ncpu, is_train=False) - - pl.seed_everything(0) - param_path = f'/pool001/rociomer/data/pre-trained-models/{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_v{args.version}/' - path_to_rxn = f'{param_path}rxn.ckpt' - if not args.restart: - - if args.featurize == 'fp': - if args.rxn_template == 'hb': - mlp = MLP(input_dim=int(4 * args.nbits), - output_dim=91, - hidden_dim=3000, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='classification', - loss='cross_entropy', - valid_loss='accuracy', - optimizer='adam', - learning_rate=1e-4, - val_freq=10, - ncpu=ncpu) - elif args.rxn_template == 'pis': - mlp = MLP(input_dim=int(4 * args.nbits), - output_dim=4700, - hidden_dim=4500, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='classification', - loss='cross_entropy', - valid_loss='accuracy', - optimizer='adam', - learning_rate=1e-4, - val_freq=10, - ncpu=ncpu) - elif args.featurize == 'gin': - if args.rxn_template == 'hb': - mlp = MLP(input_dim=int(3 * args.nbits + args.out_dim), - output_dim=91, - hidden_dim=3000, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='classification', - loss='cross_entropy', - valid_loss='accuracy', - optimizer='adam', - learning_rate=1e-4, - val_freq=10, - ncpu=ncpu) - elif args.rxn_template == 'pis': - mlp = MLP(input_dim=int(3 * args.nbits + args.out_dim), - output_dim=4700, - hidden_dim=3000, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='classification', - loss='cross_entropy', - valid_loss='accuracy', - optimizer='adam', - learning_rate=1e-4, - val_freq=10, - ncpu=ncpu) - else: - if args.rxn_template == 'hb': - mlp = MLP.load_from_checkpoint( - path_to_rxn, - input_dim=int(4 * args.nbits), - output_dim=91, - hidden_dim=3000, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='classification', - loss='cross_entropy', - valid_loss='accuracy', - optimizer='adam', - learning_rate=1e-4, - ncpu=ncpu - ) - elif args.rxn_template == 'pis': - mlp = MLP.load_from_checkpoint( - path_to_rxn, - input_dim=int(4 * args.nbits), - output_dim=4700, - hidden_dim=4500, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='classification', - loss='cross_entropy', - valid_loss='accuracy', - optimizer='adam', - learning_rate=1e-4, - ncpu=ncpu - ) - - tb_logger = pl_loggers.TensorBoardLogger(f'rxn_{args.rxn_template}_{args.featurize}_{args.radius}_{args.nbits}_logs/') - trainer = pl.Trainer(gpus=[0], max_epochs=args.epoch, progress_bar_refresh_rate=20, logger=tb_logger) - t = time.time() - - trainer.fit(mlp, train_data_iter, valid_data_iter) - - print(time.time() - t, 's') - - print('Finish!') diff --git a/syn_net/utils/data_utils.py b/syn_net/utils/data_utils.py deleted file mode 100644 index 7d055c34..00000000 --- a/syn_net/utils/data_utils.py +++ /dev/null @@ -1,824 +0,0 @@ -""" -Here we define the following classes for working with synthetic tree data: -* `Reaction` -* `ReactionSet` -* `NodeChemical` -* `NodeRxn` -* `SyntheticTree` -* `SyntheticTreeSet` -""" -import gzip -import json -import pandas as pd -from tqdm import tqdm -import rdkit.Chem as Chem -from rdkit.Chem import Draw -from rdkit.Chem import AllChem -from rdkit.Chem import rdChemReactions - -# the definition of reaction classes below -class Reaction: - """ - This class models a chemical reaction based on a SMARTS transformation. - - Args: - template (str): SMARTS string representing a chemical reaction. - rxnname (str): The name of the reaction for downstream analysis. - smiles: (str): A reaction SMILES string that macthes the SMARTS pattern. - reference (str): Reference information for the reaction. - """ - def __init__(self, template=None, rxnname=None, smiles=None, reference=None): - - if template is not None: - # define a few attributes based on the input - self.smirks = template - self.rxnname = rxnname - self.smiles = smiles - self.reference = reference - - # compute a few additional attributes - rxn = AllChem.ReactionFromSmarts(self.smirks) - rdChemReactions.ChemicalReaction.Initialize(rxn) - self.num_reactant = rxn.GetNumReactantTemplates() - if self.num_reactant == 0 or self.num_reactant > 2: - raise ValueError('This reaction is neither uni- nor bi-molecular.') - self.num_agent = rxn.GetNumAgentTemplates() - self.num_product = rxn.GetNumProductTemplates() - if self.num_reactant == 1: - self.reactant_template = list((self.smirks.split('>')[0], )) - else: - self.reactant_template = list((self.smirks.split('>')[0].split('.')[0], self.smirks.split('>')[0].split('.')[1])) - self.product_template = self.smirks.split('>')[2] - self.agent_template = self.smirks.split('>')[1] - - del rxn - else: - self.smirks = None - - def load(self, smirks, num_reactant, num_agent, num_product, reactant_template, - product_template, agent_template, available_reactants, rxnname, smiles, reference): - """ - This function loads a set of elements and reconstructs a `Reaction` object. - """ - self.smirks = smirks - self.num_reactant = num_reactant - self.num_agent = num_agent - self.num_product = num_product - self.reactant_template = list(reactant_template) - self.product_template = product_template - self.agent_template = agent_template - self.available_reactants = list(available_reactants) - self.rxnname = rxnname - self.smiles = smiles - self.reference = reference - - def get_mol(self, smi): - """ - A internal function that returns an `RDKit.Chem.Mol` object. - - Args: - smi (str or RDKit.Chem.Mol): The query molecule, as either a SMILES - string or an `RDKit.Chem.Mol` object. - - Returns: - RDKit.Chem.Mol - """ - if isinstance(smi, str): - return Chem.MolFromSmiles(smi) - elif isinstance(smi, Chem.Mol): - return smi - else: - raise TypeError('The input should be either a SMILES string or an ' - 'RDKit.Chem.Mol object.') - - def visualize(self, name='./reaction1_highlight.o.png'): - """ - A function that plots the chemical translation into a PNG figure. - One can use "from IPython.display import Image ; Image(name)" to see it - in a Python notebook. - - Args: - name (str): The path to the figure. - - Returns: - name (str): The path to the figure. - """ - rxn = AllChem.ReactionFromSmarts(self.smirks) - d2d = Draw.MolDraw2DCairo(800,300) - d2d.DrawReaction(rxn, highlightByReactant=True) - png = d2d.GetDrawingText() - open(name,'wb+').write(png) - del rxn - return name - - def is_reactant(self, smi): - """ - A function that checks if a molecule is a reactant of the reaction - defined by the `Reaction` object. - - Args: - smi (str or RDKit.Chem.Mol): The query molecule, as either a SMILES - string or an `RDKit.Chem.Mol` object. - - Returns: - result (bool): Indicates if the molecule is a reactant of the reaction. - """ - rxn = self.get_rxnobj() - smi = self.get_mol(smi) - result = rxn.IsMoleculeReactant(smi) - del rxn - return result - - def is_agent(self, smi): - """ - A function that checks if a molecule is an agent in the reaction defined - by the `Reaction` object. - - Args: - smi (str or RDKit.Chem.Mol): The query molecule, as either a SMILES - string or an `RDKit.Chem.Mol` object. - - Returns: - result (bool): Indicates if the molecule is an agent in the reaction. - """ - rxn = self.get_rxnobj() - smi = self.get_mol(smi) - result = rxn.IsMoleculeAgent(smi) - del rxn - return result - - def is_product(self, smi): - """ - A function that checks if a molecule is the product in the reaction defined - by the `Reaction` object. - - Args: - smi (str or RDKit.Chem.Mol): The query molecule, as either a SMILES - string or an `RDKit.Chem.Mol` object. - - Returns: - result (bool): Indicates if the molecule is the product in the reaction. - """ - rxn = self.get_rxnobj() - smi = self.get_mol(smi) - result = rxn.IsMoleculeProduct(smi) - del rxn - return result - - def is_reactant_first(self, smi): - """ - A function that checks if a molecule is the first reactant in the reaction - defined by the `Reaction` object, where the order of the reactants is - determined by the SMARTS pattern. - - Args: - smi (str or RDKit.Chem.Mol): The query molecule, as either a SMILES - string or an `RDKit.Chem.Mol` object. - - Returns: - result (bool): Indicates if the molecule is the first reactant in - the reaction. - """ - smi = self.get_mol(smi) - if smi.HasSubstructMatch(Chem.MolFromSmarts(self.get_reactant_template(0))): - return True - else: - return False - - def is_reactant_second(self, smi): - """ - A function that checks if a molecule is the second reactant in the reaction - defined by the `Reaction` object, where the order of the reactants is - determined by the SMARTS pattern. - - Args: - smi (str or RDKit.Chem.Mol): The query molecule, as either a SMILES - string or an `RDKit.Chem.Mol` object. - - Returns: - result (bool): Indicates if the molecule is the second reactant in - the reaction. - """ - smi = self.get_mol(smi) - if smi.HasSubstructMatch(Chem.MolFromSmarts(self.get_reactant_template(1))): - return True - else: - return False - - def get_smirks(self): - """ - A function that returns the SMARTS pattern which represents the reaction. - - Returns: - self.smirks (str): SMARTS pattern representing the reaction. - """ - return self.smirks - - def get_rxnobj(self): - """ - A function that returns the RDKit Reaction object. - - Returns: - rxn (rdChem.Reactions.ChemicalReaction): RDKit reaction object. - """ - rxn = AllChem.ReactionFromSmarts(self.smirks) - rdChemReactions.ChemicalReaction.Initialize(rxn) - return rxn - - def get_reactant_template(self, ind=0): - """ - A function that returns the SMARTS pattern which represents the specified - reactant. - - Args: - ind (int): The index of the reactant. Defaults to 0. - - Returns: - reactant_template (str): SMARTS pattern representing the reactant. - """ - return self.reactant_template[ind] - - def get_product_template(self): - """ - A function that returns the SMARTS pattern which represents the product. - - Returns: - product_template (str): SMARTS pattern representing the product. - """ - return self.product_template - - def run_reaction(self, reactants, keep_main=True): - """ - A function that transform the reactants into the corresponding product. - - Args: - reactants (list): Contains SMILES strings for the reactants. - keep_main (bool): Indicates whether to return only the main product, - or all possible products. Defaults to True. - - Returns: - uniqps (str): SMILES string representing the product. - """ - rxn = self.get_rxnobj() - - if self.num_reactant == 1: - - if isinstance(reactants, (tuple, list)): - if len(reactants) == 1: - r = self.get_mol(reactants[0]) - elif len(reactants) == 2 and reactants[1] is None: - r = self.get_mol(reactants[0]) - else: - return None - - elif isinstance(reactants, (str, Chem.Mol)): - r = self.get_mol(reactants) - else: - raise TypeError('The input of a uni-molecular reaction should ' - 'be a SMILES, an rdkit.Chem.Mol object, or a ' - 'tuple/list of length 1 or 2.') - - if not self.is_reactant(r): - return None - - ps = rxn.RunReactants((r, )) - - elif self.num_reactant == 2: - if isinstance(reactants, (tuple, list)) and len(reactants) == 2: - r1 = self.get_mol(reactants[0]) - r2 = self.get_mol(reactants[1]) - else: - raise TypeError('The input of a bi-molecular reaction should ' - 'be a tuple/list of length 2.') - - if self.is_reactant_first(r1) and self.is_reactant_second(r2): - pass - elif self.is_reactant_first(r2) and self.is_reactant_second(r1): - r1, r2 = (r2, r1) - else: - return None - - ps = rxn.RunReactants((r1, r2)) - - else: - raise ValueError('This reaction is neither uni- nor bi-molecular.') - - uniqps = [] - for p in ps: - smi = Chem.MolToSmiles(p[0]) - uniqps.append(smi) - - uniqps = list(set(uniqps)) - - assert len(uniqps) >= 1 - - del rxn - - if keep_main: - return uniqps[0] - else: - return uniqps - - def _filter_reactants(self, smi_list): - """ - Filters reactants which do not match the reaction. - - Args: - smi_list (list): Contains SMILES to search through for matches. - - Raises: - ValueError: Raised if the `Reaction` object does not describe a uni- - or bi-molecular reaction. - - Returns: - tuple: Contains list(s) of SMILES which match either the first - reactant, or, if applicable, the second reactant. - """ - if self.num_reactant == 1: # uni-molecular reaction - smi_w_patt = [] - for smi in tqdm(smi_list): - if self.is_reactant_first(smi): - smi_w_patt.append(smi) - return (smi_w_patt, ) - - elif self.num_reactant == 2: # bi-molecular reaction - smi_w_patt1 = [] - smi_w_patt2 = [] - for smi in tqdm(smi_list): - if self.is_reactant_first(smi): - smi_w_patt1.append(smi) - if self.is_reactant_second(smi): - smi_w_patt2.append(smi) - return (smi_w_patt1, smi_w_patt2) - else: - raise ValueError('This reaction is neither uni- nor bi-molecular.') - - def set_available_reactants(self, building_block_list): - """ - A function that finds the applicable building blocks from a list of - purchasable building blocks. - - Args: - building_block_list (list): The list of purchasable building blocks, - where building blocks are represented as SMILES strings. - """ - self.available_reactants = list(self._filter_reactants(building_block_list)) - return None - - -class ReactionSet: - """ - A class representing a set of reactions, for saving and loading purposes. - - Arritbutes: - rxns (list or None): Contains `Reaction` objects. One can initialize the - class with a list or None object, the latter of which is used to - define an empty list. - """ - def __init__(self, rxns=None): - if rxns is None: - self.rxns = [] - else: - self.rxns = rxns - - def load(self, json_file): - """ - A function that loads reactions from a JSON-formatted file. - - Args: - json_file (str): The path to the stored reaction file. - """ - - with gzip.open(json_file, 'r') as f: - data = json.loads(f.read().decode('utf-8')) - - for r_dict in data['reactions']: - r = Reaction() - r.load(**r_dict) - self.rxns.append(r) - - def save(self, json_file): - """ - A function that saves the reaction set to a JSON-formatted file. - - Args: - json_file (str): The path to the stored reaction file. - """ - r_list = {'reactions': [r.__dict__ for r in self.rxns]} - with gzip.open(json_file, 'w') as f: - f.write(json.dumps(r_list).encode('utf-8')) - - def __len__(self): - return len(self.rxns) - - def _print(self, x=3): - # For debugging - for i, r in enumerate(self.rxns): - if i >= x: - break - print(r.__dict__) - - -# the definition of classes for defining synthetic trees below -class NodeChemical: - """ - A class representing a chemical node in a synthetic tree. - - Args: - smiles (None or str): SMILES string representing molecule. - parent (None or int): - child (None or int): Indicates reaction which molecule participates in. - is_leaf (bool): Indicates if this is a leaf node. - is_root (bool): Indicates if this is a root node. - depth (float): - index (int): Indicates the order of this chemical node in the tree. - """ - def __init__(self, smiles=None, parent=None, child=None, is_leaf=False, - is_root=False, depth=0, index=0): - self.smiles = smiles - self.parent = parent - self.child = child - self.is_leaf = is_leaf - self.is_root = is_root - self.depth = depth - self.index = index - - -class NodeRxn: - """ - A class representing a reaction node in a synthetic tree. - - Args: - rxn_id (None or int): Index corresponding to reaction in a one-hot vector - of reaction templates. - rtype (None or int): Indicates if uni- (1) or bi-molecular (2) reaction. - parent (None or list): - child (None or list): Contains SMILES strings of reactants which lead to - the specified reaction. - depth (float): - index (int): Indicates the order of this reaction node in the tree. - """ - def __init__(self, rxn_id=None, rtype=None, parent=[], - child=None, depth=0, index=0): - self.rxn_id = rxn_id - self.rtype = rtype - self.parent = parent - self.child = child - self.depth = depth - self.index = index - - -class SyntheticTree: - """ - A class representing a synthetic tree. - - Args: - chemicals (list): A list of chemical nodes, in order of addition. - reactions (list): A list of reaction nodes, in order of addition. - actions (list): A list of actions, in order of addition. - root (NodeChemical): The root node. - depth (int): The depth of the tree. - rxn_id2type (dict): A dictionary that maps reaction indices to reaction - type (uni- or bi-molecular). - """ - def __init__(self, tree=None): - self.chemicals = [] - self.reactions = [] - self.root = None - self.depth = 0 - self.actions = [] - self.rxn_id2type = None - - if tree is not None: - self.read(tree) - - def read(self, data): - """ - A function that loads a dictionary from synthetic tree data. - - Args: - data (dict): A dictionary representing a synthetic tree. - """ - self.root = NodeChemical(**data['root']) - self.depth = data['depth'] - self.actions = data['actions'] - self.rxn_id2type = data['rxn_id2type'] - - for r_dict in data['reactions']: - r = NodeRxn(**r_dict) - self.reactions.append(r) - - for m_dict in data['chemicals']: - r = NodeChemical(**m_dict) - self.chemicals.append(r) - - def output_dict(self): - """ - A function that exports dictionary-formatted synthetic tree data. - - Returns: - data (dict): A dictionary representing a synthetic tree. - """ - return {'reactions': [r.__dict__ for r in self.reactions], - 'chemicals': [m.__dict__ for m in self.chemicals], - 'root': self.root.__dict__, - 'depth': self.depth, - 'actions': self.actions, - 'rxn_id2type': self.rxn_id2type} - - def _print(self): - """ - A function that prints the contents of the synthetic tree. - """ - print('===============Stored Molecules===============') - for node in self.chemicals: - print(node.smiles, node.is_root) - print('===============Stored Reactions===============') - for node in self.reactions: - print(node.rxn_id, node.rtype) - print('===============Followed Actions===============') - print(self.actions) - - def get_node_index(self, smi): - """ - Returns the index of the node matching the input SMILES. - - Args: - smi (str): A SMILES string that represents the query molecule. - - Returns: - index (int): Index of chemical node corresponding to the query - molecule. If the query moleucle is not in the tree, return None. - """ - for node in self.chemicals: - if smi == node.smiles: - return node.index - return None - - def get_state(self): - """ - Returns the state of the synthetic tree. The most recent root node has 0 - as its index. - - Returns: - state (list): A list contains all root node molecules. - """ - state = [] - for mol in self.chemicals: - if mol.is_root: - state.append(mol.smiles) - return state[::-1] - - def update(self, action, rxn_id, mol1, mol2, mol_product): - """ - A function that updates a synthetic tree by adding a reaction step. - - Args: - action (int): Action index, where the indices (0, 1, 2, 3) represent - (Add, Expand, Merge, and End), respectively. - rxn_id (int): Index of the reaction occured, where the index can be - anything in the range [0, len(template_list)-1]. - mol1 (str): SMILES string representing the first reactant. - mol2 (str): SMILES string representing the second reactant. - mol_product (str): SMILES string representing the product. - """ - self.actions.append(int(action)) - - if action == 3: - # End - self.root = self.chemicals[-1] - self.depth = self.root.depth - - elif action == 2: - # Merge with bi-mol rxn - node_mol1 = self.chemicals[self.get_node_index(mol1)] - node_mol2 = self.chemicals[self.get_node_index(mol2)] - node_rxn = NodeRxn(rxn_id=rxn_id, - rtype=2, - parent=None, - child=[node_mol1.smiles, node_mol2.smiles], - depth=max(node_mol1.depth, node_mol2.depth)+0.5, - index=len(self.reactions)) - node_product = NodeChemical(smiles=mol_product, - parent=None, - child=node_rxn.rxn_id, - is_leaf=False, - is_root=True, - depth=node_rxn.depth+0.5, - index=len(self.chemicals)) - - node_rxn.parent = node_product.smiles - node_mol1.parent = node_rxn.rxn_id - node_mol2.parent = node_rxn.rxn_id - node_mol1.is_root = False - node_mol2.is_root = False - - self.chemicals.append(node_product) - self.reactions.append(node_rxn) - - elif action == 1 and mol2 is None: - # Expand with uni-mol rxn - node_mol1 = self.chemicals[self.get_node_index(mol1)] - node_rxn = NodeRxn(rxn_id=rxn_id, - rtype=1, - parent=None, - child=[node_mol1.smiles], - depth=node_mol1.depth+0.5, - index=len(self.reactions)) - node_product = NodeChemical(smiles=mol_product, - parent=None, - child=node_rxn.rxn_id, - is_leaf=False, - is_root=True, - depth=node_rxn.depth+0.5, - index=len(self.chemicals)) - - node_rxn.parent = node_product.smiles - node_mol1.parent = node_rxn.rxn_id - node_mol1.is_root = False - - self.chemicals.append(node_product) - self.reactions.append(node_rxn) - - elif action == 1 and mol2 is not None: - # Expand with bi-mol rxn - node_mol1 = self.chemicals[self.get_node_index(mol1)] - node_mol2 = NodeChemical(smiles=mol2, - parent=None, - child=None, - is_leaf=True, - is_root=False, - depth=0, - index=len(self.chemicals)) - node_rxn = NodeRxn(rxn_id=rxn_id, - rtype=2, - parent=None, - child=[node_mol1.smiles, - node_mol2.smiles], - depth=max(node_mol1.depth, node_mol2.depth)+0.5, - index=len(self.reactions)) - node_product = NodeChemical(smiles=mol_product, - parent=None, - child=node_rxn.rxn_id, - is_leaf=False, - is_root=True, - depth=node_rxn.depth+0.5, - index=len(self.chemicals)+1) - - node_rxn.parent = node_product.smiles - node_mol1.parent = node_rxn.rxn_id - node_mol2.parent = node_rxn.rxn_id - node_mol1.is_root = False - - self.chemicals.append(node_mol2) - self.chemicals.append(node_product) - self.reactions.append(node_rxn) - - elif action == 0 and mol2 is None: - # Add with uni-mol rxn - node_mol1 = NodeChemical(smiles=mol1, - parent=None, - child=None, - is_leaf=True, - is_root=False, - depth=0, - index=len(self.chemicals)) - node_rxn = NodeRxn(rxn_id=rxn_id, - rtype=1, - parent=None, - child=[node_mol1.smiles], - depth=0.5, - index=len(self.reactions)) - node_product = NodeChemical(smiles=mol_product, - parent=None, - child=node_rxn.rxn_id, - is_leaf=False, - is_root=True, - depth=1, - index=len(self.chemicals)+1) - - node_rxn.parent = node_product.smiles - node_mol1.parent = node_rxn.rxn_id - - self.chemicals.append(node_mol1) - self.chemicals.append(node_product) - self.reactions.append(node_rxn) - - elif action == 0 and mol2 is not None: - # Add with bi-mol rxn - node_mol1 = NodeChemical(smiles=mol1, - parent=None, - child=None, - is_leaf=True, - is_root=False, - depth=0, - index=len(self.chemicals)) - node_mol2 = NodeChemical(smiles=mol2, - parent=None, - child=None, - is_leaf=True, - is_root=False, - depth=0, - index=len(self.chemicals)+1) - node_rxn = NodeRxn(rxn_id=rxn_id, - rtype=2, - parent=None, - child=[node_mol1.smiles, node_mol2.smiles], - depth=0.5, - index=len(self.reactions)) - node_product = NodeChemical(smiles=mol_product, - parent=None, - child=node_rxn.rxn_id, - is_leaf=False, - is_root=True, - depth=1, - index=len(self.chemicals)+2) - - node_rxn.parent = node_product.smiles - node_mol1.parent = node_rxn.rxn_id - node_mol2.parent = node_rxn.rxn_id - - self.chemicals.append(node_mol1) - self.chemicals.append(node_mol2) - self.chemicals.append(node_product) - self.reactions.append(node_rxn) - - else: - raise ValueError('Check input') - - return None - - -class SyntheticTreeSet: - """ - A class representing a set of synthetic trees, for saving and loading purposes. - - Arritbute: - sts (list): Contains `SyntheticTree`s. One can initialize the class with - either a list of synthetic trees or None, in which case an empty - list is created. - """ - def __init__(self, sts=None): - if sts is None: - self.sts = [] - else: - self.sts = sts - - def load(self, json_file): - """ - A function that loads a JSON-formatted synthetic tree file. - - Args: - json_file (str): The path to the stored synthetic tree file. - """ - with gzip.open(json_file, 'r') as f: - data = json.loads(f.read().decode('utf-8')) - - for st_dict in data['trees']: - if st_dict is None: - self.sts.append(None) - else: - st = SyntheticTree(st_dict) - self.sts.append(st) - - def save(self, json_file): - """ - A function that saves the synthetic tree set to a JSON-formatted file. - - Args: - json_file (str): The path to the stored synthetic tree file. - """ - st_list = { - 'trees': [st.output_dict() if st is not None else None for st in self.sts] - } - with gzip.open(json_file, 'w') as f: - f.write(json.dumps(st_list).encode('utf-8')) - - def __len__(self): - return len(self.sts) - - def _print(self, x=3): - # For debugging - for i, r in enumerate(self.sts): - if i >= x: - break - print(r.output_dict()) - - -if __name__ == '__main__': - """ - A test run to find available reactants for a set of reaction templates. - """ - path_to_building_blocks = '/home/whgao/shared/Data/scGen/enamine_5k.csv.gz' - # path_to_rxn_templates = '/home/whgao/shared/Data/scGen/rxn_set_hartenfeller.txt' - path_to_rxn_templates = '/home/whgao/shared/Data/scGen/rxn_set_pis_test.txt' - - building_blocks = pd.read_csv(path_to_building_blocks, compression='gzip')['SMILES'].tolist() - rxns = [] - for line in open(path_to_rxn_templates, 'rt'): - rxn = Reaction(line.split('|')[1].strip()) - rxn.set_available_reactants(building_blocks) - rxns.append(rxn) - - r = ReactionSet(rxns) - r.save('reactions_pis_test.json.gz') diff --git a/syn_net/utils/predict_beam_utils.py b/syn_net/utils/predict_beam_utils.py deleted file mode 100644 index 2f24126e..00000000 --- a/syn_net/utils/predict_beam_utils.py +++ /dev/null @@ -1,468 +0,0 @@ -""" -This file contains various utils for decoding synthetic trees using beam search. -""" -import numpy as np -from rdkit import Chem -from syn_net.utils.data_utils import SyntheticTree -from sklearn.neighbors import BallTree, KDTree -from syn_net.utils.predict_utils import * - - -np.random.seed(6) - - -def softmax(x): - """ - Computes softmax values for each sets of scores in x. - - Args: - x (np.ndarray or list): Values to normalize. - Returns: - (np.ndarray): Softmaxed values. - """ - e_x = np.exp(x - np.max(x)) - return e_x / e_x.sum(axis=0) - -def nn_search(_e, _tree, _k=1): - """ - Conducts a nearest neighbor search to find the molecule from the tree most - simimilar to the input embedding. - - Args: - _e (np.ndarray): A specific point in the dataset. - _tree (sklearn.neighbors._kd_tree.KDTree, optional): A k-d tree. - _k (int, optional): Indicates how many nearest neighbors to get. - Defaults to 1. - - Returns: - float: The distance to the nearest neighbor. - int: The indices of the nearest neighbor. - """ - dist, ind = _tree.query(_e, k=_k) - return dist[0], ind[0] - -def synthetic_tree_decoder(z_target, - building_blocks, - bb_dict, - reaction_templates, - mol_embedder, - action_net, - reactant1_net, - rxn_net, - reactant2_net, - bb_emb, - beam_width, - rxn_template, - n_bits, - max_step=15): - """ - Computes the synthetic tree given an input molecule embedding, using the - Action, Reaction, Reactant1, and Reactant2 networks and a greedy search. - - Args: - z_target (np.ndarray): Embedding for the target molecule - building_blocks (list of str): Contains available building blocks - bb_dict (dict): Building block dictionary - reaction_templates (list of Reactions): Contains reaction templates - mol_embedder (dgllife.model.gnn.gin.GIN): GNN to use for obtaining - molecular embeddings - action_net (synth_net.models.mlp.MLP): The action network - reactant1_net (synth_net.models.mlp.MLP): The reactant1 network - rxn_net (synth_net.models.mlp.MLP): The reaction network - reactant2_net (synth_net.models.mlp.MLP): The reactant2 network - bb_emb (list): Contains purchasable building block embeddings. - beam_width (int): The beam width to use for Reactant 1 search. - rxn_template (str): Specifies the set of reaction templates to use. - n_bits (int): Length of fingerprint. - max_step (int, optional): Maximum number of steps to include in the - synthetic tree - - Returns: - tree (SyntheticTree): The final synthetic tree. - act (int): The final action (to know if the tree was "properly" - terminated). - """ - # Initialization - tree = SyntheticTree() - kdtree = BallTree(bb_emb, metric=cosine_distance) - mol_recent = None - - # Start iteration - # try: - for i in range(max_step): - # Encode current state - state = tree.get_state() # a set - z_state = set_embedding(z_target, state, nbits=n_bits, mol_fp=mol_fp) - - # Predict action type, masked selection - # Action: (Add: 0, Expand: 1, Merge: 2, End: 3) - action_proba = action_net(torch.Tensor(z_state)) - action_proba = action_proba.squeeze().detach().numpy() + 1e-10 - action_mask = get_action_mask(tree.get_state(), reaction_templates) - act = np.argmax(action_proba * action_mask) - - reactant1_net_input = torch.Tensor( - np.concatenate([z_state, one_hot_encoder(act, 4)], axis=1) - ) - z_mol1 = reactant1_net(reactant1_net_input) - z_mol1 = z_mol1.detach().numpy() - - # Select first molecule - if act == 3: - # End - nlls = [0.0] - break - elif act == 0: - # Add - # **don't try to sample more points than there are in the tree - # beam search for mol1 candidates - dist, ind = nn_search(z_mol1, _tree=kdtree, _k=min(len(bb_emb), beam_width)) - try: - mol1_probas = softmax(- 0.1 * dist) - mol1_nlls = -np.log(mol1_probas) - except: # exception for beam search of length 1 - mol1_nlls = [-np.log(0.5)] - mol1_list = [building_blocks[idx] for idx in ind] - nlls = mol1_nlls - else: - # Expand or Merge - mol1_list = [mol_recent] - nlls = [-np.log(0.5)] - - rxn_list = [] - rxn_id_list = [] - mol2_list = [] - act_list = [act] * beam_width - for mol1_idx, mol1 in enumerate(mol1_list): - - z_mol1 = mol_fp(mol1) - act = act_list[mol1_idx] - - # Select reaction - z_mol1 = np.expand_dims(z_mol1, axis=0) - reaction_proba = rxn_net(torch.Tensor(np.concatenate([z_state, z_mol1], axis=1))) - reaction_proba = reaction_proba.squeeze().detach().numpy() - - if act != 2: - reaction_mask, available_list = get_reaction_mask(mol1, reaction_templates) - else: - _, reaction_mask = can_react(tree.get_state(), reaction_templates) - available_list = [[] for rxn in reaction_templates] - - if reaction_mask is None: - if len(state) == 1: - act = 3 - nlls[mol1_idx] += -np.log(action_proba * reaction_mask)[act] # correct the NLL - act_list[mol1_idx] = act - rxn_list.append(None) - rxn_id_list.append(None) - mol2_list.append(None) - continue - else: - act_list[mol1_idx] = act - rxn_list.append(None) - rxn_id_list.append(None) - mol2_list.append(None) - continue - - rxn_id = np.argmax(reaction_proba * reaction_mask) - rxn = reaction_templates[rxn_id] - rxn_nll = -np.log(reaction_proba * reaction_mask)[rxn_id] - - rxn_list.append(rxn) - rxn_id_list.append(rxn_id) - nlls[mol1_idx] += rxn_nll - - if np.isinf(rxn_nll): - mol2_list.append(None) - continue - elif rxn.num_reactant == 2: - # Select second molecule - if act == 2: - # Merge - temp = set(state) - set([mol1]) - mol2 = temp.pop() - else: - # Add or Expand - if rxn_template == 'hb': - z_mol2 = reactant2_net(torch.Tensor(np.concatenate([z_state, z_mol1, one_hot_encoder(rxn_id, 91)], axis=1))) - elif rxn_template == 'pis': - z_mol2 = reactant2_net(torch.Tensor(np.concatenate([z_state, z_mol1, one_hot_encoder(rxn_id, 4700)], axis=1))) - z_mol2 = z_mol2.detach().numpy() - available = available_list[rxn_id] - available = [bb_dict[available[i]] for i in range(len(available))] - temp_emb = bb_emb[available] - available_tree = BallTree(temp_emb, metric=cosine_distance) - dist, ind = nn_search(z_mol2, _tree=available_tree, _k=min(len(temp_emb), beam_width)) - try: - mol2_probas = softmax(-dist) - mol2_nll = -np.log(mol2_probas)[0] - except: - mol2_nll = 0.0 - mol2 = building_blocks[available[ind[0]]] - nlls[mol1_idx] += mol2_nll - else: - mol2 = None - - mol2_list.append(mol2) - - # Run reaction until get a valid (non-None) product - for i in range(0, len(nlls)): - best_idx = np.argsort(nlls)[i] - rxn = rxn_list[best_idx] - rxn_id = rxn_id_list[best_idx] - mol2 = mol2_list[best_idx] - act = act_list[best_idx] - try: - mol_product = rxn.run_reaction([mol1, mol2]) - except: - mol_product = None - else: - if mol_product is None: - continue - else: - break - - if mol_product is None or Chem.MolFromSmiles(mol_product) is None: - if len(tree.get_state()) == 1: - act = 3 - break - else: - break - - # Update - tree.update(act, int(rxn_id), mol1, mol2, mol_product) - mol_recent = mol_product - - if act != 3: - tree = tree - else: - tree.update(act, None, None, None, None) - - return tree, act - -def set_embedding_fullbeam(z_target, state, _mol_embedding, nbits): - """ - Computes embeddings for all molecules in input state. - - Args: - z_target (np.ndarray): Embedding for the target molecule. - state (list): Contains molecules in the current state, if not the - initial state. - _mol_embedding (Callable): Function to use for computing the embeddings - of the first and second molecules in the state (e.g. Morgan fingerprint). - nbits (int): Number of bits to use for the embedding. - - Returns: - np.ndarray: Embedding consisting of the concatenation of the target - molecule with the current molecules (if available) in the input - state. - """ - if len(state) == 0: - z_target = np.expand_dims(z_target, axis=0) - return np.concatenate([np.zeros((1, 2 * nbits)), z_target], axis=1) - else: - e1 = _mol_embedding(state[0]) - e1 = np.expand_dims(e1, axis=0) - if len(state) == 1: - e2 = np.zeros((1, nbits)) - else: - e2 = _mol_embedding(state[1]) - e2 = np.expand_dims(e2, axis=0) - z_target = np.expand_dims(z_target, axis=0) - return np.concatenate([e1, e2, z_target], axis=1) - -def synthetic_tree_decoder_fullbeam(z_target, - building_blocks, - bb_dict, - reaction_templates, - mol_embedder, - action_net, - reactant1_net, - rxn_net, - reactant2_net, - bb_emb, - beam_width, - rxn_template, - n_bits, - max_step=15): - """ - Computes the synthetic tree given an input molecule embedding, using the - Action, Reaction, Reactant1, and Reactant2 networks and a beam search. - - Args: - z_target (np.ndarray): Embedding for the target molecule - building_blocks (list of str): Contains available building blocks - bb_dict (dict): Building block dictionary - reaction_templates (list of Reactions): Contains reaction templates - mol_embedder (dgllife.model.gnn.gin.GIN): GNN to use for obtaining molecular embeddings - action_net (synth_net.models.mlp.MLP): The action network - reactant1_net (synth_net.models.mlp.MLP): The reactant1 network - rxn_net (synth_net.models.mlp.MLP): The reaction network - reactant2_net (synth_net.models.mlp.MLP): The reactant2 network - bb_emb (list): Contains purchasable building block embeddings. - beam_width (int): The beam width to use for Reactant 1 search. - rxn_template (str): Specifies the set of reaction templates to use. - n_bits (int): Length of fingerprint. - max_step (int, optional): Maximum number of steps to include in the synthetic tree - - Returns: - tree (SyntheticTree): The final synthetic tree - act (int): The final action (to know if the tree was "properly" terminated) - """ - # Initialization - tree = SyntheticTree() - mol_recent = None - kdtree = KDTree(bb_emb, metric='euclidean') - - # Start iteration - # try: - for i in range(max_step): - # Encode current state - state = tree.get_state() # a set - z_state = set_embedding_fullbeam(z_target, state, mol_fp, nbits=n_bits) - - # Predict action type, masked selection - # Action: (Add: 0, Expand: 1, Merge: 2, End: 3) - action_proba = action_net(torch.Tensor(z_state)) - action_proba = action_proba.squeeze().detach().numpy() - action_mask = get_action_mask(tree.get_state(), reaction_templates) - act = np.argmax(action_proba * action_mask) - - z_mol1 = reactant1_net(torch.Tensor(np.concatenate([z_state, one_hot_encoder(act, 4)], axis=1))) - z_mol1 = z_mol1.detach().numpy() - - # Select first molecule - if act == 3: - # End - mol1_nlls = [0.0] - break - elif act == 0: - # Add - # **don't try to sample more points than there are in the tree - # beam search for mol1 candidates - dist, ind = nn_search(z_mol1, _tree=kdtree, _k=min(len(bb_emb), beam_width)) - try: - mol1_probas = softmax(- 0.1 * dist) - mol1_nlls = -np.log(mol1_probas) - except: # exception for beam search of length 1 - mol1_nlls = [-np.log(0.5)] - mol1_list = [building_blocks[idx] for idx in ind] - else: - # Expand or Merge - mol1_list = [mol_recent] - mol1_nlls = [-np.log(0.5)] - - action_tuples = [] # list of action tuples created by beam search - act_list = [act] * beam_width - for mol1_idx, mol1 in enumerate(mol1_list): - - z_mol1 = mol_fp(mol1, nBits=n_bits) - act = act_list[mol1_idx] - - # Select reaction - z_mol1 = np.expand_dims(z_mol1, axis=0) - reaction_proba = rxn_net(torch.Tensor(np.concatenate([z_state, z_mol1], axis=1))) - reaction_proba = reaction_proba.squeeze().detach().numpy() - - if act != 2: - reaction_mask, available_list = get_reaction_mask(mol1, reaction_templates) - else: - _, reaction_mask = can_react(tree.get_state(), reaction_templates) - available_list = [[] for rxn in reaction_templates] - - if reaction_mask is None: - if len(state) == 1: - act = 3 - mol1_nlls[mol1_idx] += -np.log(action_proba * reaction_mask)[act] # correct the NLL - act_list[mol1_idx] = act - # nll, act, mol1, rxn, rxn_id, mol2 - action_tuples.append([mol1_nlls[mol1_idx], act, mol1, None, None, None]) - continue - else: - act_list[mol1_idx] = act - # nll, act, mol1, rxn, rxn_id, mol2 - action_tuples.append([mol1_nlls[mol1_idx], act, mol1, None, None, None]) - continue - - rxn_ids = np.argsort(-reaction_proba * reaction_mask)[:beam_width] - rxn_nlls = mol1_nlls[mol1_idx] - np.log(reaction_proba * reaction_mask) - - for rxn_id in rxn_ids: - rxn = reaction_templates[rxn_id] - rxn_nll = rxn_nlls[rxn_id] - - if np.isinf(rxn_nll): - # nll, act, mol1, rxn, rxn_id, mol2 - action_tuples.append([rxn_nll, act, mol1, rxn, rxn_id, None]) - continue - elif rxn.num_reactant == 2: - # Select second molecule - if act == 2: - # Merge - temp = set(state) - set([mol1]) - mol2 = temp.pop() - # nll, act, mol1, rxn, rxn_id, mol2 - action_tuples.append([rxn_nll, act, mol1, rxn, rxn_id, mol2]) - else: - # Add or Expand - if rxn_template == 'hb': - z_mol2 = reactant2_net(torch.Tensor(np.concatenate([z_state, z_mol1, one_hot_encoder(rxn_id, 91)], axis=1))) - elif rxn_template == 'pis': - z_mol2 = reactant2_net(torch.Tensor(np.concatenate([z_state, z_mol1, one_hot_encoder(rxn_id, 4700)], axis=1))) - - z_mol2 = z_mol2.detach().numpy() - available = available_list[rxn_id] - available = [bb_dict[available[i]] for i in range(len(available))] - temp_emb = bb_emb[available] - available_tree = KDTree(temp_emb, metric='euclidean') - dist, ind = nn_search(z_mol2, _tree=available_tree, _k=min(len(temp_emb), beam_width)) - try: - mol2_probas = softmax(-dist) - mol2_nlls = rxn_nll - np.log(mol2_probas) - except: - mol2_nlls = [rxn_nll + 0.0] - mol2_list = [building_blocks[available[idc]] for idc in ind] - for mol2_idx, mol2 in enumerate(mol2_list): - # nll, act, mol1, rxn, rxn_id, mol2 - action_tuples.append([mol2_nlls[mol2_idx], act, mol1, rxn, rxn_id, mol2]) - else: - # nll, act, mol1, rxn, rxn_id, mol2 - action_tuples.append([rxn_nll, act, mol1, rxn, rxn_id, None]) - - # Run reaction until get a valid (non-None) product - for i in range(0, len(action_tuples)): - nlls = list(zip(*action_tuples))[0] - best_idx = np.argsort(nlls)[i] - act = action_tuples[best_idx][1] - mol1 = action_tuples[best_idx][2] - rxn = action_tuples[best_idx][3] - rxn_id = action_tuples[best_idx][4] - mol2 = action_tuples[best_idx][5] - try: - mol_product = rxn.run_reaction([mol1, mol2]) - except: - mol_product = None - else: - if mol_product is None: - continue - else: - break - - if mol_product is None or Chem.MolFromSmiles(mol_product) is None: - if len(tree.get_state()) == 1: - act = 3 - break - else: - break - - # Update - tree.update(act, int(rxn_id), mol1, mol2, mol_product) - mol_recent = mol_product - - if act != 3: - tree = tree - else: - tree.update(act, None, None, None, None) - - return tree, act diff --git a/syn_net/utils/predict_utils.py b/syn_net/utils/predict_utils.py deleted file mode 100644 index 57133aa0..00000000 --- a/syn_net/utils/predict_utils.py +++ /dev/null @@ -1,1067 +0,0 @@ -""" -This file contains various utils for creating molecular embeddings and for -decoding synthetic trees. -""" -import numpy as np -import rdkit -from tqdm import tqdm -import torch -from rdkit import Chem -from rdkit import DataStructs -from rdkit.Chem import AllChem -from sklearn.neighbors import BallTree -from dgl.nn.pytorch.glob import AvgPooling -from dgl.nn.pytorch.glob import AvgPooling -from dgllife.model import load_pretrained -from dgllife.utils import mol_to_bigraph, PretrainAtomFeaturizer, PretrainBondFeaturizer -from tdc.chem_utils import MolConvert -from syn_net.models.mlp import MLP -from syn_net.utils.data_utils import SyntheticTree - - -# create a random seed for NumPy -np.random.seed(6) - -# get a GIN pretrained model to use for creating molecular embeddings -model_type = 'gin_supervised_contextpred' -device = 'cpu' -gin_pretrained_model = load_pretrained(model_type).to(device) # used to learn embedding -gin_pretrained_model.eval() - - -# general functions -def can_react(state, rxns): - """ - Determines if two molecules can react using any of the input reactions. - - Args: - state (np.ndarray): The current state in the synthetic tree. - rxns (list of Reaction objects): Contains available reaction templates. - - Returns: - np.ndarray: The sum of the reaction mask tells us how many reactions are - viable for the two molecules. - np.ndarray: The reaction mask, which masks out reactions which are not - viable for the two molecules. - """ - mol1 = state.pop() - mol2 = state.pop() - reaction_mask = [int(rxn.run_reaction([mol1, mol2]) is not None) for rxn in rxns] - return sum(reaction_mask), reaction_mask - -def get_action_mask(state, rxns): - """ - Determines which actions can apply to a given state in the synthetic tree - and returns a mask for which actions can apply. - - Args: - state (np.ndarray): The current state in the synthetic tree. - rxns (list of Reaction objects): Contains available reaction templates. - - Raises: - ValueError: There is an issue with the input state. - - Returns: - np.ndarray: The action mask. Masks out unviable actions from the current - state using 0s, with 1s at the positions corresponding to viable - actions. - """ - # Action: (Add: 0, Expand: 1, Merge: 2, End: 3) - if len(state) == 0: - return np.array([1, 0, 0, 0]) - elif len(state) == 1: - return np.array([1, 1, 0, 1]) - elif len(state) == 2: - can_react_, _ = can_react(state, rxns) - if can_react_: - return np.array([0, 1, 1, 0]) - else: - return np.array([0, 1, 0, 0]) - else: - raise ValueError('Problem with state.') - -def get_reaction_mask(smi, rxns): - """ - Determines which reaction templates can apply to the input molecule. - - Args: - smi (str): The SMILES string corresponding to the molecule in question. - rxns (list of Reaction objects): Contains available reaction templates. - - Raises: - ValueError: There is an issue with the reactants in the reaction. - - Returns: - reaction_mask (list of ints, or None): The reaction template mask. Masks - out reaction templates which are not viable for the input molecule. - If there are no viable reaction templates identified, is simply None. - available_list (list of lists, or None): Contains available reactants if - at least one viable reaction template is identified. Else is simply - None. - """ - # Return all available reaction templates - # List of available building blocks if 2 - # Exclude the case of len(available_list) == 0 - reaction_mask = [int(rxn.is_reactant(smi)) for rxn in rxns] - - if sum(reaction_mask) == 0: - return None, None - available_list = [] - mol = rdkit.Chem.MolFromSmiles(smi) - for i, rxn in enumerate(rxns): - if reaction_mask[i] and rxn.num_reactant == 2: - - if rxn.is_reactant_first(mol): - available_list.append(rxn.available_reactants[1]) - elif rxn.is_reactant_second(mol): - available_list.append(rxn.available_reactants[0]) - else: - raise ValueError('Check the reactants') - - if len(available_list[-1]) == 0: - reaction_mask[i] = 0 - - else: - available_list.append([]) - - return reaction_mask, available_list - -def graph_construction_and_featurization(smiles): - """ - Constructs graphs from SMILES and featurizes them. - - Args: - smiles (list of str): Contains SMILES of molecules to embed. - - Returns: - graphs (list of DGLGraph): List of graphs constructed and featurized. - success (list of bool): Indicators for whether the SMILES string can be - parsed by RDKit. - """ - graphs = [] - success = [] - for smi in tqdm(smiles): - try: - mol = Chem.MolFromSmiles(smi) - if mol is None: - success.append(False) - continue - g = mol_to_bigraph(mol, add_self_loop=True, - node_featurizer=PretrainAtomFeaturizer(), - edge_featurizer=PretrainBondFeaturizer(), - canonical_atom_order=False) - graphs.append(g) - success.append(True) - except: - success.append(False) - - return graphs, success - -def one_hot_encoder(dim, space): - """ - Create a one-hot encoded vector of length=`space`, with a non-zero element - at the index given by `dim`. - - Args: - dim (int): Non-zero bit in one-hot vector. - space (int): Length of one-hot encoded vector. - - Returns: - vec (np.ndarray): One-hot encoded vector. - """ - vec = np.zeros((1, space)) - vec[0, dim] = 1 - return vec - -def mol_embedding(smi, device='cpu', readout=AvgPooling()): - """ - Constructs a graph embedding using the GIN network for an input SMILES. - - Args: - smi (str): A SMILES string. - device (str): Indicates the device to run on ('cpu' or 'cuda:0'). Default 'cpu'. - - Returns: - np.ndarray: Either a zeros array or the graph embedding. - """ - - # get the embedding - if smi is None: - return np.zeros(300) - else: - mol = Chem.MolFromSmiles(smi) - # convert RDKit.Mol into featurized bi-directed DGLGraph - g = mol_to_bigraph(mol, add_self_loop=True, - node_featurizer=PretrainAtomFeaturizer(), - edge_featurizer=PretrainBondFeaturizer(), - canonical_atom_order=False) - bg = g.to(device) - nfeats = [bg.ndata.pop('atomic_number').to(device), - bg.ndata.pop('chirality_type').to(device)] - efeats = [bg.edata.pop('bond_type').to(device), - bg.edata.pop('bond_direction_type').to(device)] - with torch.no_grad(): - node_repr = gin_pretrained_model(bg, nfeats, efeats) - return readout(bg, node_repr).detach().cpu().numpy().reshape(-1, ).tolist() - - -def get_mol_embedding(smi, model, device='cpu', readout=AvgPooling()): - """ - Computes the molecular graph embedding for the input SMILES. - - Args: - smi (str): SMILES of molecule to embed. - model (dgllife.model, optional): Pre-trained NN model to use for - computing the embedding. - device (str, optional): Indicates the device to run on. Defaults to 'cpu'. - readout (dgl.nn.pytorch.glob, optional): Readout function to use for - computing the graph embedding. Defaults to readout. - - Returns: - torch.Tensor: Learned embedding for the input molecule. - """ - mol = Chem.MolFromSmiles(smi) - g = mol_to_bigraph(mol, add_self_loop=True, - node_featurizer=PretrainAtomFeaturizer(), - edge_featurizer=PretrainBondFeaturizer(), - canonical_atom_order=False) - bg = g.to(device) - nfeats = [bg.ndata.pop('atomic_number').to(device), - bg.ndata.pop('chirality_type').to(device)] - efeats = [bg.edata.pop('bond_type').to(device), - bg.edata.pop('bond_direction_type').to(device)] - with torch.no_grad(): - node_repr = model(bg, nfeats, efeats) - return readout(bg, node_repr).detach().cpu().numpy()[0] - -def mol_fp(smi, _radius=2, _nBits=4096): - """ - Computes the Morgan fingerprint for the input SMILES. - - Args: - smi (str): SMILES for molecule to compute fingerprint for. - _radius (int, optional): Fingerprint radius to use. Defaults to 2. - _nBits (int, optional): Length of fingerprint. Defaults to 1024. - - Returns: - features (np.ndarray): For valid SMILES, this is the fingerprint. - Otherwise, if the input SMILES is bad, this will be a zero vector. - """ - if smi is None: - return np.zeros(_nBits) - else: - mol = Chem.MolFromSmiles(smi) - features_vec = AllChem.GetMorganFingerprintAsBitVect(mol, _radius, _nBits) - return np.array(features_vec) - -def cosine_distance(v1, v2, eps=1e-15): - """ - Computes the cosine similarity between two vectors. - - Args: - v1 (np.ndarray): First vector. - v2 (np.ndarray): Second vector. - eps (float, optional): Small value, for numerical stability. Defaults - to 1e-15. - - Returns: - float: The cosine similarity. - """ - return (1 - np.dot(v1, v2) - / (np.linalg.norm(v1, ord=2) * np.linalg.norm(v2, ord=2) + eps)) - -def ce_distance(y, y_pred, eps=1e-15): - """ - Computes the cross-entropy between two vectors. - - Args: - y (np.ndarray): First vector. - y_pred (np.ndarray): Second vector. - eps (float, optional): Small value, for numerical stability. Defaults - to 1e-15. - - Returns: - float: The cross-entropy. - """ - y_pred = np.clip(y_pred, eps, 1 - eps) - return - np.sum((y * np.log(y_pred) + (1 - y) * np.log(1 - y_pred))) - - -def nn_search(_e, _tree, _k=1): - """ - Conducts a nearest neighbor search to find the molecule from the tree most - simimilar to the input embedding. - - Args: - _e (np.ndarray): A specific point in the dataset. - _tree (sklearn.neighbors._kd_tree.KDTree, optional): A k-d tree. - _k (int, optional): Indicates how many nearest neighbors to get. - Defaults to 1. - - Returns: - float: The distance to the nearest neighbor. - int: The indices of the nearest neighbor. - """ - dist, ind = _tree.query(_e, k=_k) - return dist[0][0], ind[0][0] - -def graph_construction_and_featurization(smiles): - """ - Constructs graphs from SMILES and featurizes them. - - Args: - smiles (list of str): SMILES of molecules, for embedding computation. - - Returns: - graphs (list of DGLGraph): List of graphs constructed and featurized. - success (list of bool): Indicators for whether the SMILES string can be - parsed by RDKit. - """ - graphs = [] - success = [] - for smi in tqdm(smiles): - try: - mol = Chem.MolFromSmiles(smi) - if mol is None: - success.append(False) - continue - g = mol_to_bigraph(mol, add_self_loop=True, - node_featurizer=PretrainAtomFeaturizer(), - edge_featurizer=PretrainBondFeaturizer(), - canonical_atom_order=False) - graphs.append(g) - success.append(True) - except: - success.append(False) - - return graphs, success - -def set_embedding(z_target, state, nbits, _mol_embedding=get_mol_embedding): - """ - Computes embeddings for all molecules in the input space. - - Args: - z_target (np.ndarray): Embedding for the target molecule. - state (list): Contains molecules in the current state, if not the - initial state. - nbits (int): Length of fingerprint. - _mol_embedding (Callable, optional): Function to use for computing the - embeddings of the first and second molecules in the state. Defaults - to `get_mol_embedding`. - - Returns: - np.ndarray: Embedding consisting of the concatenation of the target - molecule with the current molecules (if available) in the input state. - """ - if len(state) == 0: - return np.concatenate([np.zeros((1, 2 * nbits)), z_target], axis=1) - else: - e1 = np.expand_dims(_mol_embedding(state[0]), axis=0) - if len(state) == 1: - e2 = np.zeros((1, nbits)) - else: - e2 = _mol_embedding(state[1]) - return np.concatenate([e1, e2, z_target], axis=1) - -def synthetic_tree_decoder(z_target, - building_blocks, - bb_dict, - reaction_templates, - mol_embedder, - action_net, - reactant1_net, - rxn_net, - reactant2_net, - bb_emb, - rxn_template, - n_bits, - max_step=15): - """ - Computes the synthetic tree given an input molecule embedding, using the - Action, Reaction, Reactant1, and Reactant2 networks and a greedy search - - Args: - z_target (np.ndarray): Embedding for the target molecule - building_blocks (list of str): Contains available building blocks - bb_dict (dict): Building block dictionary - reaction_templates (list of Reactions): Contains reaction templates - mol_embedder (dgllife.model.gnn.gin.GIN): GNN to use for obtaining - molecular embeddings - action_net (synth_net.models.mlp.MLP): The action network - reactant1_net (synth_net.models.mlp.MLP): The reactant1 network - rxn_net (synth_net.models.mlp.MLP): The reaction network - reactant2_net (synth_net.models.mlp.MLP): The reactant2 network - bb_emb (list): Contains purchasable building block embeddings. - rxn_template (str): Specifies the set of reaction templates to use. - n_bits (int): Length of fingerprint. - max_step (int, optional): Maximum number of steps to include in the - synthetic tree - - Returns: - tree (SyntheticTree): The final synthetic tree. - act (int): The final action (to know if the tree was "properly" - terminated). - """ - # Initialization - tree = SyntheticTree() - kdtree = BallTree(bb_emb, metric=cosine_distance) - mol_recent = None - - # Start iteration - # try: - for i in range(max_step): - # Encode current state - state = tree.get_state() # a set - z_state = set_embedding(z_target, state, nbits=n_bits, _mol_embedding=mol_fp) - - # Predict action type, masked selection - # Action: (Add: 0, Expand: 1, Merge: 2, End: 3) - action_proba = action_net(torch.Tensor(z_state)) - action_proba = action_proba.squeeze().detach().numpy() + 1e-10 - action_mask = get_action_mask(tree.get_state(), reaction_templates) - act = np.argmax(action_proba * action_mask) - - z_mol1 = reactant1_net(torch.Tensor(z_state)) - z_mol1 = z_mol1.detach().numpy() - - # Select first molecule - if act == 3: - # End - break - elif act == 0: - # Add - dist, ind = nn_search(z_mol1, _tree=kdtree) - mol1 = building_blocks[ind] - else: - # Expand or Merge - mol1 = mol_recent - - z_mol1 = mol_fp(mol1) - - # Select reaction - try: - reaction_proba = rxn_net(torch.Tensor(np.concatenate([z_state, z_mol1], axis=1))) - except: - z_mol1 = np.expand_dims(z_mol1, axis=0) - reaction_proba = rxn_net(torch.Tensor(np.concatenate([z_state, z_mol1], axis=1))) - reaction_proba = reaction_proba.squeeze().detach().numpy() + 1e-10 - - if act != 2: - reaction_mask, available_list = get_reaction_mask(smi=mol1, - rxns=reaction_templates) - else: - _, reaction_mask = can_react(tree.get_state(), reaction_templates) - available_list = [[] for rxn in reaction_templates] - - if reaction_mask is None: - if len(state) == 1: - act = 3 - break - else: - break - - rxn_id = np.argmax(reaction_proba * reaction_mask) - rxn = reaction_templates[rxn_id] - - if rxn.num_reactant == 2: - # Select second molecule - if act == 2: - # Merge - temp = set(state) - set([mol1]) - mol2 = temp.pop() - else: - # Add or Expand - if rxn_template == 'hb': - num_rxns = 91 - elif rxn_template == 'pis': - num_rxns = 4700 - else: - num_rxns = 3 # unit testing uses only 3 reaction templates - reactant2_net_input = torch.Tensor( - np.concatenate([z_state, z_mol1, one_hot_encoder(rxn_id, num_rxns)], - axis=1) - ) - z_mol2 = reactant2_net(reactant2_net_input) - z_mol2 = z_mol2.detach().numpy() - available = available_list[rxn_id] - available = [bb_dict[available[i]] for i in range(len(available))] - temp_emb = bb_emb[available] - available_tree = BallTree(temp_emb, metric=cosine_distance) - dist, ind = nn_search(z_mol2, _tree=available_tree) - mol2 = building_blocks[available[ind]] - else: - mol2 = None - - # Run reaction - mol_product = rxn.run_reaction([mol1, mol2]) - if mol_product is None or Chem.MolFromSmiles(mol_product) is None: - if len(tree.get_state()) == 1: - act = 3 - break - else: - break - - # Update - tree.update(act, int(rxn_id), mol1, mol2, mol_product) - mol_recent = mol_product - - if act != 3: - tree = tree - else: - tree.update(act, None, None, None, None) - - return tree, act - -def load_modules_from_checkpoint(path_to_act, path_to_rt1, path_to_rxn, path_to_rt2, featurize, rxn_template, out_dim, nbits, ncpu): - - if rxn_template == 'unittest': - - act_net = MLP.load_from_checkpoint(path_to_act, - input_dim=int(3 * nbits), - output_dim=4, - hidden_dim=100, - num_layers=3, - dropout=0.5, - num_dropout_layers=1, - task='classification', - loss='cross_entropy', - valid_loss='accuracy', - optimizer='adam', - learning_rate=1e-4, - ncpu=ncpu) - - rt1_net = MLP.load_from_checkpoint(path_to_rt1, - input_dim=int(3 * nbits), - output_dim=out_dim, - hidden_dim=100, - num_layers=3, - dropout=0.5, - num_dropout_layers=1, - task='regression', - loss='mse', - valid_loss='mse', - optimizer='adam', - learning_rate=1e-4, - ncpu=ncpu) - - rxn_net = MLP.load_from_checkpoint(path_to_rxn, - input_dim=int(4 * nbits), - output_dim=3, - hidden_dim=100, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='classification', - loss='cross_entropy', - valid_loss='accuracy', - optimizer='adam', - learning_rate=1e-4, - ncpu=ncpu) - - rt2_net = MLP.load_from_checkpoint(path_to_rt2, - input_dim=int(4 * nbits + 3), - output_dim=out_dim, - hidden_dim=100, - num_layers=3, - dropout=0.5, - num_dropout_layers=1, - task='regression', - loss='mse', - valid_loss='mse', - optimizer='adam', - learning_rate=1e-4, - ncpu=ncpu) - elif featurize == 'fp': - - act_net = MLP.load_from_checkpoint(path_to_act, - input_dim=int(3 * nbits), - output_dim=4, - hidden_dim=1000, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='classification', - loss='cross_entropy', - valid_loss='accuracy', - optimizer='adam', - learning_rate=1e-4, - ncpu=ncpu) - - rt1_net = MLP.load_from_checkpoint(path_to_rt1, - input_dim=int(3 * nbits), - output_dim=out_dim, - hidden_dim=1200, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='regression', - loss='mse', - valid_loss='mse', - optimizer='adam', - learning_rate=1e-4, - ncpu=ncpu) - - if rxn_template == 'hb': - - rxn_net = MLP.load_from_checkpoint(path_to_rxn, - input_dim=int(4 * nbits), - output_dim=91, - hidden_dim=3000, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='classification', - loss='cross_entropy', - valid_loss='accuracy', - optimizer='adam', - learning_rate=1e-4, - ncpu=ncpu) - - rt2_net = MLP.load_from_checkpoint(path_to_rt2, - input_dim=int(4 * nbits + 91), - output_dim=out_dim, - hidden_dim=3000, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='regression', - loss='mse', - valid_loss='mse', - optimizer='adam', - learning_rate=1e-4, - ncpu=ncpu) - - elif rxn_template == 'pis': - - rxn_net = MLP.load_from_checkpoint(path_to_rxn, - input_dim=int(4 * nbits), - output_dim=4700, - hidden_dim=4500, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='classification', - loss='cross_entropy', - valid_loss='accuracy', - optimizer='adam', - learning_rate=1e-4, - ncpu=ncpu) - - rt2_net = MLP.load_from_checkpoint(path_to_rt2, - input_dim=int(4 * nbits + 4700), - output_dim=out_dim, - hidden_dim=3000, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='regression', - loss='mse', - valid_loss='mse', - optimizer='adam', - learning_rate=1e-4, - ncpu=ncpu) - - elif featurize == 'gin': - - act_net = MLP.load_from_checkpoint(path_to_act, - input_dim=int(2 * nbits + out_dim), - output_dim=4, - hidden_dim=1000, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='classification', - loss='cross_entropy', - valid_loss='accuracy', - optimizer='adam', - learning_rate=1e-4, - ncpu=ncpu) - - rt1_net = MLP.load_from_checkpoint(path_to_rt1, - input_dim=int(2 * nbits + out_dim), - output_dim=out_dim, - hidden_dim=1200, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='regression', - loss='mse', - valid_loss='mse', - optimizer='adam', - learning_rate=1e-4, - ncpu=ncpu) - - if rxn_template == 'hb': - - rxn_net = MLP.load_from_checkpoint(path_to_rxn, - input_dim=int(3 * nbits + out_dim), - output_dim=91, - hidden_dim=3000, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='classification', - loss='cross_entropy', - valid_loss='accuracy', - optimizer='adam', - learning_rate=1e-4, - ncpu=ncpu) - - rt2_net = MLP.load_from_checkpoint(path_to_rt2, - input_dim=int(3 * nbits + out_dim + 91), - output_dim=out_dim, - hidden_dim=3000, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='regression', - loss='mse', - valid_loss='mse', - optimizer='adam', - learning_rate=1e-4, - ncpu=ncpu) - - elif rxn_template == 'pis': - - rxn_net = MLP.load_from_checkpoint(path_to_rxn, - input_dim=int(3 * nbits + out_dim), - output_dim=4700, - hidden_dim=3000, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='classification', - loss='cross_entropy', - valid_loss='accuracy', - optimizer='adam', - learning_rate=1e-4, - ncpu=ncpu) - - rt2_net = MLP.load_from_checkpoint(path_to_rt2, - input_dim=int(3 * nbits + out_dim + 4700), - output_dim=out_dim, - hidden_dim=3000, - num_layers=5, - dropout=0.5, - num_dropout_layers=1, - task='regression', - loss='mse', - valid_loss='mse', - optimizer='adam', - learning_rate=1e-4, - ncpu=ncpu) - - act_net.eval() - rt1_net.eval() - rxn_net.eval() - rt2_net.eval() - - return act_net, rt1_net, rxn_net, rt2_net - -def _tanimoto_similarity(fp1, fp2): - """ - Returns the Tanimoto similarity between two molecular fingerprints. - - Args: - fp1 (np.ndarray): Molecular fingerprint 1. - fp2 (np.ndarray): Molecular fingerprint 2. - - Returns: - np.float: Tanimoto similarity. - """ - return np.sum(fp1 * fp2) / (np.sum(fp1) + np.sum(fp2) - np.sum(fp1 * fp2)) - -def tanimoto_similarity(target_fp, smis): - """ - Returns the Tanimoto similarities between a target fingerprint and molecules - in an input list of SMILES. - - Args: - target_fp (np.ndarray): Contains the reference (target) fingerprint. - smis (list of str): Contains SMILES to compute similarity to. - - Returns: - list of np.ndarray: Contains Tanimoto similarities. - """ - fps = [mol_fp(smi, 2, 4096) for smi in smis] - return [_tanimoto_similarity(target_fp, fp) for fp in fps] - - -# functions used in the *_multireactant.py -def nn_search_rt1(_e, _tree, _k=1): - dist, ind = _tree.query(_e, k=_k) - return dist[0], ind[0] - -def synthetic_tree_decoder_rt1(z_target, - building_blocks, - bb_dict, - reaction_templates, - mol_embedder, - action_net, - reactant1_net, - rxn_net, - reactant2_net, - bb_emb, - rxn_template, - n_bits, - max_step=15, - rt1_index=0): - """ - Computes the synthetic tree given an input molecule embedding, using the - Action, Reaction, Reactant1, and Reactant2 networks and a greedy search. - - Args: - z_target (np.ndarray): Embedding for the target molecule - building_blocks (list of str): Contains available building blocks - bb_dict (dict): Building block dictionary - reaction_templates (list of Reactions): Contains reaction templates - mol_embedder (dgllife.model.gnn.gin.GIN): GNN to use for obtaining - molecular embeddings - action_net (synth_net.models.mlp.MLP): The action network - reactant1_net (synth_net.models.mlp.MLP): The reactant1 network - rxn_net (synth_net.models.mlp.MLP): The reaction network - reactant2_net (synth_net.models.mlp.MLP): The reactant2 network - bb_emb (list): Contains purchasable building block embeddings. - rxn_template (str): Specifies the set of reaction templates to use. - n_bits (int): Length of fingerprint. - beam_width (int): The beam width to use for Reactant 1 search. Defaults - to 3. - max_step (int, optional): Maximum number of steps to include in the - synthetic tree - rt1_index (int, optional): Index for molecule in the building blocks - corresponding to reactant 1. - - Returns: - tree (SyntheticTree): The final synthetic tree - act (int): The final action (to know if the tree was "properly" - terminated). - """ - # Initialization - tree = SyntheticTree() - mol_recent = None - kdtree = BallTree(bb_emb, metric=cosine_distance) - - # Start iteration - for i in range(max_step): - # Encode current state - state = tree.get_state() # a set - try: - z_state = set_embedding(z_target, state, nbits=n_bits, _mol_embedding=mol_fp) - except: - z_target = np.expand_dims(z_target, axis=0) - z_state = set_embedding(z_target, state, nbits=n_bits, _mol_embedding=mol_fp) - - # Predict action type, masked selection - # Action: (Add: 0, Expand: 1, Merge: 2, End: 3) - action_proba = action_net(torch.Tensor(z_state)) - action_proba = action_proba.squeeze().detach().numpy() + 1e-10 - action_mask = get_action_mask(tree.get_state(), reaction_templates) - act = np.argmax(action_proba * action_mask) - - z_mol1 = reactant1_net(torch.Tensor(z_state)) - z_mol1 = z_mol1.detach().numpy() - - # Select first molecule - if act == 3: - # End - break - elif act == 0: - # Add - if mol_recent is not None: - dist, ind = nn_search(z_mol1, _tree=kdtree) - mol1 = building_blocks[ind] - else: - dist, ind = nn_search_rt1(z_mol1, _tree=kdtree, _k=rt1_index+1) - mol1 = building_blocks[ind[rt1_index]] - else: - # Expand or Merge - mol1 = mol_recent - - # z_mol1 = get_mol_embedding(mol1, mol_embedder) - z_mol1 = mol_fp(mol1) - - # Select reaction - try: - reaction_proba = rxn_net(torch.Tensor(np.concatenate([z_state, z_mol1], axis=1))) - except: - z_mol1 = np.expand_dims(z_mol1, axis=0) - reaction_proba = rxn_net(torch.Tensor(np.concatenate([z_state, z_mol1], axis=1))) - reaction_proba = reaction_proba.squeeze().detach().numpy() + 1e-10 - - if act != 2: - reaction_mask, available_list = get_reaction_mask(mol1, reaction_templates) - else: - _, reaction_mask = can_react(tree.get_state(), reaction_templates) - available_list = [[] for rxn in reaction_templates] - - if reaction_mask is None: - if len(state) == 1: - act = 3 - break - else: - break - - rxn_id = np.argmax(reaction_proba * reaction_mask) - rxn = reaction_templates[rxn_id] - - if rxn.num_reactant == 2: - # Select second molecule - if act == 2: - # Merge - temp = set(state) - set([mol1]) - mol2 = temp.pop() - else: - # Add or Expand - if rxn_template == 'hb': - z_mol2 = reactant2_net(torch.Tensor(np.concatenate([z_state, z_mol1, one_hot_encoder(rxn_id, 91)], axis=1))) - elif rxn_template == 'pis': - z_mol2 = reactant2_net(torch.Tensor(np.concatenate([z_state, z_mol1, one_hot_encoder(rxn_id, 4700)], axis=1))) - elif rxn_template == 'unittest': - z_mol2 = reactant2_net(torch.Tensor(np.concatenate([z_state, z_mol1, one_hot_encoder(rxn_id, 3)], axis=1))) - z_mol2 = z_mol2.detach().numpy() - available = available_list[rxn_id] - available = [bb_dict[available[i]] for i in range(len(available))] - temp_emb = bb_emb[available] - available_tree = BallTree(temp_emb, metric=cosine_distance) - dist, ind = nn_search(z_mol2, _tree=available_tree) - mol2 = building_blocks[available[ind]] - else: - mol2 = None - - # Run reaction - mol_product = rxn.run_reaction([mol1, mol2]) - if mol_product is None or Chem.MolFromSmiles(mol_product) is None: - act = 3 - break - - # Update - tree.update(act, int(rxn_id), mol1, mol2, mol_product) - mol_recent = mol_product - - if act != 3: - tree = tree - else: - tree.update(act, None, None, None, None) - - return tree, act - -def synthetic_tree_decoder_multireactant(z_target, - building_blocks, - bb_dict, - reaction_templates, - mol_embedder, - action_net, - reactant1_net, - rxn_net, - reactant2_net, - bb_emb, - rxn_template, - n_bits, - beam_width : int=3, - max_step : int=15): - """ - Computes the synthetic tree given an input molecule embedding, using the - Action, Reaction, Reactant1, and Reactant2 networks and a greedy search. - - Args: - z_target (np.ndarray): Embedding for the target molecule - building_blocks (list of str): Contains available building blocks - bb_dict (dict): Building block dictionary - reaction_templates (list of Reactions): Contains reaction templates - mol_embedder (dgllife.model.gnn.gin.GIN): GNN to use for obtaining molecular embeddings - action_net (synth_net.models.mlp.MLP): The action network - reactant1_net (synth_net.models.mlp.MLP): The reactant1 network - rxn_net (synth_net.models.mlp.MLP): The reaction network - reactant2_net (synth_net.models.mlp.MLP): The reactant2 network - bb_emb (list): Contains purchasable building block embeddings. - rxn_template (str): Specifies the set of reaction templates to use. - n_bits (int): Length of fingerprint. - beam_width (int): The beam width to use for Reactant 1 search. Defaults to 3. - max_step (int, optional): Maximum number of steps to include in the synthetic tree - - Returns: - tree (SyntheticTree): The final synthetic tree - act (int): The final action (to know if the tree was "properly" terminated) - """ - trees = [] - smiles = [] - similarities = [] - acts = [] - - for i in range(beam_width): - tree, act = synthetic_tree_decoder_rt1(z_target=z_target, - building_blocks=building_blocks, - bb_dict=bb_dict, - reaction_templates=reaction_templates, - mol_embedder=mol_embedder, - action_net=action_net, - reactant1_net=reactant1_net, - rxn_net=rxn_net, - reactant2_net=reactant2_net, - bb_emb=bb_emb, - rxn_template=rxn_template, - n_bits=n_bits, - max_step=max_step, - rt1_index=i) - - - similarities_ = np.array(tanimoto_similarity(z_target, [node.smiles for node in tree.chemicals])) - max_simi_idx = np.where(similarities_ == np.max(similarities_))[0][0] - - similarities.append(np.max(similarities_)) - smiles.append(tree.chemicals[max_simi_idx].smiles) - trees.append(tree) - acts.append(act) - - max_simi_idx = np.where(similarities == np.max(similarities))[0][0] - similarity = similarities[max_simi_idx] - tree = trees[max_simi_idx] - smi = smiles[max_simi_idx] - act = acts[max_simi_idx] - - return smi, similarity, tree, act - -def fp_embedding(smi, _radius=2, _nBits=4096): - """ - General function for building variable-size & -radius Morgan fingerprints. - - Args: - smi (str): The SMILES to encode. - _radius (int, optional): Morgan fingerprint radius. Defaults to 2. - _nBits (int, optional): Morgan fingerprint length. Defaults to 4096. - - Returns: - np.ndarray: A Morgan fingerprint generated using the specified parameters. - """ - if smi is None: - return np.zeros(_nBits).reshape((-1, )).tolist() - else: - mol = Chem.MolFromSmiles(smi) - features_vec = AllChem.GetMorganFingerprintAsBitVect(mol, _radius, _nBits) - features = np.zeros((1,)) - DataStructs.ConvertToNumpyArray(features_vec, features) - return features.reshape((-1, )).tolist() - -def fp_4096(smi): - return fp_embedding(smi, _radius=2, _nBits=4096) - -def fp_2048(smi): - return fp_embedding(smi, _radius=2, _nBits=2048) - -def fp_1024(smi): - return fp_embedding(smi, _radius=2, _nBits=1024) - -def fp_512(smi): - return fp_embedding(smi, _radius=2, _nBits=512) - -def fp_256(smi): - return fp_embedding(smi, _radius=2, _nBits=256) - -def rdkit2d_embedding(smi): - # define the RDKit 2D descriptors conversion function - rdkit2d = MolConvert(src = 'SMILES', dst = 'RDKit2D') - - if smi is None: - return np.zeros(200).reshape((-1, )).tolist() - else: - return rdkit2d(smi).tolist() \ No newline at end of file diff --git a/syn_net/utils/prep_utils.py b/syn_net/utils/prep_utils.py deleted file mode 100644 index 9839bfe3..00000000 --- a/syn_net/utils/prep_utils.py +++ /dev/null @@ -1,302 +0,0 @@ -""" -This file contains various utils for data preparation and preprocessing. -""" -import numpy as np -from scipy import sparse -from dgllife.model import load_pretrained -from tdc.chem_utils import MolConvert -from sklearn.preprocessing import OneHotEncoder -from syn_net.utils.data_utils import SyntheticTree -from syn_net.utils.predict_utils import (can_react, get_action_mask, - get_reaction_mask, mol_fp, - get_mol_embedding) - - -def rdkit2d_embedding(smi): - """ - Computes an embedding using RDKit 2D descriptors. - - Args: - smi (str): SMILES string. - - Returns: - np.ndarray: A molecular embedding corresponding to the input molecule. - """ - if smi is None: - return np.zeros(200).reshape((-1, )) - else: - # define the RDKit 2D descriptor - rdkit2d = MolConvert(src = 'SMILES', dst = 'RDKit2D') - return rdkit2d(smi).reshape(-1, ) - - -def organize(st, d_mol=300, target_embedding='fp', radius=2, nBits=4096, - output_embedding='gin'): - """ - Organizes the states and steps from the input synthetic tree into sparse - matrices. - - Args: - st (SyntheticTree): The input synthetic tree to organize. - d_mol (int, optional): The molecular embedding size. Defaults to 300. - target_embedding (str, optional): Indicates what embedding type to use - for the input target (Morgan fingerprint --> 'fp' or GIN --> 'gin'). - Defaults to 'fp'. - radius (int, optional): Morgan fingerprint radius to use. Defaults to 2. - nBits (int, optional): Number of bits to use in the Morgan fingerprints. - Defaults to 4096. - output_embedding (str, optional): Indicates what type of embedding to - use for the output node states. Defaults to 'gin'. - - Raises: - ValueError: Raised if target embedding not supported. - - Returns: - sparse.csc_matrix: Node states pulled from the tree. - sparse.csc_matrix: Actions pulled from the tree. - """ - # define model to use for molecular embedding - model_type = 'gin_supervised_contextpred' - device = 'cpu' - model = load_pretrained(model_type).to(device) - model.eval() - - states = [] - steps = [] - - if output_embedding == 'gin': - d_mol = 300 - elif output_embedding == 'fp_4096': - d_mol = 4096 - elif output_embedding == 'fp_256': - d_mol = 256 - elif output_embedding == 'rdkit2d': - d_mol = 200 - - if target_embedding == 'fp': - target = mol_fp(st.root.smiles, radius, nBits).tolist() - elif target_embedding == 'gin': - target = get_mol_embedding(st.root.smiles, model=model).tolist() - else: - raise ValueError('Target embedding only supports fp and gin.') - - most_recent_mol = None - other_root_mol = None - for i, action in enumerate(st.actions): - - most_recent_mol_embedding = mol_fp(most_recent_mol, radius, nBits).tolist() - other_root_mol_embedding = mol_fp(other_root_mol, radius, nBits).tolist() - state = most_recent_mol_embedding + other_root_mol_embedding + target - - if action == 3: - step = [3] + [0]*d_mol + [-1] + [0]*d_mol + [0]*nBits - - else: - r = st.reactions[i] - mol1 = r.child[0] - if len(r.child) == 2: - mol2 = r.child[1] - else: - mol2 = None - - if output_embedding == 'gin': - step = ([action] - + get_mol_embedding(mol1, model=model).tolist() - + [r.rxn_id] - + get_mol_embedding(mol2, model=model).tolist() - + mol_fp(mol1, radius, nBits).tolist()) - elif output_embedding == 'fp_4096': - step = ([action] - + mol_fp(mol1, 2, 4096).tolist() - + [r.rxn_id] - + mol_fp(mol2, 2, 4096).tolist() - + mol_fp(mol1, radius, nBits).tolist()) - elif output_embedding == 'fp_256': - step = ([action] - + mol_fp(mol1, 2, 256).tolist() - + [r.rxn_id] - + mol_fp(mol2, 2, 256).tolist() - + mol_fp(mol1, radius, nBits).tolist()) - elif output_embedding == 'rdkit2d': - step = ([action] - + rdkit2d_embedding(mol1).tolist() - + [r.rxn_id] - + rdkit2d_embedding(mol2).tolist() - + mol_fp(mol1, radius, nBits).tolist()) - - if action == 2: - most_recent_mol = r.parent - other_root_mol = None - - elif action == 1: - most_recent_mol = r.parent - - elif action == 0: - other_root_mol = most_recent_mol - most_recent_mol = r.parent - - states.append(state) - steps.append(step) - - return sparse.csc_matrix(np.array(states)), sparse.csc_matrix(np.array(steps)) - -def synthetic_tree_generator(building_blocks, reaction_templates, max_step=15): - """ - Generates a synthetic tree from the available building blocks and reaction - templates. Used in preparing the training/validation/testing data. - - Args: - building_blocks (list): Contains SMILES strings for purchasable building - blocks. - reaction_templates (list): Contains `Reaction` objects. - max_step (int, optional): Indicates the maximum number of reaction steps - to use for building the synthetic tree data. Defaults to 15. - - Returns: - tree (SyntheticTree): The built up synthetic tree. - action (int): Index corresponding to a specific action. - """ - # Initialization - tree = SyntheticTree() - mol_recent = None - - # Start iteration - try: - for i in range(max_step): - # Encode current state - state = tree.get_state() # a set - - # Predict action type, masked selection - # Action: (Add: 0, Expand: 1, Merge: 2, End: 3) - action_proba = np.random.rand(4) - action_mask = get_action_mask(tree.get_state(), reaction_templates) - action = np.argmax(action_proba * action_mask) - - # Select first molecule - if action == 3: - # End - break - elif action == 0: - # Add - mol1 = np.random.choice(building_blocks) - else: - # Expand or Merge - mol1 = mol_recent - - # Select reaction - reaction_proba = np.random.rand(len(reaction_templates)) - - if action != 2: - rxn_mask, available = get_reaction_mask(smi=mol1, - rxns=reaction_templates) - else: - _, rxn_mask = can_react(tree.get_state(), reaction_templates) - available = [[] for rxn in reaction_templates] - - if rxn_mask is None: - if len(state) == 1: - action = 3 - break - else: - break - - rxn_id = np.argmax(reaction_proba * rxn_mask) - rxn = reaction_templates[rxn_id] - - if rxn.num_reactant == 2: - # Select second molecule - if action == 2: - # Merge - temp = set(state) - set([mol1]) - mol2 = temp.pop() - else: - # Add or Expand - mol2 = np.random.choice(available[rxn_id]) - else: - mol2 = None - - # Run reaction - mol_product = rxn.run_reaction([mol1, mol2]) - - # Update - tree.update(action, int(rxn_id), mol1, mol2, mol_product) - mol_recent = mol_product - - except Exception as e: - print(e) - action = -1 - tree = None - - if action != 3: - tree = None - else: - tree.update(action, None, None, None, None) - - return tree, action - -def prep_data(main_dir, num_rxn, out_dim): - """ - Loads the states and steps from preprocessed *.npz files and saves data - specific to the Action, Reactant 1, Reaction, and Reactant 2 networks in - their own *.npz files. - - Args: - main_dir (str): The path to the directory containing the *.npz files. - num_rxn (int): Number of reactions in the dataset. - out_dim (int): Size of the output feature vectors. - """ - - for dataset in ['train', 'valid', 'test']: - - print('Reading ' + dataset + ' data ......') - states_list = [] - steps_list = [] - for i in range(1): - states_list.append(sparse.load_npz(f'{main_dir}states_{i}_{dataset}.npz')) - steps_list.append(sparse.load_npz(f'{main_dir}steps_{i}_{dataset}.npz')) - - states = sparse.csc_matrix(sparse.vstack(states_list)) - steps = sparse.csc_matrix(sparse.vstack(steps_list)) - - # extract Action data - X = states - y = steps[:, 0] - sparse.save_npz(f'{main_dir}X_act_{dataset}.npz', X) - sparse.save_npz(f'{main_dir}y_act_{dataset}.npz', y) - - states = sparse.csc_matrix(states.A[(steps[:, 0].A != 3).reshape(-1, )]) - steps = sparse.csc_matrix(steps.A[(steps[:, 0].A != 3).reshape(-1, )]) - - # extract Reaction data - X = sparse.hstack([states, steps[:, (2 * out_dim + 2):]]) - y = steps[:, out_dim + 1] - sparse.save_npz(f'{main_dir}X_rxn_{dataset}.npz', X) - sparse.save_npz(f'{main_dir}y_rxn_{dataset}.npz', y) - - states = sparse.csc_matrix(states.A[(steps[:, 0].A != 2).reshape(-1, )]) - steps = sparse.csc_matrix(steps.A[(steps[:, 0].A != 2).reshape(-1, )]) - - enc = OneHotEncoder(handle_unknown='ignore') - enc.fit([[i] for i in range(num_rxn)]) - # import ipdb; ipdb.set_trace(context=9) - - # extract Reactant 2 data - X = sparse.hstack( - [states, - steps[:, (2 * out_dim + 2):], - sparse.csc_matrix(enc.transform(steps[:, out_dim+1].A.reshape((-1, 1))).toarray())] - ) - y = steps[:, (out_dim+2): (2 * out_dim + 2)] - sparse.save_npz(f'{main_dir}X_rt2_{dataset}.npz', X) - sparse.save_npz(f'{main_dir}y_rt2_{dataset}.npz', y) - - states = sparse.csc_matrix(states.A[(steps[:, 0].A != 1).reshape(-1, )]) - steps = sparse.csc_matrix(steps.A[(steps[:, 0].A != 1).reshape(-1, )]) - - # extract Reactant 1 data - X = states - y = steps[:, 1: (out_dim+1)] - sparse.save_npz(f'{main_dir}X_rt1_{dataset}.npz', X) - sparse.save_npz(f'{main_dir}y_rt1_{dataset}.npz', y) - - return None diff --git a/tests/README.md b/tests/README.md index 1fcb317f..28fe8f2f 100644 --- a/tests/README.md +++ b/tests/README.md @@ -1,19 +1,12 @@ # Unit tests -## Instructions -To run the unit tests, start from the main SynNet directory and run: +Sadly, the only working unittests are for the genetic algorithm for molecular optimization. -``` -export PYTHONPATH=`pwd`:$PYTHONPATH -``` - -Then, activate the SynNet conda environment, and from the current unit tests directory, run: - -``` -python -m unittest -``` +> :warning: **TODO**: write/fix unittests and remove this todo (old tests prefixed with `_test*`) ## Dataset + The data used for unit testing consists of: -* 3 randomly sampled reaction templates from the Hartenfeller-Button dataset (*rxn_set_hb_test.txt*) -* 100 randomly sampled matching building blocks from Enamine (*building_blocks_matched.csv.gz*) + +- 3 randomly sampled reaction templates from the Hartenfeller-Button dataset (*rxn_set_hb_test.txt*) +- 100 randomly sampled matching building blocks from Enamine (*building_blocks_matched.csv.gz*) diff --git a/syn_net/utils/__init__.py b/tests/__init__.py similarity index 100% rename from syn_net/utils/__init__.py rename to tests/__init__.py diff --git a/tests/filter_unmatch_tests.py b/tests/_filter_unmatch_tests.py similarity index 90% rename from tests/filter_unmatch_tests.py rename to tests/_filter_unmatch_tests.py index d4c6fe67..ca968cc4 100644 --- a/tests/filter_unmatch_tests.py +++ b/tests/_filter_unmatch_tests.py @@ -4,14 +4,13 @@ """ import pandas as pd from tqdm import tqdm -from syn_net.utils.data_utils import * +from synnet.utils.data_utils import * if __name__ == '__main__': r_path = './data/ref/rxns_hb.json.gz' bb_path = '/home/whgao/scGen/synth_net/data/enamine_us.csv.gz' - r_set = ReactionSet() - r_set.load(r_path) + r_set = ReactionSet().load(r_path) matched_mols = set() for r in tqdm(r_set.rxns): for a_list in r.available_reactants: diff --git a/tests/test_DataPreparation.py b/tests/_test_DataPreparation.py similarity index 70% rename from tests/test_DataPreparation.py rename to tests/_test_DataPreparation.py index be9af8d9..0697064e 100644 --- a/tests/test_DataPreparation.py +++ b/tests/_test_DataPreparation.py @@ -12,9 +12,9 @@ from scipy import sparse from tqdm import tqdm -from syn_net.utils.predict_utils import get_mol_embedding -from syn_net.utils.prep_utils import organize, synthetic_tree_generator, prep_data -from syn_net.utils.data_utils import SyntheticTreeSet, Reaction, ReactionSet +from synnet.encoding.gins import get_mol_embedding +from synnet.utils.prep_utils import organize, synthetic_tree_generator, prep_data +from synnet.utils.data_utils import SyntheticTreeSet, Reaction, ReactionSet TEST_DIR = Path(__file__).parent @@ -31,10 +31,10 @@ def test_process_rxn_templates(self): """ # the following file contains the three templates at the top of # 'SynNet/data/rxn_set_hb.txt' - path_to_rxn_templates = f"{TEST_DIR}/data/rxn_set_hb_test.txt" + path_to_rxn_templates = f"{TEST_DIR}/assets/rxn_set_hb_test.txt" # load the reference building blocks (100 here) - path_to_building_blocks = f"{TEST_DIR}/data/building_blocks_matched.csv.gz" + path_to_building_blocks = f"{TEST_DIR}/assets/building_blocks_matched.csv.gz" building_blocks = pd.read_csv(path_to_building_blocks, compression="gzip")[ "SMILES" ].tolist() @@ -53,8 +53,7 @@ def test_process_rxn_templates(self): # load the reference reaction templates path_to_ref_rxn_templates = f"{TEST_DIR}/data/ref/rxns_hb.json.gz" - r_ref = ReactionSet() - r_ref.load(path_to_ref_rxn_templates) + r_ref = ReactionSet().load(path_to_ref_rxn_templates) # check here that the templates were correctly saved as a ReactionSet by # comparing to a provided reference file in 'SynNet/tests/data/ref/' @@ -76,7 +75,7 @@ def test_synthetic_tree_prep(self): rxns = r_ref.rxns # load the reference building blocks (100 here) - path_to_building_blocks = f"{TEST_DIR}/data/building_blocks_matched.csv.gz" + path_to_building_blocks = f"{TEST_DIR}/assets/building_blocks_matched.csv.gz" building_blocks = pd.read_csv(path_to_building_blocks, compression="gzip")[ "SMILES" ].tolist() @@ -107,8 +106,7 @@ def test_synthetic_tree_prep(self): # check here that the synthetic trees were correctly saved by # comparing to a provided reference file in 'SynNet/tests/data/ref/' - sts_ref = SyntheticTreeSet() - sts_ref.load(f"{TEST_DIR}/data/ref/st_data.json.gz") + sts_ref = SyntheticTreeSet().load(f"{TEST_DIR}/data/ref/st_data.json.gz") for st_idx, st in enumerate(sts_ref.sts): st = st.__dict__ ref_st = sts_ref.sts[st_idx].__dict__ @@ -128,15 +126,13 @@ def test_featurization(self): save_dir = f"{TEST_DIR}/data/" reference_data_dir = f"{TEST_DIR}/data/ref/" - st_set = SyntheticTreeSet() - st_set.load(path_st) + st_set = SyntheticTreeSet().load(path_st) data = st_set.sts del st_set states = [] steps = [] - save_idx = 0 for st in tqdm(data): try: state, step = organize( @@ -158,15 +154,15 @@ def test_featurization(self): if not os.path.exists(save_dir): os.makedirs(save_dir) - sparse.save_npz(f"{save_dir}states_{save_idx}_{dataset_type}.npz", states) - sparse.save_npz(f"{save_dir}steps_{save_idx}_{dataset_type}.npz", steps) + sparse.save_npz(f"{save_dir}states_{dataset_type}.npz", states) + sparse.save_npz(f"{save_dir}steps_{dataset_type}.npz", steps) # load the reference data, which we will compare against states_ref = sparse.load_npz( - f"{reference_data_dir}states_{save_idx}_{dataset_type}.npz" + f"{reference_data_dir}states_{dataset_type}.npz" ) steps_ref = sparse.load_npz( - f"{reference_data_dir}steps_{save_idx}_{dataset_type}.npz" + f"{reference_data_dir}steps_{dataset_type}.npz" ) # check here that states and steps were correctly saved (need to convert the @@ -186,8 +182,8 @@ def test_dataprep(self): main_dir = f"{TEST_DIR}/data/" ref_dir = f"{TEST_DIR}/data/ref/" # copy data from the reference directory to use for this particular test - copyfile(f"{ref_dir}states_0_train.npz", f"{main_dir}states_0_train.npz") - copyfile(f"{ref_dir}steps_0_train.npz", f"{main_dir}steps_0_train.npz") + copyfile(f"{ref_dir}states_train.npz", f"{main_dir}states_train.npz") + copyfile(f"{ref_dir}steps_train.npz", f"{main_dir}steps_train.npz") # the lines below will save Action-, Reactant 1-, Reaction-, and Reactant 2- # specific files directly to the 'SynNet/tests/data/' directory (e.g. @@ -195,50 +191,24 @@ def test_dataprep(self): # 'X_rt1_{train/test/valid}.npz' and 'y_rt1_{train/test/valid}.npz' # 'X_rxn_{train/test/valid}.npz' and 'y_rxn_{train/test/valid}.npz' # 'X_rt2_{train/test/valid}.npz' and 'y_rt2_{train/test/valid}.npz' - prep_data(main_dir=main_dir, num_rxn=3, out_dim=300) + prep_data(main_dir=main_dir, num_rxn=3, out_dim=300,datasets=["train"]) # check that the saved files match the reference files in # 'SynNet/tests/data/ref': + def _compare_to_reference(network_type: str): + X = sparse.load_npz(f"{main_dir}X_{network_type}_train.npz") + y = sparse.load_npz(f"{main_dir}y_{network_type}_train.npz") - # Action network data - X_act = sparse.load_npz(f"{main_dir}X_act_train.npz") - y_act = sparse.load_npz(f"{main_dir}y_act_train.npz") - - X_act_ref = sparse.load_npz(f"{ref_dir}X_act_train.npz") - y_act_ref = sparse.load_npz(f"{ref_dir}y_act_train.npz") - - self.assertEqual(X_act.toarray().all(), X_act_ref.toarray().all()) - self.assertEqual(y_act.toarray().all(), y_act_ref.toarray().all()) - - # Reactant 1 network data - X_rt1 = sparse.load_npz(f"{main_dir}X_rt1_train.npz") - y_rt1 = sparse.load_npz(f"{main_dir}y_rt1_train.npz") - - X_rt1_ref = sparse.load_npz(f"{ref_dir}X_rt1_train.npz") - y_rt1_ref = sparse.load_npz(f"{ref_dir}y_rt1_train.npz") + Xref = sparse.load_npz(f"{ref_dir}X_{network_type}_train.npz") + yref = sparse.load_npz(f"{ref_dir}y_{network_type}_train.npz") - self.assertEqual(X_rt1.toarray().all(), X_rt1_ref.toarray().all()) - self.assertEqual(y_rt1.toarray().all(), y_rt1_ref.toarray().all()) + self.assertEqual(X.toarray().all(), Xref.toarray().all(),msg=f"{network_type=}") + self.assertEqual(y.toarray().all(), yref.toarray().all(),msg=f"{network_type=}") - # Reaction network data - X_rxn = sparse.load_npz(f"{main_dir}X_rxn_train.npz") - y_rxn = sparse.load_npz(f"{main_dir}y_rxn_train.npz") + for network in ["act", "rt1", "rxn", "rt2"]: + _compare_to_reference(network) - X_rxn_ref = sparse.load_npz(f"{ref_dir}X_rxn_train.npz") - y_rxn_ref = sparse.load_npz(f"{ref_dir}y_rxn_train.npz") - self.assertEqual(X_rxn.toarray().all(), X_rxn_ref.toarray().all()) - self.assertEqual(y_rxn.toarray().all(), y_rxn_ref.toarray().all()) - - # Reactant 2 network data - X_rt2 = sparse.load_npz(f"{main_dir}X_rt2_train.npz") - y_rt2 = sparse.load_npz(f"{main_dir}y_rt2_train.npz") - - X_rt2_ref = sparse.load_npz(f"{ref_dir}X_rt2_train.npz") - y_rt2_ref = sparse.load_npz(f"{ref_dir}y_rt2_train.npz") - - self.assertEqual(X_rt2.toarray().all(), X_rt2_ref.toarray().all()) - self.assertEqual(y_rt2.toarray().all(), y_rt2_ref.toarray().all()) def test_bb_emb(self): """ @@ -255,7 +225,7 @@ def test_bb_emb(self): model.eval() # load the building blocks - path_to_building_blocks = f"{TEST_DIR}/data/building_blocks_matched.csv.gz" + path_to_building_blocks = f"{TEST_DIR}/assets/building_blocks_matched.csv.gz" building_blocks = pd.read_csv(path_to_building_blocks, compression="gzip")[ "SMILES" ].tolist() @@ -272,3 +242,7 @@ def test_bb_emb(self): embeddings_ref = np.load(f"{ref_dir}building_blocks_emb.npy") self.assertEqual(embeddings.all(), embeddings_ref.all()) + + +if __name__=="__main__": + TestDataPrep() \ No newline at end of file diff --git a/tests/test_Predict.py b/tests/_test_Predict.py similarity index 88% rename from tests/test_Predict.py rename to tests/_test_Predict.py index a2b5454a..626fe49a 100644 --- a/tests/test_Predict.py +++ b/tests/_test_Predict.py @@ -7,13 +7,12 @@ import numpy as np import pandas as pd -from syn_net.utils.predict_utils import ( - synthetic_tree_decoder_multireactant, +from synnet.utils.predict_utils import ( + synthetic_tree_decoder_greedy_search, mol_fp, - load_modules_from_checkpoint, ) -from syn_net.utils.data_utils import SyntheticTreeSet, ReactionSet - +from synnet.utils.data_utils import SyntheticTreeSet, ReactionSet +from syn_net.models.chkpt_loader import load_modules_from_checkpoint TEST_DIR = Path(__file__).parent @@ -42,7 +41,7 @@ def test_predict(self): # define path to the reaction templates and purchasable building blocks path_to_reaction_file = f"{ref_dir}rxns_hb.json.gz" - path_to_building_blocks = f"{TEST_DIR}/data/building_blocks_matched.csv.gz" + path_to_building_blocks = f"{TEST_DIR}/assets/building_blocks_matched.csv.gz" # define paths to pretrained modules path_to_act = f"{ref_dir}act.ckpt" @@ -57,8 +56,7 @@ def test_predict(self): bb_dict = {building_blocks[i]: i for i in range(len(building_blocks))} # load the reaction templates as a ReactionSet object - rxn_set = ReactionSet() - rxn_set.load(path_to_reaction_file) + rxn_set = ReactionSet().load(path_to_reaction_file) rxns = rxn_set.rxns # load the pre-trained modules @@ -76,8 +74,7 @@ def test_predict(self): # load the query molecules (i.e. molecules to decode) path_to_data = f"{ref_dir}st_data.json.gz" - sts = SyntheticTreeSet() - sts.load(path_to_data) + sts = SyntheticTreeSet().load(path_to_data) smis_query = [st.root.smiles for st in sts.sts] # start to decode the query molecules (no multiprocessing for the unit tests here) @@ -86,7 +83,7 @@ def test_predict(self): trees = [] for smi in smis_query: emb = mol_fp(smi) - smi, similarity, tree, action = synthetic_tree_decoder_multireactant( + smi, similarity, tree, action = synthetic_tree_decoder_greedy_search( z_target=emb, building_blocks=building_blocks, bb_dict=bb_dict, diff --git a/tests/test_Training.py b/tests/_test_Training.py similarity index 72% rename from tests/test_Training.py rename to tests/_test_Training.py index 75ff83fd..0cfbc039 100644 --- a/tests/test_Training.py +++ b/tests/_test_Training.py @@ -4,17 +4,33 @@ from pathlib import Path import unittest import shutil - +from multiprocessing import cpu_count import pytorch_lightning as pl from pytorch_lightning import loggers as pl_loggers from scipy import sparse import torch -from syn_net.models.mlp import MLP, load_array +from synnet.models.mlp import MLP, load_array +from synnet.MolEmbedder import MolEmbedder TEST_DIR = Path(__file__).parent +REACTION_TEMPLATES_FILE = f"{TEST_DIR}/assets/rxn_set_hb_test.txt" + +def _fetch_molembedder(): + file = "tests/data/building_blocks_emb.npy" + molembedder = MolEmbedder().load_precomputed(file).init_balltree(metric="euclidean") + return molembedder + +class TestReactionTemplateFile(unittest.TestCase): + + def test_number_of_reaction_templates(self): + """ Count number of lines in file, i.e. the number of reaction templates.""" + with open(REACTION_TEMPLATES_FILE,"r") as f: + nReactionTemplates = sum(1 for _ in f) + self.assertEqual(nReactionTemplates,3) + class TestTraining(unittest.TestCase): """ @@ -22,6 +38,11 @@ class TestTraining(unittest.TestCase): reaction network, (4) reactant 2 network. """ + def setUp(self) -> None: + import warnings + warnings.filterwarnings("ignore", ".*does not have many workers.*") + warnings.filterwarnings("ignore", ".*GPU available but not used.*") + def test_action_network(self): """ Tests the Action Network. @@ -31,12 +52,14 @@ def test_action_network(self): nbits = 4096 batch_size = 10 epochs = 2 - ncpu = 2 + ncpu = min(2,cpu_count()) validation_option = "accuracy" ref_dir = f"{TEST_DIR}/data/ref/" X = sparse.load_npz(ref_dir + "X_act_train.npz") + assert X.shape==(4,3*nbits) # (4,12288) y = sparse.load_npz(ref_dir + "y_act_train.npz") + assert y.shape==(4,1) # (4,1) X = torch.Tensor(X.A) y = torch.LongTensor( y.A.reshape( @@ -69,15 +92,15 @@ def test_action_network(self): f"act_{embedding}_{radius}_{nbits}_logs/" ) trainer = pl.Trainer( - max_epochs=epochs, progress_bar_refresh_rate=20, logger=tb_logger + max_epochs=epochs, logger=tb_logger, weights_summary=None, ) trainer.fit(mlp, train_data_iter, valid_data_iter) train_loss = float(trainer.callback_metrics["train_loss"]) - train_loss_ref = 1.4203987121582031 + train_loss_ref = 1.2967982292175293 shutil.rmtree(f"act_{embedding}_{radius}_{nbits}_logs/") - self.assertEqual(train_loss, train_loss_ref) + self.assertAlmostEqual(train_loss, train_loss_ref) def test_reactant1_network(self): """ @@ -86,17 +109,19 @@ def test_reactant1_network(self): embedding = "fp" radius = 2 nbits = 4096 - out_dim = 300 + out_dim = 300 # Note: out_dim 300 = gin embedding batch_size = 10 epochs = 2 - ncpu = 2 - validation_option = "nn_accuracy" + ncpu = min(2,cpu_count()) + validation_option = "nn_accuracy_gin_unittest" ref_dir = f"{TEST_DIR}/data/ref/" # load the reaction data X = sparse.load_npz(ref_dir + "X_rt1_train.npz") + assert X.shape==(2,3*nbits) # (4,12288) X = torch.Tensor(X.A) y = sparse.load_npz(ref_dir + "y_rt1_train.npz") + assert y.shape==(2,300) # (2,300) y = torch.Tensor(y.A) train_data_iter = load_array((X, y), batch_size, ncpu=ncpu, is_train=True) @@ -117,6 +142,7 @@ def test_reactant1_network(self): optimizer="adam", learning_rate=1e-4, val_freq=10, + molembedder=_fetch_molembedder(), ncpu=ncpu, ) @@ -124,15 +150,15 @@ def test_reactant1_network(self): f"rt1_{embedding}_{radius}_{nbits}_logs/" ) trainer = pl.Trainer( - max_epochs=epochs, progress_bar_refresh_rate=20, logger=tb_logger + max_epochs=epochs, logger=tb_logger, weights_summary=None, ) trainer.fit(mlp, train_data_iter, valid_data_iter) train_loss = float(trainer.callback_metrics["train_loss"]) - train_loss_ref = 0.35571354627609253 + train_loss_ref = 0.33368119597435 shutil.rmtree(f"rt1_{embedding}_{radius}_{nbits}_logs/") - self.assertEqual(train_loss, train_loss_ref) + self.assertAlmostEqual(train_loss, train_loss_ref) def test_reaction_network(self): """ @@ -143,13 +169,15 @@ def test_reaction_network(self): nbits = 4096 batch_size = 10 epochs = 2 - ncpu = 2 - n_templates = 3 # num templates in 'data/rxn_set_hb_test.txt' + ncpu = min(2,cpu_count()) + n_templates = 3 # num templates in `REACTION_TEMPLATES_FILE` validation_option = "accuracy" ref_dir = f"{TEST_DIR}/data/ref/" X = sparse.load_npz(ref_dir + "X_rxn_train.npz") + assert X.shape==(2,4*nbits) # (2,16384) y = sparse.load_npz(ref_dir + "y_rxn_train.npz") + assert y.shape==(2, 1) # (2, 1) X = torch.Tensor(X.A) y = torch.LongTensor( y.A.reshape( @@ -182,7 +210,7 @@ def test_reaction_network(self): f"rxn_{embedding}_{radius}_{nbits}_logs/" ) trainer = pl.Trainer( - max_epochs=epochs, progress_bar_refresh_rate=20, logger=tb_logger + max_epochs=epochs, logger=tb_logger, weights_summary=None, ) trainer.fit(mlp, train_data_iter, valid_data_iter) @@ -190,7 +218,7 @@ def test_reaction_network(self): train_loss_ref = 1.1214743852615356 shutil.rmtree(f"rxn_{embedding}_{radius}_{nbits}_logs/") - self.assertEqual(train_loss, train_loss_ref) + self.assertAlmostEqual(train_loss, train_loss_ref,places=-6) def test_reactant2_network(self): """ @@ -199,16 +227,18 @@ def test_reactant2_network(self): embedding = "fp" radius = 2 nbits = 4096 - out_dim = 300 + out_dim = 300 # Note: out_dim 300 = gin embedding batch_size = 10 epochs = 2 - ncpu = 2 + ncpu = min(2,cpu_count()) n_templates = 3 # num templates in 'data/rxn_set_hb_test.txt' - validation_option = "nn_accuracy" + validation_option = "nn_accuracy_gin_unittest" ref_dir = f"{TEST_DIR}/data/ref/" X = sparse.load_npz(ref_dir + "X_rt2_train.npz") + assert X.shape==(2,4*nbits+n_templates) # (2,16387) y = sparse.load_npz(ref_dir + "y_rt2_train.npz") + assert y.shape==(2,300) # (2,300) X = torch.Tensor(X.A) y = torch.Tensor(y.A) train_data_iter = load_array((X, y), batch_size, ncpu=ncpu, is_train=True) @@ -230,6 +260,7 @@ def test_reactant2_network(self): optimizer="adam", learning_rate=1e-4, val_freq=10, + molembedder=_fetch_molembedder(), ncpu=ncpu, ) @@ -237,12 +268,12 @@ def test_reactant2_network(self): f"rt2_{embedding}_{radius}_{nbits}_logs/" ) trainer = pl.Trainer( - max_epochs=epochs, progress_bar_refresh_rate=20, logger=tb_logger + max_epochs=epochs, logger=tb_logger, weights_summary=None, ) trainer.fit(mlp, train_data_iter, valid_data_iter) train_loss = float(trainer.callback_metrics["train_loss"]) - train_loss_ref = 0.41246509552001953 + train_loss_ref = 0.3026905953884125 shutil.rmtree(f"rt2_{embedding}_{radius}_{nbits}_logs/") - self.assertEqual(train_loss, train_loss_ref) + self.assertAlmostEqual(train_loss, train_loss_ref) diff --git a/tests/data/building_blocks_matched.csv.gz b/tests/assets/building_blocks_matched.csv.gz similarity index 100% rename from tests/data/building_blocks_matched.csv.gz rename to tests/assets/building_blocks_matched.csv.gz diff --git a/tests/data/rxn_set_hb_test.txt b/tests/assets/rxn_set_hb_test.txt similarity index 100% rename from tests/data/rxn_set_hb_test.txt rename to tests/assets/rxn_set_hb_test.txt diff --git a/tests/assets/syntree-small.json b/tests/assets/syntree-small.json new file mode 100644 index 00000000..8b865180 --- /dev/null +++ b/tests/assets/syntree-small.json @@ -0,0 +1,139 @@ +{ + "reactions": [ + { + "rxn_id": 12, + "rtype": 2, + "parent": "CCOc1ccc(CCNC(=O)CN2N=NC=C2CN2CCC(C(=O)O)CC2)cc1OCC", + "child": [ + "CCOc1ccc(CCNC(=O)CCl)cc1OCC", + "C#CCN1CCC(C(=O)O)CC1.Cl" + ], + "depth": 0.5, + "index": 0 + }, + { + "rxn_id": 47, + "rtype": 2, + "parent": "C=C(C)C(=O)OCCNC(=O)N(c1ccc(C#N)cc1C)C1CC1", + "child": [ + "C=C(C)C(=O)OCCN=C=O", + "Cc1cc(C#N)ccc1NC1CC1" + ], + "depth": 0.5, + "index": 1 + }, + { + "rxn_id": 15, + "rtype": 2, + "parent": "C=C(C)C(=O)OCCNC(=O)N(c1ccc(C2=NNC(C3CCN(Cc4cnnn4CC(=O)NCCc4ccc(OCC)c(OCC)c4)CC3)=N2)cc1C)C1CC1", + "child": [ + "C=C(C)C(=O)OCCNC(=O)N(c1ccc(C#N)cc1C)C1CC1", + "CCOc1ccc(CCNC(=O)CN2N=NC=C2CN2CCC(C(=O)O)CC2)cc1OCC" + ], + "depth": 1.5, + "index": 2 + }, + { + "rxn_id": 49, + "rtype": 1, + "parent": "C=C(C)C(=O)OCCNC(=O)N(c1ccc(-c2n[nH]c(C3CCN(Cc4cnnn4CC4=NCCc5cc(OCC)c(OCC)cc54)CC3)n2)cc1C)C1CC1", + "child": [ + "C=C(C)C(=O)OCCNC(=O)N(c1ccc(C2=NNC(C3CCN(Cc4cnnn4CC(=O)NCCc4ccc(OCC)c(OCC)c4)CC3)=N2)cc1C)C1CC1" + ], + "depth": 2.5, + "index": 3 + } + ], + "chemicals": [ + { + "smiles": "CCOc1ccc(CCNC(=O)CCl)cc1OCC", + "parent": 12, + "child": null, + "is_leaf": true, + "is_root": false, + "depth": 0, + "index": 0 + }, + { + "smiles": "C#CCN1CCC(C(=O)O)CC1.Cl", + "parent": 12, + "child": null, + "is_leaf": true, + "is_root": false, + "depth": 0, + "index": 1 + }, + { + "smiles": "CCOc1ccc(CCNC(=O)CN2N=NC=C2CN2CCC(C(=O)O)CC2)cc1OCC", + "parent": 15, + "child": 12, + "is_leaf": false, + "is_root": false, + "depth": 1, + "index": 2 + }, + { + "smiles": "C=C(C)C(=O)OCCN=C=O", + "parent": 47, + "child": null, + "is_leaf": true, + "is_root": false, + "depth": 0, + "index": 3 + }, + { + "smiles": "Cc1cc(C#N)ccc1NC1CC1", + "parent": 47, + "child": null, + "is_leaf": true, + "is_root": false, + "depth": 0, + "index": 4 + }, + { + "smiles": "C=C(C)C(=O)OCCNC(=O)N(c1ccc(C#N)cc1C)C1CC1", + "parent": 15, + "child": 47, + "is_leaf": false, + "is_root": false, + "depth": 1, + "index": 5 + }, + { + "smiles": "C=C(C)C(=O)OCCNC(=O)N(c1ccc(C2=NNC(C3CCN(Cc4cnnn4CC(=O)NCCc4ccc(OCC)c(OCC)c4)CC3)=N2)cc1C)C1CC1", + "parent": 49, + "child": 15, + "is_leaf": false, + "is_root": false, + "depth": 2.0, + "index": 6 + }, + { + "smiles": "C=C(C)C(=O)OCCNC(=O)N(c1ccc(-c2n[nH]c(C3CCN(Cc4cnnn4CC4=NCCc5cc(OCC)c(OCC)cc54)CC3)n2)cc1C)C1CC1", + "parent": null, + "child": 49, + "is_leaf": false, + "is_root": true, + "depth": 3.0, + "index": 7 + } + ], + "root": { + "smiles": "C=C(C)C(=O)OCCNC(=O)N(c1ccc(-c2n[nH]c(C3CCN(Cc4cnnn4CC4=NCCc5cc(OCC)c(OCC)cc54)CC3)n2)cc1C)C1CC1", + "parent": null, + "child": 49, + "is_leaf": false, + "is_root": true, + "depth": 3.0, + "index": 7 + }, + "depth": 3.0, + "actions": [ + 0, + 0, + 2, + 1, + 3 + ], + "rxn_id2type": null +} \ No newline at end of file diff --git a/tests/data/ref/states_0_train.npz b/tests/data/ref/states_train.npz similarity index 100% rename from tests/data/ref/states_0_train.npz rename to tests/data/ref/states_train.npz diff --git a/tests/data/ref/steps_0_train.npz b/tests/data/ref/steps_train.npz similarity index 100% rename from tests/data/ref/steps_0_train.npz rename to tests/data/ref/steps_train.npz diff --git a/tests/test_Optimization.py b/tests/test_Optimization.py index fabdd14f..9d69b139 100644 --- a/tests/test_Optimization.py +++ b/tests/test_Optimization.py @@ -3,7 +3,7 @@ """ import unittest import numpy as np -from syn_net.utils.ga_utils import crossover, mutation, fitness_sum +from synnet.utils.ga_utils import crossover, mutation, fitness_sum class TestOptimization(unittest.TestCase):