Skip to content

Commit

Permalink
add runner
Browse files Browse the repository at this point in the history
Signed-off-by: Zhiyuan Chen <[email protected]>
  • Loading branch information
ZhiyuanChen committed Sep 23, 2024
1 parent 704d5f2 commit 40b4315
Show file tree
Hide file tree
Showing 72 changed files with 2,127 additions and 209 deletions.
1 change: 1 addition & 0 deletions .codespell-whitelist.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
datas
ser
marz
manuel
Expand Down
9 changes: 9 additions & 0 deletions docs/docs/data/multitask.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
---
authors:
- Zhiyuan Chen
date: 2024-05-04
---

# MultiTask

::: multimolecule.data.multitask
9 changes: 9 additions & 0 deletions docs/docs/runners/config.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
---
authors:
- Zhiyuan Chen
date: 2024-05-04
---

# MultiMoleculeConfig

::: multimolecule.runners.MultiMoleculeConfig
9 changes: 9 additions & 0 deletions docs/docs/runners/index.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
---
authors:
- Zhiyuan Chen
date: 2024-05-04
---

# runners

--8<-- "multimolecule/runners/README.md:8:"
9 changes: 9 additions & 0 deletions docs/docs/runners/runner.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
---
authors:
- Zhiyuan Chen
date: 2024-05-04
---

# MultiMoleculeRunner

::: multimolecule.runners.MultiMoleculeRunner
5 changes: 5 additions & 0 deletions docs/mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,14 @@ repo_url: https://github.com/DLS5-Omics/multimolecule

nav:
- index.md
- runners:
- runners/index.md
- MultiMoleculeRunner: runners/runner.md
- MultiMoleculeConfig: runners/config.md
- data:
- data/index.md
- Dataset: data/dataset.md
- multitask: data/multitask.md
- datasets:
- datasets/index.md
- DNA:
Expand Down
15 changes: 7 additions & 8 deletions multimolecule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

from .apis import evaluate, infer, train
from .data import Dataset
from .models import (
AutoModelForContactPrediction,
Expand Down Expand Up @@ -111,33 +112,35 @@
HeadConfig,
HeadRegistry,
HeadTransformRegistry,
HeadTransformRegistryHF,
IdentityTransform,
LinearTransform,
MaskedLMHead,
MaskedLMHeadConfig,
NonLinearTransform,
NucleotideHeadRegistryHF,
NucleotideKMerHead,
NucleotidePredictionHead,
PositionEmbeddingRegistry,
PositionEmbeddingRegistryHF,
PredictionHead,
RotaryEmbedding,
SequencePredictionHead,
SinusoidalEmbedding,
TokenHeadRegistryHF,
TokenKMerHead,
TokenPredictionHead,
)
from .runners import MultiMoleculeConfig, MultiMoleculeRunner
from .tasks import Task, TaskLevel, TaskType
from .tokenisers import Alphabet, DnaTokenizer, DotBracketTokenizer, ProteinTokenizer, RnaTokenizer, Tokenizer
from .utils import count_parameters

