-
Notifications
You must be signed in to change notification settings - Fork 113
【BAAI】add MoFlow pretraining std case #397
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
6e298f7
add MoFlow std case
308bc2e
update readme
11e50f3
add case example for test_conf
2cb2ece
change to comment
2dc1140
rdkit add version
abee046
add jit & cuda_graph to mutable_params, overwritten by vendors are al…
9d23243
rename config_name to dataset_name
0109774
set time statistic variables to 0
eb1fd00
update seed and target_nuv
bfc3d1d
update 1x8 result for official bs
9d09a23
update notice for readme
67bd369
Update test_conf.py
yuzhou03 285bb7f
Merge branch 'main' into moflow
yuzhou03 e019f7a
Merge branch 'main' into moflow
yuzhou03 fc84f0d
Merge branch 'main' into moflow
shh2000 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
|
||
## Model Introduction | ||
MoFlow is a model for molecule generation that leverages Normalizing Flows. Normalizing Flows is a class of generative neural networks that directly models the probability density of the data. They consist of a sequence of invertible transformations that convert the input data that follow some hard-to-model distribution into a latent code that follows a normal distribution which can then be easily used for sampling. | ||
|
||
MoFlow was first introduced by Chengxi Zang et al. in their paper titled "MoFlow: An Invertible Flow Model for Generating Molecular Graphs" [paper](https://arxiv.org/pdf/2006.10137.pdf). | ||
|
||
|
||
|
||
## Model source code | ||
This repository includes software from [MoFlow](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/DrugDiscovery/MoFlow) | ||
licensed under the Apache License, Version 2.0 | ||
|
||
Some of the files in this directory were modified by BAAI in 2024 to support FlagPerf. | ||
|
||
## Dataset | ||
### getting the data | ||
This [original source code repository](https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/DrugDiscovery/MoFlow#getting-the-data) contains the prepare_datasets.sh script that will automatically download and process the dataset. By default, data will be downloaded to the /data/ directory in the container. | ||
```bash | ||
bash prepare_datasets.sh | ||
``` | ||
### preprocess the dataset | ||
Start the container with [Dockerfile](https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/DrugDiscovery/MoFlow/Dockerfile). Enter the container. | ||
excute the folowing script to preprocess the dataset. | ||
```bash | ||
python3 scripts/data_preprocess.py | ||
``` | ||
|
||
### dataset strucutres | ||
preview directory structures of \<data_dir\> | ||
```bash | ||
tree . | ||
``` | ||
|
||
``` | ||
. | ||
├── valid_idx_zinc250k.json | ||
├── zinc250k.csv | ||
└── zinc250k_relgcn_kekulized_ggnp.npz | ||
``` | ||
|
||
| FileName | Size(Bytes) | MD5 | | ||
| ---------------------------------- | ----------- | -------------------------------- | | ||
| valid_idx_zinc250k.json | 187832 | f8045b49a413c31136a0645d30c0b846 | | ||
| zinc250k.csv | 23736231 | cd330eafb7a2cc413b3c9cafaf3efece | | ||
| zinc250k_relgcn_kekulized_ggnp.npz | 375680462 | c91985e309a9f76457169859dbe1e662 | | ||
|
||
|
||
## Checkpoint | ||
- None | ||
|
||
## AI Frameworks && Accelerators supports | ||
|
||
| | Pytorch | Paddle | TensorFlow2 | | ||
| ---------- | ------------------------------------------ | ------ | ----------- | | ||
| Nvidia GPU | [✅](../../nvidia/moflow-pytorch/README.md) | N/A | N/A | |
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from ._base import * | ||
from .mutable_params import mutable_params |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
# DO NOT MODIFY THESE REQUIRED PARAMETERS | ||
|
||
# Required parameters | ||
vendor: str = None | ||
data_dir: str = None | ||
name: str = "moflow" | ||
cudnn_benchmark: bool = False | ||
cudnn_deterministic: bool = True | ||
|
||
# Optional parameters | ||
|
||
# ========================================================= | ||
# data | ||
# ========================================================= | ||
# The config to choose. This parameter allows one to switch between different datasets. | ||
# and their dedicated configurations of the neural network. By default, a pre-defined "zinc250k" config is used. | ||
dataset_name: str = "zinc250k" | ||
# Number of workers in the data loader. | ||
num_workers: int = 4 | ||
|
||
# ========================================================= | ||
# loss scale | ||
# ========================================================= | ||
# Base learning rate. | ||
lr: float = 0.0005 | ||
# beta1 parameter for the optimizer. | ||
beta1: float = 0.9 | ||
# beta2 parameter for the optimizer. | ||
beta2: float = 0.99 | ||
# Gradient clipping norm. | ||
clip: float = 1.0 | ||
# ========================================================= | ||
# train && evaluate | ||
# ========================================================= | ||
# Batch size per GPU for training | ||
train_batch_size: int = 512 | ||
eval_batch_size: int = 100 | ||
|
||
target_nuv: float = 87.9 | ||
|
||
# Frequency for saving checkpoints, expressed in epochs. If -1 is provided, checkpoints will not be saved. | ||
save_epochs: int = 50 | ||
# Evaluation frequency, expressed in epochs. If -1 is provided, an evaluation will not be performed. | ||
eval_epochs: int = 5 | ||
|
||
# Number of warmup steps. This value is used for benchmarking and for CUDA graph capture. | ||
warmup_steps: int = 20 | ||
# Number of steps used for training/inference. This parameter allows finishing. | ||
# training earlier than the specified number of epochs. | ||
# If used with inference, it allows generating more molecules (by default only a single batch of molecules is generated). | ||
steps: int = -1 | ||
# Temperature used for sampling. | ||
temperature: float = 0.3 | ||
first_epoch: int = 0 | ||
epochs: int = 300 | ||
|
||
allow_untrained = False | ||
|
||
do_train = True | ||
fp16 = False | ||
amp: bool = True | ||
distributed: bool = True | ||
|
||
# Directory where checkpoints are stored | ||
results_dir: str = "moflow_results" | ||
# Path to store generated molecules. If an empty string is provided, predictions will not be saved (useful for benchmarking and debugging). | ||
# predictions_path: str = "moflow_results/predictions.smi" | ||
# ========================================================= | ||
# experiment | ||
# ========================================================= | ||
# Compile the model with `torch.jit.script`. Can be used to speed up training or inference. | ||
jit: bool = False | ||
# Capture GPU kernels with CUDA graphs. This option allows to speed up training | ||
cuda_graph: bool = True | ||
# Verbosity level. Specify the following values: 0, 1, 2, 3, where 0 means minimal verbosity (errors only) and 3 - maximal (debugging). | ||
verbosity: int = 1 | ||
# Path for DLLogger log. This file will contain information about the speed and accuracy of the model during training and inference. | ||
# Note that if the file already exists, new logs will be added at the end. | ||
log_path: str = "moflow_results/moflow.json" | ||
yuzhou03 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Frequency for writing logs, expressed in steps. | ||
log_interval: int = 20 | ||
|
||
|
||
# Apply validity correction after the generation of the molecules. | ||
correct_validity: bool = False | ||
# ========================================================= | ||
# utils | ||
# ========================================================= | ||
# Random seed used to initialize the distributed loaders | ||
seed: int = 1 | ||
dist_backend: str = 'nccl' | ||
|
||
device: str = None | ||
|
||
# ========================================================= | ||
# for driver | ||
# ========================================================= | ||
# rank of the GPU, used to launch distributed training. | ||
local_rank: int = -1 | ||
use_env: bool = True | ||
log_freq: int = 500 | ||
print_freq: int = 500 | ||
n_device: int = 1 | ||
sync_bn: bool = False | ||
gradient_accumulation_steps: int = 1 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
mutable_params = [ | ||
'vendor', 'data_dir', 'lr', 'train_batch_size', 'eval_batch_size', | ||
'do_train', 'amp', 'fp16', 'distributed', 'dist_backend', 'num_workers', | ||
'device', 'cudnn_benchmark', 'cudnn_deterministic', | ||
'jit', 'cuda_graph' | ||
] |
Empty file.
109 changes: 109 additions & 0 deletions
109
training/benchmarks/moflow/pytorch/data/data_frame_parser.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
# Copyright 2020 Chengxi Zang | ||
# | ||
# Permission is hereby granted, free of charge, to any person obtaining a | ||
# copy of this software and associated documentation files (the "Software"), | ||
# to deal in the Software without restriction, including without limitation | ||
# the rights to use, copy, modify, merge, publish, distribute, sublicense, | ||
# and/or sell copies of the Software, and to permit persons to whom the | ||
# Software is furnished to do so, subject to the following conditions: | ||
# | ||
# The above copyright notice and this permission notice shall be included | ||
# in all copies or substantial portions of the Software. | ||
# | ||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING | ||
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS | ||
# IN THE SOFTWARE. | ||
|
||
|
||
from logging import getLogger | ||
import traceback | ||
from typing import List | ||
|
||
import numpy as np | ||
import pandas as pd | ||
from rdkit import Chem | ||
from tqdm import tqdm | ||
|
||
from moflow.data.encoding import MolEncoder, EncodingError | ||
from moflow.data.data_loader import NumpyTupleDataset | ||
|
||
|
||
class DataFrameParser: | ||
""" | ||
This DataFrameParser parses pandas dataframe containing SMILES and, optionally, some additional features. | ||
|
||
Args: | ||
encoder (MolEncoder): encoder instance | ||
labels (list): labels column that should be loaded | ||
smiles_col (str): smiles column | ||
""" | ||
|
||
def __init__(self, encoder: MolEncoder, | ||
labels: List[str], | ||
smiles_col: str = 'smiles'): | ||
super(DataFrameParser, self).__init__() | ||
self.labels = labels | ||
self.smiles_col = smiles_col | ||
self.logger = getLogger(__name__) | ||
self.encoder = encoder | ||
|
||
def parse(self, df: pd.DataFrame) -> NumpyTupleDataset: | ||
"""Parse DataFrame using `encoder` and prepare a dataset instance | ||
|
||
Labels are extracted from `labels` columns and input features are | ||
extracted from smiles information in `smiles` column. | ||
""" | ||
all_nodes = [] | ||
all_edges = [] | ||
|
||
total_count = df.shape[0] | ||
fail_count = 0 | ||
success_count = 0 | ||
for smiles in tqdm(df[self.smiles_col], total=df.shape[0]): | ||
try: | ||
mol = Chem.MolFromSmiles(smiles) | ||
if mol is None: | ||
fail_count += 1 | ||
continue | ||
# Note that smiles expression is not unique. | ||
# we obtain canonical smiles | ||
nodes, edges = self.encoder.encode_mol(mol) | ||
|
||
except EncodingError as e: | ||
fail_count += 1 | ||
continue | ||
except Exception as e: | ||
self.logger.warning('parse(), type: {}, {}' | ||
.format(type(e).__name__, e.args)) | ||
self.logger.info(traceback.format_exc()) | ||
fail_count += 1 | ||
continue | ||
all_nodes.append(nodes) | ||
all_edges.append(edges) | ||
success_count += 1 | ||
|
||
result = [np.array(all_nodes), np.array(all_edges), *(df[label_col].values for label_col in self.labels)] | ||
self.logger.info('Preprocess finished. FAIL {}, SUCCESS {}, TOTAL {}' | ||
.format(fail_count, success_count, total_count)) | ||
|
||
dataset = NumpyTupleDataset(result) | ||
return dataset |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
# Copyright 2020 Chengxi Zang | ||
# | ||
# Permission is hereby granted, free of charge, to any person obtaining a | ||
# copy of this software and associated documentation files (the "Software"), | ||
# to deal in the Software without restriction, including without limitation | ||
# the rights to use, copy, modify, merge, publish, distribute, sublicense, | ||
# and/or sell copies of the Software, and to permit persons to whom the | ||
# Software is furnished to do so, subject to the following conditions: | ||
# | ||
# The above copyright notice and this permission notice shall be included | ||
# in all copies or substantial portions of the Software. | ||
# | ||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING | ||
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS | ||
# IN THE SOFTWARE. | ||
|
||
|
||
import os | ||
import logging | ||
from typing import Any, Callable, Iterable, Optional, Tuple | ||
|
||
import numpy as np | ||
from torch.utils.data import Dataset | ||
|
||
|
||
class NumpyTupleDataset(Dataset): | ||
"""Dataset of a tuple of datasets. | ||
|
||
It combines multiple datasets into one dataset. Each example is represented | ||
by a tuple whose ``i``-th item corresponds to the i-th dataset. | ||
And each ``i``-th dataset is expected to be an instance of numpy.ndarray. | ||
|
||
Args: | ||
datasets: Underlying datasets. The ``i``-th one is used for the | ||
``i``-th item of each example. All datasets must have the same | ||
length. | ||
transform: An optional function applied to an item bofre returning | ||
""" | ||
|
||
def __init__(self, datasets: Iterable[np.ndarray], transform: Optional[Callable] = None) -> None: | ||
if not datasets: | ||
raise ValueError('no datasets are given') | ||
length = len(datasets[0]) | ||
for i, dataset in enumerate(datasets): | ||
if len(dataset) != length: | ||
raise ValueError( | ||
'dataset of the index {} has a wrong length'.format(i)) | ||
self._datasets = datasets | ||
self._length = length | ||
self.transform = transform | ||
|
||
def __len__(self) -> int: | ||
return self._length | ||
|
||
def __getitem__(self, index: int) -> Tuple[Any]: | ||
item = [dataset[index] for dataset in self._datasets] | ||
|
||
if self.transform: | ||
item = self.transform(item) | ||
return item | ||
|
||
def get_datasets(self) -> Tuple[np.ndarray]: | ||
return self._datasets | ||
|
||
|
||
def save(self, filepath: str) -> None: | ||
"""save the dataset to filepath in npz format | ||
|
||
Args: | ||
filepath (str): filepath to save dataset. It is recommended to end | ||
with '.npz' extension. | ||
""" | ||
np.savez(filepath, *self._datasets) | ||
logging.info('Save {} done.'.format(filepath)) | ||
|
||
@classmethod | ||
def load(cls, filepath: str, transform: Optional[Callable] = None): | ||
logging.info('Loading file {}'.format(filepath)) | ||
if not os.path.exists(filepath): | ||
raise ValueError('Invalid filepath {} for dataset'.format(filepath)) | ||
load_data = np.load(filepath) | ||
result = [] | ||
i = 0 | ||
while True: | ||
key = 'arr_{}'.format(i) | ||
if key in load_data.keys(): | ||
result.append(load_data[key]) | ||
i += 1 | ||
else: | ||
break | ||
return cls(result, transform) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.