Skip to content

Commit

Permalink
add runner
Browse files Browse the repository at this point in the history
Signed-off-by: Zhiyuan Chen <[email protected]>
  • Loading branch information
ZhiyuanChen committed Sep 10, 2024
1 parent 8ae3401 commit 7b67d5e
Show file tree
Hide file tree
Showing 46 changed files with 1,192 additions and 49 deletions.
1 change: 1 addition & 0 deletions .codespell-whitelist.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
datas
ser
marz
manuel
Expand Down
7 changes: 7 additions & 0 deletions multimolecule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

from .apis import evaluate, inference, train
from .data import Dataset
from .models import (
AutoModelForContactPrediction,
Expand Down Expand Up @@ -130,14 +131,20 @@
TokenKMerHead,
TokenPredictionHead,
)
from .runners import MultiMoleculeConfig, MultiMoleculeRunner
from .tasks import Task, TaskLevel, TaskType
from .tokenisers import Alphabet, DnaTokenizer, DotBracketTokenizer, ProteinTokenizer, RnaTokenizer, Tokenizer
from .utils import count_parameters

__all__ = [
"train",
"evaluate",
"inference",
"modeling_auto",
"modeling_outputs",
"Dataset",
"MultiMoleculeConfig",
"MultiMoleculeRunner",
"PreTrainedConfig",
"HeadConfig",
"BaseHeadConfig",
Expand Down
19 changes: 19 additions & 0 deletions multimolecule/apis/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# MultiMolecule
# Copyright (C) 2024-Present MultiMolecule

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.

# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

from .run import evaluate, inference, train

__all__ = ["train", "evaluate", "inference"]
110 changes: 110 additions & 0 deletions multimolecule/apis/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# MultiMolecule
# Copyright (C) 2024-Present MultiMolecule

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.

# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

import atexit
import os
import warnings
from typing import Type

import danling as dl
import torch

from multimolecule.runners import MultiMoleculeConfig, MultiMoleculeRunner

try:
import nni
except ImportError:
nni = None


def train(
config: MultiMoleculeConfig = None, # type: ignore
runner_cls: Type[MultiMoleculeRunner] = MultiMoleculeRunner,
):
if config is None:
config = MultiMoleculeConfig()
config = config.parse(default_config="config", no_default_config_action="warn")
config.interpolate(unsafe_eval=True)
if config.allow_tf32:
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
if config.reduced_precision_reduction:
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True
if config.get("nni", False):
if nni is None:
raise ValueError("Unable to retrieve nni parameters, since nni is not installed.")
config.merge(nni.get_next_parameter())
with dl.debug(config.get("debug", False)):
runner = runner_cls(config)
atexit.register(runner.print_result)
atexit.register(runner.save_result)
atexit.register(runner.save_checkpoint)
result = runner.train()
return result


def evaluate(
config: MultiMoleculeConfig = None, # type: ignore
runner_cls: Type[MultiMoleculeRunner] = MultiMoleculeRunner,
):
if config is None:
config = MultiMoleculeConfig.empty()
config = config.parse(default_config="config", no_default_config_action="warn")
config.interpolate(unsafe_eval=True)
if config.allow_tf32:
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
if config.reduced_precision_reduction:
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True
if "checkpoint" not in config or not isinstance(config.checkpoint, str):
raise RuntimeError("Please specify `checkpoint` to run evaluate")
for name, data in config.datas.items():
if "eval" not in data or not isinstance(data.eval, str):
raise RuntimeError(f"Please specify `eval` to run evaluate in datas.{name}")
runner = runner_cls(config)
result = runner.evaluate_epoch("eval")
print(result)
return result


def inference(
config: MultiMoleculeConfig = None, # type: ignore
runner_cls: Type[MultiMoleculeRunner] = MultiMoleculeRunner,
):
if config is None:
config = MultiMoleculeConfig.empty()
config = config.parse(default_config="config", no_default_config_action="warn")
config.interpolate(unsafe_eval=True)
if config.allow_tf32:
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
if config.reduced_precision_reduction:
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True
if "checkpoint" not in config or not isinstance(config.checkpoint, str):
raise RuntimeError("Please specify `checkpoint` to run infer.")
for name, data in config.datas.items():
if "inf" not in data or not isinstance(data.inf, str):
raise RuntimeError(f"Please specify `inf` to run infer in datas.{name}")
if "result_path" not in config or not isinstance(config.result_path, str):
config.result_path = os.path.join(os.getcwd(), "result.json")
warnings.warn("`result_path` is not specified, default to `result.json`.", RuntimeWarning, stacklevel=2)
runner = runner_cls(config)
result = runner.inference()
runner.save(result, config.result_path)
return result
3 changes: 1 addition & 2 deletions multimolecule/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from __future__ import annotations

from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import Any, List
from warnings import warn

Expand Down Expand Up @@ -232,7 +231,7 @@ def post(
self.update(self.map(self.map_discrete))
self.set_transform(self.transform)

@cached_property
@property
def tasks(self) -> NestedDict:
return self.infer_tasks()

Expand Down
4 changes: 2 additions & 2 deletions multimolecule/models/calm/configuration_calm.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,5 +130,5 @@ def __init__(
self.use_cache = use_cache
self.emb_layer_norm_before = emb_layer_norm_before
self.token_dropout = token_dropout
self.head = HeadConfig(**head if head is not None else {})
self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {})
self.head = HeadConfig(**head) if head is not None else None
self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None
2 changes: 1 addition & 1 deletion multimolecule/models/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class PreTrainedConfig(PretrainedConfig):
Base class for all model configuration classes.
"""

head: HeadConfig
head: HeadConfig | None

hidden_size: int

Expand Down
4 changes: 2 additions & 2 deletions multimolecule/models/ernierna/configuration_ernierna.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,5 +113,5 @@ def __init__(
self.pairwise_alpha = pairwise_alpha
self.is_decoder = is_decoder
self.use_cache = use_cache
self.head = HeadConfig(**head if head is not None else {})
self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {})
self.head = HeadConfig(**head) if head is not None else None
self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None
2 changes: 1 addition & 1 deletion multimolecule/models/ernierna/modeling_ernierna.py
Original file line number Diff line number Diff line change
Expand Up @@ -1232,7 +1232,7 @@ def __init__(self, config: ErnieRnaConfig, head_config: HeadConfig | None = None
super().__init__()
if head_config is None:
head_config = config.head
self.config = head_config
self.config = head_config # type: HeadConfig # type: ignore[assignment]
self.bos_token_id = config.bos_token_id
self.eos_token_id = config.eos_token_id
self.pad_token_id = config.pad_token_id
Expand Down
4 changes: 2 additions & 2 deletions multimolecule/models/rinalmo/configuration_rinalmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,6 @@ def __init__(
self.use_cache = use_cache
self.learnable_beta = learnable_beta
self.token_dropout = token_dropout
self.head = HeadConfig(**head if head is not None else {})
self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {})
self.head = HeadConfig(**head) if head is not None else None
self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None
self.emb_layer_norm_before = emb_layer_norm_before
4 changes: 2 additions & 2 deletions multimolecule/models/rnabert/configuration_rnabert.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,5 +115,5 @@ def __init__(
self.position_embedding_type = position_embedding_type
self.is_decoder = is_decoder
self.use_cache = use_cache
self.head = HeadConfig(**head if head is not None else {})
self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {})
self.head = HeadConfig(**head) if head is not None else None
self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None
4 changes: 2 additions & 2 deletions multimolecule/models/rnaernie/configuration_rnaernie.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,5 +111,5 @@ def __init__(
self.position_embedding_type = position_embedding_type
self.is_decoder = is_decoder
self.use_cache = use_cache
self.head = HeadConfig(**head if head is not None else {})
self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {})
self.head = HeadConfig(**head) if head is not None else None
self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None
4 changes: 2 additions & 2 deletions multimolecule/models/rnafm/configuration_rnafm.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,5 +134,5 @@ def __init__(
self.use_cache = use_cache
self.emb_layer_norm_before = emb_layer_norm_before
self.token_dropout = token_dropout
self.head = HeadConfig(**head if head is not None else {})
self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {})
self.head = HeadConfig(**head) if head is not None else None
self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None
4 changes: 2 additions & 2 deletions multimolecule/models/rnamsm/configuration_rnamsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,5 +119,5 @@ def __init__(
self.attention_type = attention_type
self.embed_positions_msa = embed_positions_msa
self.attention_bias = attention_bias
self.head = HeadConfig(**head if head is not None else {})
self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {})
self.head = HeadConfig(**head) if head is not None else None
self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None
4 changes: 2 additions & 2 deletions multimolecule/models/splicebert/configuration_splicebert.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,5 +111,5 @@ def __init__(
self.position_embedding_type = position_embedding_type
self.is_decoder = is_decoder
self.use_cache = use_cache
self.head = HeadConfig(**head if head is not None else {})
self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {})
self.head = HeadConfig(**head) if head is not None else None
self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None
4 changes: 2 additions & 2 deletions multimolecule/models/utrbert/configuration_utrbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,5 +128,5 @@ def __init__(
self.position_embedding_type = position_embedding_type
self.is_decoder = is_decoder
self.use_cache = use_cache
self.head = HeadConfig(**head if head is not None else {})
self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {})
self.head = HeadConfig(**head) if head is not None else None
self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None
4 changes: 2 additions & 2 deletions multimolecule/models/utrlm/configuration_utrlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def __init__(
self.use_cache = use_cache
self.emb_layer_norm_before = emb_layer_norm_before
self.token_dropout = token_dropout
self.head = HeadConfig(**head if head is not None else {})
self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {})
self.head = HeadConfig(**head) if head is not None else None
self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None
self.ss_head = HeadConfig(**ss_head) if ss_head is not None else None
self.mfe_head = HeadConfig(**mfe_head) if mfe_head is not None else None
7 changes: 6 additions & 1 deletion multimolecule/module/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

from .criterions import Criterion
from .criterions import Criterion, CriterionRegistry
from .embeddings import PositionEmbeddingRegistry, PositionEmbeddingRegistryHF, RotaryEmbedding, SinusoidalEmbedding
from .heads import (
BaseHeadConfig,
Expand All @@ -38,8 +38,13 @@
TokenKMerHead,
TokenPredictionHead,
)
from .model import MultiMoleculeModel
from .registry import ModelRegistry

__all__ = [
"ModelRegistry",
"MultiMoleculeModel",
"CriterionRegistry",
"Criterion",
"PositionEmbeddingRegistry",
"PositionEmbeddingRegistryHF",
Expand Down
21 changes: 21 additions & 0 deletions multimolecule/module/backbones/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# MultiMolecule
# Copyright (C) 2024-Present MultiMolecule

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.

# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

from .registry import BackboneRegistry
from .sequence import SequenceBackbone
from .sequences import SequenceRegistry

__all__ = ["BackboneRegistry", "SequenceRegistry", "SequenceBackbone"]
21 changes: 21 additions & 0 deletions multimolecule/module/backbones/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# MultiMolecule
# Copyright (C) 2024-Present MultiMolecule

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.

# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

from __future__ import annotations

from chanfig import Registry

BackboneRegistry = Registry()
46 changes: 46 additions & 0 deletions multimolecule/module/backbones/sequence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# MultiMolecule
# Copyright (C) 2024-Present MultiMolecule

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.

# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

from __future__ import annotations

import torch
from chanfig import FlatDict
from danling import NestedTensor
from torch import Tensor, nn

from .registry import BackboneRegistry
from .sequences import SequenceRegistry


@BackboneRegistry.register("sequence", default=True)
class SequenceBackbone(nn.Module):
def __init__(self, sequence) -> None:
super().__init__()
self.sequence = SequenceRegistry.build(**sequence)
self.sequence_dropout = nn.Dropout(sequence.pop("dropout", 0), inplace=True)
self.config = self.sequence.config
self.out_channels = self.config.hidden_size

def forward(self, sequence: NestedTensor | Tensor, *args, **kwargs) -> tuple[FlatDict, FlatDict]:
attentions = None
input_ids, attention_mask = sequence.tensor, sequence.mask
sequence_output = self.sequence(input_ids.int(), attention_mask)
sequence_output["pooler_output"] = self.sequence_dropout(sequence_output["pooler_output"])
sequence_output["last_hidden_state"] = self.sequence_dropout(sequence_output["last_hidden_state"])
if "attentions" in sequence_output:
attentions = torch.stack(sequence_output["attentions"], dim=1).detach()

return sequence_output, attentions
Loading

0 comments on commit 7b67d5e

Please sign in to comment.