From 58d1cc3f8b21a35dac44a856ecb80287705ea31a Mon Sep 17 00:00:00 2001 From: Zhiyuan Chen Date: Sat, 11 May 2024 17:04:38 +0800 Subject: [PATCH] add runner Signed-off-by: Zhiyuan Chen --- .codespell-whitelist.txt | 1 + docs/docs/about/license-faq.md | 7 +- docs/docs/about/license-faq.zh.md | 7 +- docs/docs/data/multitask.md | 9 + docs/docs/runners/config.md | 9 + docs/docs/runners/index.md | 9 + docs/docs/runners/runner.md | 9 + docs/mkdocs.yml | 5 + multimolecule/__init__.py | 16 +- multimolecule/apis/__init__.py | 19 + multimolecule/apis/run.py | 115 ++++++ multimolecule/apis/stat.py | 99 ++++++ multimolecule/data/__init__.py | 10 +- multimolecule/data/dataset.py | 100 ++++-- multimolecule/data/multitask.py | 246 +++++++++++++ multimolecule/data/utils.py | 6 +- multimolecule/defaults.py | 7 +- multimolecule/models/__init__.py | 2 + .../models/calm/configuration_calm.py | 4 +- multimolecule/models/calm/modeling_calm.py | 12 +- multimolecule/models/configuration_utils.py | 14 +- .../models/ernierna/configuration_ernierna.py | 4 +- .../models/ernierna/modeling_ernierna.py | 15 +- .../models/rinalmo/configuration_rinalmo.py | 4 +- .../models/rinalmo/modeling_rinalmo.py | 12 +- .../models/rnabert/configuration_rnabert.py | 4 +- .../models/rnabert/modeling_rnabert.py | 22 +- .../models/rnaernie/configuration_rnaernie.py | 4 +- .../models/rnaernie/modeling_rnaernie.py | 12 +- .../models/rnafm/configuration_rnafm.py | 4 +- multimolecule/models/rnafm/modeling_rnafm.py | 14 +- .../models/rnamsm/configuration_rnamsm.py | 4 +- .../models/rnamsm/modeling_rnamsm.py | 14 +- .../splicebert/configuration_splicebert.py | 4 +- .../models/splicebert/modeling_splicebert.py | 12 +- .../models/utrbert/configuration_utrbert.py | 4 +- .../models/utrbert/modeling_utrbert.py | 12 +- .../models/utrlm/configuration_utrlm.py | 4 +- multimolecule/models/utrlm/modeling_utrlm.py | 14 +- multimolecule/module/__init__.py | 14 +- multimolecule/module/backbones/__init__.py | 21 ++ multimolecule/module/backbones/registry.py | 21 ++ multimolecule/module/backbones/sequence.py | 59 ++++ .../module/backbones/sequences/__init__.py | 20 ++ .../module/backbones/sequences/onehot.py | 39 +++ .../module/backbones/sequences/registry.py | 66 ++++ multimolecule/module/criterions/__init__.py | 14 +- multimolecule/module/criterions/binary.py | 44 +++ multimolecule/module/criterions/generic.py | 33 +- multimolecule/module/criterions/multiclass.py | 44 +++ multimolecule/module/criterions/multilabel.py | 44 +++ multimolecule/module/criterions/registry.py | 29 ++ multimolecule/module/criterions/regression.py | 44 +++ multimolecule/module/heads/__init__.py | 12 +- multimolecule/module/heads/config.py | 24 +- multimolecule/module/heads/contact.py | 247 +++++++++++-- multimolecule/module/heads/generic.py | 58 ++- multimolecule/module/heads/pretrain.py | 6 +- multimolecule/module/heads/registry.py | 12 +- multimolecule/module/heads/token.py | 80 +---- multimolecule/module/heads/utils.py | 29 -- multimolecule/module/model.py | 89 +++++ multimolecule/module/necks/__init__.py | 21 ++ multimolecule/module/necks/bert.py | 102 ++++++ multimolecule/module/necks/cat.py | 43 +++ multimolecule/module/necks/registry.py | 21 ++ multimolecule/module/registry.py | 35 ++ multimolecule/runners/README.md | 9 + multimolecule/runners/__init__.py | 20 ++ multimolecule/runners/base_runner.py | 331 ++++++++++++++++++ multimolecule/runners/config.py | 87 +++++ multimolecule/runners/metrics.py | 37 ++ multimolecule/runners/runner.py | 42 +++ multimolecule/tasks/task.py | 3 +- multimolecule/train.py | 20 ++ pyproject.toml | 4 +- tests/data/test_dataset.py | 14 +- 77 files changed, 2375 insertions(+), 332 deletions(-) create mode 100644 docs/docs/data/multitask.md create mode 100644 docs/docs/runners/config.md create mode 100644 docs/docs/runners/index.md create mode 100644 docs/docs/runners/runner.md create mode 100644 multimolecule/apis/__init__.py create mode 100644 multimolecule/apis/run.py create mode 100644 multimolecule/apis/stat.py create mode 100644 multimolecule/data/multitask.py create mode 100644 multimolecule/module/backbones/__init__.py create mode 100644 multimolecule/module/backbones/registry.py create mode 100644 multimolecule/module/backbones/sequence.py create mode 100644 multimolecule/module/backbones/sequences/__init__.py create mode 100644 multimolecule/module/backbones/sequences/onehot.py create mode 100644 multimolecule/module/backbones/sequences/registry.py create mode 100644 multimolecule/module/criterions/binary.py create mode 100644 multimolecule/module/criterions/multiclass.py create mode 100644 multimolecule/module/criterions/multilabel.py create mode 100644 multimolecule/module/criterions/registry.py create mode 100644 multimolecule/module/criterions/regression.py create mode 100644 multimolecule/module/model.py create mode 100644 multimolecule/module/necks/__init__.py create mode 100644 multimolecule/module/necks/bert.py create mode 100644 multimolecule/module/necks/cat.py create mode 100644 multimolecule/module/necks/registry.py create mode 100644 multimolecule/module/registry.py create mode 100644 multimolecule/runners/README.md create mode 100644 multimolecule/runners/__init__.py create mode 100644 multimolecule/runners/base_runner.py create mode 100644 multimolecule/runners/config.py create mode 100644 multimolecule/runners/metrics.py create mode 100644 multimolecule/runners/runner.py create mode 100644 multimolecule/train.py diff --git a/.codespell-whitelist.txt b/.codespell-whitelist.txt index 44c7e9f5..467e5c38 100644 --- a/.codespell-whitelist.txt +++ b/.codespell-whitelist.txt @@ -1,3 +1,4 @@ +datas ser marz manuel diff --git a/docs/docs/about/license-faq.md b/docs/docs/about/license-faq.md index e8e341ff..4dd0bb15 100644 --- a/docs/docs/about/license-faq.md +++ b/docs/docs/about/license-faq.md @@ -52,7 +52,12 @@ We also consider research papers and manuscripts a special form of documentation Since research papers are considered a form of source code, publishers are legally required to open-source all materials on their server to comply with the _[License](license.md)_ if they publish papers using MultiMolecule. This is generally impractical for most publishers. -As a special exemption under section 7 of the _[License](license.md)_, we grant permission to publish research papers using MultiMolecule in fully open access journals, conferences, or preprint servers, provided all published manuscripts are made available under the [GNU Free Documentation License (GFDL)](https://www.gnu.org/licenses/fdl.html), or a [Creative Commons license](https://creativecommons.org), or an [OSI-approved license](https://opensource.org/licenses) that permits the sharing of manuscripts. +As a special exemption under section 7 of the _[License](license.md)_, we grant permission to publish research papers using MultiMolecule in fully open access journals, conferences, or preprint servers that do not charge any fee from authors, provided all published manuscripts are made available under the [GNU Free Documentation License (GFDL)](https://www.gnu.org/licenses/fdl.html), or a [Creative Commons license](https://creativecommons.org), or an [OSI-approved license](https://opensource.org/licenses) that permits the sharing of manuscripts. + +As a special exemption under section 7 of the _[License](license.md)_, we grant permission to publish research papers using MultiMolecule in certain non-profit journals, conferences, or preprint servers. Currently, the non-profit journals, conferences, or preprint servers we allow include: + +- All journals published by American Association for the Advancement of Science (AAAS) +- eLife For publishing in closed access journals or conferences, you must obtain a separate license from us. This typically involves co-authorship, a fee to support the project, or both. Contact us at [multimolecule@zyc.ai](mailto:multimolecule@zyc.ai) for more information. diff --git a/docs/docs/about/license-faq.zh.md b/docs/docs/about/license-faq.zh.md index 86ef6d07..a8d13a86 100644 --- a/docs/docs/about/license-faq.zh.md +++ b/docs/docs/about/license-faq.zh.md @@ -60,7 +60,12 @@ 由于研究论文被视为源代码的一种形式,如果发表使用 MultiMolecule 的论文,出版商必须开源其服务器上的所有材料,以符合 _[许可协议](license.zh.md)_ 的要求。对于大多数出版商来说,这是不切实际的。 -作为 _[许可协议](license.zh.md)_ 第 7 条的特别豁免,我们允许在完全开放获取的期刊、会议或预印本服务器上发表使用 MultiMolecule 的研究论文,前提是所有发表的手稿都应按照允许共享手稿的[GNU 自由文档许可协议(GFDL)](https://www.gnu.org/licenses/fdl.html)或[知识共享许可协议](https://creativecommons.org)或[OSI 批准许可协议](https://opensource.org/licenses)提供。 +作为 _[许可协议](license.zh.md)_ 第 7 条的特别豁免,我们允许在不向作者收取任何费用的完全开放获取的期刊、会议或预印本服务器上发表使用 MultiMolecule 的研究论文,前提是所有发表的手稿都应按照允许共享手稿的[GNU 自由文档许可协议(GFDL)](https://www.gnu.org/licenses/fdl.html)或[知识共享许可协议](https://creativecommons.org)或[OSI 批准许可协议](https://opensource.org/licenses)提供。 + +作为 _[许可协议](license.zh.md)_ 第 7 条的特别豁免,我们允许在部分非盈利性的杂志、会议或预印本服务器上发表使用 MultiMolecule 的研究论文。目前,我们允许的非盈利性杂志、会议或预印本服务器包括: + +- 美国科学促进会(AAAS)旗下的所有期刊 +- eLife 要在封闭获取的期刊或会议上发表论文,您必须从我们这里获得单独的许可。这通常包括共同署名、支持项目的费用或两者兼而有之。请通过 [multimolecule@zyc.ai](mailto:multimolecule@zyc.ai) 与我们联系以获取更多信息。 diff --git a/docs/docs/data/multitask.md b/docs/docs/data/multitask.md new file mode 100644 index 00000000..054c6205 --- /dev/null +++ b/docs/docs/data/multitask.md @@ -0,0 +1,9 @@ +--- +authors: + - Zhiyuan Chen +date: 2024-05-04 +--- + +# MultiTask + +::: multimolecule.data.multitask diff --git a/docs/docs/runners/config.md b/docs/docs/runners/config.md new file mode 100644 index 00000000..8e188199 --- /dev/null +++ b/docs/docs/runners/config.md @@ -0,0 +1,9 @@ +--- +authors: + - Zhiyuan Chen +date: 2024-05-04 +--- + +# MultiMoleculeConfig + +::: multimolecule.runners.MultiMoleculeConfig diff --git a/docs/docs/runners/index.md b/docs/docs/runners/index.md new file mode 100644 index 00000000..75a0f528 --- /dev/null +++ b/docs/docs/runners/index.md @@ -0,0 +1,9 @@ +--- +authors: + - Zhiyuan Chen +date: 2024-05-04 +--- + +# runners + +--8<-- "multimolecule/runners/README.md:8:" diff --git a/docs/docs/runners/runner.md b/docs/docs/runners/runner.md new file mode 100644 index 00000000..7f2f4c12 --- /dev/null +++ b/docs/docs/runners/runner.md @@ -0,0 +1,9 @@ +--- +authors: + - Zhiyuan Chen +date: 2024-05-04 +--- + +# MultiMoleculeRunner + +::: multimolecule.runners.base_runner.BaseRunner diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 93d53a45..9324d9c3 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -9,9 +9,14 @@ repo_url: https://github.com/DLS5-Omics/multimolecule nav: - index.md + - runners: + - runners/index.md + - MultiMoleculeRunner: runners/runner.md + - MultiMoleculeConfig: runners/config.md - data: - data/index.md - Dataset: data/dataset.md + - multitask: data/multitask.md - datasets: - datasets/index.md - DNA: diff --git a/multimolecule/__init__.py b/multimolecule/__init__.py index 240e9fcc..c9ee3a87 100644 --- a/multimolecule/__init__.py +++ b/multimolecule/__init__.py @@ -14,6 +14,7 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +from .apis import evaluate, infer, train from .data import Dataset from .models import ( AutoModelForContactPrediction, @@ -111,30 +112,33 @@ HeadConfig, HeadRegistry, HeadTransformRegistry, - HeadTransformRegistryHF, IdentityTransform, LinearTransform, MaskedLMHead, MaskedLMHeadConfig, NonLinearTransform, PositionEmbeddingRegistry, - PositionEmbeddingRegistryHF, PredictionHead, RotaryEmbedding, SequencePredictionHead, SinusoidalEmbedding, - TokenHeadRegistryHF, TokenKMerHead, TokenPredictionHead, ) +from .runners import MultiMoleculeConfig, MultiMoleculeRunner from .tasks import Task, TaskLevel, TaskType from .tokenisers import Alphabet, DnaTokenizer, DotBracketTokenizer, ProteinTokenizer, RnaTokenizer, Tokenizer from .utils import count_parameters __all__ = [ + "train", + "evaluate", + "infer", "modeling_auto", "modeling_outputs", "Dataset", + "MultiMoleculeConfig", + "MultiMoleculeRunner", "PreTrainedConfig", "HeadConfig", "BaseHeadConfig", @@ -233,21 +237,15 @@ "HeadRegistry", "PredictionHead", "SequencePredictionHead", - "TokenHeadRegistryHF", "TokenPredictionHead", "TokenKMerHead", - "NucleotideHeadRegistryHF", - "NucleotidePredictionHead", - "NucleotideKMerHead", "ContactPredictionHead", "MaskedLMHead", "HeadTransformRegistry", - "HeadTransformRegistryHF", "LinearTransform", "NonLinearTransform", "IdentityTransform", "PositionEmbeddingRegistry", - "PositionEmbeddingRegistryHF", "RotaryEmbedding", "SinusoidalEmbedding", "Criterion", diff --git a/multimolecule/apis/__init__.py b/multimolecule/apis/__init__.py new file mode 100644 index 00000000..8e3e5b3c --- /dev/null +++ b/multimolecule/apis/__init__.py @@ -0,0 +1,19 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from .run import evaluate, infer, train + +__all__ = ["train", "evaluate", "infer"] diff --git a/multimolecule/apis/run.py b/multimolecule/apis/run.py new file mode 100644 index 00000000..1fdb7666 --- /dev/null +++ b/multimolecule/apis/run.py @@ -0,0 +1,115 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +# mypy: disable-error-code="attr-defined" + +import atexit +import os +import warnings +from typing import Type + +import danling as dl +import torch + +from multimolecule.runners import MultiMoleculeConfig, MultiMoleculeRunner + +try: + import nni +except ImportError: + nni = None + + +def train( + config: MultiMoleculeConfig = None, # type: ignore + runner_cls: Type[MultiMoleculeRunner] = MultiMoleculeRunner, +): + if config is None: + config = MultiMoleculeConfig() + config = config.parse(default_config="config", no_default_config_action="warn") + config.interpolate(unsafe_eval=True) + config.training = True + if config.allow_tf32: + torch.backends.cudnn.allow_tf32 = True + torch.backends.cuda.matmul.allow_tf32 = True + if config.reduced_precision_reduction: + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True + if config.get("nni", False): + if nni is None: + raise ValueError("Unable to retrieve nni parameters, since nni is not installed.") + config.merge(nni.get_next_parameter()) + with dl.debug(config.get("debug", False)): + runner = runner_cls(config) + atexit.register(runner.print_result) + atexit.register(runner.save_result) + atexit.register(runner.save_checkpoint) + result = runner.train() + return result + + +def evaluate( + config: MultiMoleculeConfig = None, # type: ignore + runner_cls: Type[MultiMoleculeRunner] = MultiMoleculeRunner, +): + if config is None: + config = MultiMoleculeConfig.empty() + config = config.parse(default_config="config", no_default_config_action="warn") + config.interpolate(unsafe_eval=True) + config.training = False + if config.allow_tf32: + torch.backends.cudnn.allow_tf32 = True + torch.backends.cuda.matmul.allow_tf32 = True + if config.reduced_precision_reduction: + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True + if "checkpoint" not in config or not isinstance(config.checkpoint, str): + raise RuntimeError("Please specify `checkpoint` to run evaluate") + for name, data in config.datas.items(): + if "evaluation" not in data or not isinstance(data.evaluate, str): + raise RuntimeError(f"Please specify `evaluation` to run evaluate in datas.{name}") + runner = runner_cls(config) + result = runner.evaluate_epoch("evaluation") + print(result) + return result + + +def infer( + config: MultiMoleculeConfig = None, # type: ignore + runner_cls: Type[MultiMoleculeRunner] = MultiMoleculeRunner, +): + if config is None: + config = MultiMoleculeConfig.empty() + config = config.parse(default_config="config", no_default_config_action="warn") + config.interpolate(unsafe_eval=True) + config.training = False + if config.allow_tf32: + torch.backends.cudnn.allow_tf32 = True + torch.backends.cuda.matmul.allow_tf32 = True + if config.reduced_precision_reduction: + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True + if "checkpoint" not in config or not isinstance(config.checkpoint, str): + raise RuntimeError("Please specify `checkpoint` to run infer.") + for name, data in config.datas.items(): + if "inference" not in data or not isinstance(data.inference, str): + raise RuntimeError(f"Please specify `inference` to run infer in datas.{name}") + if "result_path" not in config or not isinstance(config.result_path, str): + config.result_path = os.path.join(os.getcwd(), "result.json") + warnings.warn("`result_path` is not specified, default to `result.json`.", RuntimeWarning, stacklevel=2) + runner = runner_cls(config) + result = runner.infer() + runner.save(result, config.result_path) + return result diff --git a/multimolecule/apis/stat.py b/multimolecule/apis/stat.py new file mode 100644 index 00000000..5e525d55 --- /dev/null +++ b/multimolecule/apis/stat.py @@ -0,0 +1,99 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import os +import shutil +from statistics import mean +from typing import List + +import chanfig +import pandas as pd +from chanfig import NestedDict +from tqdm import tqdm + + +class Result(NestedDict): + pretrained: str + id: str + seed: int + epoch: int + validation: NestedDict + test: NestedDict + + +def get_result_stat(experiment_root: str, remove_empty: bool = True) -> List[Result]: + results = [] + for root, _, files in tqdm(os.walk(experiment_root)): + if "run.log" in files: + if "best.json" not in files: + if remove_empty: + shutil.rmtree(root) + continue + best = NestedDict.from_json(os.path.join(root, "best.json")) + if "index" not in best: + if remove_empty: + shutil.rmtree(root) + continue + config = NestedDict.from_yaml(os.path.join(root, "trainer.yaml")) + pretrained = config.pretrained.split("/")[-1] + seed = config.seed + pretrained, seed = "", 1 + result = Result(id=best.id, pretrained=pretrained, seed=seed) + result.validation = NestedDict( + {k: format(mean(v) if isinstance(v, list) else v, ".8f") for k, v in best.validation.items()} + ) + result.test = NestedDict( + {k: format(mean(v) if isinstance(v, list) else v, ".8f") for k, v in best.test.items()} + ) + result.epoch = best.index + result.pop("validation.time", None) + result.pop("test.time", None) + result.pop("validation.loss", None) + result.pop("test.loss", None) + result.pop("validation.lr", None) + result.pop("test.lr", None) + results.append(result) + # Remove empty directories, perform twice to remove all empty directories + if remove_empty: + for root, dirs, files in os.walk(experiment_root): + if not files and not dirs: + os.rmdir(root) + for root, dirs, files in os.walk(experiment_root): + if not files and not dirs: + os.rmdir(root) + results.sort(key=lambda x: (x.pretrained, x.seed, x.id)) + return results + + +def write_result_stat(results: List[Result], path: str): + results = [dict(result.all_items()) for result in results] # type: ignore[misc] + df = pd.DataFrame.from_dict(results) + df.insert(len(df.keys()) - 1, "comment", "") + df.fillna("") + df.to_csv(path, index=False) + + +class Config(chanfig.Config): + experiment_root: str = "experiments" + out_path: str = "result.csv" + + +if __name__ == "__main__": + config = Config().parse() + result_stat = get_result_stat(config.experiment_root) + if not len(result_stat) > 0: + raise ValueError("No results found") + write_result_stat(result_stat, config.out_path) diff --git a/multimolecule/data/__init__.py b/multimolecule/data/__init__.py index 62196c10..f2366d77 100644 --- a/multimolecule/data/__init__.py +++ b/multimolecule/data/__init__.py @@ -15,6 +15,14 @@ # along with this program. If not, see . from .dataset import Dataset +from .multitask import DistributedMultiTaskSampler, MultiTaskDataset, MultiTaskSampler from .utils import no_collate -__all__ = ["Dataset", "no_collate"] +__all__ = [ + "Dataset", + "PandasDataset", + "MultiTaskDataset", + "MultiTaskSampler", + "DistributedMultiTaskSampler", + "no_collate", +] diff --git a/multimolecule/data/dataset.py b/multimolecule/data/dataset.py index 54565349..6c5b9f6e 100644 --- a/multimolecule/data/dataset.py +++ b/multimolecule/data/dataset.py @@ -80,10 +80,13 @@ class Dataset(datasets.Dataset): preprocess: Whether to preprocess the dataset. Preprocessing involves pre-tokenizing the sequences using the tokenizer. Defaults to `True`. - auto_rename_cols: Whether to automatically rename columns to standard names. - Only works when there is exactly one feature column / one label column. - You can control the naming through `multimolecule.defaults.SEQUENCE_COL_NAME` and - `multimolecule.defaults.LABEL_COL_NAME`. + auto_rename_sequence_col: Whether to automatically rename sequence columns to standard name. + Only works when there is exactly one sequence column + You can control the naming through `multimolecule.defaults.SEQUENCE_COL_NAME`. + For more refined control, use `column_names_map`. + auto_rename_label_cols: Whether to automatically rename label column to standard name. + Only works when there is exactly one label column. + You can control the naming through `multimolecule.defaults.LABEL_COL_NAME`. For more refined control, use `column_names_map`. column_names_map: A mapping of column names to new column names. This is useful for renaming columns to inputs that are expected by a model. @@ -122,7 +125,8 @@ class Dataset(datasets.Dataset): _discrete_map: Mapping preprocess: bool = True - auto_rename_cols: bool = False + auto_rename_sequence_col: bool = True + auto_rename_label_col: bool = False column_names_map: Mapping[str, str] | None = None ignored_cols: List[str] = [] @@ -136,7 +140,8 @@ def __init__( label_cols: List | None = None, id_cols: List | None = None, preprocess: bool | None = None, - auto_rename_cols: bool | None = None, + auto_rename_sequence_col: bool | None = None, + auto_rename_label_col: bool | None = None, column_names_map: Mapping[str, str] | None = None, truncation: bool | None = None, max_seq_length: int | None = None, @@ -149,8 +154,9 @@ def __init__( fingerprint: str | None = None, ignored_cols: List[str] | None = None, ): + self._tasks = NestedDict() if tasks is not None: - self._tasks = NestedDict(tasks) + self.tasks = tasks if discrete_map is not None: self._discrete_map = discrete_map arrow_table = self.build_table( @@ -166,7 +172,8 @@ def __init__( preprocess=preprocess, truncation=truncation, max_seq_length=max_seq_length, - auto_rename_cols=auto_rename_cols, + auto_rename_sequence_col=auto_rename_sequence_col, + auto_rename_label_col=auto_rename_label_col, column_names_map=column_names_map, ) self.ignored_cols = ignored_cols or self.id_cols @@ -187,13 +194,13 @@ def build_table( data = dl.load_pandas(data) if isinstance(data, DataFrame): data = data.loc[:, ~data.columns.str.contains("^Unnamed")] - data = pa.Table.from_pandas(data) + data = pa.Table.from_pandas(data, preserve_index=False) elif isinstance(data, dict): data = pa.Table.from_pydict(data) elif isinstance(data, list): data = pa.Table.from_pylist(data) elif isinstance(data, DataFrame): - data = pa.Table.from_pandas(data) + data = pa.Table.from_pandas(data, preserve_index=False) if feature_cols is not None and label_cols is not None: data = data.select(feature_cols + label_cols) data = self.process_nan(data, nan_process=nan_process, fill_value=fill_value) @@ -206,7 +213,8 @@ def post( max_seq_length: int | None = None, truncation: bool | None = None, preprocess: bool | None = None, - auto_rename_cols: bool | None = None, + auto_rename_sequence_col: bool | None = None, + auto_rename_label_col: bool | None = None, column_names_map: Mapping[str, str] | None = None, ) -> None: r""" @@ -214,7 +222,8 @@ def post( It first identifies the special columns (sequence and structure columns) in the dataset. Then it sets the feature and label columns based on the input arguments. - If `auto_rename_cols` is `True`, it will automatically rename the columns to model inputs. + If `auto_rename_sequence_col` is `True`, it will automatically rename the sequence column. + If `auto_rename_label_col` is `True`, it will automatically rename the label column. Finally, it sets the [`transform`][datasets.Dataset.set_transform] function based on the `preprocess` flag. """ if tokenizer is None: @@ -237,19 +246,24 @@ def post( self.seq_length_offset += 1 if preprocess is not None: self.preprocess = preprocess - if auto_rename_cols is not None: - self.auto_rename_cols = auto_rename_cols - if self.auto_rename_cols: - if column_names_map is not None: - raise ValueError("auto_rename_cols and column_names_map are mutually exclusive.") + if auto_rename_sequence_col is not None: + self.auto_rename_sequence_col = auto_rename_sequence_col + if auto_rename_label_col is not None: + self.auto_rename_label_col = auto_rename_label_col + if column_names_map is None: column_names_map = {} - if len(self.feature_cols) == 1: - column_names_map[self.feature_cols[0]] = defaults.SEQUENCE_COL_NAME - if len(self.label_cols) == 1: - column_names_map[self.label_cols[0]] = defaults.LABEL_COL_NAME + if self.auto_rename_sequence_col: + if len(self.sequence_cols) != 1: + raise ValueError("auto_rename_sequence_col can only be used when there is exactly one sequence column.") + column_names_map[self.sequence_cols[0]] = defaults.SEQUENCE_COL_NAME # type: ignore[index] + if self.auto_rename_label_col: + if len(self.label_cols) != 1: + raise ValueError("auto_rename_label_col can only be used when there is exactly one label column.") + column_names_map[self.label_cols[0]] = defaults.LABEL_COL_NAME # type: ignore[index] self.column_names_map = column_names_map if self.column_names_map: self.rename_columns(self.column_names_map) + self.infer_tasks() if self.preprocess: self.update(self.map(self.tokenization)) @@ -258,7 +272,7 @@ def post( if self.discrete_map: self.update(self.map(self.map_discrete)) fn_kwargs = { - "columns": [name for name, task in self.tasks.items() if task.level in ["nucleotide", "contact"]], + "columns": [name for name, task in self.tasks.items() if task.level in ["token", "contact"]], "max_seq_length": self.max_seq_length - self.seq_length_offset, } if self.truncation and 0 < self.max_seq_length < 2**32: @@ -297,20 +311,23 @@ def collate(self, col: str, data: Any) -> Tensor | NestedTensor | None: except ValueError: return NestedTensor(data) - def infer_tasks(self, tasks: Mapping | None = None, sequence_col: str | None = None) -> NestedDict: - self._tasks = tasks or NestedDict() + def infer_tasks(self, sequence_col: str | None = None) -> NestedDict: for col in self.label_cols: - if col not in self.tasks: - if col in self.secondary_structure_cols: - task = Task(TaskType.Binary, level=TaskLevel.Contact, num_labels=1) - self._tasks[col] = task # type: ignore[index] - warn( - f"Secondary structure columns are assumed to be {task}." - " Please explicitly specify the task if this is not the case." - ) - else: - self._tasks[col] = self.infer_task(col, sequence_col) # type: ignore[index] - return self._tasks + if col in self.tasks: + continue + if col in self.secondary_structure_cols: + task = Task(TaskType.Binary, level=TaskLevel.Contact, num_labels=1) + self.tasks[col] = task # type: ignore[index] + warn( + f"Secondary structure columns are assumed to be {task}. " + "Please explicitly specify the task if this is not the case." + ) + else: + try: + self.tasks[col] = self.infer_task(col, sequence_col) # type: ignore[index] + except ValueError: + raise ValueError(f"Unable to infer task for column {col}.") + return self.tasks def infer_task(self, label_col: str, sequence_col: str | None = None) -> Task: if sequence_col is None: @@ -404,7 +421,7 @@ def rename_columns(self, column_mapping: Mapping[str, str], new_fingerprint: str self._label_cols = [column_mapping.get(i, i) for i in self.label_cols] self._sequence_cols = [column_mapping.get(i, i) for i in self.sequence_cols] self._secondary_structure_cols = [column_mapping.get(i, i) for i in self.secondary_structure_cols] - self._tasks = {column_mapping.get(k, k): v for k, v in self.tasks.items()} + self.tasks = {column_mapping.get(k, k): v for k, v in self.tasks.items()} return self def rename_column( @@ -418,7 +435,7 @@ def rename_column( self._secondary_structure_cols = [ new_column_name if i == original_column_name else i for i in self.secondary_structure_cols ] - self._tasks = {new_column_name if k == original_column_name else k: v for k, v in self.tasks.items()} + self.tasks = {new_column_name if k == original_column_name else k: v for k, v in self.tasks.items()} return self def process_nan(self, data: Table, nan_process: str | None, fill_value: str | int | float = 0) -> Table: @@ -470,9 +487,18 @@ def secondary_structure_cols(self) -> List: @property def tasks(self) -> NestedDict: if not hasattr(self, "_tasks"): + self._tasks = NestedDict() return self.infer_tasks() return self._tasks + @tasks.setter + def tasks(self, tasks: Mapping): + self._tasks = NestedDict() + for name, task in tasks.items(): + if not isinstance(task, Task): + task = Task(**task) + self._tasks[name] = task + @property def discrete_map(self) -> Mapping: if not hasattr(self, "_discrete_map"): diff --git a/multimolecule/data/multitask.py b/multimolecule/data/multitask.py new file mode 100644 index 00000000..7c20e829 --- /dev/null +++ b/multimolecule/data/multitask.py @@ -0,0 +1,246 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from __future__ import annotations + +from bisect import bisect_right +from collections.abc import Iterator, Mapping, Sequence +from copy import deepcopy +from random import choices + +import torch +from chanfig import NestedDict +from torch import distributed as dist +from torch.utils import data + +from .dataset import Dataset + + +class MultiTaskDataset(data.ConcatDataset): + + datasets: Mapping + dataset_keys: Sequence[str] + dataset_values: Sequence[Dataset] + + def __init__(self, datasets: Mapping) -> None: + for key, dataset in datasets.items(): + if not isinstance(dataset, Dataset): + raise TypeError(f"Dataset {key} should be an instance of Dataset") + self.datasets = datasets + if not len(self.datasets) > 0: + raise ValueError("MultiTaskDataset should contain at least one dataset") + self.dataset_keys, self.dataset_values = zip(*self.datasets.items()) + self.cumulative_sizes = self.cumsum(self.dataset_values) + + def __getitems__(self, key: Sequence[int]) -> Mapping: + dataset_idx = bisect_right(self.cumulative_sizes, key[0]) + if dataset_idx == 0: + sample_idx = key + else: + sample_idx = [i - self.cumulative_sizes[dataset_idx - 1] for i in key] + batch = self.dataset_values[dataset_idx][sample_idx] + batch["dataset"] = self.dataset_keys[dataset_idx] + return batch + + @property + def tasks(self) -> NestedDict: + tasks = NestedDict() + for dataset in self.dataset_values: + for n, t in dataset.tasks.items(): + if n not in tasks: + tasks[n] = t + elif tasks[n] != t: + raise ValueError(f"Task {n} has different configurations across datasets") + return tasks + + @property + def dataset_tasks(self) -> NestedDict: + return NestedDict({k: v.tasks for k, v in self.datasets.items()}) + + def __repr__(self) -> str: + return f"MultiTaskDataset({', '.join([str(d) for d in self.datasets])})" + + +class MultiTaskSampler(data.BatchSampler): + r""" + Ensure all items in a batch comes from the same dataset. + + Arguments: + sampler (Sampler): Base sampler. + batch_size (int): Size of mini-batch. + drop_last (bool): If ``True``, the sampler will drop the last batch if + its size would be less than ``batch_size`` + """ + + datasets: Sequence[Dataset] + + def __init__( # pylint: disable=super-init-not-called + self, + dataset: MultiTaskDataset, + batch_size: int, + shuffle: bool = True, + drop_last: bool = False, + sampler_cls: type[data.Sampler] | None = None, + weights: list[int] | None = None, + ) -> None: + self.datasets = dataset.dataset_values + self.batch_size = batch_size + self.drop_last = drop_last + self.shuffle = shuffle + if sampler_cls is None: + sampler_cls = data.RandomSampler if shuffle else data.SequentialSampler + self.samplers = [sampler_cls(d) for d in self.datasets] # type: ignore + self.dataset_sizes = [len(d) for d in self.datasets] # type: ignore + self.cumulative_sizes = dataset.cumulative_sizes + self.num_datasets = len(self.datasets) + self.weights = weights if weights is not None else self.dataset_sizes + + def __iter__(self): + sampler_iters = [(i, iter(s)) for i, s in enumerate(self.samplers)] + sampler_weights = deepcopy(self.weights) + sampler_idx = 0 + # Implemented based on the benchmarking in https://github.com/pytorch/pytorch/pull/76951 + if self.drop_last: + while sampler_iters: + if self.shuffle: + sampler_idx = choices(range(len(sampler_iters)), weights=sampler_weights)[0] + sampler_id, sampler_iter = sampler_iters[sampler_idx] + cumulative_size = self.cumulative_sizes[sampler_id - 1] if sampler_id > 0 else 0 + try: + batch = [next(sampler_iter) + cumulative_size for _ in range(self.batch_size)] + yield batch + except StopIteration: + sampler_iters.pop(sampler_idx) + sampler_weights.pop(sampler_idx) + else: + while sampler_iters: + if self.shuffle: + sampler_idx = choices(range(len(sampler_iters)), weights=sampler_weights)[0] + sampler_id, sampler_iter = sampler_iters[sampler_idx] + cumulative_size = self.cumulative_sizes[sampler_id - 1] if sampler_id > 0 else 0 + batch = [0] * self.batch_size + idx_in_batch = 0 + try: + for _ in range(self.batch_size): + batch[idx_in_batch] = next(sampler_iter) + cumulative_size + idx_in_batch += 1 + yield batch + idx_in_batch = 0 # noqa: SIM113 + batch = [0] * self.batch_size + except StopIteration: + sampler_iters.pop(sampler_idx) + sampler_weights.pop(sampler_idx) + if idx_in_batch > 0: + yield batch[:idx_in_batch] + + def __len__(self): + batch_size = self.batch_size + if self.drop_last: + return sum(len(d) // batch_size for d in self.datasets) + return sum((len(d) + batch_size - 1) // batch_size for d in self.datasets) + + +class DistributedMultiTaskSampler(MultiTaskSampler): # pylint: disable=too-few-public-methods + r""" + Distributed version of MultiTaskSampler, which ensures that all GPUs sample data from the + same sub-dataset in each step without requiring additional communication. + The dataset selection is based on a random seed mechanism that is synchronized across epochs. + + See Also: + [MultiTaskSampler][MultiTaskSampler] + """ + + def __init__( + self, + dataset: MultiTaskDataset, + batch_size: int, + shuffle: bool = True, + drop_last: bool = False, + sampler_cls: type[data.Sampler] = data.RandomSampler, + weights: list[int] | None = None, + seed: int = 0, + ) -> None: + super().__init__(dataset, batch_size, shuffle, drop_last, sampler_cls, weights) + self.samplers = [data.DistributedSampler(d, shuffle=shuffle, drop_last=drop_last) for d in self.datasets] + self.seed = seed + self.epoch = 0 + + def set_epoch(self, epoch: int): + """ + Sets the epoch for deterministic shuffling. + """ + self.epoch = epoch + for sampler in self.samplers: + sampler.set_epoch(epoch) + + def _get_sampler_idx(self, high: int) -> int: + """ + Determines which sampler (i.e., sub-dataset) to use based on the seed and epoch. + """ + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + sampler_idx = torch.randint(low=0, high=high, size=(1,), generator=g).item() + return sampler_idx + + def __iter__(self) -> Iterator: + sampler_iters = [(i, iter(s)) for i, s in enumerate(self.samplers)] + sampler_weights = deepcopy(self.weights) + + if self.drop_last: + while sampler_iters: + # Sample the same sub-dataset across all GPUs using the seeded index + sampler_idx = self._get_sampler_idx(len(sampler_iters)) + sampler_id, sampler_iter = sampler_iters[sampler_idx] + cumulative_size = self.cumulative_sizes[sampler_id - 1] if sampler_id > 0 else 0 + try: + batch = [next(sampler_iter) + cumulative_size for _ in range(self.batch_size)] + yield batch + except StopIteration: + sampler_iters.pop(sampler_idx) + sampler_weights.pop(sampler_idx) + else: + while sampler_iters: + # Sample the same sub-dataset across all GPUs using the seeded index + sampler_idx = self._get_sampler_idx(len(sampler_iters)) + sampler_id, sampler_iter = sampler_iters[sampler_idx] + cumulative_size = self.cumulative_sizes[sampler_id - 1] if sampler_id > 0 else 0 + batch = [0] * self.batch_size + idx_in_batch = 0 + try: + for _ in range(self.batch_size): + batch[idx_in_batch] = next(sampler_iter) + cumulative_size + idx_in_batch += 1 + yield batch + idx_in_batch = 0 # noqa: SIM113 + batch = [0] * self.batch_size + except StopIteration: + sampler_iters.pop(sampler_idx) + sampler_weights.pop(sampler_idx) + if idx_in_batch > 0: + yield batch[:idx_in_batch] + + def __len__(self) -> int: + batch_size = self.batch_size * self.world_size + if self.drop_last: + return sum(len(d) // batch_size for d in self.datasets) + return sum((len(d) + batch_size - 1) // batch_size for d in self.datasets) + + @property + def world_size(self) -> int: + r"""Return the number of processes in the current process group.""" + if dist.is_available() and dist.is_initialized(): + return dist.get_world_size() + return 1 diff --git a/multimolecule/data/utils.py b/multimolecule/data/utils.py index 85bc4423..1afddd1b 100644 --- a/multimolecule/data/utils.py +++ b/multimolecule/data/utils.py @@ -60,7 +60,7 @@ def infer_task( level = TaskLevel.Contact num_labels = len(flattened) // num_contacts elif len(flattened) % num_tokens == 0: - level = TaskLevel.Nucleotide + level = TaskLevel.Token num_labels = len(flattened) // num_tokens elif len(flattened) % num_elem == 0: level = TaskLevel.Sequence @@ -86,7 +86,7 @@ def infer_task( task_type = TaskType.MultiClass if num_labels > 2 else TaskType.Binary num_labels = 1 if task_type == TaskType.Binary else num_labels if num_tokens_flattened == num_tokens: - return Task(task_type, level=TaskLevel.Nucleotide, num_labels=num_labels) + return Task(task_type, level=TaskLevel.Token, num_labels=num_labels) if num_contacts_flattened == num_contacts: return Task(task_type, level=TaskLevel.Contact, num_labels=num_labels) return Task(task_type, level=TaskLevel.Sequence, num_labels=num_labels) @@ -122,7 +122,7 @@ def map_value(value: Any, mapping: dict[str, int] | None) -> Any: def truncate_value(value: Any, max_seq_length: int, level: int | None = None) -> Any: - if level == TaskLevel.Nucleotide: + if level == TaskLevel.Token: return value[:max_seq_length] if level == TaskLevel.Contact: return [i[:max_seq_length] for i in value[:max_seq_length]] diff --git a/multimolecule/defaults.py b/multimolecule/defaults.py index c299ea1a..a908bbdb 100644 --- a/multimolecule/defaults.py +++ b/multimolecule/defaults.py @@ -14,11 +14,16 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +TRAIN_SPLITS = ("train",) +VALIDATION_SPLITS = ("val", "validation") +TEST_SPLITS = ("test", "eval", "evaluation") +INFERENCE_SPLITS = ("inf", "inference") +DATASET_SPLITS = TRAIN_SPLITS + VALIDATION_SPLITS + TEST_SPLITS + INFERENCE_SPLITS ID_COL_NAMES = ["id", "idx", "index"] SEQUENCE_COL_NAMES = ["input_ids", "sequence", "seq"] SECONDARY_STRUCTURE_COL_NAMES = ["secondary_structure", "ss"] LABEL_COL_NAMES = ["label", "labels"] -SEQUENCE_COL_NAME = "input_ids" +SEQUENCE_COL_NAME = "sequence" LABEL_COL_NAME = "labels" LABLE_TYPE_THRESHOLD = 0.5 TASK_INFERENCE_NUM_ROWS = 100 diff --git a/multimolecule/models/__init__.py b/multimolecule/models/__init__.py index 66147616..29d99436 100644 --- a/multimolecule/models/__init__.py +++ b/multimolecule/models/__init__.py @@ -14,6 +14,7 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +from multimolecule.module import HeadConfig from multimolecule.tokenisers import DnaTokenizer, ProteinTokenizer, RnaTokenizer from .calm import ( @@ -127,6 +128,7 @@ __all__ = [ "PreTrainedConfig", + "HeadConfig", "DnaTokenizer", "RnaTokenizer", "ProteinTokenizer", diff --git a/multimolecule/models/calm/configuration_calm.py b/multimolecule/models/calm/configuration_calm.py index c5d73c03..032bda8e 100644 --- a/multimolecule/models/calm/configuration_calm.py +++ b/multimolecule/models/calm/configuration_calm.py @@ -127,5 +127,5 @@ def __init__( self.use_cache = use_cache self.emb_layer_norm_before = emb_layer_norm_before self.token_dropout = token_dropout - self.head = HeadConfig(**head if head is not None else {}) - self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {}) + self.head = HeadConfig(**head) if head is not None else None + self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None diff --git a/multimolecule/models/calm/modeling_calm.py b/multimolecule/models/calm/modeling_calm.py index 25c1eba4..c8abdffe 100644 --- a/multimolecule/models/calm/modeling_calm.py +++ b/multimolecule/models/calm/modeling_calm.py @@ -270,9 +270,9 @@ class CaLmForSequencePrediction(CaLmPreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.tensor([[1]])) >>> output["logits"].shape - torch.Size([1, 2]) + torch.Size([1, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: CaLmConfig): @@ -334,9 +334,9 @@ class CaLmForTokenPrediction(CaLmPreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.randint(2, (1, 5))) >>> output["logits"].shape - torch.Size([1, 5, 2]) + torch.Size([1, 5, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: CaLmConfig): @@ -398,9 +398,9 @@ class CaLmForContactPrediction(CaLmPreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.randint(2, (1, 5, 5))) >>> output["logits"].shape - torch.Size([1, 5, 5, 2]) + torch.Size([1, 5, 5, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: CaLmConfig): diff --git a/multimolecule/models/configuration_utils.py b/multimolecule/models/configuration_utils.py index 2047d671..ce6f10ea 100644 --- a/multimolecule/models/configuration_utils.py +++ b/multimolecule/models/configuration_utils.py @@ -30,7 +30,8 @@ class PreTrainedConfig(PretrainedConfig): Base class for all model configuration classes. """ - head: HeadConfig + head: HeadConfig | None + num_labels: int = 1 hidden_size: int @@ -42,7 +43,15 @@ class PreTrainedConfig(PretrainedConfig): null_token_id: int = 5 def __init__( - self, pad_token_id=0, bos_token_id=1, eos_token_id=2, unk_token_id=3, mask_token_id=4, null_token_id=5, **kwargs + self, + pad_token_id: int = 0, + bos_token_id: int = 1, + eos_token_id: int = 2, + unk_token_id: int = 3, + mask_token_id: int = 4, + null_token_id: int = 5, + num_labels: int = 1, + **kwargs, ): super().__init__( pad_token_id=pad_token_id, @@ -51,6 +60,7 @@ def __init__( unk_token_id=unk_token_id, mask_token_id=mask_token_id, null_token_id=null_token_id, + num_labels=num_labels, **kwargs, ) diff --git a/multimolecule/models/ernierna/configuration_ernierna.py b/multimolecule/models/ernierna/configuration_ernierna.py index 0648bb2d..bfd11d51 100644 --- a/multimolecule/models/ernierna/configuration_ernierna.py +++ b/multimolecule/models/ernierna/configuration_ernierna.py @@ -110,5 +110,5 @@ def __init__( self.pairwise_alpha = pairwise_alpha self.is_decoder = is_decoder self.use_cache = use_cache - self.head = HeadConfig(**head if head is not None else {}) - self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {}) + self.head = HeadConfig(**head) if head is not None else None + self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None diff --git a/multimolecule/models/ernierna/modeling_ernierna.py b/multimolecule/models/ernierna/modeling_ernierna.py index 6354a68c..1378e256 100644 --- a/multimolecule/models/ernierna/modeling_ernierna.py +++ b/multimolecule/models/ernierna/modeling_ernierna.py @@ -321,7 +321,7 @@ class ErnieRnaForSequencePrediction(ErnieRnaPreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input) >>> output["logits"].shape - torch.Size([1, 2]) + torch.Size([1, 1]) """ def __init__(self, config: ErnieRnaConfig): @@ -385,9 +385,9 @@ class ErnieRnaForTokenPrediction(ErnieRnaPreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.randint(2, (1, 5))) >>> output["logits"].shape - torch.Size([1, 5, 2]) + torch.Size([1, 5, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: ErnieRnaConfig): @@ -452,9 +452,9 @@ class ErnieRnaForContactPrediction(ErnieRnaPreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.randint(2, (1, 5, 5))) >>> output["logits"].shape - torch.Size([1, 5, 5, 2]) + torch.Size([1, 5, 5, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: ErnieRnaConfig): @@ -1183,11 +1183,8 @@ class ErnieRnaContactClassificationHead(nn.Module): def __init__(self, config: ErnieRnaConfig, head_config: HeadConfig | None = None): super().__init__() if head_config is None: - head_config = config.head + head_config = config.head or HeadConfig() self.config = head_config - self.bos_token_id = config.bos_token_id - self.eos_token_id = config.eos_token_id - self.pad_token_id = config.pad_token_id self.conv1 = nn.Conv2d(1, 8, 7, 1, 3) self.relu = nn.ReLU(inplace=True) self.dropout = nn.Dropout(p=0.3) diff --git a/multimolecule/models/rinalmo/configuration_rinalmo.py b/multimolecule/models/rinalmo/configuration_rinalmo.py index 5e21725d..1cc963b2 100644 --- a/multimolecule/models/rinalmo/configuration_rinalmo.py +++ b/multimolecule/models/rinalmo/configuration_rinalmo.py @@ -125,6 +125,6 @@ def __init__( self.use_cache = use_cache self.learnable_beta = learnable_beta self.token_dropout = token_dropout - self.head = HeadConfig(**head if head is not None else {}) - self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {}) + self.head = HeadConfig(**head) if head is not None else None + self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None self.emb_layer_norm_before = emb_layer_norm_before diff --git a/multimolecule/models/rinalmo/modeling_rinalmo.py b/multimolecule/models/rinalmo/modeling_rinalmo.py index b45d2823..d0ac6e8c 100644 --- a/multimolecule/models/rinalmo/modeling_rinalmo.py +++ b/multimolecule/models/rinalmo/modeling_rinalmo.py @@ -269,9 +269,9 @@ class RiNALMoForSequencePrediction(RiNALMoPreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.tensor([[1]])) >>> output["logits"].shape - torch.Size([1, 2]) + torch.Size([1, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: RiNALMoConfig): @@ -333,9 +333,9 @@ class RiNALMoForTokenPrediction(RiNALMoPreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.randint(2, (1, 5))) >>> output["logits"].shape - torch.Size([1, 5, 2]) + torch.Size([1, 5, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: RiNALMoConfig): @@ -397,9 +397,9 @@ class RiNALMoForContactPrediction(RiNALMoPreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.randint(2, (1, 5, 5))) >>> output["logits"].shape - torch.Size([1, 5, 5, 2]) + torch.Size([1, 5, 5, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: RiNALMoConfig): diff --git a/multimolecule/models/rnabert/configuration_rnabert.py b/multimolecule/models/rnabert/configuration_rnabert.py index f044ecc7..97632d2e 100644 --- a/multimolecule/models/rnabert/configuration_rnabert.py +++ b/multimolecule/models/rnabert/configuration_rnabert.py @@ -112,5 +112,5 @@ def __init__( self.position_embedding_type = position_embedding_type self.is_decoder = is_decoder self.use_cache = use_cache - self.head = HeadConfig(**head if head is not None else {}) - self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {}) + self.head = HeadConfig(**head) if head is not None else None + self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None diff --git a/multimolecule/models/rnabert/modeling_rnabert.py b/multimolecule/models/rnabert/modeling_rnabert.py index 74b06cf1..32f7bf01 100644 --- a/multimolecule/models/rnabert/modeling_rnabert.py +++ b/multimolecule/models/rnabert/modeling_rnabert.py @@ -37,7 +37,13 @@ from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from transformers.utils import logging -from multimolecule.module import ContactPredictionHead, MaskedLMHead, SequencePredictionHead, TokenPredictionHead +from multimolecule.module import ( + ContactPredictionHead, + HeadConfig, + MaskedLMHead, + SequencePredictionHead, + TokenPredictionHead, +) from ..modeling_outputs import ContactPredictorOutput, SequencePredictorOutput, TokenPredictorOutput from .configuration_rnabert import RnaBertConfig @@ -266,9 +272,9 @@ class RnaBertForSequencePrediction(RnaBertPreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.tensor([[1]])) >>> output["logits"].shape - torch.Size([1, 2]) + torch.Size([1, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: RnaBertConfig): @@ -330,9 +336,9 @@ class RnaBertForTokenPrediction(RnaBertPreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.randint(2, (1, 5))) >>> output["logits"].shape - torch.Size([1, 5, 2]) + torch.Size([1, 5, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: RnaBertConfig): @@ -394,9 +400,9 @@ class RnaBertForContactPrediction(RnaBertPreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.randint(2, (1, 5, 5))) >>> output["logits"].shape - torch.Size([1, 5, 5, 2]) + torch.Size([1, 5, 5, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: RnaBertConfig): @@ -1065,7 +1071,7 @@ def __init__(self, config: RnaBertConfig): vocab_size, config.vocab_size = config.vocab_size, config.ss_vocab_size self.predictions_ss = MaskedLMHead(config) config.vocab_size = vocab_size - self.seq_relationship = SequencePredictionHead(config) + self.seq_relationship = SequencePredictionHead(config, HeadConfig(num_labels=2)) def forward( self, diff --git a/multimolecule/models/rnaernie/configuration_rnaernie.py b/multimolecule/models/rnaernie/configuration_rnaernie.py index 2d540c9d..7a788297 100644 --- a/multimolecule/models/rnaernie/configuration_rnaernie.py +++ b/multimolecule/models/rnaernie/configuration_rnaernie.py @@ -108,5 +108,5 @@ def __init__( self.position_embedding_type = position_embedding_type self.is_decoder = is_decoder self.use_cache = use_cache - self.head = HeadConfig(**head if head is not None else {}) - self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {}) + self.head = HeadConfig(**head) if head is not None else None + self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None diff --git a/multimolecule/models/rnaernie/modeling_rnaernie.py b/multimolecule/models/rnaernie/modeling_rnaernie.py index 7e0f4d10..8107ee20 100644 --- a/multimolecule/models/rnaernie/modeling_rnaernie.py +++ b/multimolecule/models/rnaernie/modeling_rnaernie.py @@ -270,9 +270,9 @@ class RnaErnieForSequencePrediction(RnaErniePreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.tensor([[1]])) >>> output["logits"].shape - torch.Size([1, 2]) + torch.Size([1, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config): @@ -334,9 +334,9 @@ class RnaErnieForTokenPrediction(RnaErniePreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.randint(2, (1, 5))) >>> output["logits"].shape - torch.Size([1, 5, 2]) + torch.Size([1, 5, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: RnaErnieConfig): @@ -398,9 +398,9 @@ class RnaErnieForContactPrediction(RnaErniePreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.randint(2, (1, 5, 5))) >>> output["logits"].shape - torch.Size([1, 5, 5, 2]) + torch.Size([1, 5, 5, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: RnaErnieConfig): diff --git a/multimolecule/models/rnafm/configuration_rnafm.py b/multimolecule/models/rnafm/configuration_rnafm.py index 8fdb7f49..ef1f0c18 100644 --- a/multimolecule/models/rnafm/configuration_rnafm.py +++ b/multimolecule/models/rnafm/configuration_rnafm.py @@ -131,5 +131,5 @@ def __init__( self.use_cache = use_cache self.emb_layer_norm_before = emb_layer_norm_before self.token_dropout = token_dropout - self.head = HeadConfig(**head if head is not None else {}) - self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {}) + self.head = HeadConfig(**head) if head is not None else None + self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None diff --git a/multimolecule/models/rnafm/modeling_rnafm.py b/multimolecule/models/rnafm/modeling_rnafm.py index 99f553da..6898da9c 100644 --- a/multimolecule/models/rnafm/modeling_rnafm.py +++ b/multimolecule/models/rnafm/modeling_rnafm.py @@ -272,9 +272,9 @@ class RnaFmForSequencePrediction(RnaFmPreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.tensor([[1]])) >>> output["logits"].shape - torch.Size([1, 2]) + torch.Size([1, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: RnaFmConfig): @@ -336,9 +336,9 @@ class RnaFmForTokenPrediction(RnaFmPreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.randint(2, (1, 5))) >>> output["logits"].shape - torch.Size([1, 5, 2]) + torch.Size([1, 5, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: RnaFmConfig): @@ -400,9 +400,9 @@ class RnaFmForContactPrediction(RnaFmPreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.randint(2, (1, 5, 5))) >>> output["logits"].shape - torch.Size([1, 5, 5, 2]) + torch.Size([1, 5, 5, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: RnaFmConfig): @@ -555,7 +555,7 @@ class RnaFmForPreTraining(RnaFmPreTrainedModel): >>> output["logits"].shape torch.Size([1, 7, 26]) >>> output["contact_map"].shape - torch.Size([1, 5, 5, 2]) + torch.Size([1, 5, 5, 1]) """ _tied_weights_keys = [ diff --git a/multimolecule/models/rnamsm/configuration_rnamsm.py b/multimolecule/models/rnamsm/configuration_rnamsm.py index ae914c82..2e8150ba 100644 --- a/multimolecule/models/rnamsm/configuration_rnamsm.py +++ b/multimolecule/models/rnamsm/configuration_rnamsm.py @@ -116,5 +116,5 @@ def __init__( self.attention_type = attention_type self.embed_positions_msa = embed_positions_msa self.attention_bias = attention_bias - self.head = HeadConfig(**head if head is not None else {}) - self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {}) + self.head = HeadConfig(**head) if head is not None else None + self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None diff --git a/multimolecule/models/rnamsm/modeling_rnamsm.py b/multimolecule/models/rnamsm/modeling_rnamsm.py index 5ed6bf87..0390a129 100644 --- a/multimolecule/models/rnamsm/modeling_rnamsm.py +++ b/multimolecule/models/rnamsm/modeling_rnamsm.py @@ -176,9 +176,9 @@ class RnaMsmForSequencePrediction(RnaMsmPreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.tensor([[1]])) >>> output["logits"].shape - torch.Size([1, 2]) + torch.Size([1, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: RnaMsmConfig): @@ -239,9 +239,9 @@ class RnaMsmForTokenPrediction(RnaMsmPreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.randint(2, (1, 5))) >>> output["logits"].shape - torch.Size([1, 5, 2]) + torch.Size([1, 5, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: RnaMsmConfig): @@ -302,9 +302,9 @@ class RnaMsmForContactPrediction(RnaMsmPreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.randint(2, (1, 5, 5))) >>> output["logits"].shape - torch.Size([1, 5, 5, 2]) + torch.Size([1, 5, 5, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: RnaMsmConfig): @@ -449,7 +449,7 @@ class RnaMsmForPreTraining(RnaMsmPreTrainedModel): >>> output["logits"].shape torch.Size([1, 7, 26]) >>> output["contact_map"].shape - torch.Size([1, 5, 5, 2]) + torch.Size([1, 5, 5, 1]) """ _tied_weights_keys = [ diff --git a/multimolecule/models/splicebert/configuration_splicebert.py b/multimolecule/models/splicebert/configuration_splicebert.py index 66b46a88..f789516d 100644 --- a/multimolecule/models/splicebert/configuration_splicebert.py +++ b/multimolecule/models/splicebert/configuration_splicebert.py @@ -108,5 +108,5 @@ def __init__( self.position_embedding_type = position_embedding_type self.is_decoder = is_decoder self.use_cache = use_cache - self.head = HeadConfig(**head if head is not None else {}) - self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {}) + self.head = HeadConfig(**head) if head is not None else None + self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None diff --git a/multimolecule/models/splicebert/modeling_splicebert.py b/multimolecule/models/splicebert/modeling_splicebert.py index 1b0fd072..9d129d74 100644 --- a/multimolecule/models/splicebert/modeling_splicebert.py +++ b/multimolecule/models/splicebert/modeling_splicebert.py @@ -274,9 +274,9 @@ class SpliceBertForSequencePrediction(SpliceBertPreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.tensor([[1]])) >>> output["logits"].shape - torch.Size([1, 2]) + torch.Size([1, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: SpliceBertConfig): @@ -338,9 +338,9 @@ class SpliceBertForTokenPrediction(SpliceBertPreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.randint(2, (1, 5))) >>> output["logits"].shape - torch.Size([1, 5, 2]) + torch.Size([1, 5, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: SpliceBertConfig): @@ -402,9 +402,9 @@ class SpliceBertForContactPrediction(SpliceBertPreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.randint(2, (1, 5, 5))) >>> output["logits"].shape - torch.Size([1, 5, 5, 2]) + torch.Size([1, 5, 5, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: SpliceBertConfig): diff --git a/multimolecule/models/utrbert/configuration_utrbert.py b/multimolecule/models/utrbert/configuration_utrbert.py index d032c5ee..5230c04f 100644 --- a/multimolecule/models/utrbert/configuration_utrbert.py +++ b/multimolecule/models/utrbert/configuration_utrbert.py @@ -125,5 +125,5 @@ def __init__( self.position_embedding_type = position_embedding_type self.is_decoder = is_decoder self.use_cache = use_cache - self.head = HeadConfig(**head if head is not None else {}) - self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {}) + self.head = HeadConfig(**head) if head is not None else None + self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None diff --git a/multimolecule/models/utrbert/modeling_utrbert.py b/multimolecule/models/utrbert/modeling_utrbert.py index 1a5b47f9..688bedbe 100644 --- a/multimolecule/models/utrbert/modeling_utrbert.py +++ b/multimolecule/models/utrbert/modeling_utrbert.py @@ -264,9 +264,9 @@ class UtrBertForSequencePrediction(UtrBertPreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.tensor([[1]])) >>> output["logits"].shape - torch.Size([1, 2]) + torch.Size([1, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: UtrBertConfig): @@ -328,9 +328,9 @@ class UtrBertForTokenPrediction(UtrBertPreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.randint(2, (1, 5))) >>> output["logits"].shape - torch.Size([1, 5, 2]) + torch.Size([1, 5, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: UtrBertConfig): @@ -393,9 +393,9 @@ class UtrBertForContactPrediction(UtrBertPreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.randint(2, (1, 5, 5))) >>> output["logits"].shape - torch.Size([1, 5, 5, 2]) + torch.Size([1, 5, 5, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: UtrBertConfig): diff --git a/multimolecule/models/utrlm/configuration_utrlm.py b/multimolecule/models/utrlm/configuration_utrlm.py index f0f705de..a4f930d7 100644 --- a/multimolecule/models/utrlm/configuration_utrlm.py +++ b/multimolecule/models/utrlm/configuration_utrlm.py @@ -127,7 +127,7 @@ def __init__( self.use_cache = use_cache self.emb_layer_norm_before = emb_layer_norm_before self.token_dropout = token_dropout - self.head = HeadConfig(**head if head is not None else {}) - self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {}) + self.head = HeadConfig(**head) if head is not None else None + self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None self.ss_head = HeadConfig(**ss_head) if ss_head is not None else None self.mfe_head = HeadConfig(**mfe_head) if mfe_head is not None else None diff --git a/multimolecule/models/utrlm/modeling_utrlm.py b/multimolecule/models/utrlm/modeling_utrlm.py index 535f99f0..aae1b593 100644 --- a/multimolecule/models/utrlm/modeling_utrlm.py +++ b/multimolecule/models/utrlm/modeling_utrlm.py @@ -272,9 +272,9 @@ class UtrLmForSequencePrediction(UtrLmPreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.tensor([[1]])) >>> output["logits"].shape - torch.Size([1, 2]) + torch.Size([1, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: UtrLmConfig): @@ -336,9 +336,9 @@ class UtrLmForTokenPrediction(UtrLmPreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.randint(2, (1, 5))) >>> output["logits"].shape - torch.Size([1, 5, 2]) + torch.Size([1, 5, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: UtrLmConfig): @@ -400,9 +400,9 @@ class UtrLmForContactPrediction(UtrLmPreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.randint(2, (1, 5, 5))) >>> output["logits"].shape - torch.Size([1, 5, 5, 2]) + torch.Size([1, 5, 5, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: UtrLmConfig): @@ -555,7 +555,7 @@ class UtrLmForPreTraining(UtrLmPreTrainedModel): >>> output["logits"].shape torch.Size([1, 7, 26]) >>> output["contact_map"].shape - torch.Size([1, 5, 5, 2]) + torch.Size([1, 5, 5, 1]) """ _tied_weights_keys = [ diff --git a/multimolecule/module/__init__.py b/multimolecule/module/__init__.py index 0128fe9b..dbba900b 100644 --- a/multimolecule/module/__init__.py +++ b/multimolecule/module/__init__.py @@ -14,8 +14,8 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from .criterions import Criterion -from .embeddings import PositionEmbeddingRegistry, PositionEmbeddingRegistryHF, RotaryEmbedding, SinusoidalEmbedding +from .criterions import Criterion, CriterionRegistry +from .embeddings import PositionEmbeddingRegistry, RotaryEmbedding, SinusoidalEmbedding from .heads import ( BaseHeadConfig, ContactPredictionHead, @@ -23,7 +23,6 @@ HeadOutput, HeadRegistry, HeadTransformRegistry, - HeadTransformRegistryHF, IdentityTransform, LinearTransform, MaskedLMHead, @@ -31,15 +30,18 @@ NonLinearTransform, PredictionHead, SequencePredictionHead, - TokenHeadRegistryHF, TokenKMerHead, TokenPredictionHead, ) +from .model import MultiMoleculeModel +from .registry import ModelRegistry __all__ = [ + "ModelRegistry", + "MultiMoleculeModel", + "CriterionRegistry", "Criterion", "PositionEmbeddingRegistry", - "PositionEmbeddingRegistryHF", "RotaryEmbedding", "SinusoidalEmbedding", "BaseHeadConfig", @@ -48,14 +50,12 @@ "HeadRegistry", "PredictionHead", "SequencePredictionHead", - "TokenHeadRegistryHF", "TokenPredictionHead", "TokenKMerHead", "ContactPredictionHead", "MaskedLMHead", "HeadOutput", "HeadTransformRegistry", - "HeadTransformRegistryHF", "LinearTransform", "NonLinearTransform", "IdentityTransform", diff --git a/multimolecule/module/backbones/__init__.py b/multimolecule/module/backbones/__init__.py new file mode 100644 index 00000000..d69e6292 --- /dev/null +++ b/multimolecule/module/backbones/__init__.py @@ -0,0 +1,21 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from .registry import BackboneRegistry +from .sequence import SequenceBackbone +from .sequences import SequenceRegistry + +__all__ = ["BackboneRegistry", "SequenceRegistry", "SequenceBackbone"] diff --git a/multimolecule/module/backbones/registry.py b/multimolecule/module/backbones/registry.py new file mode 100644 index 00000000..47be122d --- /dev/null +++ b/multimolecule/module/backbones/registry.py @@ -0,0 +1,21 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from __future__ import annotations + +from chanfig import Registry + +BackboneRegistry = Registry() diff --git a/multimolecule/module/backbones/sequence.py b/multimolecule/module/backbones/sequence.py new file mode 100644 index 00000000..2b0ee0cf --- /dev/null +++ b/multimolecule/module/backbones/sequence.py @@ -0,0 +1,59 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from __future__ import annotations + +import torch +from chanfig import FlatDict +from danling import NestedTensor +from torch import Tensor, nn + +from .registry import BackboneRegistry +from .sequences import SequenceRegistry + + +@BackboneRegistry.register("sequence", default=True) +class SequenceBackbone(nn.Module): + def __init__(self, sequence) -> None: + super().__init__() + sequence_dropout = sequence.pop("dropout", 0) + self.sequence = SequenceRegistry.build(**sequence) + self.sequence_dropout = nn.Dropout(sequence_dropout) + self.config = self.sequence.config + self.out_channels = self.config.hidden_size + + def forward(self, sequence: NestedTensor | Tensor, *args, **kwargs) -> tuple[FlatDict, FlatDict]: + attentions = None + input_ids, attention_mask = sequence.tensor, sequence.mask + sequence_output = self.sequence(input_ids.int(), attention_mask) + if "last_hidden_state" in sequence_output: + sequence_output["last_hidden_state"] = self.sequence_dropout(sequence_output["last_hidden_state"]) + elif "logits" in sequence_output: + sequence_output["last_hidden_state"] = self.sequence_dropout(sequence_output["logits"]) + else: + raise ValueError("No token output") + if "pooler_output" in sequence_output: + sequence_output["pooler_output"] = self.sequence_dropout(sequence_output["pooler_output"]) + elif "logits" in sequence_output: + sequence_output["pooler_output"] = self.sequence_dropout( + sequence_output["logits"].mean(dim=1, keepdim=True) + ) + else: + raise ValueError("No sequence output") + if "attentions" in sequence_output: + attentions = torch.stack(sequence_output["attentions"], dim=1).detach() + + return sequence_output, attentions diff --git a/multimolecule/module/backbones/sequences/__init__.py b/multimolecule/module/backbones/sequences/__init__.py new file mode 100644 index 00000000..e6e5cd08 --- /dev/null +++ b/multimolecule/module/backbones/sequences/__init__.py @@ -0,0 +1,20 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from .onehot import OneHot +from .registry import SequenceRegistry + +__all__ = ["SequenceRegistry", "OneHot"] diff --git a/multimolecule/module/backbones/sequences/onehot.py b/multimolecule/module/backbones/sequences/onehot.py new file mode 100644 index 00000000..bc4c979f --- /dev/null +++ b/multimolecule/module/backbones/sequences/onehot.py @@ -0,0 +1,39 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import torch +from chanfig import FlatDict +from torch import nn +from transformers import AutoConfig + +from .registry import SequenceRegistry + + +@SequenceRegistry.register("onehot") +class OneHot(nn.Module): + def __init__(self, pretrained: str) -> None: + super().__init__() + self.config = AutoConfig.from_pretrained(str(pretrained)) + self.module = nn.Embedding(self.config.vocab_size, self.config.hidden_size) + + def forward(self, input_ids, attn_mask) -> FlatDict: + output = FlatDict() + output["last_hidden_state"] = self.module(input_ids) + valid_length = attn_mask.sum(dim=1) + output["pooler_output"] = torch.stack( + [t[: valid_length[i]].sum(0) for i, t in enumerate(output["last_hidden_state"])] + ) + return output diff --git a/multimolecule/module/backbones/sequences/registry.py b/multimolecule/module/backbones/sequences/registry.py new file mode 100644 index 00000000..c9178231 --- /dev/null +++ b/multimolecule/module/backbones/sequences/registry.py @@ -0,0 +1,66 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from __future__ import annotations + +import danling as dl +import transformers +from chanfig import Registry as Registry_ +from torch import nn +from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel + + +class Registry(Registry_): # pylint: disable=too-few-public-methods + def build( + self, + type: str | None = None, + name: str | None = None, + use_pretrained: bool = True, + gradient_checkpoint: bool = False, + checkpoint: str | None = None, + *args, + **kwargs, + ) -> nn.Module: + if type is not None: + if type in self: + sequence_cls = self.lookup(type) + sequence = self.init(sequence_cls, *args, **kwargs) + if checkpoint is not None: + sequence.load_state_dict(dl.load(checkpoint)) + elif hasattr(transformers, type + "Model"): + if use_pretrained: + sequence_cls: PreTrainedModel = getattr(transformers, type + "Model") # type: ignore[no-redef] + sequence = sequence_cls.from_pretrained(name, *args, **kwargs) + else: + config_cls: PretrainedConfig = getattr(transformers, type + "Config") + config, kwargs = config_cls.from_pretrained(name, return_unused_kwargs=True, **kwargs) + sequence_cls: PreTrainedModel = getattr(transformers, type + "Model") # type: ignore[no-redef] + sequence = sequence_cls.from_config(config, *args, **kwargs) + else: + raise ValueError(f"Sequence {type} not found in registry or transformers") + else: + if use_pretrained: + sequence = AutoModel.from_pretrained(name, *args, **kwargs) + else: + config, kwargs = AutoConfig.from_pretrained(name, return_unused_kwargs=True, **kwargs) + sequence = AutoModel.from_config(config, *args, **kwargs) + + if gradient_checkpoint: + sequence.gradient_checkpointing_enable() + return sequence + + +SequenceRegistry = Registry() diff --git a/multimolecule/module/criterions/__init__.py b/multimolecule/module/criterions/__init__.py index 104334b5..4b9adf7e 100644 --- a/multimolecule/module/criterions/__init__.py +++ b/multimolecule/module/criterions/__init__.py @@ -14,6 +14,18 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +from .binary import BCEWithLogitsLoss from .generic import Criterion +from .multiclass import CrossEntropyLoss +from .multilabel import MultiLabelSoftMarginLoss +from .registry import CriterionRegistry +from .regression import MSELoss -__all__ = ["Criterion"] +__all__ = [ + "CriterionRegistry", + "Criterion", + "MSELoss", + "BCEWithLogitsLoss", + "CrossEntropyLoss", + "MultiLabelSoftMarginLoss", +] diff --git a/multimolecule/module/criterions/binary.py b/multimolecule/module/criterions/binary.py new file mode 100644 index 00000000..0bf53e59 --- /dev/null +++ b/multimolecule/module/criterions/binary.py @@ -0,0 +1,44 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +from danling import NestedTensor +from torch import Tensor, nn + +from .registry import CriterionRegistry + +if TYPE_CHECKING: + from ..heads.config import HeadConfig + + +@CriterionRegistry.register("binary") +class BCEWithLogitsLoss(nn.BCEWithLogitsLoss): + def __init__(self, config: HeadConfig) -> None: + super().__init__(**config.get("loss", {})) + self.config = config + + def forward(self, input: NestedTensor | Tensor, target: NestedTensor | Tensor) -> Tensor: + if isinstance(input, NestedTensor): + input = torch.cat(input.flatten().storage()) + if isinstance(target, NestedTensor): + target = torch.cat(target.flatten().storage()) + if input.ndim == target.ndim + 1: + input = input.squeeze(-1) + return super().forward(input, target.float()) diff --git a/multimolecule/module/criterions/generic.py b/multimolecule/module/criterions/generic.py index b003c81d..a6731933 100644 --- a/multimolecule/module/criterions/generic.py +++ b/multimolecule/module/criterions/generic.py @@ -17,8 +17,8 @@ from __future__ import annotations from typing import TYPE_CHECKING +from warnings import warn -import torch from danling import NestedTensor from torch import Tensor, nn from torch.nn import functional as F @@ -26,10 +26,13 @@ if TYPE_CHECKING: from ..heads.config import HeadConfig +from .registry import CriterionRegistry + +@CriterionRegistry.register(default=True) class Criterion(nn.Module): - problem_types = ["regression", "single_label_classification", "multi_label_classification"] + problem_types = ["regression", "binary", "multiclass", "multilabel"] def __init__(self, config: HeadConfig) -> None: super().__init__() @@ -41,21 +44,31 @@ def forward(self, logits: Tensor | NestedTensor, labels: Tensor | NestedTensor) if labels is None: return None if self.problem_type is None: - if self.num_labels == 1: + if labels.is_floating_point(): self.problem_type = "regression" - elif self.num_labels > 1 and labels.dtype in (torch.long, torch.int): - self.problem_type = "single_label_classification" + elif self.num_labels == 1: + self.problem_type = "binary" + elif labels.unique().numel() == 2: + self.problem_type = "multilabel" else: - self.problem_type = "multi_label_classification" + self.problem_type = "multiclass" + warn( + f"`problem_type` is not set. Assuming {self.problem_type}. \n" + "This can lead to unexpected behavior. Please set `problem_type` explicitly." + ) self.config.problem_type = self.problem_type if self.problem_type == "regression": labels = labels.to(logits.dtype) if self.num_labels == 1: return F.mse_loss(logits.squeeze(), labels.squeeze()) logits, labels = logits.view(-1, self.num_labels), labels.view(-1, self.num_labels) - return sum(F.mse_loss(logits[:, i], labels[:, i]).sqrt() for i in range(self.num_labels)) - if self.problem_type == "single_label_classification": + return sum(F.mse_loss(logits[:, i], labels[:, i]).sqrt() for i in range(self.num_labels)) # type: ignore + if self.problem_type == "multiclass": return F.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1)) - if self.problem_type == "multi_label_classification": - return F.binary_cross_entropy_with_logits(logits, labels) + if self.problem_type == "binary": + if logits.ndim == labels.ndim + 1: + logits = logits.squeeze(-1) + return F.binary_cross_entropy_with_logits(logits, labels.to(logits.dtype)) + if self.problem_type == "multilabel": + return F.multilabel_soft_margin_loss(logits, labels.to(logits.dtype)) raise ValueError(f"problem_type should be one of {self.problem_types}, but got {self.problem_type}") diff --git a/multimolecule/module/criterions/multiclass.py b/multimolecule/module/criterions/multiclass.py new file mode 100644 index 00000000..f7070e94 --- /dev/null +++ b/multimolecule/module/criterions/multiclass.py @@ -0,0 +1,44 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +from danling import NestedTensor +from torch import Tensor, nn + +if TYPE_CHECKING: + from ..heads.config import HeadConfig + +from .registry import CriterionRegistry + + +@CriterionRegistry.register("multiclass") +class CrossEntropyLoss(nn.CrossEntropyLoss): + def __init__(self, config: HeadConfig) -> None: + super().__init__(**config.get("loss", {})) + self.config = config + + def forward(self, input: NestedTensor | Tensor, target: NestedTensor | Tensor) -> Tensor: + if isinstance(input, NestedTensor): + input = torch.cat(input.storage()) + if isinstance(target, NestedTensor): + target = torch.cat(target.storage()) + if input.ndim > 2: + input, target = input.view(-1, input.size(-1)), target.view(-1) + return super().forward(input, target.long()) diff --git a/multimolecule/module/criterions/multilabel.py b/multimolecule/module/criterions/multilabel.py new file mode 100644 index 00000000..c72bb9f9 --- /dev/null +++ b/multimolecule/module/criterions/multilabel.py @@ -0,0 +1,44 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +from danling import NestedTensor +from torch import Tensor, nn + +if TYPE_CHECKING: + from ..heads.config import HeadConfig + +from .registry import CriterionRegistry + + +@CriterionRegistry.register("multilabel") +class MultiLabelSoftMarginLoss(nn.MultiLabelSoftMarginLoss): + def __init__(self, config: HeadConfig) -> None: + super().__init__(**config.get("loss", {})) + self.config = config + + def forward(self, input: NestedTensor | Tensor, target: NestedTensor | Tensor) -> Tensor: + if isinstance(target, NestedTensor) and target.ndim > 2: + input, target = input.view(-1, input.size(-1)), target.view(-1, target.size(-1)) + if isinstance(input, NestedTensor): + input = torch.cat(input.storage()) + if isinstance(target, NestedTensor): + target = torch.cat(target.storage()) + return super().forward(input, target.float()) diff --git a/multimolecule/module/criterions/registry.py b/multimolecule/module/criterions/registry.py new file mode 100644 index 00000000..856341f7 --- /dev/null +++ b/multimolecule/module/criterions/registry.py @@ -0,0 +1,29 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from chanfig import ConfigRegistry as Registry_ +from torch import nn + + +class Registry(Registry_): # pylint: disable=too-few-public-methods + key = "problem_type" + + def build(self, config) -> nn.Module: # type: ignore[override] + name = getattr(config, self.getattr("key")) + return self.init(self.lookup(name), config) # type: ignore[arg-type] + + +CriterionRegistry = Registry(fallback=True) diff --git a/multimolecule/module/criterions/regression.py b/multimolecule/module/criterions/regression.py new file mode 100644 index 00000000..4f39e0eb --- /dev/null +++ b/multimolecule/module/criterions/regression.py @@ -0,0 +1,44 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +from danling import NestedTensor +from torch import Tensor, nn + +if TYPE_CHECKING: + from ..heads.config import HeadConfig + +from .registry import CriterionRegistry + + +@CriterionRegistry.register("regression") +class MSELoss(nn.MSELoss): + def __init__(self, config: HeadConfig) -> None: + super().__init__(**config.get("loss", {})) + self.config = config + + def forward(self, input: NestedTensor | Tensor, target: NestedTensor | Tensor) -> Tensor: + if isinstance(input, NestedTensor): + input = torch.cat(input.flatten().storage()) + if isinstance(target, NestedTensor): + target = torch.cat(target.flatten().storage()) + if input.ndim == target.ndim + 1: + target = target.unsqueeze(-1) + return super().forward(input, target.to(input.dtype)) diff --git a/multimolecule/module/heads/__init__.py b/multimolecule/module/heads/__init__.py index 8cc91f29..0e857c5e 100644 --- a/multimolecule/module/heads/__init__.py +++ b/multimolecule/module/heads/__init__.py @@ -21,14 +21,8 @@ from .pretrain import MaskedLMHead from .registry import HeadRegistry from .sequence import SequencePredictionHead -from .token import TokenHeadRegistryHF, TokenKMerHead, TokenPredictionHead -from .transform import ( - HeadTransformRegistry, - HeadTransformRegistryHF, - IdentityTransform, - LinearTransform, - NonLinearTransform, -) +from .token import TokenKMerHead, TokenPredictionHead +from .transform import HeadTransformRegistry, IdentityTransform, LinearTransform, NonLinearTransform __all__ = [ "BaseHeadConfig", @@ -37,14 +31,12 @@ "HeadRegistry", "PredictionHead", "SequencePredictionHead", - "TokenHeadRegistryHF", "TokenPredictionHead", "TokenKMerHead", "ContactPredictionHead", "MaskedLMHead", "HeadOutput", "HeadTransformRegistry", - "HeadTransformRegistryHF", "LinearTransform", "NonLinearTransform", "IdentityTransform", diff --git a/multimolecule/module/heads/config.py b/multimolecule/module/heads/config.py index bb0dbba6..3b9ee64b 100644 --- a/multimolecule/module/heads/config.py +++ b/multimolecule/module/heads/config.py @@ -16,15 +16,13 @@ from __future__ import annotations -from collections import OrderedDict -from dataclasses import dataclass +from chanfig import FlatDict -class BaseHeadConfig(OrderedDict): +class BaseHeadConfig(FlatDict): pass -@dataclass class HeadConfig(BaseHeadConfig): r""" Configuration class for a prediction head. @@ -35,8 +33,8 @@ class HeadConfig(BaseHeadConfig): Head should look for [`Config.num_labels`][multimolecule.PreTrainedConfig] if is `None`. problem_type: - Problem type for `XxxForYyyPrediction` models. Can be one of `"regression"`, - `"single_label_classification"` or `"multi_label_classification"`. + Problem type for `XxxForYyyPrediction` models. Can be one of `"binary"`, `"regression"`, + `"multiclass"` or `"multilabel"`. Head should look for [`Config.problem_type`][multimolecule.PreTrainedConfig] if is `None`. hidden_size: @@ -55,14 +53,18 @@ class HeadConfig(BaseHeadConfig): The activation function of the final prediction output. layer_norm_eps: The epsilon used by the layer normalization layers. - output_name (`str`, *optional*): + output_name: The name of the tensor required in model outputs. If is `None`, will use the default output name of the corresponding head. + type: + The type of the head in the model. + + This is used by [`MultiMoleculeModel`][multimolecule.MultiMoleculeModel] to construct heads. """ - num_labels: int = None # type: ignore[assignment] - problem_type: str = None # type: ignore[assignment] + num_labels: int | None = None + problem_type: str | None = None hidden_size: int | None = None dropout: float = 0.0 transform: str | None = None @@ -71,9 +73,9 @@ class HeadConfig(BaseHeadConfig): act: str | None = None layer_norm_eps: float = 1e-12 output_name: str | None = None + type: str | None = None -@dataclass class MaskedLMHeadConfig(BaseHeadConfig): r""" Configuration class for a Masked Language Modeling head. @@ -95,7 +97,7 @@ class MaskedLMHeadConfig(BaseHeadConfig): The activation function of the final prediction output. layer_norm_eps: The epsilon used by the layer normalization layers. - output_name (`str`, *optional*): + output_name: The name of the tensor required in model outputs. If is `None`, will use the default output name of the corresponding head. diff --git a/multimolecule/module/heads/contact.py b/multimolecule/module/heads/contact.py index 50ec4fbb..cdef94d4 100644 --- a/multimolecule/module/heads/contact.py +++ b/multimolecule/module/heads/contact.py @@ -16,10 +16,12 @@ from __future__ import annotations -from typing import Mapping, Tuple +from typing import Callable, Mapping, Tuple, Type import torch from danling import NestedTensor +from danling.modules import MLP +from lazy_imports import try_import from torch import Tensor, nn from transformers.modeling_outputs import ModelOutput from typing_extensions import TYPE_CHECKING @@ -28,13 +30,55 @@ from .generic import PredictionHead from .output import HeadOutput from .registry import HeadRegistry -from .utils import average_product_correct, symmetrize + +with try_import() as tv: + from torchvision.models.resnet import BasicBlock, Bottleneck if TYPE_CHECKING: from multimolecule.models import PreTrainedConfig -@HeadRegistry.register("contact") +@HeadRegistry.contact.register("default", default=True) +class ContactHead(PredictionHead): + + output_name: str = "last_hidden_state" + + def __init__(self, config: PreTrainedConfig, head_config: HeadConfig | None = None): + super().__init__(config, head_config) + out_channels: int = self.config.hidden_size # type: ignore[assignment] + self.qk_proj = nn.Linear(out_channels, 2 * out_channels) + self.ffn = MLP(1, out_channels, residual=False) + + def forward( # type: ignore[override] # pylint: disable=arguments-renamed + self, + outputs: ModelOutput | Mapping | Tuple[Tensor, ...], + attention_mask: Tensor | None = None, + input_ids: NestedTensor | Tensor | None = None, + labels: Tensor | None = None, + output_name: str | None = None, + **kwargs, + ) -> HeadOutput: + if isinstance(outputs, (Mapping, ModelOutput)): + output = outputs[output_name or self.output_name] + elif isinstance(outputs, tuple): + output = outputs[0] + else: + raise ValueError(f"Unsupported type for outputs: {type(outputs)}") + + if attention_mask is None: + attention_mask = self._get_attention_mask(input_ids) + output = output * attention_mask.unsqueeze(-1) + output, _, _ = self._remove_special_tokens(output, attention_mask, input_ids) + + q, k = self.qk_proj(output).chunk(2, dim=-1) + contact_map = (q @ k.transpose(-2, -1)).unsqueeze(-1) + contact_map = contact_map + self.ffn(contact_map) + if "continuous" in outputs: + contact_map = contact_map * (1 + outputs["continuous"].unsqueeze(dim=-1)) # type: ignore[call-overload] + return super().forward(contact_map, labels) + + +@HeadRegistry.contact.register("attention") class ContactPredictionHead(PredictionHead): r""" Head for tasks in contact-level. @@ -50,13 +94,20 @@ class ContactPredictionHead(PredictionHead): output_name: str = "attentions" r"""The default output to use for the head.""" + requires_attention: bool = True + def __init__(self, config: PreTrainedConfig, head_config: HeadConfig | None = None): super().__init__(config, head_config) - self.bos_token_id = config.bos_token_id - self.eos_token_id = config.eos_token_id - self.pad_token_id = config.pad_token_id - self.decoder = nn.Linear( - config.num_hidden_layers * config.num_attention_heads, self.num_labels, bias=self.config.bias + self.config.hidden_size = config.num_hidden_layers * config.num_attention_heads + num_layers = self.config.get("num_layers", 16) + num_channels = self.config.get("num_channels", self.config.hidden_size // 10) # type: ignore[operator] + block = self.config.get("block", "auto") + self.decoder = ResNet( + num_layers=num_layers, + hidden_size=self.config.hidden_size, # type: ignore[arg-type] + block=block, + num_channels=num_channels, + num_labels=self.num_labels, ) if head_config is not None and head_config.output_name is not None: self.output_name = head_config.output_name @@ -81,19 +132,6 @@ def forward( # type: ignore[override] # pylint: disable=arguments-renamed output_name: The name of the output to use. Defaults to `self.output_name`. """ - if attention_mask is None: - if isinstance(input_ids, NestedTensor): - input_ids, attention_mask = input_ids.tensor, input_ids.mask - else: - if input_ids is None: - raise ValueError( - f"Either attention_mask or input_ids must be provided for {self.__class__.__name__} to work." - ) - if self.pad_token_id is None: - raise ValueError( - f"pad_token_id must be provided when attention_mask is not passed to {self.__class__.__name__}." - ) - attention_mask = input_ids.ne(self.pad_token_id) if isinstance(outputs, (Mapping, ModelOutput)): output = outputs[output_name or self.output_name] @@ -105,13 +143,14 @@ def forward( # type: ignore[override] # pylint: disable=arguments-renamed # This makes no difference most of the time because the other tokens won't attend to them, # but it does for the contact prediction task, which takes attentions as input, # so we have to mimic that here. + if attention_mask is None: + attention_mask = self._get_attention_mask(input_ids) attention_mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2) - attentions *= attention_mask[:, None, None, :, :] + attentions = attentions * attention_mask[:, None, None, :, :] # remove cls token attentions if self.bos_token_id is not None: attentions = attentions[..., 1:, 1:] - # process attention_mask and input_ids to make removal of eos token happy attention_mask = attention_mask[..., 1:] if input_ids is not None: input_ids = input_ids[..., 1:] @@ -124,14 +163,172 @@ def forward( # type: ignore[override] # pylint: disable=arguments-renamed seq_length = attention_mask.size(-1) eos_mask = torch.arange(seq_length, device=attentions.device).unsqueeze(0) == last_valid_indices eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2) - attentions *= eos_mask[:, None, None, :, :] + attentions = attentions * eos_mask[:, None, None, :, :] attentions = attentions[..., :-1, :-1] # features: batch x channels x input_ids x input_ids (symmetric) batch_size, layers, heads, seqlen, _ = attentions.size() attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen) - attentions = attentions.to(self.decoder.weight.device) + attentions = attentions.to(self.decoder.proj.weight.device) attentions = average_product_correct(symmetrize(attentions)) attentions = attentions.permute(0, 2, 3, 1).squeeze(3) return super().forward(attentions, labels, **kwargs) + + +@HeadRegistry.contact.register("logits") +class ContactLogitsHead(PredictionHead): + r""" + Head for tasks in contact-level. + + Performs symmetrization, and average product correct. + + Args: + config: The configuration object for the model. + head_config: The configuration object for the head. + If None, will use configuration from the `config`. + """ + + output_name: str = "last_hidden_state" + r"""The default output to use for the head.""" + + requires_attention: bool = False + + def __init__(self, config: PreTrainedConfig, head_config: HeadConfig | None = None): + super().__init__(config, head_config) + num_layers = self.config.get("num_layers", 16) + num_channels = self.config.get("num_channels", self.config.hidden_size // 10) # type: ignore[operator] + block = self.config.get("block", "auto") + self.decoder = ResNet( + num_layers=num_layers, + hidden_size=self.config.hidden_size, # type: ignore[arg-type] + block=block, + num_channels=num_channels, + num_labels=self.num_labels, + ) + if head_config is not None and head_config.output_name is not None: + self.output_name = head_config.output_name + + def forward( # type: ignore[override] # pylint: disable=arguments-renamed + self, + outputs: ModelOutput | Mapping | Tuple[Tensor, ...], + attention_mask: Tensor | None = None, + input_ids: NestedTensor | Tensor | None = None, + labels: Tensor | None = None, + output_name: str | None = None, + **kwargs, + ) -> HeadOutput: + r""" + Forward pass of the ContactPredictionHead. + + Args: + outputs: The outputs of the model. + attention_mask: The attention mask for the inputs. + input_ids: The input ids for the inputs. + labels: The labels for the head. + output_name: The name of the output to use. + Defaults to `self.output_name`. + """ + if isinstance(outputs, (Mapping, ModelOutput)): + output = outputs[output_name or self.output_name] + elif isinstance(outputs, tuple): + output = outputs[0] + else: + raise ValueError(f"Unsupported type for outputs: {type(outputs)}") + + if attention_mask is None: + attention_mask = self._get_attention_mask(input_ids) + output = output * attention_mask.unsqueeze(-1) + output, _, _ = self._remove_special_tokens(output, attention_mask, input_ids) + + # make symmetric contact map + contact_map = output.unsqueeze(1) * output.unsqueeze(2) + + return super().forward(contact_map, labels, **kwargs) + + +class ResNet(nn.Module): + def __init__( + self, + num_layers: int, + hidden_size: int, + block: Type[BasicBlock | Bottleneck] | str = "auto", + num_channels: int | None = None, + num_labels: int = 1, + norm_layer: Callable[..., nn.Module] | None = None, + zero_init_residual: bool = True, + ) -> None: + tv.check() + super().__init__() + + if block == "auto": + block = BasicBlock if num_layers < 50 else Bottleneck + elif block in ("basic", "BasicBlock"): + block = BasicBlock + elif block in ("bottleneck", "Bottleneck"): + block = Bottleneck + else: + raise ValueError(f"Unknown block type: {block}") + if num_channels is None: + num_channels = hidden_size // 10 + if norm_layer is None: + norm_layer = LayerNorm2D + + self.proj = nn.Conv2d(hidden_size, num_channels, kernel_size=3, stride=1, padding=1, bias=False) + self.norm = norm_layer(num_channels) + self.relu = nn.ReLU(inplace=True) + self.layers = nn.Sequential( + *[block(num_channels, num_channels, norm_layer=norm_layer) for _ in range(num_layers)] # type: ignore + ) + self.output = nn.Linear(num_channels, num_labels) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck) and m.bn3.weight is not None: + nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] + elif isinstance(m, BasicBlock) and m.bn2.weight is not None: + nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] + + def forward(self, x: Tensor) -> Tensor: + x = self.proj(x.transpose(1, 3)) + x = self.norm(x) + x = self.relu(x) + x = self.layers(x) + x = self.output(x.transpose(1, 3)) + return x + + +class LayerNorm2D(nn.GroupNorm): + def __init__(self, num_features: int, eps: float = 1e-5, elementwise_affine: bool = True) -> None: + super().__init__(num_channels=num_features, eps=eps, affine=elementwise_affine, num_groups=1) + self.num_channels = num_features + + def __repr__(self): + return f"{self.__class__.__name__}(num_channels={self.num_channels}, eps={self.eps}, affine={self.affine})" + + +def symmetrize(x): + "Make layer symmetric in final two dimensions, used for contact prediction." + return x + x.transpose(-1, -2) + + +def average_product_correct(x): + "Perform average product correct, used for contact prediction." + a1 = x.sum(-1, keepdims=True) + a2 = x.sum(-2, keepdims=True) + a12 = x.sum((-1, -2), keepdims=True) + + avg = a1 * a2 + avg.div_(a12) # in-place to reduce memory + normalized = x - avg + return normalized diff --git a/multimolecule/module/heads/generic.py b/multimolecule/module/heads/generic.py index d97950a2..ae82e178 100644 --- a/multimolecule/module/heads/generic.py +++ b/multimolecule/module/heads/generic.py @@ -16,7 +16,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Tuple from warnings import warn import torch @@ -24,7 +24,7 @@ from torch import Tensor, nn from transformers.activations import ACT2FN -from ..criterions import Criterion +from ..criterions import CriterionRegistry from .config import HeadConfig from .output import HeadOutput from .transform import HeadTransformRegistryHF @@ -44,24 +44,28 @@ class PredictionHead(nn.Module): """ num_labels: int + requires_attention: bool = False def __init__(self, config: PreTrainedConfig, head_config: HeadConfig | None = None): super().__init__() if head_config is None: - head_config = config.head + head_config = config.head or HeadConfig(num_labels=config.num_labels) + elif head_config.num_labels is None: + head_config.num_labels = config.num_labels self.config = head_config if self.config.hidden_size is None: self.config.hidden_size = config.hidden_size - if self.config.num_labels is None: - self.config.num_labels = config.num_labels if self.config.problem_type is None: self.config.problem_type = config.problem_type - self.num_labels = self.config.num_labels + self.bos_token_id = config.bos_token_id + self.eos_token_id = config.eos_token_id + self.pad_token_id = config.pad_token_id + self.num_labels = self.config.num_labels # type: ignore[assignment] self.dropout = nn.Dropout(self.config.dropout) self.transform = HeadTransformRegistryHF.build(self.config) - self.decoder = nn.Linear(config.hidden_size, self.num_labels, bias=self.config.bias) + self.decoder = nn.Linear(self.config.hidden_size, self.num_labels, bias=self.config.bias) self.activation = ACT2FN[self.config.act] if self.config.act is not None else None - self.criterion = Criterion(self.config) + self.criterion = CriterionRegistry.build(self.config) def forward(self, embeddings: Tensor, labels: Tensor | None, **kwargs) -> HeadOutput: r""" @@ -85,6 +89,42 @@ def forward(self, embeddings: Tensor, labels: Tensor | None, **kwargs) -> HeadOu if isinstance(labels, NestedTensor): if isinstance(output, Tensor): output = labels.nested_like(output, strict=False) - return HeadOutput(output, self.criterion(torch.cat(output.storage()), torch.cat(labels.storage()))) + return HeadOutput(output, self.criterion(output.concat, labels.concat)) return HeadOutput(output, self.criterion(output, labels)) return HeadOutput(output) + + def _get_attention_mask(self, input_ids: NestedTensor | Tensor) -> Tensor: + if isinstance(input_ids, NestedTensor): + return input_ids.mask + if input_ids is None: + raise ValueError( + f"Either attention_mask or input_ids must be provided for {self.__class__.__name__} to work." + ) + if self.pad_token_id is None: + raise ValueError( + f"pad_token_id must be provided when attention_mask is not passed to {self.__class__.__name__}." + ) + return input_ids.ne(self.pad_token_id) + + def _remove_special_tokens( + self, output: Tensor, attention_mask: Tensor, input_ids: Tensor | None + ) -> Tuple[Tensor, Tensor, Tensor]: + # remove cls token embeddings + if self.bos_token_id is not None: + output = output[..., 1:, :] + attention_mask = attention_mask[..., 1:] + if input_ids is not None: + input_ids = input_ids[..., 1:] + # remove eos token embeddings + if self.eos_token_id is not None: + if input_ids is not None: + eos_mask = input_ids.ne(self.eos_token_id).to(output) + input_ids = input_ids[..., :-1] + else: + last_valid_indices = attention_mask.sum(dim=-1) + seq_length = attention_mask.size(-1) + eos_mask = torch.arange(seq_length, device=output.device) == last_valid_indices.unsqueeze(1) + output = output * eos_mask[:, :, None] + output = output[..., :-1, :] + attention_mask = attention_mask[..., 1:] + return output, attention_mask, input_ids diff --git a/multimolecule/module/heads/pretrain.py b/multimolecule/module/heads/pretrain.py index 994cb8ca..c6968c4b 100644 --- a/multimolecule/module/heads/pretrain.py +++ b/multimolecule/module/heads/pretrain.py @@ -53,8 +53,8 @@ def __init__( ): super().__init__() if head_config is None: - head_config = config.lm_head if hasattr(config, "lm_head") else config.head # type: ignore[assignment] - self.config: MaskedLMHeadConfig = head_config # type: ignore[assignment] + head_config = (config.lm_head if hasattr(config, "lm_head") else config.head) or MaskedLMHeadConfig() + self.config: MaskedLMHeadConfig = head_config if self.config.hidden_size is None: self.config.hidden_size = config.hidden_size self.num_labels = config.vocab_size @@ -97,6 +97,6 @@ def forward( if isinstance(labels, NestedTensor): if isinstance(output, Tensor): output = labels.nested_like(output, strict=False) - return HeadOutput(output, F.cross_entropy(torch.cat(output.storage()), torch.cat(labels.storage()))) + return HeadOutput(output, F.cross_entropy(output.concat, labels.concat)) return HeadOutput(output, F.cross_entropy(output.view(-1, self.num_labels), labels.view(-1))) return HeadOutput(output) diff --git a/multimolecule/module/heads/registry.py b/multimolecule/module/heads/registry.py index e5393e4e..6db3b680 100644 --- a/multimolecule/module/heads/registry.py +++ b/multimolecule/module/heads/registry.py @@ -14,6 +14,16 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from chanfig import Registry +from chanfig import ConfigRegistry as Registry_ +from torch import nn + + +class Registry(Registry_): # pylint: disable=too-few-public-methods + key = "type" + + def build(self, config, head_config) -> nn.Module: # type: ignore[override] + name = getattr(head_config, self.getattr("key")) + return self.init(self.lookup(name), config, head_config) # type: ignore[arg-type] + HeadRegistry = Registry(default_factory=Registry, fallback=True) diff --git a/multimolecule/module/heads/token.py b/multimolecule/module/heads/token.py index dbe6c721..5697d36c 100644 --- a/multimolecule/module/heads/token.py +++ b/multimolecule/module/heads/token.py @@ -19,7 +19,6 @@ from functools import partial from typing import TYPE_CHECKING, Mapping, Tuple -import torch from chanfig import ConfigRegistry from danling import NestedTensor from torch import Tensor @@ -54,9 +53,6 @@ class TokenPredictionHead(PredictionHead): def __init__(self, config: PreTrainedConfig, head_config: HeadConfig | None = None): super().__init__(config, head_config) - self.bos_token_id = config.bos_token_id - self.eos_token_id = config.eos_token_id - self.pad_token_id = config.pad_token_id if head_config is not None and head_config.output_name is not None: self.output_name = head_config.output_name @@ -80,45 +76,17 @@ def forward( # type: ignore[override] # pylint: disable=arguments-renamed output_name: The name of the output to use. Defaults to `self.output_name`. """ - if attention_mask is None: - if isinstance(input_ids, NestedTensor): - input_ids, attention_mask = input_ids.tensor, input_ids.mask - else: - if input_ids is None: - raise ValueError( - f"Either attention_mask or input_ids must be provided for {self.__class__.__name__} to work." - ) - if self.pad_token_id is None: - raise ValueError( - f"pad_token_id must be provided when attention_mask is not passed to {self.__class__.__name__}." - ) - attention_mask = input_ids.ne(self.pad_token_id) - if isinstance(outputs, (Mapping, ModelOutput)): output = outputs[output_name or self.output_name] elif isinstance(outputs, tuple): output = outputs[0] else: raise ValueError(f"Unsupported type for outputs: {type(outputs)}") - output = output * attention_mask.unsqueeze(-1) - # remove cls token embeddings - if self.bos_token_id is not None: - output = output[..., 1:, :] - # process attention_mask and input_ids to make removal of eos token happy - attention_mask = attention_mask[..., 1:] - if input_ids is not None: - input_ids = input_ids[..., 1:] - # remove eos token embeddings - if self.eos_token_id is not None: - if input_ids is not None: - eos_mask = input_ids.ne(self.eos_token_id).to(output) - else: - last_valid_indices = attention_mask.sum(dim=-1) - seq_length = attention_mask.size(-1) - eos_mask = torch.arange(seq_length, device=output.device) == last_valid_indices.unsqueeze(1) - output = output * eos_mask[:, :, None] - output = output[..., :-1, :] + if attention_mask is None: + attention_mask = self._get_attention_mask(input_ids) + output = output * attention_mask.unsqueeze(-1) + output, _, _ = self._remove_special_tokens(output, attention_mask, input_ids) return super().forward(output, labels, **kwargs) @@ -141,9 +109,6 @@ class TokenKMerHead(PredictionHead): def __init__(self, config: PreTrainedConfig, head_config: HeadConfig | None = None): super().__init__(config, head_config) self.nmers = config.nmers - self.bos_token_id = config.bos_token_id - self.eos_token_id = config.eos_token_id - self.pad_token_id = config.pad_token_id if head_config is not None and head_config.output_name is not None: self.output_name = head_config.output_name # Do not pass bos_token_id and eos_token_id to unfold_kmer_embeddings @@ -170,46 +135,17 @@ def forward( # type: ignore[override] # pylint: disable=arguments-renamed output_name: The name of the output to use. Defaults to `self.output_name`. """ - if attention_mask is None: - if isinstance(input_ids, NestedTensor): - input_ids, attention_mask = input_ids.tensor, input_ids.mask - else: - if input_ids is None: - raise ValueError( - f"Either attention_mask or input_ids must be provided for {self.__class__.__name__} to work." - ) - if self.pad_token_id is None: - raise ValueError( - f"pad_token_id must be provided when attention_mask is not passed to {self.__class__.__name__}." - ) - attention_mask = input_ids.ne(self.pad_token_id) - if isinstance(outputs, (Mapping, ModelOutput)): output = outputs[output_name or self.output_name] elif isinstance(outputs, tuple): output = outputs[0] else: raise ValueError(f"Unsupported type for outputs: {type(outputs)}") - output = output * attention_mask.unsqueeze(-1) - # remove cls token embeddings - if self.bos_token_id is not None: - output = output[..., 1:, :] - attention_mask = attention_mask[..., 1:] - if input_ids is not None: - input_ids = input_ids[..., 1:] - # remove eos token embeddings - if self.eos_token_id is not None: - if input_ids is not None: - eos_mask = input_ids.ne(self.eos_token_id).to(output) - input_ids = input_ids[..., :-1] - else: - last_valid_indices = attention_mask.sum(dim=-1) - seq_length = attention_mask.size(-1) - eos_mask = torch.arange(seq_length, device=output.device) == last_valid_indices.unsqueeze(1) - output = output * eos_mask[:, :, None] - output = output[..., :-1, :] - attention_mask = attention_mask[..., 1:] + if attention_mask is None: + attention_mask = self._get_attention_mask(input_ids) + output = output * attention_mask.unsqueeze(-1) + output, attention_mask, _ = self._remove_special_tokens(output, attention_mask, input_ids) output = self.unfold_kmer_embeddings(output, attention_mask) return super().forward(output, labels, **kwargs) diff --git a/multimolecule/module/heads/utils.py b/multimolecule/module/heads/utils.py index c5937c6d..cc1f3654 100644 --- a/multimolecule/module/heads/utils.py +++ b/multimolecule/module/heads/utils.py @@ -119,32 +119,3 @@ def unfold_kmer_embeddings( embedding = torch.cat([embedding, tensor[seq_len - 1][None, :]]) output[index, : seq_len + nmers - 1] = embedding return output - - -def rotate_half(x): - x1, x2 = x.chunk(2, dim=-1) - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(x, cos, sin): - cos = cos[:, :, : x.shape[-2], :] - sin = sin[:, :, : x.shape[-2], :] - - return (x * cos) + (rotate_half(x) * sin) - - -def symmetrize(x): - "Make layer symmetric in final two dimensions, used for contact prediction." - return x + x.transpose(-1, -2) - - -def average_product_correct(x): - "Perform average product correct, used for contact prediction." - a1 = x.sum(-1, keepdims=True) - a2 = x.sum(-2, keepdims=True) - a12 = x.sum((-1, -2), keepdims=True) - - avg = a1 * a2 - avg.div_(a12) # in-place to reduce memory - normalized = x - avg - return normalized diff --git a/multimolecule/module/model.py b/multimolecule/module/model.py new file mode 100644 index 00000000..256783be --- /dev/null +++ b/multimolecule/module/model.py @@ -0,0 +1,89 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from __future__ import annotations + +from chanfig import FlatDict +from danling import NestedTensor +from torch import Tensor, nn + +from .backbones import BackboneRegistry +from .heads import HeadRegistry +from .necks import NeckRegistry +from .registry import ModelRegistry + + +@ModelRegistry.register(default=True) +class MultiMoleculeModel(nn.Module): + def __init__( + self, + backbone: dict, + heads: dict, + neck: dict | None = None, + max_length: int = 1024, + truncation: bool = False, + ): + super().__init__() + + # Backbone + self.backbone = BackboneRegistry.build(**backbone) + backbone = self.backbone.config + out_channels = self.backbone.out_channels + + # Neck + if neck: + num_discrete = self.backbone.num_discrete + num_continuous = self.backbone.num_continuous + embed_dim = self.backbone.sequence.config.hidden_size + attention_heads = self.backbone.sequence.config.num_attention_heads + neck.update( + { + "num_discrete": num_discrete, + "num_continuous": num_continuous, + "embed_dim": embed_dim, + "attention_heads": attention_heads, + "max_length": max_length, + "truncation": truncation, + } + ) + self.neck = NeckRegistry.build(**neck) + out_channels = self.neck.out_channels + else: + self.neck = None + + # Heads + for head in heads.values(): + if "hidden_size" not in head or head["hidden_size"] is None: + head["hidden_size"] = out_channels + self.heads = nn.ModuleDict({name: HeadRegistry.build(backbone, head) for name, head in heads.items()}) + if any(getattr(h, "requires_attention", False) for h in self.heads.values()): + self.backbone.sequence.config.output_attentions = True + + def forward( + self, + sequence: NestedTensor | Tensor, + discrete: Tensor | None = None, + continuous: Tensor | None = None, + dataset: str | None = None, + **labels: NestedTensor | Tensor, + ) -> FlatDict: + ret = FlatDict() + output, _ = self.backbone(sequence, discrete, continuous) + if self.neck is not None: + output = self.neck(**output) + for task, label in labels.items(): + ret[task] = self.heads[task](output, input_ids=sequence, labels=label) + return ret diff --git a/multimolecule/module/necks/__init__.py b/multimolecule/module/necks/__init__.py new file mode 100644 index 00000000..e8f1f7e2 --- /dev/null +++ b/multimolecule/module/necks/__init__.py @@ -0,0 +1,21 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from .bert import BERTNeck +from .cat import CatNeck +from .registry import NeckRegistry + +__all__ = ["NeckRegistry", "CatNeck", "BERTNeck"] diff --git a/multimolecule/module/necks/bert.py b/multimolecule/module/necks/bert.py new file mode 100644 index 00000000..1360f0dd --- /dev/null +++ b/multimolecule/module/necks/bert.py @@ -0,0 +1,102 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from __future__ import annotations + +import torch +from chanfig import FlatDict +from danling.modules import TransformerEncoder, TransformerEncoderLayer +from torch import Tensor, nn + +from .registry import NeckRegistry + +MAX_LENGTH = 1024 + + +@NeckRegistry.register("bert") +class BERTNeck(nn.Module): + def __init__( # pylint: disable=keyword-arg-before-vararg + self, + num_discrete: int, + num_continuous: int, + embed_dim: int, + attention_heads: int, + num_layers: int = 6, + max_length: int | None = None, + truncation: bool = False, + dropout: float = 0.1, + *args, + **kwargs, + ): + super().__init__() + self.cls_token_dis = nn.Parameter(torch.zeros(embed_dim)) + self.cls_token_con = nn.Parameter(torch.zeros(embed_dim)) + if max_length is None: + if truncation: + max_length = MAX_LENGTH + 1 + num_discrete + 1 + num_continuous + else: + max_length = MAX_LENGTH * 4 + 1 + num_discrete + 1 + num_continuous + self.max_length = max_length + self.pos_embed = nn.Parameter(torch.zeros(1, self.max_length, embed_dim)) + bert_layer = TransformerEncoderLayer( + embed_dim, attention_heads, *args, dropout=dropout, attn_dropout=dropout, ffn_dropout=dropout, **kwargs + ) + self.bert = TransformerEncoder(bert_layer, num_layers) + self.out_channels = embed_dim + nn.init.normal_(self.pos_embed, std=0.02) + nn.init.trunc_normal_(self.cls_token_dis, std=0.2) + nn.init.trunc_normal_(self.cls_token_con, std=0.2) + + def forward( + self, + cls_token: Tensor | None = None, + all_tokens: Tensor | None = None, + discrete: Tensor | None = None, + continuous: Tensor | None = None, + ) -> FlatDict: + ret = FlatDict() + if cls_token is not None: + ret["cls_token"] = self._forward(cls_token, discrete, continuous) + if all_tokens is not None: + ret["all_tokens"] = self._forward(all_tokens, discrete, continuous) + return ret + + def _forward( + self, + sequence: Tensor, + discrete: Tensor | None = None, + continuous: Tensor | None = None, + ) -> Tensor: + if sequence is None: + raise ValueError("sequence should not be None.") + if sequence.dim() == 2: + sequence = sequence[:, None] + batch_size, seq_len, _ = sequence.shape + output = sequence + if discrete is not None: + cls_token_dis = self.cls_token_dis.expand(batch_size, 1, -1) + output = torch.cat((output, cls_token_dis, discrete), dim=1) + if continuous is not None: + cls_token_con = self.cls_token_con.expand(batch_size, -1)[:, None] + output = torch.cat((output, cls_token_con, continuous), dim=1) + all_len = output.shape[1] + if all_len > self.pos_embed.shape[1]: + raise ValueError("sequence length is out of range.") + output = output + self.pos_embed[:, 0:all_len, :] + output = self.bert(output)[0][:, 0:seq_len, :] + if seq_len == 1: + output = output.squeeze(1) + return output diff --git a/multimolecule/module/necks/cat.py b/multimolecule/module/necks/cat.py new file mode 100644 index 00000000..d5165a92 --- /dev/null +++ b/multimolecule/module/necks/cat.py @@ -0,0 +1,43 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from __future__ import annotations + +import torch +from chanfig import FlatDict +from torch import Tensor + +from .registry import NeckRegistry + + +@NeckRegistry.register("cat") +class CatNeck: # pylint: disable=too-few-public-methods + def __init__(self, embed_dim: int): + self.out_channels = embed_dim * 2 + + def __call__( + self, + cls_token: Tensor | None = None, + all_tokens: Tensor | None = None, + discrete: Tensor | None = None, + continuous: Tensor | None = None, + ) -> FlatDict: + ret = FlatDict() + if cls_token is not None: + ret.cls_token = torch.cat(tuple(i for i in (cls_token, discrete, continuous) if i is not None), -1) + if all_tokens is not None: + ret.all_tokens = torch.cat(tuple(i for i in (all_tokens, discrete, continuous) if i is not None), -1) + return ret diff --git a/multimolecule/module/necks/registry.py b/multimolecule/module/necks/registry.py new file mode 100644 index 00000000..c024227c --- /dev/null +++ b/multimolecule/module/necks/registry.py @@ -0,0 +1,21 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from __future__ import annotations + +from chanfig import Registry + +NeckRegistry = Registry() diff --git a/multimolecule/module/registry.py b/multimolecule/module/registry.py new file mode 100644 index 00000000..b0332463 --- /dev/null +++ b/multimolecule/module/registry.py @@ -0,0 +1,35 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from __future__ import annotations + +from chanfig import Registry as Registry_ +from torch import nn + +from .backbones import BackboneRegistry +from .backbones.sequences import SequenceRegistry +from .heads import HeadRegistry +from .necks import NeckRegistry + + +class Registry(Registry_): # pylint: disable=too-few-public-methods + def build(self, *args, **kwargs) -> nn.Module: + return super().build(*args, **kwargs) + + +ModelRegistry = Registry() + +__all__ = ["ModelRegistry", "BackboneRegistry", "SequenceRegistry", "NeckRegistry", "HeadRegistry"] diff --git a/multimolecule/runners/README.md b/multimolecule/runners/README.md new file mode 100644 index 00000000..bb1000ad --- /dev/null +++ b/multimolecule/runners/README.md @@ -0,0 +1,9 @@ +--- +authors: + - Zhiyuan Chen +date: 2024-05-04 +--- + +# runners + +`runners` provide an easy-to-use interface for running experiments. diff --git a/multimolecule/runners/__init__.py b/multimolecule/runners/__init__.py new file mode 100644 index 00000000..70fa4076 --- /dev/null +++ b/multimolecule/runners/__init__.py @@ -0,0 +1,20 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from .config import MultiMoleculeConfig +from .runner import MultiMoleculeRunner + +__all__ = ["MultiMoleculeConfig", "MultiMoleculeRunner"] diff --git a/multimolecule/runners/base_runner.py b/multimolecule/runners/base_runner.py new file mode 100644 index 00000000..ad5b081d --- /dev/null +++ b/multimolecule/runners/base_runner.py @@ -0,0 +1,331 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import math +import os +from functools import cached_property, partial +from typing import Any, Tuple +from warnings import warn + +import danling as dl +import torch +from art import text2art +from chanfig import FlatDict, NestedDict +from danling import MultiTaskMetrics +from datasets import disable_progress_bars, get_dataset_split_names +from torch import nn, optim +from torch.nn import functional as F +from torch.utils import data +from tqdm import tqdm +from transformers import AutoTokenizer + +from multimolecule import defaults +from multimolecule.data import Dataset, DistributedMultiTaskSampler, MultiTaskDataset, MultiTaskSampler +from multimolecule.module import HeadConfig, ModelRegistry + +from .config import MultiMoleculeConfig +from .metrics import MetricRegistry + +disable_progress_bars() + + +class BaseRunner(dl.BaseRunner): + + all_datasets: NestedDict + + def __init__(self, config: MultiMoleculeConfig): + if config.art: + print(text2art("MultiMolecule", "rand-large")) + super().__init__(config) + self.name = config.name + self.tokenizer = AutoTokenizer.from_pretrained(self.config.pretrained) + self.build_datasets() + self.build_dataloaders() + self.model = ModelRegistry.build(**self.network) + if self.config.get("checkpoint"): + ckpt = dl.load(self.config.checkpoint) + model = ckpt.get("model", ckpt) + parameters = self.model.load_state_dict(model, strict=False) + if parameters.missing_keys: + raise ValueError(f"Missing keys in model: {parameters.missing_keys}") + if parameters.unexpected_keys: + warn(f"Unexpected keys in model: {parameters.unexpected_keys}") + self.model = self.model.to(self.device) + if self.distributed: + self.model = nn.parallel.DistributedDataParallel( + self.model, find_unused_parameters=True, bucket_cap_mb=32, gradient_as_bucket_view=True + ) + if self.config.training: + if self.config.optim and not (self.config.platform == "deepspeed" and self.config.deepspeed.optimizer): + self.optimizer = getattr(optim, self.config.optim.pop("name"))( + params=self.model.parameters(), **self.config.optim + ) + if self.config.sched and not (self.config.platform == "deepspeed" and self.config.deepspeed.scheduler): + self.scheduler = dl.optim.LRScheduler(self.optimizer, total_steps=self.total_steps, **self.config.sched) + self.metrics = self.build_metrics() + + def __post_init__(self): + super().__post_init__() + self.yaml(os.path.join(self.dir, "trainer.yaml")) + print(self) + print(self.get_dataset_lengths()) + + def train_step(self, data) -> Tuple[Any, torch.Tensor]: + with self.autocast(), self.accumulate(): + pred = self.model(**data) + loss = self.loss_fn(pred, data) + self.advance(loss) + self.metric_fn(pred, data) + return pred, loss + + def evaluate_step(self, data) -> Tuple[Any, torch.Tensor]: + pred = self.model(**data) + loss = self.loss_fn(pred, data) + self.metric_fn(pred, data) + return pred, loss + + @torch.inference_mode() + def infer(self, split: str = "inf") -> NestedDict | FlatDict | list: + r""" + Perform inference on `split`. + + Args: + split (str): split to run inference + + Return: + Inference outputs. + + - If the model has single output: + + - If labels are available, a [`FlatDict`][chanfig.FlatDict] with keys `predict` and `label` is + returned. + - If labels are not available, a list of predictions is returned. + + - If the model has multiple outputs: + - If labels are available, a [`NestedDict`][chanfig.NestedDict] with keys as task names and values + as dictionaries with keys `predict` and `label` is returned. + - If labels are not available, a [`FlatDict`][chanfig.FlatDict] with keys as task names and values + as lists of predictions is returned. + """ + + self.mode = "inf" # type: ignore + loader = self.dataloaders[split] + preds = FlatDict() + labels = FlatDict() + for _, data in tqdm(enumerate(loader), total=len(loader)): # noqa: F402 + pred = self.model(**data) + for task, p in pred.items(): + preds[task].extend(p["logits"].squeeze(-1).tolist()) + if task in data: + labels[task].extend(data[task].squeeze(-1).tolist()) + + if self.distributed: + torch.cuda.synchronize() + for task in preds.keys(): + preds[task] = self.gather_for_metrics(preds[task]) + for task in labels.keys(): + labels[task] = self.gather_for_metrics(labels[task]) + if labels: + if len(preds) == 1: + return FlatDict(predict=next(iter(preds.values())), label=next(iter(labels.values()))) + return NestedDict({task: {"predict": preds[task], "label": labels[task]} for task in preds}) + if len(preds) == 1: + return next(iter(preds.values())) + return preds + + def loss_fn(self, pred, data): + if self.balance == "rlw": + loss = torch.stack([p["loss"] for p in pred.values()]) + weight = F.softmax(torch.randn(len(pred), device=loss.device, dtype=loss.dtype), dim=-1) + return loss.T @ weight + if self.balance == "gls": + return math.prod(p["loss"] for p in pred.values()) ** (1 / len(pred)) + if self.balance != "ew": + warn(f"Unknown balance method {self.balance}, using equal weighting.") + return sum(p["loss"] for p in pred.values()) / len(pred) + + def metric_fn(self, pred, data): + metric = self.metrics[data["dataset"]] if "dataset" in data else self.metrics + metric.update({t: (p["logits"], data[t]) for t, p in pred.items()}) + + @cached_property + def tasks(self): + if not self.datasets: + raise ValueError("No datasets found") + if "train" in self.datasets: + return self.datasets.train.tasks + return next(iter(self.datasets.values())).tasks + + @cached_property + def dataset_tasks(self): + if not self.datasets: + raise ValueError("No datasets found") + dataset = self.datasets.train if "train" in self.datasets else next(iter(self.datasets.values())) + tasks = self.tasks + dataset_tasks = dataset.dataset_tasks if isinstance(dataset, MultiTaskDataset) else dataset.tasks + for dataset in self.datasets.values(): + if isinstance(dataset, MultiTaskDataset): + for dataset_name, tasks_ in dataset.dataset_tasks.items(): + for task_name, task_ in tasks_.items(): + if task_name not in tasks: + raise ValueError(f"Task {task_name} of dataset {dataset_name} is not defined") + task = tasks[task_name] + if task != task_: + warn( + f"Task {task_name} of dataset {dataset_name} has different configurations " + "compared to training data, using training configuration.\n" + "This may lead to unexpected behavior.", + ) + if dataset_name not in dataset_tasks: + dataset_tasks[dataset_name] = NestedDict() + if task_name not in dataset_tasks[dataset_name]: + dataset_tasks[dataset_name][task_name] = task + else: + for task_name, task_ in dataset.tasks.items(): + if task_name not in tasks: + raise ValueError(f"Task {task_name} is not defined") + task = tasks[task_name] + if task != task_: + warn( + f"Task {task_name} has different configurations " + "compared to training data, using training configuration.\n" + "This may lead to unexpected behavior.", + ) + if task_name not in dataset_tasks: + dataset_tasks[task_name] = task + return dataset_tasks + + @cached_property + def network(self): + heads = { + name: HeadConfig(num_labels=task.num_labels, problem_type=task.type, type=task.level) + for name, task in self.tasks.items() + } + if "heads" not in self.config.network: + self.config.network.heads = NestedDict(heads) + else: + self.config.network.heads.merge(heads, overwrite=False) + return self.config.network + + def build_datasets(self): + if "data" in self.config: + self.datasets = self.all_datasets = self._build_dataset(self.config.data) + return + if "datas" in self.config: + self.all_datasets = NestedDict( + {name: self._build_dataset(config, name) for name, config in self.config.datas.items()} + ) + datasets = { + subkey: {key: subdict[subkey] for key, subdict in self.all_datasets.items() if subkey in subdict} + for subkey in {k for v in self.all_datasets.values() for k in v} + } + self.datasets = NestedDict({split: MultiTaskDataset(datas) for split, datas in datasets.items()}) + return + raise ValueError("No data configuration found") + + def _build_dataset(self, config: NestedDict, name: str | None = None) -> NestedDict: + name = name or config.root + print(f"Building dataset {name}") + dataset = NestedDict() + train_splits = [key for key in config.keys() if key.startswith(defaults.TRAIN_SPLITS)] + validation_splits = [key for key in config.keys() if key.startswith(defaults.VALIDATION_SPLITS)] + test_splits = [key for key in config.keys() if key.startswith(defaults.TEST_SPLITS)] + inference_splits = [key for key in config.keys() if key.startswith(defaults.INFERENCE_SPLITS)] + all_splits = train_splits + validation_splits + test_splits + inference_splits + ignored_keys = all_splits + ["root"] + dataset_factory = partial( + Dataset, + tokenizer=self.tokenizer, + **{k: v for k, v in config.items() if k not in ignored_keys}, + ) + if os.path.isdir(config.root): + for split in train_splits: + dataset[split] = dataset_factory(os.path.join(config.root, config[split]), split="train") + for split in validation_splits: + dataset[split] = dataset_factory(os.path.join(config.root, config[split]), split="validation") + for split in test_splits: + dataset[split] = dataset_factory(os.path.join(config.root, config[split]), split="test") + for split in inference_splits: + dataset[split] = dataset_factory(os.path.join(config.root, config[split]), split=config[split]) + else: + splits = get_dataset_split_names(config.root) + existing_splits = {k for k in defaults.DATASET_SPLITS if config.get(k) is not None} + if not existing_splits: + if "train" in splits: + config.train = "train" + if "validation" in splits: + config.validation = "validation" + if "test" in splits: + config.test = "test" + for split in existing_splits: + dataset[split] = dataset_factory(config.root, split=split) + if not dataset: + raise ValueError(f"No datasets built. This is likely due to missing data paths in {config}.") + return dataset + + def build_dataloaders(self): + datasets = {k: d for k, d in self.datasets.items() if k not in self.dataloaders} + default_kwargs = self.config.get("dataloader", NestedDict()) + dataloader_kwargs = NestedDict({k: default_kwargs.pop(k) for k in self.datasets if k in default_kwargs}) + for k, d in datasets.items(): + dataloader_kwargs.setdefault(k, NestedDict()) + dataloader_kwargs[k].merge(default_kwargs, overwrite=False) + batch_size = dataloader_kwargs[k].pop("batch_size") + shuffle = dataloader_kwargs[k].pop("shuffle", getattr(d, "train", True)) + drop_last = dataloader_kwargs[k].pop("drop_last", not getattr(d, "train", True)) + if isinstance(d, MultiTaskDataset): + batch_sampler = ( + DistributedMultiTaskSampler(d, batch_size, shuffle=shuffle, drop_last=drop_last) + if self.distributed + else MultiTaskSampler(d, batch_size, shuffle=shuffle, drop_last=drop_last) + ) + else: + sampler = ( + data.distributed.DistributedSampler(d, shuffle=shuffle) + if self.distributed + else data.RandomSampler(d) if shuffle else data.SequentialSampler(d) + ) + batch_sampler = data.BatchSampler(sampler, batch_size, drop_last=drop_last) + self.dataloaders[k] = data.DataLoader( + d, batch_sampler=batch_sampler, collate_fn=self.collate_fn, **dataloader_kwargs[k] + ) + + def build_metrics(self) -> MultiTaskMetrics: + return MultiTaskMetrics( + { + name: MetricRegistry.build(type=task.type, num_labels=task.num_labels) + for name, task in self.dataset_tasks.all_items() + } + ) + + def collate_fn(self, batch): + return {k: v.to(self.device) if hasattr(v, "to") else v for k, v in batch.items()} + + def get_dataset_lengths(self) -> str: + repr = "dataset lengths:\n" + longest_name = max(len(name) for name in self.all_datasets.keys()) + for name, dataset in self.all_datasets.items(): + if isinstance(dataset, NestedDict): + repr += f"{name}:" + if len(name) < longest_name: + repr += " " * (longest_name - len(name)) + repr += "\t\t" + for split, d in dataset.items(): + repr += f" {split}: {len(d)}\t" + else: + repr += f"{name}: {len(dataset)}\t" + repr += "\n" + return repr diff --git a/multimolecule/runners/config.py b/multimolecule/runners/config.py new file mode 100644 index 00000000..ff941b3e --- /dev/null +++ b/multimolecule/runners/config.py @@ -0,0 +1,87 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from __future__ import annotations + +import os +from pathlib import Path +from typing import List + +from chanfig import Config +from transformers import PretrainedConfig + + +class DataConfig(Config): + root: str = "." + train: str | None + validation: str | None + test: str | None + feature_cols: List | None = None + label_cols: List | None = None + truncation: bool = True + + +class OptimConfig(Config): + name: str = "AdamW" + lr: float = 1e-3 + weight_decay: float = 1e-2 + + +class MultiMoleculeConfig(Config): + + balance: str = "ew" + platform: str = "torch" + training: bool = True + + pretrained: str + use_pretrained: bool = True + transformers: PretrainedConfig + epoch_end: int = 20 + + data: DataConfig + + tensorboard: bool = True + save_interval: int = 10 + seed: int = 1016 + art: bool = True + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.datas = Config(default_factory=DataConfig) + self.dataloader.batch_size = 32 + self.optim = OptimConfig() + self.sched.final_lr = 0 + + def post(self): + if "data" in self: + if self.datas: + raise ValueError("Only one of `data` or `datas` can be specified, but not both") + del self.datas + self["network.backbone.sequence.name"] = self.pretrained + self["network.backbone.sequence.use_pretrained"] = self.use_pretrained + pretrained = self.pretrained + if os.path.exists(self.pretrained): + path = Path(pretrained) + if os.path.isfile(pretrained): + pretrained = str(path.relative_to(path.parents[1]).with_suffix("")) + else: + pretrained = path.stem + + self.name = f"{pretrained}" + if "optim" in self: + optim_name = self.optim.get("name", "no") + self.name += f"-{self.optim.lr}@{optim_name}" + self.name += f"-{self.seed}" diff --git a/multimolecule/runners/metrics.py b/multimolecule/runners/metrics.py new file mode 100644 index 00000000..da584cbc --- /dev/null +++ b/multimolecule/runners/metrics.py @@ -0,0 +1,37 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from chanfig import Registry as Registry_ +from danling.metrics import binary_metrics, multiclass_metrics, multilabel_metrics, regression_metrics + + +class Registry(Registry_): + + def build(self, type, num_labels: int | None = None, **kwargs): + if type == "multilabel": + return self.init(self.lookup(type), num_labels=num_labels, **kwargs) + if type == "multiclass": + return self.init(self.lookup(type), num_classes=num_labels, **kwargs) + if type == "regression": + return self.init(self.lookup(type), num_outputs=num_labels, **kwargs) + return self.init(self.lookup(type), **kwargs) + + +MetricRegistry = Registry(key="type") +MetricRegistry.register(binary_metrics, "binary") +MetricRegistry.register(multiclass_metrics, "multiclass") +MetricRegistry.register(multilabel_metrics, "multilabel") +MetricRegistry.register(regression_metrics, "regression") diff --git a/multimolecule/runners/runner.py b/multimolecule/runners/runner.py new file mode 100644 index 00000000..97cab4e7 --- /dev/null +++ b/multimolecule/runners/runner.py @@ -0,0 +1,42 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import danling as dl + +from .base_runner import BaseRunner + + +class MultiMoleculeRunner(type): + def __new__(cls, config): + if config.get("platform", "torch") == "torch": + return TorchRunner(config) + if config.platform == "deepspeed": + return DeepSpeedRunner(config) + if config.platform == "accelerate": + return AccelerateRunner(config) + raise ValueError(f"Unsupported platform: {config.platform}") + + +class TorchRunner(BaseRunner, dl.TorchRunner): + pass + + +class DeepSpeedRunner(BaseRunner, dl.DeepSpeedRunner): + pass + + +class AccelerateRunner(BaseRunner, dl.AccelerateRunner): + pass diff --git a/multimolecule/tasks/task.py b/multimolecule/tasks/task.py index e2473ab0..5d435f83 100644 --- a/multimolecule/tasks/task.py +++ b/multimolecule/tasks/task.py @@ -34,9 +34,8 @@ class TaskType(StrEnum): class TaskLevel(StrEnum): Sequence = auto() - Nucleotide = auto() + Token = auto() Contact = auto() - # Token = auto() @dataclass diff --git a/multimolecule/train.py b/multimolecule/train.py new file mode 100644 index 00000000..f146bc97 --- /dev/null +++ b/multimolecule/train.py @@ -0,0 +1,20 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from .apis import train + +if __name__ == "__main__": + train() diff --git a/pyproject.toml b/pyproject.toml index 972b1d8a..ee4c525d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,9 +45,11 @@ dynamic = [ ] dependencies = [ "accelerate", + "art", "chanfig>=0.0.105", - "danling>=0.3.11", + "danling>=0.4.0b1", "datasets", + 'StrEnum; python_version < "3.11"', "torch", "transformers", ] diff --git a/tests/data/test_dataset.py b/tests/data/test_dataset.py index 5ddfb3c4..318e289e 100644 --- a/tests/data/test_dataset.py +++ b/tests/data/test_dataset.py @@ -126,7 +126,7 @@ def test_spliceai(self, preprocess: bool): feature_cols=feature_cols, label_cols=label_cols, ) - task = Task(type=TaskType.Binary, level=TaskLevel.Nucleotide, num_labels=1) + task = Task(type=TaskType.Binary, level=TaskLevel.Token, num_labels=1) elem = dataset[0] assert isinstance(elem["sequence"], torch.LongTensor) assert isinstance(elem["splice_ai"], torch.LongTensor) @@ -175,20 +175,18 @@ def test_rna_task_recognition_json(self): assert dataset.tasks["sequence_regression"] == Task( type=TaskType.Regression, level=TaskLevel.Sequence, num_labels=1 ) - assert dataset.tasks["nucleotide_binary"] == Task( - type=TaskType.Binary, level=TaskLevel.Nucleotide, num_labels=1 - ) + assert dataset.tasks["nucleotide_binary"] == Task(type=TaskType.Binary, level=TaskLevel.Token, num_labels=1) assert dataset.tasks["nucleotide_multiclass"] == Task( - type=TaskType.MultiClass, level=TaskLevel.Nucleotide, num_labels=5 + type=TaskType.MultiClass, level=TaskLevel.Token, num_labels=5 ) assert dataset.tasks["nucleotide_multilabel"] == Task( - type=TaskType.MultiLabel, level=TaskLevel.Nucleotide, num_labels=5 + type=TaskType.MultiLabel, level=TaskLevel.Token, num_labels=5 ) assert dataset.tasks["nucleotide_multireg"] == Task( - type=TaskType.Regression, level=TaskLevel.Nucleotide, num_labels=5 + type=TaskType.Regression, level=TaskLevel.Token, num_labels=5 ) assert dataset.tasks["nucleotide_regression"] == Task( - type=TaskType.Regression, level=TaskLevel.Nucleotide, num_labels=1 + type=TaskType.Regression, level=TaskLevel.Token, num_labels=1 ) assert dataset.tasks["contact_binary"] == Task(type=TaskType.Binary, level=TaskLevel.Contact, num_labels=1) assert dataset.tasks["contact_multiclass"] == Task(