-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Zhiyuan Chen <[email protected]>
- Loading branch information
1 parent
8ae3401
commit 7b67d5e
Showing
46 changed files
with
1,192 additions
and
49 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
datas | ||
ser | ||
marz | ||
manuel | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# MultiMolecule | ||
# Copyright (C) 2024-Present MultiMolecule | ||
|
||
# This program is free software: you can redistribute it and/or modify | ||
# it under the terms of the GNU Affero General Public License as published by | ||
# the Free Software Foundation, either version 3 of the License, or | ||
# any later version. | ||
|
||
# This program is distributed in the hope that it will be useful, | ||
# but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
# GNU Affero General Public License for more details. | ||
|
||
# You should have received a copy of the GNU Affero General Public License | ||
# along with this program. If not, see <http://www.gnu.org/licenses/>. | ||
|
||
from .run import evaluate, inference, train | ||
|
||
__all__ = ["train", "evaluate", "inference"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
# MultiMolecule | ||
# Copyright (C) 2024-Present MultiMolecule | ||
|
||
# This program is free software: you can redistribute it and/or modify | ||
# it under the terms of the GNU Affero General Public License as published by | ||
# the Free Software Foundation, either version 3 of the License, or | ||
# any later version. | ||
|
||
# This program is distributed in the hope that it will be useful, | ||
# but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
# GNU Affero General Public License for more details. | ||
|
||
# You should have received a copy of the GNU Affero General Public License | ||
# along with this program. If not, see <http://www.gnu.org/licenses/>. | ||
|
||
import atexit | ||
import os | ||
import warnings | ||
from typing import Type | ||
|
||
import danling as dl | ||
import torch | ||
|
||
from multimolecule.runners import MultiMoleculeConfig, MultiMoleculeRunner | ||
|
||
try: | ||
import nni | ||
except ImportError: | ||
nni = None | ||
|
||
|
||
def train( | ||
config: MultiMoleculeConfig = None, # type: ignore | ||
runner_cls: Type[MultiMoleculeRunner] = MultiMoleculeRunner, | ||
): | ||
if config is None: | ||
config = MultiMoleculeConfig() | ||
config = config.parse(default_config="config", no_default_config_action="warn") | ||
config.interpolate(unsafe_eval=True) | ||
if config.allow_tf32: | ||
torch.backends.cudnn.allow_tf32 = True | ||
torch.backends.cuda.matmul.allow_tf32 = True | ||
if config.reduced_precision_reduction: | ||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True | ||
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True | ||
if config.get("nni", False): | ||
if nni is None: | ||
raise ValueError("Unable to retrieve nni parameters, since nni is not installed.") | ||
config.merge(nni.get_next_parameter()) | ||
with dl.debug(config.get("debug", False)): | ||
runner = runner_cls(config) | ||
atexit.register(runner.print_result) | ||
atexit.register(runner.save_result) | ||
atexit.register(runner.save_checkpoint) | ||
result = runner.train() | ||
return result | ||
|
||
|
||
def evaluate( | ||
config: MultiMoleculeConfig = None, # type: ignore | ||
runner_cls: Type[MultiMoleculeRunner] = MultiMoleculeRunner, | ||
): | ||
if config is None: | ||
config = MultiMoleculeConfig.empty() | ||
config = config.parse(default_config="config", no_default_config_action="warn") | ||
config.interpolate(unsafe_eval=True) | ||
if config.allow_tf32: | ||
torch.backends.cudnn.allow_tf32 = True | ||
torch.backends.cuda.matmul.allow_tf32 = True | ||
if config.reduced_precision_reduction: | ||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True | ||
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True | ||
if "checkpoint" not in config or not isinstance(config.checkpoint, str): | ||
raise RuntimeError("Please specify `checkpoint` to run evaluate") | ||
for name, data in config.datas.items(): | ||
if "eval" not in data or not isinstance(data.eval, str): | ||
raise RuntimeError(f"Please specify `eval` to run evaluate in datas.{name}") | ||
runner = runner_cls(config) | ||
result = runner.evaluate_epoch("eval") | ||
print(result) | ||
return result | ||
|
||
|
||
def inference( | ||
config: MultiMoleculeConfig = None, # type: ignore | ||
runner_cls: Type[MultiMoleculeRunner] = MultiMoleculeRunner, | ||
): | ||
if config is None: | ||
config = MultiMoleculeConfig.empty() | ||
config = config.parse(default_config="config", no_default_config_action="warn") | ||
config.interpolate(unsafe_eval=True) | ||
if config.allow_tf32: | ||
torch.backends.cudnn.allow_tf32 = True | ||
torch.backends.cuda.matmul.allow_tf32 = True | ||
if config.reduced_precision_reduction: | ||
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True | ||
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True | ||
if "checkpoint" not in config or not isinstance(config.checkpoint, str): | ||
raise RuntimeError("Please specify `checkpoint` to run infer.") | ||
for name, data in config.datas.items(): | ||
if "inf" not in data or not isinstance(data.inf, str): | ||
raise RuntimeError(f"Please specify `inf` to run infer in datas.{name}") | ||
if "result_path" not in config or not isinstance(config.result_path, str): | ||
config.result_path = os.path.join(os.getcwd(), "result.json") | ||
warnings.warn("`result_path` is not specified, default to `result.json`.", RuntimeWarning, stacklevel=2) | ||
runner = runner_cls(config) | ||
result = runner.inference() | ||
runner.save(result, config.result_path) | ||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# MultiMolecule | ||
# Copyright (C) 2024-Present MultiMolecule | ||
|
||
# This program is free software: you can redistribute it and/or modify | ||
# it under the terms of the GNU Affero General Public License as published by | ||
# the Free Software Foundation, either version 3 of the License, or | ||
# any later version. | ||
|
||
# This program is distributed in the hope that it will be useful, | ||
# but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
# GNU Affero General Public License for more details. | ||
|
||
# You should have received a copy of the GNU Affero General Public License | ||
# along with this program. If not, see <http://www.gnu.org/licenses/>. | ||
|
||
from .registry import BackboneRegistry | ||
from .sequence import SequenceBackbone | ||
from .sequences import SequenceRegistry | ||
|
||
__all__ = ["BackboneRegistry", "SequenceRegistry", "SequenceBackbone"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# MultiMolecule | ||
# Copyright (C) 2024-Present MultiMolecule | ||
|
||
# This program is free software: you can redistribute it and/or modify | ||
# it under the terms of the GNU Affero General Public License as published by | ||
# the Free Software Foundation, either version 3 of the License, or | ||
# any later version. | ||
|
||
# This program is distributed in the hope that it will be useful, | ||
# but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
# GNU Affero General Public License for more details. | ||
|
||
# You should have received a copy of the GNU Affero General Public License | ||
# along with this program. If not, see <http://www.gnu.org/licenses/>. | ||
|
||
from __future__ import annotations | ||
|
||
from chanfig import Registry | ||
|
||
BackboneRegistry = Registry() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# MultiMolecule | ||
# Copyright (C) 2024-Present MultiMolecule | ||
|
||
# This program is free software: you can redistribute it and/or modify | ||
# it under the terms of the GNU Affero General Public License as published by | ||
# the Free Software Foundation, either version 3 of the License, or | ||
# any later version. | ||
|
||
# This program is distributed in the hope that it will be useful, | ||
# but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
# GNU Affero General Public License for more details. | ||
|
||
# You should have received a copy of the GNU Affero General Public License | ||
# along with this program. If not, see <http://www.gnu.org/licenses/>. | ||
|
||
from __future__ import annotations | ||
|
||
import torch | ||
from chanfig import FlatDict | ||
from danling import NestedTensor | ||
from torch import Tensor, nn | ||
|
||
from .registry import BackboneRegistry | ||
from .sequences import SequenceRegistry | ||
|
||
|
||
@BackboneRegistry.register("sequence", default=True) | ||
class SequenceBackbone(nn.Module): | ||
def __init__(self, sequence) -> None: | ||
super().__init__() | ||
self.sequence = SequenceRegistry.build(**sequence) | ||
self.sequence_dropout = nn.Dropout(sequence.pop("dropout", 0), inplace=True) | ||
self.config = self.sequence.config | ||
self.out_channels = self.config.hidden_size | ||
|
||
def forward(self, sequence: NestedTensor | Tensor, *args, **kwargs) -> tuple[FlatDict, FlatDict]: | ||
attentions = None | ||
input_ids, attention_mask = sequence.tensor, sequence.mask | ||
sequence_output = self.sequence(input_ids.int(), attention_mask) | ||
sequence_output["pooler_output"] = self.sequence_dropout(sequence_output["pooler_output"]) | ||
sequence_output["last_hidden_state"] = self.sequence_dropout(sequence_output["last_hidden_state"]) | ||
if "attentions" in sequence_output: | ||
attentions = torch.stack(sequence_output["attentions"], dim=1).detach() | ||
|
||
return sequence_output, attentions |
Oops, something went wrong.