-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Zhiyuan Chen <[email protected]>
- Loading branch information
1 parent
da9bc56
commit 7cbdf5b
Showing
67 changed files
with
1,832 additions
and
170 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
datas | ||
ser | ||
marz | ||
manuel | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
--- | ||
authors: | ||
- Zhiyuan Chen | ||
date: 2024-05-04 | ||
--- | ||
|
||
# MultiTask | ||
|
||
::: multimolecule.data.multitask |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
--- | ||
authors: | ||
- Zhiyuan Chen | ||
date: 2024-05-04 | ||
--- | ||
|
||
# MultiMoleculeConfig | ||
|
||
::: multimolecule.runners.MultiMoleculeConfig |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
--- | ||
authors: | ||
- Zhiyuan Chen | ||
date: 2024-05-04 | ||
--- | ||
|
||
# runners | ||
|
||
--8<-- "multimolecule/runners/README.md:8:" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
--- | ||
authors: | ||
- Zhiyuan Chen | ||
date: 2024-05-04 | ||
--- | ||
|
||
# MultiMoleculeRunner | ||
|
||
::: multimolecule.runners.MultiMoleculeRunner |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# MultiMolecule | ||
# Copyright (C) 2024-Present MultiMolecule | ||
|
||
# This program is free software: you can redistribute it and/or modify | ||
# it under the terms of the GNU Affero General Public License as published by | ||
# the Free Software Foundation, either version 3 of the License, or | ||
# any later version. | ||
|
||
# This program is distributed in the hope that it will be useful, | ||
# but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
# GNU Affero General Public License for more details. | ||
|
||
# You should have received a copy of the GNU Affero General Public License | ||
# along with this program. If not, see <http://www.gnu.org/licenses/>. | ||
|
||
from .run import evaluate, inference, train | ||
|
||
__all__ = ["train", "evaluate", "inference"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
# MultiMolecule | ||
# Copyright (C) 2024-Present MultiMolecule | ||
|
||
# This program is free software: you can redistribute it and/or modify | ||
# it under the terms of the GNU Affero General Public License as published by | ||
# the Free Software Foundation, either version 3 of the License, or | ||
# any later version. | ||
|
||
# This program is distributed in the hope that it will be useful, | ||
# but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
# GNU Affero General Public License for more details. | ||
|
||
# You should have received a copy of the GNU Affero General Public License | ||
# along with this program. If not, see <http://www.gnu.org/licenses/>. | ||
|
||
import atexit | ||
import os | ||
import warnings | ||
from typing import Type | ||
|
||
import danling as dl | ||
import torch | ||
|
||
from multimolecule.runners import MultiMoleculeConfig, MultiMoleculeRunner | ||
|
||
try: | ||
import nni | ||
except ImportError: | ||
nni = None | ||
|
||
|
||
def train( | ||
config: MultiMoleculeConfig = None, # type: ignore | ||
runner_cls: Type[MultiMoleculeRunner] = MultiMoleculeRunner, | ||
): | ||
if config is None: | ||
config = MultiMoleculeConfig() | ||
config = config.parse(default_config="config", no_default_config_action="warn") | ||
config.interpolate(unsafe_eval=True) | ||
if config.allow_tf32: | ||
torch.backends.cudnn.allow_tf32 = True | ||
torch.backends.cuda.matmul.allow_tf32 = True | ||
if config.reduced_precision_reduction: | ||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True | ||
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True | ||
if config.get("nni", False): | ||
if nni is None: | ||
raise ValueError("Unable to retrieve nni parameters, since nni is not installed.") | ||
config.merge(nni.get_next_parameter()) | ||
with dl.debug(config.get("debug", False)): | ||
runner = runner_cls(config) | ||
atexit.register(runner.print_result) | ||
atexit.register(runner.save_result) | ||
atexit.register(runner.save_checkpoint) | ||
result = runner.train() | ||
return result | ||
|
||
|
||
def evaluate( | ||
config: MultiMoleculeConfig = None, # type: ignore | ||
runner_cls: Type[MultiMoleculeRunner] = MultiMoleculeRunner, | ||
): | ||
if config is None: | ||
config = MultiMoleculeConfig.empty() | ||
config = config.parse(default_config="config", no_default_config_action="warn") | ||
config.interpolate(unsafe_eval=True) | ||
if config.allow_tf32: | ||
torch.backends.cudnn.allow_tf32 = True | ||
torch.backends.cuda.matmul.allow_tf32 = True | ||
if config.reduced_precision_reduction: | ||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True | ||
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True | ||
if "checkpoint" not in config or not isinstance(config.checkpoint, str): | ||
raise RuntimeError("Please specify `checkpoint` to run evaluate") | ||
for name, data in config.datas.items(): | ||
if "eval" not in data or not isinstance(data.eval, str): | ||
raise RuntimeError(f"Please specify `eval` to run evaluate in datas.{name}") | ||
runner = runner_cls(config) | ||
result = runner.evaluate_epoch("eval") | ||
print(result) | ||
return result | ||
|
||
|
||
def inference( | ||
config: MultiMoleculeConfig = None, # type: ignore | ||
runner_cls: Type[MultiMoleculeRunner] = MultiMoleculeRunner, | ||
): | ||
if config is None: | ||
config = MultiMoleculeConfig.empty() | ||
config = config.parse(default_config="config", no_default_config_action="warn") | ||
config.interpolate(unsafe_eval=True) | ||
if config.allow_tf32: | ||
torch.backends.cudnn.allow_tf32 = True | ||
torch.backends.cuda.matmul.allow_tf32 = True | ||
if config.reduced_precision_reduction: | ||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True | ||
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True | ||
if "checkpoint" not in config or not isinstance(config.checkpoint, str): | ||
raise RuntimeError("Please specify `checkpoint` to run infer.") | ||
for name, data in config.datas.items(): | ||
if "inf" not in data or not isinstance(data.inf, str): | ||
raise RuntimeError(f"Please specify `inf` to run infer in datas.{name}") | ||
if "result_path" not in config or not isinstance(config.result_path, str): | ||
config.result_path = os.path.join(os.getcwd(), "result.json") | ||
warnings.warn("`result_path` is not specified, default to `result.json`.", RuntimeWarning, stacklevel=2) | ||
runner = runner_cls(config) | ||
result = runner.inference() | ||
runner.save(result, config.result_path) | ||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
# MultiMolecule | ||
# Copyright (C) 2024-Present MultiMolecule | ||
|
||
# This program is free software: you can redistribute it and/or modify | ||
# it under the terms of the GNU Affero General Public License as published by | ||
# the Free Software Foundation, either version 3 of the License, or | ||
# any later version. | ||
|
||
# This program is distributed in the hope that it will be useful, | ||
# but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
# GNU Affero General Public License for more details. | ||
|
||
# You should have received a copy of the GNU Affero General Public License | ||
# along with this program. If not, see <http://www.gnu.org/licenses/>. | ||
|
||
import os | ||
import shutil | ||
from statistics import mean | ||
from typing import List | ||
|
||
import chanfig | ||
import pandas as pd | ||
from chanfig import NestedDict | ||
from tqdm import tqdm | ||
|
||
|
||
class Result(NestedDict): | ||
pretrained: str | ||
id: str | ||
seed: int | ||
epoch: int | ||
validation: NestedDict | ||
test: NestedDict | ||
|
||
|
||
def get_result_stat(experiment_root: str, remove_empty: bool = True) -> List[Result]: | ||
results = [] | ||
for root, _, files in tqdm(os.walk(experiment_root)): | ||
if "run.log" in files: | ||
if "best.json" not in files: | ||
if remove_empty: | ||
shutil.rmtree(root) | ||
continue | ||
best = NestedDict.from_json(os.path.join(root, "best.json")) | ||
if "index" not in best: | ||
if remove_empty: | ||
shutil.rmtree(root) | ||
continue | ||
config = NestedDict.from_yaml(os.path.join(root, "trainer.yaml")) | ||
pretrained = config.pretrained.split("/")[-1] | ||
seed = config.seed | ||
pretrained, seed = "", 1 | ||
result = Result(id=best.id, pretrained=pretrained, seed=seed) | ||
result.validation = NestedDict( | ||
{k: format(mean(v) if isinstance(v, list) else v, ".8f") for k, v in best.validation.items()} | ||
) | ||
result.test = NestedDict( | ||
{k: format(mean(v) if isinstance(v, list) else v, ".8f") for k, v in best.test.items()} | ||
) | ||
result.epoch = best.index | ||
result.pop("validation.time", None) | ||
result.pop("test.time", None) | ||
result.pop("validation.loss", None) | ||
result.pop("test.loss", None) | ||
result.pop("validation.lr", None) | ||
result.pop("test.lr", None) | ||
results.append(result) | ||
# Remove empty directories, perform twice to remove all empty directories | ||
if remove_empty: | ||
for root, dirs, files in os.walk(experiment_root): | ||
if not files and not dirs: | ||
os.rmdir(root) | ||
for root, dirs, files in os.walk(experiment_root): | ||
if not files and not dirs: | ||
os.rmdir(root) | ||
results.sort(key=lambda x: (x.pretrained, x.seed, x.id)) | ||
return results | ||
|
||
|
||
def write_result_stat(results: List[Result], path: str): | ||
results = [dict(result.all_items()) for result in results] # type: ignore[misc] | ||
df = pd.DataFrame.from_dict(results) | ||
df.insert(len(df.keys()) - 1, "comment", "") | ||
df.fillna("") | ||
df.to_csv(path, index=False) | ||
|
||
|
||
class Config(chanfig.Config): | ||
experiment_root: str = "experiments" | ||
out_path: str = "result.csv" | ||
|
||
|
||
if __name__ == "__main__": | ||
config = Config().parse() | ||
result_stat = get_result_stat(config.experiment_root) | ||
if not len(result_stat) > 0: | ||
raise ValueError("No results found") | ||
write_result_stat(result_stat, config.out_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.