__all__ = [
"train",
"evaluate",
"infer",
"modeling_auto",
"modeling_outputs",
"Dataset",
"MultiMoleculeConfig",
"MultiMoleculeRunner",
"PreTrainedConfig",
"HeadConfig",
"BaseHeadConfig",
Expand Down Expand Up @@ -236,21 +239,17 @@
"HeadRegistry",
"PredictionHead",
"SequencePredictionHead",
"TokenHeadRegistryHF",
"TokenPredictionHead",
"TokenKMerHead",
"NucleotideHeadRegistryHF",
"NucleotidePredictionHead",
"NucleotideKMerHead",
"ContactPredictionHead",
"MaskedLMHead",
"HeadTransformRegistry",
"HeadTransformRegistryHF",
"LinearTransform",
"NonLinearTransform",
"IdentityTransform",
"PositionEmbeddingRegistry",
"PositionEmbeddingRegistryHF",
"RotaryEmbedding",
"SinusoidalEmbedding",
"Criterion",
Expand Down
19 changes: 19 additions & 0 deletions multimolecule/apis/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# MultiMolecule
# Copyright (C) 2024-Present MultiMolecule

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.

# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

from .run import evaluate, infer, train

__all__ = ["train", "evaluate", "infer"]
115 changes: 115 additions & 0 deletions multimolecule/apis/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# MultiMolecule
# Copyright (C) 2024-Present MultiMolecule

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.

# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

# mypy: disable-error-code="attr-defined"

import atexit
import os
import warnings
from typing import Type

import danling as dl
import torch

from multimolecule.runners import MultiMoleculeConfig, MultiMoleculeRunner

try:
import nni
except ImportError:
nni = None


def train(
config: MultiMoleculeConfig = None, # type: ignore
runner_cls: Type[MultiMoleculeRunner] = MultiMoleculeRunner,
):
if config is None:
config = MultiMoleculeConfig()
config = config.parse(default_config="config", no_default_config_action="warn")
config.interpolate(unsafe_eval=True)
config.training = True
if config.allow_tf32:
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
if config.reduced_precision_reduction:
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True
if config.get("nni", False):
if nni is None:
raise ValueError("Unable to retrieve nni parameters, since nni is not installed.")
config.merge(nni.get_next_parameter())
with dl.debug(config.get("debug", False)):
runner = runner_cls(config)
atexit.register(runner.print_result)
atexit.register(runner.save_result)
atexit.register(runner.save_checkpoint)
result = runner.train()
return result


def evaluate(
config: MultiMoleculeConfig = None, # type: ignore
runner_cls: Type[MultiMoleculeRunner] = MultiMoleculeRunner,
):
if config is None:
config = MultiMoleculeConfig.empty()
config = config.parse(default_config="config", no_default_config_action="warn")
config.interpolate(unsafe_eval=True)
config.training = False
if config.allow_tf32:
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
if config.reduced_precision_reduction:
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True
if "checkpoint" not in config or not isinstance(config.checkpoint, str):
raise RuntimeError("Please specify `checkpoint` to run evaluate")
for name, data in config.datas.items():
if "evaluation" not in data or not isinstance(data.evaluate, str):
raise RuntimeError(f"Please specify `evaluation` to run evaluate in datas.{name}")
runner = runner_cls(config)
result = runner.evaluate_epoch("evaluation")
print(result)
return result


def infer(
config: MultiMoleculeConfig = None, # type: ignore
runner_cls: Type[MultiMoleculeRunner] = MultiMoleculeRunner,
):
if config is None:
config = MultiMoleculeConfig.empty()
config = config.parse(default_config="config", no_default_config_action="warn")
config.interpolate(unsafe_eval=True)
config.training = False
if config.allow_tf32:
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
if config.reduced_precision_reduction:
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True
if "checkpoint" not in config or not isinstance(config.checkpoint, str):
raise RuntimeError("Please specify `checkpoint` to run infer.")
for name, data in config.datas.items():
if "inference" not in data or not isinstance(data.inference, str):
raise RuntimeError(f"Please specify `inference` to run infer in datas.{name}")
if "result_path" not in config or not isinstance(config.result_path, str):
config.result_path = os.path.join(os.getcwd(), "result.json")
warnings.warn("`result_path` is not specified, default to `result.json`.", RuntimeWarning, stacklevel=2)
runner = runner_cls(config)
result = runner.infer()
runner.save(result, config.result_path)
return result
99 changes: 99 additions & 0 deletions multimolecule/apis/stat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# MultiMolecule
# Copyright (C) 2024-Present MultiMolecule

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.

# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

import os
import shutil
from statistics import mean
from typing import List

import chanfig
import pandas as pd
from chanfig import NestedDict
from tqdm import tqdm


class Result(NestedDict):
pretrained: str
id: str
seed: int
epoch: int
validation: NestedDict
test: NestedDict


def get_result_stat(experiment_root: str, remove_empty: bool = True) -> List[Result]:
results = []
for root, _, files in tqdm(os.walk(experiment_root)):
if "run.log" in files:
if "best.json" not in files:
if remove_empty:
shutil.rmtree(root)
continue
best = NestedDict.from_json(os.path.join(root, "best.json"))
if "index" not in best:
if remove_empty:
shutil.rmtree(root)
continue
config = NestedDict.from_yaml(os.path.join(root, "trainer.yaml"))
pretrained = config.pretrained.split("/")[-1]
seed = config.seed
pretrained, seed = "", 1
result = Result(id=best.id, pretrained=pretrained, seed=seed)
result.validation = NestedDict(
{k: format(mean(v) if isinstance(v, list) else v, ".8f") for k, v in best.validation.items()}
)
result.test = NestedDict(
{k: format(mean(v) if isinstance(v, list) else v, ".8f") for k, v in best.test.items()}
)
result.epoch = best.index
result.pop("validation.time", None)
result.pop("test.time", None)
result.pop("validation.loss", None)
result.pop("test.loss", None)
result.pop("validation.lr", None)
result.pop("test.lr", None)
results.append(result)
# Remove empty directories, perform twice to remove all empty directories
if remove_empty:
for root, dirs, files in os.walk(experiment_root):
if not files and not dirs:
os.rmdir(root)
for root, dirs, files in os.walk(experiment_root):
if not files and not dirs:
os.rmdir(root)
results.sort(key=lambda x: (x.pretrained, x.seed, x.id))
return results


def write_result_stat(results: List[Result], path: str):
results = [dict(result.all_items()) for result in results] # type: ignore[misc]
df = pd.DataFrame.from_dict(results)
df.insert(len(df.keys()) - 1, "comment", "")
df.fillna("")
df.to_csv(path, index=False)


class Config(chanfig.Config):
experiment_root: str = "experiments"
out_path: str = "result.csv"


if __name__ == "__main__":
config = Config().parse()
result_stat = get_result_stat(config.experiment_root)
if not len(result_stat) > 0:
raise ValueError("No results found")
write_result_stat(result_stat, config.out_path)
10 changes: 9 additions & 1 deletion multimolecule/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.

from .dataset import Dataset
from .multitask import DistributedMultiTaskSampler, MultiTaskDataset, MultiTaskSampler
from .utils import no_collate

__all__ = ["Dataset", "no_collate"]
__all__ = [
"Dataset",
"PandasDataset",
"MultiTaskDataset",
"MultiTaskSampler",
"DistributedMultiTaskSampler",
"no_collate",
]
Loading

0 comments on commit 40b4315

Please sign in to comment.