diff --git a/.codespell-whitelist.txt b/.codespell-whitelist.txt index 44c7e9f5..467e5c38 100644 --- a/.codespell-whitelist.txt +++ b/.codespell-whitelist.txt @@ -1,3 +1,4 @@ +datas ser marz manuel diff --git a/docs/docs/data/multitask.md b/docs/docs/data/multitask.md new file mode 100644 index 00000000..054c6205 --- /dev/null +++ b/docs/docs/data/multitask.md @@ -0,0 +1,9 @@ +--- +authors: + - Zhiyuan Chen +date: 2024-05-04 +--- + +# MultiTask + +::: multimolecule.data.multitask diff --git a/docs/docs/runners/config.md b/docs/docs/runners/config.md new file mode 100644 index 00000000..8e188199 --- /dev/null +++ b/docs/docs/runners/config.md @@ -0,0 +1,9 @@ +--- +authors: + - Zhiyuan Chen +date: 2024-05-04 +--- + +# MultiMoleculeConfig + +::: multimolecule.runners.MultiMoleculeConfig diff --git a/docs/docs/runners/index.md b/docs/docs/runners/index.md new file mode 100644 index 00000000..75a0f528 --- /dev/null +++ b/docs/docs/runners/index.md @@ -0,0 +1,9 @@ +--- +authors: + - Zhiyuan Chen +date: 2024-05-04 +--- + +# runners + +--8<-- "multimolecule/runners/README.md:8:" diff --git a/docs/docs/runners/runner.md b/docs/docs/runners/runner.md new file mode 100644 index 00000000..2c93ce1b --- /dev/null +++ b/docs/docs/runners/runner.md @@ -0,0 +1,9 @@ +--- +authors: + - Zhiyuan Chen +date: 2024-05-04 +--- + +# MultiMoleculeRunner + +::: multimolecule.runners.MultiMoleculeRunner diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 93d53a45..9324d9c3 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -9,9 +9,14 @@ repo_url: https://github.com/DLS5-Omics/multimolecule nav: - index.md + - runners: + - runners/index.md + - MultiMoleculeRunner: runners/runner.md + - MultiMoleculeConfig: runners/config.md - data: - data/index.md - Dataset: data/dataset.md + - multitask: data/multitask.md - datasets: - datasets/index.md - DNA: diff --git a/multimolecule/__init__.py b/multimolecule/__init__.py index 240e9fcc..c9ee3a87 100644 --- a/multimolecule/__init__.py +++ b/multimolecule/__init__.py @@ -14,6 +14,7 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +from .apis import evaluate, infer, train from .data import Dataset from .models import ( AutoModelForContactPrediction, @@ -111,30 +112,33 @@ HeadConfig, HeadRegistry, HeadTransformRegistry, - HeadTransformRegistryHF, IdentityTransform, LinearTransform, MaskedLMHead, MaskedLMHeadConfig, NonLinearTransform, PositionEmbeddingRegistry, - PositionEmbeddingRegistryHF, PredictionHead, RotaryEmbedding, SequencePredictionHead, SinusoidalEmbedding, - TokenHeadRegistryHF, TokenKMerHead, TokenPredictionHead, ) +from .runners import MultiMoleculeConfig, MultiMoleculeRunner from .tasks import Task, TaskLevel, TaskType from .tokenisers import Alphabet, DnaTokenizer, DotBracketTokenizer, ProteinTokenizer, RnaTokenizer, Tokenizer from .utils import count_parameters __all__ = [ + "train", + "evaluate", + "infer", "modeling_auto", "modeling_outputs", "Dataset", + "MultiMoleculeConfig", + "MultiMoleculeRunner", "PreTrainedConfig", "HeadConfig", "BaseHeadConfig", @@ -233,21 +237,15 @@ "HeadRegistry", "PredictionHead", "SequencePredictionHead", - "TokenHeadRegistryHF", "TokenPredictionHead", "TokenKMerHead", - "NucleotideHeadRegistryHF", - "NucleotidePredictionHead", - "NucleotideKMerHead", "ContactPredictionHead", "MaskedLMHead", "HeadTransformRegistry", - "HeadTransformRegistryHF", "LinearTransform", "NonLinearTransform", "IdentityTransform", "PositionEmbeddingRegistry", - "PositionEmbeddingRegistryHF", "RotaryEmbedding", "SinusoidalEmbedding", "Criterion", diff --git a/multimolecule/apis/__init__.py b/multimolecule/apis/__init__.py new file mode 100644 index 00000000..8e3e5b3c --- /dev/null +++ b/multimolecule/apis/__init__.py @@ -0,0 +1,19 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from .run import evaluate, infer, train + +__all__ = ["train", "evaluate", "infer"] diff --git a/multimolecule/apis/run.py b/multimolecule/apis/run.py new file mode 100644 index 00000000..1fdb7666 --- /dev/null +++ b/multimolecule/apis/run.py @@ -0,0 +1,115 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +# mypy: disable-error-code="attr-defined" + +import atexit +import os +import warnings +from typing import Type + +import danling as dl +import torch + +from multimolecule.runners import MultiMoleculeConfig, MultiMoleculeRunner + +try: + import nni +except ImportError: + nni = None + + +def train( + config: MultiMoleculeConfig = None, # type: ignore + runner_cls: Type[MultiMoleculeRunner] = MultiMoleculeRunner, +): + if config is None: + config = MultiMoleculeConfig() + config = config.parse(default_config="config", no_default_config_action="warn") + config.interpolate(unsafe_eval=True) + config.training = True + if config.allow_tf32: + torch.backends.cudnn.allow_tf32 = True + torch.backends.cuda.matmul.allow_tf32 = True + if config.reduced_precision_reduction: + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True + if config.get("nni", False): + if nni is None: + raise ValueError("Unable to retrieve nni parameters, since nni is not installed.") + config.merge(nni.get_next_parameter()) + with dl.debug(config.get("debug", False)): + runner = runner_cls(config) + atexit.register(runner.print_result) + atexit.register(runner.save_result) + atexit.register(runner.save_checkpoint) + result = runner.train() + return result + + +def evaluate( + config: MultiMoleculeConfig = None, # type: ignore + runner_cls: Type[MultiMoleculeRunner] = MultiMoleculeRunner, +): + if config is None: + config = MultiMoleculeConfig.empty() + config = config.parse(default_config="config", no_default_config_action="warn") + config.interpolate(unsafe_eval=True) + config.training = False + if config.allow_tf32: + torch.backends.cudnn.allow_tf32 = True + torch.backends.cuda.matmul.allow_tf32 = True + if config.reduced_precision_reduction: + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True + if "checkpoint" not in config or not isinstance(config.checkpoint, str): + raise RuntimeError("Please specify `checkpoint` to run evaluate") + for name, data in config.datas.items(): + if "evaluation" not in data or not isinstance(data.evaluate, str): + raise RuntimeError(f"Please specify `evaluation` to run evaluate in datas.{name}") + runner = runner_cls(config) + result = runner.evaluate_epoch("evaluation") + print(result) + return result + + +def infer( + config: MultiMoleculeConfig = None, # type: ignore + runner_cls: Type[MultiMoleculeRunner] = MultiMoleculeRunner, +): + if config is None: + config = MultiMoleculeConfig.empty() + config = config.parse(default_config="config", no_default_config_action="warn") + config.interpolate(unsafe_eval=True) + config.training = False + if config.allow_tf32: + torch.backends.cudnn.allow_tf32 = True + torch.backends.cuda.matmul.allow_tf32 = True + if config.reduced_precision_reduction: + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True + if "checkpoint" not in config or not isinstance(config.checkpoint, str): + raise RuntimeError("Please specify `checkpoint` to run infer.") + for name, data in config.datas.items(): + if "inference" not in data or not isinstance(data.inference, str): + raise RuntimeError(f"Please specify `inference` to run infer in datas.{name}") + if "result_path" not in config or not isinstance(config.result_path, str): + config.result_path = os.path.join(os.getcwd(), "result.json") + warnings.warn("`result_path` is not specified, default to `result.json`.", RuntimeWarning, stacklevel=2) + runner = runner_cls(config) + result = runner.infer() + runner.save(result, config.result_path) + return result diff --git a/multimolecule/apis/stat.py b/multimolecule/apis/stat.py new file mode 100644 index 00000000..5e525d55 --- /dev/null +++ b/multimolecule/apis/stat.py @@ -0,0 +1,99 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import os +import shutil +from statistics import mean +from typing import List + +import chanfig +import pandas as pd +from chanfig import NestedDict +from tqdm import tqdm + + +class Result(NestedDict): + pretrained: str + id: str + seed: int + epoch: int + validation: NestedDict + test: NestedDict + + +def get_result_stat(experiment_root: str, remove_empty: bool = True) -> List[Result]: + results = [] + for root, _, files in tqdm(os.walk(experiment_root)): + if "run.log" in files: + if "best.json" not in files: + if remove_empty: + shutil.rmtree(root) + continue + best = NestedDict.from_json(os.path.join(root, "best.json")) + if "index" not in best: + if remove_empty: + shutil.rmtree(root) + continue + config = NestedDict.from_yaml(os.path.join(root, "trainer.yaml")) + pretrained = config.pretrained.split("/")[-1] + seed = config.seed + pretrained, seed = "", 1 + result = Result(id=best.id, pretrained=pretrained, seed=seed) + result.validation = NestedDict( + {k: format(mean(v) if isinstance(v, list) else v, ".8f") for k, v in best.validation.items()} + ) + result.test = NestedDict( + {k: format(mean(v) if isinstance(v, list) else v, ".8f") for k, v in best.test.items()} + ) + result.epoch = best.index + result.pop("validation.time", None) + result.pop("test.time", None) + result.pop("validation.loss", None) + result.pop("test.loss", None) + result.pop("validation.lr", None) + result.pop("test.lr", None) + results.append(result) + # Remove empty directories, perform twice to remove all empty directories + if remove_empty: + for root, dirs, files in os.walk(experiment_root): + if not files and not dirs: + os.rmdir(root) + for root, dirs, files in os.walk(experiment_root): + if not files and not dirs: + os.rmdir(root) + results.sort(key=lambda x: (x.pretrained, x.seed, x.id)) + return results + + +def write_result_stat(results: List[Result], path: str): + results = [dict(result.all_items()) for result in results] # type: ignore[misc] + df = pd.DataFrame.from_dict(results) + df.insert(len(df.keys()) - 1, "comment", "") + df.fillna("") + df.to_csv(path, index=False) + + +class Config(chanfig.Config): + experiment_root: str = "experiments" + out_path: str = "result.csv" + + +if __name__ == "__main__": + config = Config().parse() + result_stat = get_result_stat(config.experiment_root) + if not len(result_stat) > 0: + raise ValueError("No results found") + write_result_stat(result_stat, config.out_path) diff --git a/multimolecule/data/__init__.py b/multimolecule/data/__init__.py index 62196c10..f2366d77 100644 --- a/multimolecule/data/__init__.py +++ b/multimolecule/data/__init__.py @@ -15,6 +15,14 @@ # along with this program. If not, see . from .dataset import Dataset +from .multitask import DistributedMultiTaskSampler, MultiTaskDataset, MultiTaskSampler from .utils import no_collate -__all__ = ["Dataset", "no_collate"] +__all__ = [ + "Dataset", + "PandasDataset", + "MultiTaskDataset", + "MultiTaskSampler", + "DistributedMultiTaskSampler", + "no_collate", +] diff --git a/multimolecule/data/dataset.py b/multimolecule/data/dataset.py index 54565349..2b6fae49 100644 --- a/multimolecule/data/dataset.py +++ b/multimolecule/data/dataset.py @@ -149,8 +149,9 @@ def __init__( fingerprint: str | None = None, ignored_cols: List[str] | None = None, ): + self._tasks = NestedDict() if tasks is not None: - self._tasks = NestedDict(tasks) + self.tasks = tasks if discrete_map is not None: self._discrete_map = discrete_map arrow_table = self.build_table( @@ -187,13 +188,13 @@ def build_table( data = dl.load_pandas(data) if isinstance(data, DataFrame): data = data.loc[:, ~data.columns.str.contains("^Unnamed")] - data = pa.Table.from_pandas(data) + data = pa.Table.from_pandas(data, preserve_index=False) elif isinstance(data, dict): data = pa.Table.from_pydict(data) elif isinstance(data, list): data = pa.Table.from_pylist(data) elif isinstance(data, DataFrame): - data = pa.Table.from_pandas(data) + data = pa.Table.from_pandas(data, preserve_index=False) if feature_cols is not None and label_cols is not None: data = data.select(feature_cols + label_cols) data = self.process_nan(data, nan_process=nan_process, fill_value=fill_value) @@ -250,6 +251,7 @@ def post( self.column_names_map = column_names_map if self.column_names_map: self.rename_columns(self.column_names_map) + self.infer_tasks() if self.preprocess: self.update(self.map(self.tokenization)) @@ -258,7 +260,7 @@ def post( if self.discrete_map: self.update(self.map(self.map_discrete)) fn_kwargs = { - "columns": [name for name, task in self.tasks.items() if task.level in ["nucleotide", "contact"]], + "columns": [name for name, task in self.tasks.items() if task.level in ["token", "contact"]], "max_seq_length": self.max_seq_length - self.seq_length_offset, } if self.truncation and 0 < self.max_seq_length < 2**32: @@ -297,20 +299,23 @@ def collate(self, col: str, data: Any) -> Tensor | NestedTensor | None: except ValueError: return NestedTensor(data) - def infer_tasks(self, tasks: Mapping | None = None, sequence_col: str | None = None) -> NestedDict: - self._tasks = tasks or NestedDict() + def infer_tasks(self, sequence_col: str | None = None) -> NestedDict: for col in self.label_cols: - if col not in self.tasks: - if col in self.secondary_structure_cols: - task = Task(TaskType.Binary, level=TaskLevel.Contact, num_labels=1) - self._tasks[col] = task # type: ignore[index] - warn( - f"Secondary structure columns are assumed to be {task}." - " Please explicitly specify the task if this is not the case." - ) - else: - self._tasks[col] = self.infer_task(col, sequence_col) # type: ignore[index] - return self._tasks + if col in self.tasks: + continue + if col in self.secondary_structure_cols: + task = Task(TaskType.Binary, level=TaskLevel.Contact, num_labels=1) + self.tasks[col] = task # type: ignore[index] + warn( + f"Secondary structure columns are assumed to be {task}. " + "Please explicitly specify the task if this is not the case." + ) + else: + try: + self.tasks[col] = self.infer_task(col, sequence_col) # type: ignore[index] + except ValueError: + raise ValueError(f"Unable to infer task for column {col}.") + return self.tasks def infer_task(self, label_col: str, sequence_col: str | None = None) -> Task: if sequence_col is None: @@ -404,7 +409,7 @@ def rename_columns(self, column_mapping: Mapping[str, str], new_fingerprint: str self._label_cols = [column_mapping.get(i, i) for i in self.label_cols] self._sequence_cols = [column_mapping.get(i, i) for i in self.sequence_cols] self._secondary_structure_cols = [column_mapping.get(i, i) for i in self.secondary_structure_cols] - self._tasks = {column_mapping.get(k, k): v for k, v in self.tasks.items()} + self.tasks = {column_mapping.get(k, k): v for k, v in self.tasks.items()} return self def rename_column( @@ -418,7 +423,7 @@ def rename_column( self._secondary_structure_cols = [ new_column_name if i == original_column_name else i for i in self.secondary_structure_cols ] - self._tasks = {new_column_name if k == original_column_name else k: v for k, v in self.tasks.items()} + self.tasks = {new_column_name if k == original_column_name else k: v for k, v in self.tasks.items()} return self def process_nan(self, data: Table, nan_process: str | None, fill_value: str | int | float = 0) -> Table: @@ -470,9 +475,18 @@ def secondary_structure_cols(self) -> List: @property def tasks(self) -> NestedDict: if not hasattr(self, "_tasks"): + self._tasks = NestedDict() return self.infer_tasks() return self._tasks + @tasks.setter + def tasks(self, tasks: Mapping): + self._tasks = NestedDict() + for name, task in tasks.items(): + if not isinstance(task, Task): + task = Task(**task) + self._tasks[name] = task + @property def discrete_map(self) -> Mapping: if not hasattr(self, "_discrete_map"): diff --git a/multimolecule/data/multitask.py b/multimolecule/data/multitask.py new file mode 100644 index 00000000..7c20e829 --- /dev/null +++ b/multimolecule/data/multitask.py @@ -0,0 +1,246 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from __future__ import annotations + +from bisect import bisect_right +from collections.abc import Iterator, Mapping, Sequence +from copy import deepcopy +from random import choices + +import torch +from chanfig import NestedDict +from torch import distributed as dist +from torch.utils import data + +from .dataset import Dataset + + +class MultiTaskDataset(data.ConcatDataset): + + datasets: Mapping + dataset_keys: Sequence[str] + dataset_values: Sequence[Dataset] + + def __init__(self, datasets: Mapping) -> None: + for key, dataset in datasets.items(): + if not isinstance(dataset, Dataset): + raise TypeError(f"Dataset {key} should be an instance of Dataset") + self.datasets = datasets + if not len(self.datasets) > 0: + raise ValueError("MultiTaskDataset should contain at least one dataset") + self.dataset_keys, self.dataset_values = zip(*self.datasets.items()) + self.cumulative_sizes = self.cumsum(self.dataset_values) + + def __getitems__(self, key: Sequence[int]) -> Mapping: + dataset_idx = bisect_right(self.cumulative_sizes, key[0]) + if dataset_idx == 0: + sample_idx = key + else: + sample_idx = [i - self.cumulative_sizes[dataset_idx - 1] for i in key] + batch = self.dataset_values[dataset_idx][sample_idx] + batch["dataset"] = self.dataset_keys[dataset_idx] + return batch + + @property + def tasks(self) -> NestedDict: + tasks = NestedDict() + for dataset in self.dataset_values: + for n, t in dataset.tasks.items(): + if n not in tasks: + tasks[n] = t + elif tasks[n] != t: + raise ValueError(f"Task {n} has different configurations across datasets") + return tasks + + @property + def dataset_tasks(self) -> NestedDict: + return NestedDict({k: v.tasks for k, v in self.datasets.items()}) + + def __repr__(self) -> str: + return f"MultiTaskDataset({', '.join([str(d) for d in self.datasets])})" + + +class MultiTaskSampler(data.BatchSampler): + r""" + Ensure all items in a batch comes from the same dataset. + + Arguments: + sampler (Sampler): Base sampler. + batch_size (int): Size of mini-batch. + drop_last (bool): If ``True``, the sampler will drop the last batch if + its size would be less than ``batch_size`` + """ + + datasets: Sequence[Dataset] + + def __init__( # pylint: disable=super-init-not-called + self, + dataset: MultiTaskDataset, + batch_size: int, + shuffle: bool = True, + drop_last: bool = False, + sampler_cls: type[data.Sampler] | None = None, + weights: list[int] | None = None, + ) -> None: + self.datasets = dataset.dataset_values + self.batch_size = batch_size + self.drop_last = drop_last + self.shuffle = shuffle + if sampler_cls is None: + sampler_cls = data.RandomSampler if shuffle else data.SequentialSampler + self.samplers = [sampler_cls(d) for d in self.datasets] # type: ignore + self.dataset_sizes = [len(d) for d in self.datasets] # type: ignore + self.cumulative_sizes = dataset.cumulative_sizes + self.num_datasets = len(self.datasets) + self.weights = weights if weights is not None else self.dataset_sizes + + def __iter__(self): + sampler_iters = [(i, iter(s)) for i, s in enumerate(self.samplers)] + sampler_weights = deepcopy(self.weights) + sampler_idx = 0 + # Implemented based on the benchmarking in https://github.com/pytorch/pytorch/pull/76951 + if self.drop_last: + while sampler_iters: + if self.shuffle: + sampler_idx = choices(range(len(sampler_iters)), weights=sampler_weights)[0] + sampler_id, sampler_iter = sampler_iters[sampler_idx] + cumulative_size = self.cumulative_sizes[sampler_id - 1] if sampler_id > 0 else 0 + try: + batch = [next(sampler_iter) + cumulative_size for _ in range(self.batch_size)] + yield batch + except StopIteration: + sampler_iters.pop(sampler_idx) + sampler_weights.pop(sampler_idx) + else: + while sampler_iters: + if self.shuffle: + sampler_idx = choices(range(len(sampler_iters)), weights=sampler_weights)[0] + sampler_id, sampler_iter = sampler_iters[sampler_idx] + cumulative_size = self.cumulative_sizes[sampler_id - 1] if sampler_id > 0 else 0 + batch = [0] * self.batch_size + idx_in_batch = 0 + try: + for _ in range(self.batch_size): + batch[idx_in_batch] = next(sampler_iter) + cumulative_size + idx_in_batch += 1 + yield batch + idx_in_batch = 0 # noqa: SIM113 + batch = [0] * self.batch_size + except StopIteration: + sampler_iters.pop(sampler_idx) + sampler_weights.pop(sampler_idx) + if idx_in_batch > 0: + yield batch[:idx_in_batch] + + def __len__(self): + batch_size = self.batch_size + if self.drop_last: + return sum(len(d) // batch_size for d in self.datasets) + return sum((len(d) + batch_size - 1) // batch_size for d in self.datasets) + + +class DistributedMultiTaskSampler(MultiTaskSampler): # pylint: disable=too-few-public-methods + r""" + Distributed version of MultiTaskSampler, which ensures that all GPUs sample data from the + same sub-dataset in each step without requiring additional communication. + The dataset selection is based on a random seed mechanism that is synchronized across epochs. + + See Also: + [MultiTaskSampler][MultiTaskSampler] + """ + + def __init__( + self, + dataset: MultiTaskDataset, + batch_size: int, + shuffle: bool = True, + drop_last: bool = False, + sampler_cls: type[data.Sampler] = data.RandomSampler, + weights: list[int] | None = None, + seed: int = 0, + ) -> None: + super().__init__(dataset, batch_size, shuffle, drop_last, sampler_cls, weights) + self.samplers = [data.DistributedSampler(d, shuffle=shuffle, drop_last=drop_last) for d in self.datasets] + self.seed = seed + self.epoch = 0 + + def set_epoch(self, epoch: int): + """ + Sets the epoch for deterministic shuffling. + """ + self.epoch = epoch + for sampler in self.samplers: + sampler.set_epoch(epoch) + + def _get_sampler_idx(self, high: int) -> int: + """ + Determines which sampler (i.e., sub-dataset) to use based on the seed and epoch. + """ + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + sampler_idx = torch.randint(low=0, high=high, size=(1,), generator=g).item() + return sampler_idx + + def __iter__(self) -> Iterator: + sampler_iters = [(i, iter(s)) for i, s in enumerate(self.samplers)] + sampler_weights = deepcopy(self.weights) + + if self.drop_last: + while sampler_iters: + # Sample the same sub-dataset across all GPUs using the seeded index + sampler_idx = self._get_sampler_idx(len(sampler_iters)) + sampler_id, sampler_iter = sampler_iters[sampler_idx] + cumulative_size = self.cumulative_sizes[sampler_id - 1] if sampler_id > 0 else 0 + try: + batch = [next(sampler_iter) + cumulative_size for _ in range(self.batch_size)] + yield batch + except StopIteration: + sampler_iters.pop(sampler_idx) + sampler_weights.pop(sampler_idx) + else: + while sampler_iters: + # Sample the same sub-dataset across all GPUs using the seeded index + sampler_idx = self._get_sampler_idx(len(sampler_iters)) + sampler_id, sampler_iter = sampler_iters[sampler_idx] + cumulative_size = self.cumulative_sizes[sampler_id - 1] if sampler_id > 0 else 0 + batch = [0] * self.batch_size + idx_in_batch = 0 + try: + for _ in range(self.batch_size): + batch[idx_in_batch] = next(sampler_iter) + cumulative_size + idx_in_batch += 1 + yield batch + idx_in_batch = 0 # noqa: SIM113 + batch = [0] * self.batch_size + except StopIteration: + sampler_iters.pop(sampler_idx) + sampler_weights.pop(sampler_idx) + if idx_in_batch > 0: + yield batch[:idx_in_batch] + + def __len__(self) -> int: + batch_size = self.batch_size * self.world_size + if self.drop_last: + return sum(len(d) // batch_size for d in self.datasets) + return sum((len(d) + batch_size - 1) // batch_size for d in self.datasets) + + @property + def world_size(self) -> int: + r"""Return the number of processes in the current process group.""" + if dist.is_available() and dist.is_initialized(): + return dist.get_world_size() + return 1 diff --git a/multimolecule/data/utils.py b/multimolecule/data/utils.py index 85bc4423..1afddd1b 100644 --- a/multimolecule/data/utils.py +++ b/multimolecule/data/utils.py @@ -60,7 +60,7 @@ def infer_task( level = TaskLevel.Contact num_labels = len(flattened) // num_contacts elif len(flattened) % num_tokens == 0: - level = TaskLevel.Nucleotide + level = TaskLevel.Token num_labels = len(flattened) // num_tokens elif len(flattened) % num_elem == 0: level = TaskLevel.Sequence @@ -86,7 +86,7 @@ def infer_task( task_type = TaskType.MultiClass if num_labels > 2 else TaskType.Binary num_labels = 1 if task_type == TaskType.Binary else num_labels if num_tokens_flattened == num_tokens: - return Task(task_type, level=TaskLevel.Nucleotide, num_labels=num_labels) + return Task(task_type, level=TaskLevel.Token, num_labels=num_labels) if num_contacts_flattened == num_contacts: return Task(task_type, level=TaskLevel.Contact, num_labels=num_labels) return Task(task_type, level=TaskLevel.Sequence, num_labels=num_labels) @@ -122,7 +122,7 @@ def map_value(value: Any, mapping: dict[str, int] | None) -> Any: def truncate_value(value: Any, max_seq_length: int, level: int | None = None) -> Any: - if level == TaskLevel.Nucleotide: + if level == TaskLevel.Token: return value[:max_seq_length] if level == TaskLevel.Contact: return [i[:max_seq_length] for i in value[:max_seq_length]] diff --git a/multimolecule/defaults.py b/multimolecule/defaults.py index c299ea1a..0a718a60 100644 --- a/multimolecule/defaults.py +++ b/multimolecule/defaults.py @@ -14,6 +14,7 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +DATASET_SPLITS = ["train", "validation", "test", "evaluation", "inference"] ID_COL_NAMES = ["id", "idx", "index"] SEQUENCE_COL_NAMES = ["input_ids", "sequence", "seq"] SECONDARY_STRUCTURE_COL_NAMES = ["secondary_structure", "ss"] diff --git a/multimolecule/models/__init__.py b/multimolecule/models/__init__.py index 66147616..29d99436 100644 --- a/multimolecule/models/__init__.py +++ b/multimolecule/models/__init__.py @@ -14,6 +14,7 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +from multimolecule.module import HeadConfig from multimolecule.tokenisers import DnaTokenizer, ProteinTokenizer, RnaTokenizer from .calm import ( @@ -127,6 +128,7 @@ __all__ = [ "PreTrainedConfig", + "HeadConfig", "DnaTokenizer", "RnaTokenizer", "ProteinTokenizer", diff --git a/multimolecule/models/calm/configuration_calm.py b/multimolecule/models/calm/configuration_calm.py index c5d73c03..032bda8e 100644 --- a/multimolecule/models/calm/configuration_calm.py +++ b/multimolecule/models/calm/configuration_calm.py @@ -127,5 +127,5 @@ def __init__( self.use_cache = use_cache self.emb_layer_norm_before = emb_layer_norm_before self.token_dropout = token_dropout - self.head = HeadConfig(**head if head is not None else {}) - self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {}) + self.head = HeadConfig(**head) if head is not None else None + self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None diff --git a/multimolecule/models/calm/modeling_calm.py b/multimolecule/models/calm/modeling_calm.py index 25c1eba4..c8abdffe 100644 --- a/multimolecule/models/calm/modeling_calm.py +++ b/multimolecule/models/calm/modeling_calm.py @@ -270,9 +270,9 @@ class CaLmForSequencePrediction(CaLmPreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.tensor([[1]])) >>> output["logits"].shape - torch.Size([1, 2]) + torch.Size([1, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: CaLmConfig): @@ -334,9 +334,9 @@ class CaLmForTokenPrediction(CaLmPreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.randint(2, (1, 5))) >>> output["logits"].shape - torch.Size([1, 5, 2]) + torch.Size([1, 5, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: CaLmConfig): @@ -398,9 +398,9 @@ class CaLmForContactPrediction(CaLmPreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.randint(2, (1, 5, 5))) >>> output["logits"].shape - torch.Size([1, 5, 5, 2]) + torch.Size([1, 5, 5, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: CaLmConfig): diff --git a/multimolecule/models/configuration_utils.py b/multimolecule/models/configuration_utils.py index 2047d671..ce6f10ea 100644 --- a/multimolecule/models/configuration_utils.py +++ b/multimolecule/models/configuration_utils.py @@ -30,7 +30,8 @@ class PreTrainedConfig(PretrainedConfig): Base class for all model configuration classes. """ - head: HeadConfig + head: HeadConfig | None + num_labels: int = 1 hidden_size: int @@ -42,7 +43,15 @@ class PreTrainedConfig(PretrainedConfig): null_token_id: int = 5 def __init__( - self, pad_token_id=0, bos_token_id=1, eos_token_id=2, unk_token_id=3, mask_token_id=4, null_token_id=5, **kwargs + self, + pad_token_id: int = 0, + bos_token_id: int = 1, + eos_token_id: int = 2, + unk_token_id: int = 3, + mask_token_id: int = 4, + null_token_id: int = 5, + num_labels: int = 1, + **kwargs, ): super().__init__( pad_token_id=pad_token_id, @@ -51,6 +60,7 @@ def __init__( unk_token_id=unk_token_id, mask_token_id=mask_token_id, null_token_id=null_token_id, + num_labels=num_labels, **kwargs, ) diff --git a/multimolecule/models/ernierna/configuration_ernierna.py b/multimolecule/models/ernierna/configuration_ernierna.py index 0648bb2d..bfd11d51 100644 --- a/multimolecule/models/ernierna/configuration_ernierna.py +++ b/multimolecule/models/ernierna/configuration_ernierna.py @@ -110,5 +110,5 @@ def __init__( self.pairwise_alpha = pairwise_alpha self.is_decoder = is_decoder self.use_cache = use_cache - self.head = HeadConfig(**head if head is not None else {}) - self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {}) + self.head = HeadConfig(**head) if head is not None else None + self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None diff --git a/multimolecule/models/ernierna/modeling_ernierna.py b/multimolecule/models/ernierna/modeling_ernierna.py index 6354a68c..1378e256 100644 --- a/multimolecule/models/ernierna/modeling_ernierna.py +++ b/multimolecule/models/ernierna/modeling_ernierna.py @@ -321,7 +321,7 @@ class ErnieRnaForSequencePrediction(ErnieRnaPreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input) >>> output["logits"].shape - torch.Size([1, 2]) + torch.Size([1, 1]) """ def __init__(self, config: ErnieRnaConfig): @@ -385,9 +385,9 @@ class ErnieRnaForTokenPrediction(ErnieRnaPreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.randint(2, (1, 5))) >>> output["logits"].shape - torch.Size([1, 5, 2]) + torch.Size([1, 5, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: ErnieRnaConfig): @@ -452,9 +452,9 @@ class ErnieRnaForContactPrediction(ErnieRnaPreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.randint(2, (1, 5, 5))) >>> output["logits"].shape - torch.Size([1, 5, 5, 2]) + torch.Size([1, 5, 5, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: ErnieRnaConfig): @@ -1183,11 +1183,8 @@ class ErnieRnaContactClassificationHead(nn.Module): def __init__(self, config: ErnieRnaConfig, head_config: HeadConfig | None = None): super().__init__() if head_config is None: - head_config = config.head + head_config = config.head or HeadConfig() self.config = head_config - self.bos_token_id = config.bos_token_id - self.eos_token_id = config.eos_token_id - self.pad_token_id = config.pad_token_id self.conv1 = nn.Conv2d(1, 8, 7, 1, 3) self.relu = nn.ReLU(inplace=True) self.dropout = nn.Dropout(p=0.3) diff --git a/multimolecule/models/rinalmo/configuration_rinalmo.py b/multimolecule/models/rinalmo/configuration_rinalmo.py index 5e21725d..1cc963b2 100644 --- a/multimolecule/models/rinalmo/configuration_rinalmo.py +++ b/multimolecule/models/rinalmo/configuration_rinalmo.py @@ -125,6 +125,6 @@ def __init__( self.use_cache = use_cache self.learnable_beta = learnable_beta self.token_dropout = token_dropout - self.head = HeadConfig(**head if head is not None else {}) - self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {}) + self.head = HeadConfig(**head) if head is not None else None + self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None self.emb_layer_norm_before = emb_layer_norm_before diff --git a/multimolecule/models/rinalmo/modeling_rinalmo.py b/multimolecule/models/rinalmo/modeling_rinalmo.py index b45d2823..d0ac6e8c 100644 --- a/multimolecule/models/rinalmo/modeling_rinalmo.py +++ b/multimolecule/models/rinalmo/modeling_rinalmo.py @@ -269,9 +269,9 @@ class RiNALMoForSequencePrediction(RiNALMoPreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.tensor([[1]])) >>> output["logits"].shape - torch.Size([1, 2]) + torch.Size([1, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: RiNALMoConfig): @@ -333,9 +333,9 @@ class RiNALMoForTokenPrediction(RiNALMoPreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.randint(2, (1, 5))) >>> output["logits"].shape - torch.Size([1, 5, 2]) + torch.Size([1, 5, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: RiNALMoConfig): @@ -397,9 +397,9 @@ class RiNALMoForContactPrediction(RiNALMoPreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.randint(2, (1, 5, 5))) >>> output["logits"].shape - torch.Size([1, 5, 5, 2]) + torch.Size([1, 5, 5, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: RiNALMoConfig): diff --git a/multimolecule/models/rnabert/configuration_rnabert.py b/multimolecule/models/rnabert/configuration_rnabert.py index f044ecc7..97632d2e 100644 --- a/multimolecule/models/rnabert/configuration_rnabert.py +++ b/multimolecule/models/rnabert/configuration_rnabert.py @@ -112,5 +112,5 @@ def __init__( self.position_embedding_type = position_embedding_type self.is_decoder = is_decoder self.use_cache = use_cache - self.head = HeadConfig(**head if head is not None else {}) - self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {}) + self.head = HeadConfig(**head) if head is not None else None + self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None diff --git a/multimolecule/models/rnabert/modeling_rnabert.py b/multimolecule/models/rnabert/modeling_rnabert.py index 74b06cf1..32f7bf01 100644 --- a/multimolecule/models/rnabert/modeling_rnabert.py +++ b/multimolecule/models/rnabert/modeling_rnabert.py @@ -37,7 +37,13 @@ from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from transformers.utils import logging -from multimolecule.module import ContactPredictionHead, MaskedLMHead, SequencePredictionHead, TokenPredictionHead +from multimolecule.module import ( + ContactPredictionHead, + HeadConfig, + MaskedLMHead, + SequencePredictionHead, + TokenPredictionHead, +) from ..modeling_outputs import ContactPredictorOutput, SequencePredictorOutput, TokenPredictorOutput from .configuration_rnabert import RnaBertConfig @@ -266,9 +272,9 @@ class RnaBertForSequencePrediction(RnaBertPreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.tensor([[1]])) >>> output["logits"].shape - torch.Size([1, 2]) + torch.Size([1, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: RnaBertConfig): @@ -330,9 +336,9 @@ class RnaBertForTokenPrediction(RnaBertPreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.randint(2, (1, 5))) >>> output["logits"].shape - torch.Size([1, 5, 2]) + torch.Size([1, 5, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: RnaBertConfig): @@ -394,9 +400,9 @@ class RnaBertForContactPrediction(RnaBertPreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.randint(2, (1, 5, 5))) >>> output["logits"].shape - torch.Size([1, 5, 5, 2]) + torch.Size([1, 5, 5, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: RnaBertConfig): @@ -1065,7 +1071,7 @@ def __init__(self, config: RnaBertConfig): vocab_size, config.vocab_size = config.vocab_size, config.ss_vocab_size self.predictions_ss = MaskedLMHead(config) config.vocab_size = vocab_size - self.seq_relationship = SequencePredictionHead(config) + self.seq_relationship = SequencePredictionHead(config, HeadConfig(num_labels=2)) def forward( self, diff --git a/multimolecule/models/rnaernie/configuration_rnaernie.py b/multimolecule/models/rnaernie/configuration_rnaernie.py index 2d540c9d..7a788297 100644 --- a/multimolecule/models/rnaernie/configuration_rnaernie.py +++ b/multimolecule/models/rnaernie/configuration_rnaernie.py @@ -108,5 +108,5 @@ def __init__( self.position_embedding_type = position_embedding_type self.is_decoder = is_decoder self.use_cache = use_cache - self.head = HeadConfig(**head if head is not None else {}) - self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {}) + self.head = HeadConfig(**head) if head is not None else None + self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None diff --git a/multimolecule/models/rnaernie/modeling_rnaernie.py b/multimolecule/models/rnaernie/modeling_rnaernie.py index 7e0f4d10..8107ee20 100644 --- a/multimolecule/models/rnaernie/modeling_rnaernie.py +++ b/multimolecule/models/rnaernie/modeling_rnaernie.py @@ -270,9 +270,9 @@ class RnaErnieForSequencePrediction(RnaErniePreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.tensor([[1]])) >>> output["logits"].shape - torch.Size([1, 2]) + torch.Size([1, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config): @@ -334,9 +334,9 @@ class RnaErnieForTokenPrediction(RnaErniePreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.randint(2, (1, 5))) >>> output["logits"].shape - torch.Size([1, 5, 2]) + torch.Size([1, 5, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: RnaErnieConfig): @@ -398,9 +398,9 @@ class RnaErnieForContactPrediction(RnaErniePreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.randint(2, (1, 5, 5))) >>> output["logits"].shape - torch.Size([1, 5, 5, 2]) + torch.Size([1, 5, 5, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: RnaErnieConfig): diff --git a/multimolecule/models/rnafm/configuration_rnafm.py b/multimolecule/models/rnafm/configuration_rnafm.py index 8fdb7f49..ef1f0c18 100644 --- a/multimolecule/models/rnafm/configuration_rnafm.py +++ b/multimolecule/models/rnafm/configuration_rnafm.py @@ -131,5 +131,5 @@ def __init__( self.use_cache = use_cache self.emb_layer_norm_before = emb_layer_norm_before self.token_dropout = token_dropout - self.head = HeadConfig(**head if head is not None else {}) - self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {}) + self.head = HeadConfig(**head) if head is not None else None + self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None diff --git a/multimolecule/models/rnafm/modeling_rnafm.py b/multimolecule/models/rnafm/modeling_rnafm.py index 99f553da..6898da9c 100644 --- a/multimolecule/models/rnafm/modeling_rnafm.py +++ b/multimolecule/models/rnafm/modeling_rnafm.py @@ -272,9 +272,9 @@ class RnaFmForSequencePrediction(RnaFmPreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.tensor([[1]])) >>> output["logits"].shape - torch.Size([1, 2]) + torch.Size([1, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: RnaFmConfig): @@ -336,9 +336,9 @@ class RnaFmForTokenPrediction(RnaFmPreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.randint(2, (1, 5))) >>> output["logits"].shape - torch.Size([1, 5, 2]) + torch.Size([1, 5, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: RnaFmConfig): @@ -400,9 +400,9 @@ class RnaFmForContactPrediction(RnaFmPreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.randint(2, (1, 5, 5))) >>> output["logits"].shape - torch.Size([1, 5, 5, 2]) + torch.Size([1, 5, 5, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: RnaFmConfig): @@ -555,7 +555,7 @@ class RnaFmForPreTraining(RnaFmPreTrainedModel): >>> output["logits"].shape torch.Size([1, 7, 26]) >>> output["contact_map"].shape - torch.Size([1, 5, 5, 2]) + torch.Size([1, 5, 5, 1]) """ _tied_weights_keys = [ diff --git a/multimolecule/models/rnamsm/configuration_rnamsm.py b/multimolecule/models/rnamsm/configuration_rnamsm.py index ae914c82..2e8150ba 100644 --- a/multimolecule/models/rnamsm/configuration_rnamsm.py +++ b/multimolecule/models/rnamsm/configuration_rnamsm.py @@ -116,5 +116,5 @@ def __init__( self.attention_type = attention_type self.embed_positions_msa = embed_positions_msa self.attention_bias = attention_bias - self.head = HeadConfig(**head if head is not None else {}) - self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {}) + self.head = HeadConfig(**head) if head is not None else None + self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None diff --git a/multimolecule/models/rnamsm/modeling_rnamsm.py b/multimolecule/models/rnamsm/modeling_rnamsm.py index 5ed6bf87..0390a129 100644 --- a/multimolecule/models/rnamsm/modeling_rnamsm.py +++ b/multimolecule/models/rnamsm/modeling_rnamsm.py @@ -176,9 +176,9 @@ class RnaMsmForSequencePrediction(RnaMsmPreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.tensor([[1]])) >>> output["logits"].shape - torch.Size([1, 2]) + torch.Size([1, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: RnaMsmConfig): @@ -239,9 +239,9 @@ class RnaMsmForTokenPrediction(RnaMsmPreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.randint(2, (1, 5))) >>> output["logits"].shape - torch.Size([1, 5, 2]) + torch.Size([1, 5, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: RnaMsmConfig): @@ -302,9 +302,9 @@ class RnaMsmForContactPrediction(RnaMsmPreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.randint(2, (1, 5, 5))) >>> output["logits"].shape - torch.Size([1, 5, 5, 2]) + torch.Size([1, 5, 5, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: RnaMsmConfig): @@ -449,7 +449,7 @@ class RnaMsmForPreTraining(RnaMsmPreTrainedModel): >>> output["logits"].shape torch.Size([1, 7, 26]) >>> output["contact_map"].shape - torch.Size([1, 5, 5, 2]) + torch.Size([1, 5, 5, 1]) """ _tied_weights_keys = [ diff --git a/multimolecule/models/splicebert/configuration_splicebert.py b/multimolecule/models/splicebert/configuration_splicebert.py index 66b46a88..f789516d 100644 --- a/multimolecule/models/splicebert/configuration_splicebert.py +++ b/multimolecule/models/splicebert/configuration_splicebert.py @@ -108,5 +108,5 @@ def __init__( self.position_embedding_type = position_embedding_type self.is_decoder = is_decoder self.use_cache = use_cache - self.head = HeadConfig(**head if head is not None else {}) - self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {}) + self.head = HeadConfig(**head) if head is not None else None + self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None diff --git a/multimolecule/models/splicebert/modeling_splicebert.py b/multimolecule/models/splicebert/modeling_splicebert.py index 1b0fd072..9d129d74 100644 --- a/multimolecule/models/splicebert/modeling_splicebert.py +++ b/multimolecule/models/splicebert/modeling_splicebert.py @@ -274,9 +274,9 @@ class SpliceBertForSequencePrediction(SpliceBertPreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.tensor([[1]])) >>> output["logits"].shape - torch.Size([1, 2]) + torch.Size([1, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: SpliceBertConfig): @@ -338,9 +338,9 @@ class SpliceBertForTokenPrediction(SpliceBertPreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.randint(2, (1, 5))) >>> output["logits"].shape - torch.Size([1, 5, 2]) + torch.Size([1, 5, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: SpliceBertConfig): @@ -402,9 +402,9 @@ class SpliceBertForContactPrediction(SpliceBertPreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.randint(2, (1, 5, 5))) >>> output["logits"].shape - torch.Size([1, 5, 5, 2]) + torch.Size([1, 5, 5, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: SpliceBertConfig): diff --git a/multimolecule/models/utrbert/configuration_utrbert.py b/multimolecule/models/utrbert/configuration_utrbert.py index d032c5ee..5230c04f 100644 --- a/multimolecule/models/utrbert/configuration_utrbert.py +++ b/multimolecule/models/utrbert/configuration_utrbert.py @@ -125,5 +125,5 @@ def __init__( self.position_embedding_type = position_embedding_type self.is_decoder = is_decoder self.use_cache = use_cache - self.head = HeadConfig(**head if head is not None else {}) - self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {}) + self.head = HeadConfig(**head) if head is not None else None + self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None diff --git a/multimolecule/models/utrbert/modeling_utrbert.py b/multimolecule/models/utrbert/modeling_utrbert.py index 1a5b47f9..688bedbe 100644 --- a/multimolecule/models/utrbert/modeling_utrbert.py +++ b/multimolecule/models/utrbert/modeling_utrbert.py @@ -264,9 +264,9 @@ class UtrBertForSequencePrediction(UtrBertPreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.tensor([[1]])) >>> output["logits"].shape - torch.Size([1, 2]) + torch.Size([1, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: UtrBertConfig): @@ -328,9 +328,9 @@ class UtrBertForTokenPrediction(UtrBertPreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.randint(2, (1, 5))) >>> output["logits"].shape - torch.Size([1, 5, 2]) + torch.Size([1, 5, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: UtrBertConfig): @@ -393,9 +393,9 @@ class UtrBertForContactPrediction(UtrBertPreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.randint(2, (1, 5, 5))) >>> output["logits"].shape - torch.Size([1, 5, 5, 2]) + torch.Size([1, 5, 5, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: UtrBertConfig): diff --git a/multimolecule/models/utrlm/configuration_utrlm.py b/multimolecule/models/utrlm/configuration_utrlm.py index f0f705de..a4f930d7 100644 --- a/multimolecule/models/utrlm/configuration_utrlm.py +++ b/multimolecule/models/utrlm/configuration_utrlm.py @@ -127,7 +127,7 @@ def __init__( self.use_cache = use_cache self.emb_layer_norm_before = emb_layer_norm_before self.token_dropout = token_dropout - self.head = HeadConfig(**head if head is not None else {}) - self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {}) + self.head = HeadConfig(**head) if head is not None else None + self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None self.ss_head = HeadConfig(**ss_head) if ss_head is not None else None self.mfe_head = HeadConfig(**mfe_head) if mfe_head is not None else None diff --git a/multimolecule/models/utrlm/modeling_utrlm.py b/multimolecule/models/utrlm/modeling_utrlm.py index 535f99f0..aae1b593 100644 --- a/multimolecule/models/utrlm/modeling_utrlm.py +++ b/multimolecule/models/utrlm/modeling_utrlm.py @@ -272,9 +272,9 @@ class UtrLmForSequencePrediction(UtrLmPreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.tensor([[1]])) >>> output["logits"].shape - torch.Size([1, 2]) + torch.Size([1, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: UtrLmConfig): @@ -336,9 +336,9 @@ class UtrLmForTokenPrediction(UtrLmPreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.randint(2, (1, 5))) >>> output["logits"].shape - torch.Size([1, 5, 2]) + torch.Size([1, 5, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: UtrLmConfig): @@ -400,9 +400,9 @@ class UtrLmForContactPrediction(UtrLmPreTrainedModel): >>> input = tokenizer("ACGUN", return_tensors="pt") >>> output = model(**input, labels=torch.randint(2, (1, 5, 5))) >>> output["logits"].shape - torch.Size([1, 5, 5, 2]) + torch.Size([1, 5, 5, 1]) >>> output["loss"] # doctest:+ELLIPSIS - tensor(..., grad_fn=) + tensor(..., grad_fn=) """ def __init__(self, config: UtrLmConfig): @@ -555,7 +555,7 @@ class UtrLmForPreTraining(UtrLmPreTrainedModel): >>> output["logits"].shape torch.Size([1, 7, 26]) >>> output["contact_map"].shape - torch.Size([1, 5, 5, 2]) + torch.Size([1, 5, 5, 1]) """ _tied_weights_keys = [ diff --git a/multimolecule/module/__init__.py b/multimolecule/module/__init__.py index 0128fe9b..dbba900b 100644 --- a/multimolecule/module/__init__.py +++ b/multimolecule/module/__init__.py @@ -14,8 +14,8 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from .criterions import Criterion -from .embeddings import PositionEmbeddingRegistry, PositionEmbeddingRegistryHF, RotaryEmbedding, SinusoidalEmbedding +from .criterions import Criterion, CriterionRegistry +from .embeddings import PositionEmbeddingRegistry, RotaryEmbedding, SinusoidalEmbedding from .heads import ( BaseHeadConfig, ContactPredictionHead, @@ -23,7 +23,6 @@ HeadOutput, HeadRegistry, HeadTransformRegistry, - HeadTransformRegistryHF, IdentityTransform, LinearTransform, MaskedLMHead, @@ -31,15 +30,18 @@ NonLinearTransform, PredictionHead, SequencePredictionHead, - TokenHeadRegistryHF, TokenKMerHead, TokenPredictionHead, ) +from .model import MultiMoleculeModel +from .registry import ModelRegistry __all__ = [ + "ModelRegistry", + "MultiMoleculeModel", + "CriterionRegistry", "Criterion", "PositionEmbeddingRegistry", - "PositionEmbeddingRegistryHF", "RotaryEmbedding", "SinusoidalEmbedding", "BaseHeadConfig", @@ -48,14 +50,12 @@ "HeadRegistry", "PredictionHead", "SequencePredictionHead", - "TokenHeadRegistryHF", "TokenPredictionHead", "TokenKMerHead", "ContactPredictionHead", "MaskedLMHead", "HeadOutput", "HeadTransformRegistry", - "HeadTransformRegistryHF", "LinearTransform", "NonLinearTransform", "IdentityTransform", diff --git a/multimolecule/module/backbones/__init__.py b/multimolecule/module/backbones/__init__.py new file mode 100644 index 00000000..d69e6292 --- /dev/null +++ b/multimolecule/module/backbones/__init__.py @@ -0,0 +1,21 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from .registry import BackboneRegistry +from .sequence import SequenceBackbone +from .sequences import SequenceRegistry + +__all__ = ["BackboneRegistry", "SequenceRegistry", "SequenceBackbone"] diff --git a/multimolecule/module/backbones/registry.py b/multimolecule/module/backbones/registry.py new file mode 100644 index 00000000..47be122d --- /dev/null +++ b/multimolecule/module/backbones/registry.py @@ -0,0 +1,21 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from __future__ import annotations + +from chanfig import Registry + +BackboneRegistry = Registry() diff --git a/multimolecule/module/backbones/sequence.py b/multimolecule/module/backbones/sequence.py new file mode 100644 index 00000000..a30cbf83 --- /dev/null +++ b/multimolecule/module/backbones/sequence.py @@ -0,0 +1,46 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from __future__ import annotations + +import torch +from chanfig import FlatDict +from danling import NestedTensor +from torch import Tensor, nn + +from .registry import BackboneRegistry +from .sequences import SequenceRegistry + + +@BackboneRegistry.register("sequence", default=True) +class SequenceBackbone(nn.Module): + def __init__(self, sequence) -> None: + super().__init__() + self.sequence = SequenceRegistry.build(**sequence) + self.sequence_dropout = nn.Dropout(sequence.pop("dropout", 0), inplace=True) + self.config = self.sequence.config + self.out_channels = self.config.hidden_size + + def forward(self, sequence: NestedTensor | Tensor, *args, **kwargs) -> tuple[FlatDict, FlatDict]: + attentions = None + input_ids, attention_mask = sequence.tensor, sequence.mask + sequence_output = self.sequence(input_ids.int(), attention_mask) + sequence_output["pooler_output"] = self.sequence_dropout(sequence_output["pooler_output"]) + sequence_output["last_hidden_state"] = self.sequence_dropout(sequence_output["last_hidden_state"]) + if "attentions" in sequence_output: + attentions = torch.stack(sequence_output["attentions"], dim=1).detach() + + return sequence_output, attentions diff --git a/multimolecule/module/backbones/sequences/__init__.py b/multimolecule/module/backbones/sequences/__init__.py new file mode 100644 index 00000000..e6e5cd08 --- /dev/null +++ b/multimolecule/module/backbones/sequences/__init__.py @@ -0,0 +1,20 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from .onehot import OneHot +from .registry import SequenceRegistry + +__all__ = ["SequenceRegistry", "OneHot"] diff --git a/multimolecule/module/backbones/sequences/onehot.py b/multimolecule/module/backbones/sequences/onehot.py new file mode 100644 index 00000000..bc4c979f --- /dev/null +++ b/multimolecule/module/backbones/sequences/onehot.py @@ -0,0 +1,39 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import torch +from chanfig import FlatDict +from torch import nn +from transformers import AutoConfig + +from .registry import SequenceRegistry + + +@SequenceRegistry.register("onehot") +class OneHot(nn.Module): + def __init__(self, pretrained: str) -> None: + super().__init__() + self.config = AutoConfig.from_pretrained(str(pretrained)) + self.module = nn.Embedding(self.config.vocab_size, self.config.hidden_size) + + def forward(self, input_ids, attn_mask) -> FlatDict: + output = FlatDict() + output["last_hidden_state"] = self.module(input_ids) + valid_length = attn_mask.sum(dim=1) + output["pooler_output"] = torch.stack( + [t[: valid_length[i]].sum(0) for i, t in enumerate(output["last_hidden_state"])] + ) + return output diff --git a/multimolecule/module/backbones/sequences/registry.py b/multimolecule/module/backbones/sequences/registry.py new file mode 100644 index 00000000..c9178231 --- /dev/null +++ b/multimolecule/module/backbones/sequences/registry.py @@ -0,0 +1,66 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from __future__ import annotations + +import danling as dl +import transformers +from chanfig import Registry as Registry_ +from torch import nn +from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel + + +class Registry(Registry_): # pylint: disable=too-few-public-methods + def build( + self, + type: str | None = None, + name: str | None = None, + use_pretrained: bool = True, + gradient_checkpoint: bool = False, + checkpoint: str | None = None, + *args, + **kwargs, + ) -> nn.Module: + if type is not None: + if type in self: + sequence_cls = self.lookup(type) + sequence = self.init(sequence_cls, *args, **kwargs) + if checkpoint is not None: + sequence.load_state_dict(dl.load(checkpoint)) + elif hasattr(transformers, type + "Model"): + if use_pretrained: + sequence_cls: PreTrainedModel = getattr(transformers, type + "Model") # type: ignore[no-redef] + sequence = sequence_cls.from_pretrained(name, *args, **kwargs) + else: + config_cls: PretrainedConfig = getattr(transformers, type + "Config") + config, kwargs = config_cls.from_pretrained(name, return_unused_kwargs=True, **kwargs) + sequence_cls: PreTrainedModel = getattr(transformers, type + "Model") # type: ignore[no-redef] + sequence = sequence_cls.from_config(config, *args, **kwargs) + else: + raise ValueError(f"Sequence {type} not found in registry or transformers") + else: + if use_pretrained: + sequence = AutoModel.from_pretrained(name, *args, **kwargs) + else: + config, kwargs = AutoConfig.from_pretrained(name, return_unused_kwargs=True, **kwargs) + sequence = AutoModel.from_config(config, *args, **kwargs) + + if gradient_checkpoint: + sequence.gradient_checkpointing_enable() + return sequence + + +SequenceRegistry = Registry() diff --git a/multimolecule/module/criterions/__init__.py b/multimolecule/module/criterions/__init__.py index 104334b5..4b9adf7e 100644 --- a/multimolecule/module/criterions/__init__.py +++ b/multimolecule/module/criterions/__init__.py @@ -14,6 +14,18 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +from .binary import BCEWithLogitsLoss from .generic import Criterion +from .multiclass import CrossEntropyLoss +from .multilabel import MultiLabelSoftMarginLoss +from .registry import CriterionRegistry +from .regression import MSELoss -__all__ = ["Criterion"] +__all__ = [ + "CriterionRegistry", + "Criterion", + "MSELoss", + "BCEWithLogitsLoss", + "CrossEntropyLoss", + "MultiLabelSoftMarginLoss", +] diff --git a/multimolecule/module/criterions/binary.py b/multimolecule/module/criterions/binary.py new file mode 100644 index 00000000..0bf53e59 --- /dev/null +++ b/multimolecule/module/criterions/binary.py @@ -0,0 +1,44 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +from danling import NestedTensor +from torch import Tensor, nn + +from .registry import CriterionRegistry + +if TYPE_CHECKING: + from ..heads.config import HeadConfig + + +@CriterionRegistry.register("binary") +class BCEWithLogitsLoss(nn.BCEWithLogitsLoss): + def __init__(self, config: HeadConfig) -> None: + super().__init__(**config.get("loss", {})) + self.config = config + + def forward(self, input: NestedTensor | Tensor, target: NestedTensor | Tensor) -> Tensor: + if isinstance(input, NestedTensor): + input = torch.cat(input.flatten().storage()) + if isinstance(target, NestedTensor): + target = torch.cat(target.flatten().storage()) + if input.ndim == target.ndim + 1: + input = input.squeeze(-1) + return super().forward(input, target.float()) diff --git a/multimolecule/module/criterions/generic.py b/multimolecule/module/criterions/generic.py index b003c81d..a6731933 100644 --- a/multimolecule/module/criterions/generic.py +++ b/multimolecule/module/criterions/generic.py @@ -17,8 +17,8 @@ from __future__ import annotations from typing import TYPE_CHECKING +from warnings import warn -import torch from danling import NestedTensor from torch import Tensor, nn from torch.nn import functional as F @@ -26,10 +26,13 @@ if TYPE_CHECKING: from ..heads.config import HeadConfig +from .registry import CriterionRegistry + +@CriterionRegistry.register(default=True) class Criterion(nn.Module): - problem_types = ["regression", "single_label_classification", "multi_label_classification"] + problem_types = ["regression", "binary", "multiclass", "multilabel"] def __init__(self, config: HeadConfig) -> None: super().__init__() @@ -41,21 +44,31 @@ def forward(self, logits: Tensor | NestedTensor, labels: Tensor | NestedTensor) if labels is None: return None if self.problem_type is None: - if self.num_labels == 1: + if labels.is_floating_point(): self.problem_type = "regression" - elif self.num_labels > 1 and labels.dtype in (torch.long, torch.int): - self.problem_type = "single_label_classification" + elif self.num_labels == 1: + self.problem_type = "binary" + elif labels.unique().numel() == 2: + self.problem_type = "multilabel" else: - self.problem_type = "multi_label_classification" + self.problem_type = "multiclass" + warn( + f"`problem_type` is not set. Assuming {self.problem_type}. \n" + "This can lead to unexpected behavior. Please set `problem_type` explicitly." + ) self.config.problem_type = self.problem_type if self.problem_type == "regression": labels = labels.to(logits.dtype) if self.num_labels == 1: return F.mse_loss(logits.squeeze(), labels.squeeze()) logits, labels = logits.view(-1, self.num_labels), labels.view(-1, self.num_labels) - return sum(F.mse_loss(logits[:, i], labels[:, i]).sqrt() for i in range(self.num_labels)) - if self.problem_type == "single_label_classification": + return sum(F.mse_loss(logits[:, i], labels[:, i]).sqrt() for i in range(self.num_labels)) # type: ignore + if self.problem_type == "multiclass": return F.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1)) - if self.problem_type == "multi_label_classification": - return F.binary_cross_entropy_with_logits(logits, labels) + if self.problem_type == "binary": + if logits.ndim == labels.ndim + 1: + logits = logits.squeeze(-1) + return F.binary_cross_entropy_with_logits(logits, labels.to(logits.dtype)) + if self.problem_type == "multilabel": + return F.multilabel_soft_margin_loss(logits, labels.to(logits.dtype)) raise ValueError(f"problem_type should be one of {self.problem_types}, but got {self.problem_type}") diff --git a/multimolecule/module/criterions/multiclass.py b/multimolecule/module/criterions/multiclass.py new file mode 100644 index 00000000..f7070e94 --- /dev/null +++ b/multimolecule/module/criterions/multiclass.py @@ -0,0 +1,44 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +from danling import NestedTensor +from torch import Tensor, nn + +if TYPE_CHECKING: + from ..heads.config import HeadConfig + +from .registry import CriterionRegistry + + +@CriterionRegistry.register("multiclass") +class CrossEntropyLoss(nn.CrossEntropyLoss): + def __init__(self, config: HeadConfig) -> None: + super().__init__(**config.get("loss", {})) + self.config = config + + def forward(self, input: NestedTensor | Tensor, target: NestedTensor | Tensor) -> Tensor: + if isinstance(input, NestedTensor): + input = torch.cat(input.storage()) + if isinstance(target, NestedTensor): + target = torch.cat(target.storage()) + if input.ndim > 2: + input, target = input.view(-1, input.size(-1)), target.view(-1) + return super().forward(input, target.long()) diff --git a/multimolecule/module/criterions/multilabel.py b/multimolecule/module/criterions/multilabel.py new file mode 100644 index 00000000..c72bb9f9 --- /dev/null +++ b/multimolecule/module/criterions/multilabel.py @@ -0,0 +1,44 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +from danling import NestedTensor +from torch import Tensor, nn + +if TYPE_CHECKING: + from ..heads.config import HeadConfig + +from .registry import CriterionRegistry + + +@CriterionRegistry.register("multilabel") +class MultiLabelSoftMarginLoss(nn.MultiLabelSoftMarginLoss): + def __init__(self, config: HeadConfig) -> None: + super().__init__(**config.get("loss", {})) + self.config = config + + def forward(self, input: NestedTensor | Tensor, target: NestedTensor | Tensor) -> Tensor: + if isinstance(target, NestedTensor) and target.ndim > 2: + input, target = input.view(-1, input.size(-1)), target.view(-1, target.size(-1)) + if isinstance(input, NestedTensor): + input = torch.cat(input.storage()) + if isinstance(target, NestedTensor): + target = torch.cat(target.storage()) + return super().forward(input, target.float()) diff --git a/multimolecule/module/criterions/registry.py b/multimolecule/module/criterions/registry.py new file mode 100644 index 00000000..856341f7 --- /dev/null +++ b/multimolecule/module/criterions/registry.py @@ -0,0 +1,29 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from chanfig import ConfigRegistry as Registry_ +from torch import nn + + +class Registry(Registry_): # pylint: disable=too-few-public-methods + key = "problem_type" + + def build(self, config) -> nn.Module: # type: ignore[override] + name = getattr(config, self.getattr("key")) + return self.init(self.lookup(name), config) # type: ignore[arg-type] + + +CriterionRegistry = Registry(fallback=True) diff --git a/multimolecule/module/criterions/regression.py b/multimolecule/module/criterions/regression.py new file mode 100644 index 00000000..4f39e0eb --- /dev/null +++ b/multimolecule/module/criterions/regression.py @@ -0,0 +1,44 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +from danling import NestedTensor +from torch import Tensor, nn + +if TYPE_CHECKING: + from ..heads.config import HeadConfig + +from .registry import CriterionRegistry + + +@CriterionRegistry.register("regression") +class MSELoss(nn.MSELoss): + def __init__(self, config: HeadConfig) -> None: + super().__init__(**config.get("loss", {})) + self.config = config + + def forward(self, input: NestedTensor | Tensor, target: NestedTensor | Tensor) -> Tensor: + if isinstance(input, NestedTensor): + input = torch.cat(input.flatten().storage()) + if isinstance(target, NestedTensor): + target = torch.cat(target.flatten().storage()) + if input.ndim == target.ndim + 1: + target = target.unsqueeze(-1) + return super().forward(input, target.to(input.dtype)) diff --git a/multimolecule/module/heads/__init__.py b/multimolecule/module/heads/__init__.py index 8cc91f29..0e857c5e 100644 --- a/multimolecule/module/heads/__init__.py +++ b/multimolecule/module/heads/__init__.py @@ -21,14 +21,8 @@ from .pretrain import MaskedLMHead from .registry import HeadRegistry from .sequence import SequencePredictionHead -from .token import TokenHeadRegistryHF, TokenKMerHead, TokenPredictionHead -from .transform import ( - HeadTransformRegistry, - HeadTransformRegistryHF, - IdentityTransform, - LinearTransform, - NonLinearTransform, -) +from .token import TokenKMerHead, TokenPredictionHead +from .transform import HeadTransformRegistry, IdentityTransform, LinearTransform, NonLinearTransform __all__ = [ "BaseHeadConfig", @@ -37,14 +31,12 @@ "HeadRegistry", "PredictionHead", "SequencePredictionHead", - "TokenHeadRegistryHF", "TokenPredictionHead", "TokenKMerHead", "ContactPredictionHead", "MaskedLMHead", "HeadOutput", "HeadTransformRegistry", - "HeadTransformRegistryHF", "LinearTransform", "NonLinearTransform", "IdentityTransform", diff --git a/multimolecule/module/heads/config.py b/multimolecule/module/heads/config.py index bb0dbba6..3b9ee64b 100644 --- a/multimolecule/module/heads/config.py +++ b/multimolecule/module/heads/config.py @@ -16,15 +16,13 @@ from __future__ import annotations -from collections import OrderedDict -from dataclasses import dataclass +from chanfig import FlatDict -class BaseHeadConfig(OrderedDict): +class BaseHeadConfig(FlatDict): pass -@dataclass class HeadConfig(BaseHeadConfig): r""" Configuration class for a prediction head. @@ -35,8 +33,8 @@ class HeadConfig(BaseHeadConfig): Head should look for [`Config.num_labels`][multimolecule.PreTrainedConfig] if is `None`. problem_type: - Problem type for `XxxForYyyPrediction` models. Can be one of `"regression"`, - `"single_label_classification"` or `"multi_label_classification"`. + Problem type for `XxxForYyyPrediction` models. Can be one of `"binary"`, `"regression"`, + `"multiclass"` or `"multilabel"`. Head should look for [`Config.problem_type`][multimolecule.PreTrainedConfig] if is `None`. hidden_size: @@ -55,14 +53,18 @@ class HeadConfig(BaseHeadConfig): The activation function of the final prediction output. layer_norm_eps: The epsilon used by the layer normalization layers. - output_name (`str`, *optional*): + output_name: The name of the tensor required in model outputs. If is `None`, will use the default output name of the corresponding head. + type: + The type of the head in the model. + + This is used by [`MultiMoleculeModel`][multimolecule.MultiMoleculeModel] to construct heads. """ - num_labels: int = None # type: ignore[assignment] - problem_type: str = None # type: ignore[assignment] + num_labels: int | None = None + problem_type: str | None = None hidden_size: int | None = None dropout: float = 0.0 transform: str | None = None @@ -71,9 +73,9 @@ class HeadConfig(BaseHeadConfig): act: str | None = None layer_norm_eps: float = 1e-12 output_name: str | None = None + type: str | None = None -@dataclass class MaskedLMHeadConfig(BaseHeadConfig): r""" Configuration class for a Masked Language Modeling head. @@ -95,7 +97,7 @@ class MaskedLMHeadConfig(BaseHeadConfig): The activation function of the final prediction output. layer_norm_eps: The epsilon used by the layer normalization layers. - output_name (`str`, *optional*): + output_name: The name of the tensor required in model outputs. If is `None`, will use the default output name of the corresponding head. diff --git a/multimolecule/module/heads/contact.py b/multimolecule/module/heads/contact.py index 50ec4fbb..cdef94d4 100644 --- a/multimolecule/module/heads/contact.py +++ b/multimolecule/module/heads/contact.py @@ -16,10 +16,12 @@ from __future__ import annotations -from typing import Mapping, Tuple +from typing import Callable, Mapping, Tuple, Type import torch from danling import NestedTensor +from danling.modules import MLP +from lazy_imports import try_import from torch import Tensor, nn from transformers.modeling_outputs import ModelOutput from typing_extensions import TYPE_CHECKING @@ -28,13 +30,55 @@ from .generic import PredictionHead from .output import HeadOutput from .registry import HeadRegistry -from .utils import average_product_correct, symmetrize + +with try_import() as tv: + from torchvision.models.resnet import BasicBlock, Bottleneck if TYPE_CHECKING: from multimolecule.models import PreTrainedConfig -@HeadRegistry.register("contact") +@HeadRegistry.contact.register("default", default=True) +class ContactHead(PredictionHead): + + output_name: str = "last_hidden_state" + + def __init__(self, config: PreTrainedConfig, head_config: HeadConfig | None = None): + super().__init__(config, head_config) + out_channels: int = self.config.hidden_size # type: ignore[assignment] + self.qk_proj = nn.Linear(out_channels, 2 * out_channels) + self.ffn = MLP(1, out_channels, residual=False) + + def forward( # type: ignore[override] # pylint: disable=arguments-renamed + self, + outputs: ModelOutput | Mapping | Tuple[Tensor, ...], + attention_mask: Tensor | None = None, + input_ids: NestedTensor | Tensor | None = None, + labels: Tensor | None = None, + output_name: str | None = None, + **kwargs, + ) -> HeadOutput: + if isinstance(outputs, (Mapping, ModelOutput)): + output = outputs[output_name or self.output_name] + elif isinstance(outputs, tuple): + output = outputs[0] + else: + raise ValueError(f"Unsupported type for outputs: {type(outputs)}") + + if attention_mask is None: + attention_mask = self._get_attention_mask(input_ids) + output = output * attention_mask.unsqueeze(-1) + output, _, _ = self._remove_special_tokens(output, attention_mask, input_ids) + + q, k = self.qk_proj(output).chunk(2, dim=-1) + contact_map = (q @ k.transpose(-2, -1)).unsqueeze(-1) + contact_map = contact_map + self.ffn(contact_map) + if "continuous" in outputs: + contact_map = contact_map * (1 + outputs["continuous"].unsqueeze(dim=-1)) # type: ignore[call-overload] + return super().forward(contact_map, labels) + + +@HeadRegistry.contact.register("attention") class ContactPredictionHead(PredictionHead): r""" Head for tasks in contact-level. @@ -50,13 +94,20 @@ class ContactPredictionHead(PredictionHead): output_name: str = "attentions" r"""The default output to use for the head.""" + requires_attention: bool = True + def __init__(self, config: PreTrainedConfig, head_config: HeadConfig | None = None): super().__init__(config, head_config) - self.bos_token_id = config.bos_token_id - self.eos_token_id = config.eos_token_id - self.pad_token_id = config.pad_token_id - self.decoder = nn.Linear( - config.num_hidden_layers * config.num_attention_heads, self.num_labels, bias=self.config.bias + self.config.hidden_size = config.num_hidden_layers * config.num_attention_heads + num_layers = self.config.get("num_layers", 16) + num_channels = self.config.get("num_channels", self.config.hidden_size // 10) # type: ignore[operator] + block = self.config.get("block", "auto") + self.decoder = ResNet( + num_layers=num_layers, + hidden_size=self.config.hidden_size, # type: ignore[arg-type] + block=block, + num_channels=num_channels, + num_labels=self.num_labels, ) if head_config is not None and head_config.output_name is not None: self.output_name = head_config.output_name @@ -81,19 +132,6 @@ def forward( # type: ignore[override] # pylint: disable=arguments-renamed output_name: The name of the output to use. Defaults to `self.output_name`. """ - if attention_mask is None: - if isinstance(input_ids, NestedTensor): - input_ids, attention_mask = input_ids.tensor, input_ids.mask - else: - if input_ids is None: - raise ValueError( - f"Either attention_mask or input_ids must be provided for {self.__class__.__name__} to work." - ) - if self.pad_token_id is None: - raise ValueError( - f"pad_token_id must be provided when attention_mask is not passed to {self.__class__.__name__}." - ) - attention_mask = input_ids.ne(self.pad_token_id) if isinstance(outputs, (Mapping, ModelOutput)): output = outputs[output_name or self.output_name] @@ -105,13 +143,14 @@ def forward( # type: ignore[override] # pylint: disable=arguments-renamed # This makes no difference most of the time because the other tokens won't attend to them, # but it does for the contact prediction task, which takes attentions as input, # so we have to mimic that here. + if attention_mask is None: + attention_mask = self._get_attention_mask(input_ids) attention_mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2) - attentions *= attention_mask[:, None, None, :, :] + attentions = attentions * attention_mask[:, None, None, :, :] # remove cls token attentions if self.bos_token_id is not None: attentions = attentions[..., 1:, 1:] - # process attention_mask and input_ids to make removal of eos token happy attention_mask = attention_mask[..., 1:] if input_ids is not None: input_ids = input_ids[..., 1:] @@ -124,14 +163,172 @@ def forward( # type: ignore[override] # pylint: disable=arguments-renamed seq_length = attention_mask.size(-1) eos_mask = torch.arange(seq_length, device=attentions.device).unsqueeze(0) == last_valid_indices eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2) - attentions *= eos_mask[:, None, None, :, :] + attentions = attentions * eos_mask[:, None, None, :, :] attentions = attentions[..., :-1, :-1] # features: batch x channels x input_ids x input_ids (symmetric) batch_size, layers, heads, seqlen, _ = attentions.size() attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen) - attentions = attentions.to(self.decoder.weight.device) + attentions = attentions.to(self.decoder.proj.weight.device) attentions = average_product_correct(symmetrize(attentions)) attentions = attentions.permute(0, 2, 3, 1).squeeze(3) return super().forward(attentions, labels, **kwargs) + + +@HeadRegistry.contact.register("logits") +class ContactLogitsHead(PredictionHead): + r""" + Head for tasks in contact-level. + + Performs symmetrization, and average product correct. + + Args: + config: The configuration object for the model. + head_config: The configuration object for the head. + If None, will use configuration from the `config`. + """ + + output_name: str = "last_hidden_state" + r"""The default output to use for the head.""" + + requires_attention: bool = False + + def __init__(self, config: PreTrainedConfig, head_config: HeadConfig | None = None): + super().__init__(config, head_config) + num_layers = self.config.get("num_layers", 16) + num_channels = self.config.get("num_channels", self.config.hidden_size // 10) # type: ignore[operator] + block = self.config.get("block", "auto") + self.decoder = ResNet( + num_layers=num_layers, + hidden_size=self.config.hidden_size, # type: ignore[arg-type] + block=block, + num_channels=num_channels, + num_labels=self.num_labels, + ) + if head_config is not None and head_config.output_name is not None: + self.output_name = head_config.output_name + + def forward( # type: ignore[override] # pylint: disable=arguments-renamed + self, + outputs: ModelOutput | Mapping | Tuple[Tensor, ...], + attention_mask: Tensor | None = None, + input_ids: NestedTensor | Tensor | None = None, + labels: Tensor | None = None, + output_name: str | None = None, + **kwargs, + ) -> HeadOutput: + r""" + Forward pass of the ContactPredictionHead. + + Args: + outputs: The outputs of the model. + attention_mask: The attention mask for the inputs. + input_ids: The input ids for the inputs. + labels: The labels for the head. + output_name: The name of the output to use. + Defaults to `self.output_name`. + """ + if isinstance(outputs, (Mapping, ModelOutput)): + output = outputs[output_name or self.output_name] + elif isinstance(outputs, tuple): + output = outputs[0] + else: + raise ValueError(f"Unsupported type for outputs: {type(outputs)}") + + if attention_mask is None: + attention_mask = self._get_attention_mask(input_ids) + output = output * attention_mask.unsqueeze(-1) + output, _, _ = self._remove_special_tokens(output, attention_mask, input_ids) + + # make symmetric contact map + contact_map = output.unsqueeze(1) * output.unsqueeze(2) + + return super().forward(contact_map, labels, **kwargs) + + +class ResNet(nn.Module): + def __init__( + self, + num_layers: int, + hidden_size: int, + block: Type[BasicBlock | Bottleneck] | str = "auto", + num_channels: int | None = None, + num_labels: int = 1, + norm_layer: Callable[..., nn.Module] | None = None, + zero_init_residual: bool = True, + ) -> None: + tv.check() + super().__init__() + + if block == "auto": + block = BasicBlock if num_layers < 50 else Bottleneck + elif block in ("basic", "BasicBlock"): + block = BasicBlock + elif block in ("bottleneck", "Bottleneck"): + block = Bottleneck + else: + raise ValueError(f"Unknown block type: {block}") + if num_channels is None: + num_channels = hidden_size // 10 + if norm_layer is None: + norm_layer = LayerNorm2D + + self.proj = nn.Conv2d(hidden_size, num_channels, kernel_size=3, stride=1, padding=1, bias=False) + self.norm = norm_layer(num_channels) + self.relu = nn.ReLU(inplace=True) + self.layers = nn.Sequential( + *[block(num_channels, num_channels, norm_layer=norm_layer) for _ in range(num_layers)] # type: ignore + ) + self.output = nn.Linear(num_channels, num_labels) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck) and m.bn3.weight is not None: + nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] + elif isinstance(m, BasicBlock) and m.bn2.weight is not None: + nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] + + def forward(self, x: Tensor) -> Tensor: + x = self.proj(x.transpose(1, 3)) + x = self.norm(x) + x = self.relu(x) + x = self.layers(x) + x = self.output(x.transpose(1, 3)) + return x + + +class LayerNorm2D(nn.GroupNorm): + def __init__(self, num_features: int, eps: float = 1e-5, elementwise_affine: bool = True) -> None: + super().__init__(num_channels=num_features, eps=eps, affine=elementwise_affine, num_groups=1) + self.num_channels = num_features + + def __repr__(self): + return f"{self.__class__.__name__}(num_channels={self.num_channels}, eps={self.eps}, affine={self.affine})" + + +def symmetrize(x): + "Make layer symmetric in final two dimensions, used for contact prediction." + return x + x.transpose(-1, -2) + + +def average_product_correct(x): + "Perform average product correct, used for contact prediction." + a1 = x.sum(-1, keepdims=True) + a2 = x.sum(-2, keepdims=True) + a12 = x.sum((-1, -2), keepdims=True) + + avg = a1 * a2 + avg.div_(a12) # in-place to reduce memory + normalized = x - avg + return normalized diff --git a/multimolecule/module/heads/generic.py b/multimolecule/module/heads/generic.py index d97950a2..ae82e178 100644 --- a/multimolecule/module/heads/generic.py +++ b/multimolecule/module/heads/generic.py @@ -16,7 +16,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Tuple from warnings import warn import torch @@ -24,7 +24,7 @@ from torch import Tensor, nn from transformers.activations import ACT2FN -from ..criterions import Criterion +from ..criterions import CriterionRegistry from .config import HeadConfig from .output import HeadOutput from .transform import HeadTransformRegistryHF @@ -44,24 +44,28 @@ class PredictionHead(nn.Module): """ num_labels: int + requires_attention: bool = False def __init__(self, config: PreTrainedConfig, head_config: HeadConfig | None = None): super().__init__() if head_config is None: - head_config = config.head + head_config = config.head or HeadConfig(num_labels=config.num_labels) + elif head_config.num_labels is None: + head_config.num_labels = config.num_labels self.config = head_config if self.config.hidden_size is None: self.config.hidden_size = config.hidden_size - if self.config.num_labels is None: - self.config.num_labels = config.num_labels if self.config.problem_type is None: self.config.problem_type = config.problem_type - self.num_labels = self.config.num_labels + self.bos_token_id = config.bos_token_id + self.eos_token_id = config.eos_token_id + self.pad_token_id = config.pad_token_id + self.num_labels = self.config.num_labels # type: ignore[assignment] self.dropout = nn.Dropout(self.config.dropout) self.transform = HeadTransformRegistryHF.build(self.config) - self.decoder = nn.Linear(config.hidden_size, self.num_labels, bias=self.config.bias) + self.decoder = nn.Linear(self.config.hidden_size, self.num_labels, bias=self.config.bias) self.activation = ACT2FN[self.config.act] if self.config.act is not None else None - self.criterion = Criterion(self.config) + self.criterion = CriterionRegistry.build(self.config) def forward(self, embeddings: Tensor, labels: Tensor | None, **kwargs) -> HeadOutput: r""" @@ -85,6 +89,42 @@ def forward(self, embeddings: Tensor, labels: Tensor | None, **kwargs) -> HeadOu if isinstance(labels, NestedTensor): if isinstance(output, Tensor): output = labels.nested_like(output, strict=False) - return HeadOutput(output, self.criterion(torch.cat(output.storage()), torch.cat(labels.storage()))) + return HeadOutput(output, self.criterion(output.concat, labels.concat)) return HeadOutput(output, self.criterion(output, labels)) return HeadOutput(output) + + def _get_attention_mask(self, input_ids: NestedTensor | Tensor) -> Tensor: + if isinstance(input_ids, NestedTensor): + return input_ids.mask + if input_ids is None: + raise ValueError( + f"Either attention_mask or input_ids must be provided for {self.__class__.__name__} to work." + ) + if self.pad_token_id is None: + raise ValueError( + f"pad_token_id must be provided when attention_mask is not passed to {self.__class__.__name__}." + ) + return input_ids.ne(self.pad_token_id) + + def _remove_special_tokens( + self, output: Tensor, attention_mask: Tensor, input_ids: Tensor | None + ) -> Tuple[Tensor, Tensor, Tensor]: + # remove cls token embeddings + if self.bos_token_id is not None: + output = output[..., 1:, :] + attention_mask = attention_mask[..., 1:] + if input_ids is not None: + input_ids = input_ids[..., 1:] + # remove eos token embeddings + if self.eos_token_id is not None: + if input_ids is not None: + eos_mask = input_ids.ne(self.eos_token_id).to(output) + input_ids = input_ids[..., :-1] + else: + last_valid_indices = attention_mask.sum(dim=-1) + seq_length = attention_mask.size(-1) + eos_mask = torch.arange(seq_length, device=output.device) == last_valid_indices.unsqueeze(1) + output = output * eos_mask[:, :, None] + output = output[..., :-1, :] + attention_mask = attention_mask[..., 1:] + return output, attention_mask, input_ids diff --git a/multimolecule/module/heads/pretrain.py b/multimolecule/module/heads/pretrain.py index 994cb8ca..c6968c4b 100644 --- a/multimolecule/module/heads/pretrain.py +++ b/multimolecule/module/heads/pretrain.py @@ -53,8 +53,8 @@ def __init__( ): super().__init__() if head_config is None: - head_config = config.lm_head if hasattr(config, "lm_head") else config.head # type: ignore[assignment] - self.config: MaskedLMHeadConfig = head_config # type: ignore[assignment] + head_config = (config.lm_head if hasattr(config, "lm_head") else config.head) or MaskedLMHeadConfig() + self.config: MaskedLMHeadConfig = head_config if self.config.hidden_size is None: self.config.hidden_size = config.hidden_size self.num_labels = config.vocab_size @@ -97,6 +97,6 @@ def forward( if isinstance(labels, NestedTensor): if isinstance(output, Tensor): output = labels.nested_like(output, strict=False) - return HeadOutput(output, F.cross_entropy(torch.cat(output.storage()), torch.cat(labels.storage()))) + return HeadOutput(output, F.cross_entropy(output.concat, labels.concat)) return HeadOutput(output, F.cross_entropy(output.view(-1, self.num_labels), labels.view(-1))) return HeadOutput(output) diff --git a/multimolecule/module/heads/registry.py b/multimolecule/module/heads/registry.py index e5393e4e..6db3b680 100644 --- a/multimolecule/module/heads/registry.py +++ b/multimolecule/module/heads/registry.py @@ -14,6 +14,16 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from chanfig import Registry +from chanfig import ConfigRegistry as Registry_ +from torch import nn + + +class Registry(Registry_): # pylint: disable=too-few-public-methods + key = "type" + + def build(self, config, head_config) -> nn.Module: # type: ignore[override] + name = getattr(head_config, self.getattr("key")) + return self.init(self.lookup(name), config, head_config) # type: ignore[arg-type] + HeadRegistry = Registry(default_factory=Registry, fallback=True) diff --git a/multimolecule/module/heads/token.py b/multimolecule/module/heads/token.py index dbe6c721..5697d36c 100644 --- a/multimolecule/module/heads/token.py +++ b/multimolecule/module/heads/token.py @@ -19,7 +19,6 @@ from functools import partial from typing import TYPE_CHECKING, Mapping, Tuple -import torch from chanfig import ConfigRegistry from danling import NestedTensor from torch import Tensor @@ -54,9 +53,6 @@ class TokenPredictionHead(PredictionHead): def __init__(self, config: PreTrainedConfig, head_config: HeadConfig | None = None): super().__init__(config, head_config) - self.bos_token_id = config.bos_token_id - self.eos_token_id = config.eos_token_id - self.pad_token_id = config.pad_token_id if head_config is not None and head_config.output_name is not None: self.output_name = head_config.output_name @@ -80,45 +76,17 @@ def forward( # type: ignore[override] # pylint: disable=arguments-renamed output_name: The name of the output to use. Defaults to `self.output_name`. """ - if attention_mask is None: - if isinstance(input_ids, NestedTensor): - input_ids, attention_mask = input_ids.tensor, input_ids.mask - else: - if input_ids is None: - raise ValueError( - f"Either attention_mask or input_ids must be provided for {self.__class__.__name__} to work." - ) - if self.pad_token_id is None: - raise ValueError( - f"pad_token_id must be provided when attention_mask is not passed to {self.__class__.__name__}." - ) - attention_mask = input_ids.ne(self.pad_token_id) - if isinstance(outputs, (Mapping, ModelOutput)): output = outputs[output_name or self.output_name] elif isinstance(outputs, tuple): output = outputs[0] else: raise ValueError(f"Unsupported type for outputs: {type(outputs)}") - output = output * attention_mask.unsqueeze(-1) - # remove cls token embeddings - if self.bos_token_id is not None: - output = output[..., 1:, :] - # process attention_mask and input_ids to make removal of eos token happy - attention_mask = attention_mask[..., 1:] - if input_ids is not None: - input_ids = input_ids[..., 1:] - # remove eos token embeddings - if self.eos_token_id is not None: - if input_ids is not None: - eos_mask = input_ids.ne(self.eos_token_id).to(output) - else: - last_valid_indices = attention_mask.sum(dim=-1) - seq_length = attention_mask.size(-1) - eos_mask = torch.arange(seq_length, device=output.device) == last_valid_indices.unsqueeze(1) - output = output * eos_mask[:, :, None] - output = output[..., :-1, :] + if attention_mask is None: + attention_mask = self._get_attention_mask(input_ids) + output = output * attention_mask.unsqueeze(-1) + output, _, _ = self._remove_special_tokens(output, attention_mask, input_ids) return super().forward(output, labels, **kwargs) @@ -141,9 +109,6 @@ class TokenKMerHead(PredictionHead): def __init__(self, config: PreTrainedConfig, head_config: HeadConfig | None = None): super().__init__(config, head_config) self.nmers = config.nmers - self.bos_token_id = config.bos_token_id - self.eos_token_id = config.eos_token_id - self.pad_token_id = config.pad_token_id if head_config is not None and head_config.output_name is not None: self.output_name = head_config.output_name # Do not pass bos_token_id and eos_token_id to unfold_kmer_embeddings @@ -170,46 +135,17 @@ def forward( # type: ignore[override] # pylint: disable=arguments-renamed output_name: The name of the output to use. Defaults to `self.output_name`. """ - if attention_mask is None: - if isinstance(input_ids, NestedTensor): - input_ids, attention_mask = input_ids.tensor, input_ids.mask - else: - if input_ids is None: - raise ValueError( - f"Either attention_mask or input_ids must be provided for {self.__class__.__name__} to work." - ) - if self.pad_token_id is None: - raise ValueError( - f"pad_token_id must be provided when attention_mask is not passed to {self.__class__.__name__}." - ) - attention_mask = input_ids.ne(self.pad_token_id) - if isinstance(outputs, (Mapping, ModelOutput)): output = outputs[output_name or self.output_name] elif isinstance(outputs, tuple): output = outputs[0] else: raise ValueError(f"Unsupported type for outputs: {type(outputs)}") - output = output * attention_mask.unsqueeze(-1) - # remove cls token embeddings - if self.bos_token_id is not None: - output = output[..., 1:, :] - attention_mask = attention_mask[..., 1:] - if input_ids is not None: - input_ids = input_ids[..., 1:] - # remove eos token embeddings - if self.eos_token_id is not None: - if input_ids is not None: - eos_mask = input_ids.ne(self.eos_token_id).to(output) - input_ids = input_ids[..., :-1] - else: - last_valid_indices = attention_mask.sum(dim=-1) - seq_length = attention_mask.size(-1) - eos_mask = torch.arange(seq_length, device=output.device) == last_valid_indices.unsqueeze(1) - output = output * eos_mask[:, :, None] - output = output[..., :-1, :] - attention_mask = attention_mask[..., 1:] + if attention_mask is None: + attention_mask = self._get_attention_mask(input_ids) + output = output * attention_mask.unsqueeze(-1) + output, attention_mask, _ = self._remove_special_tokens(output, attention_mask, input_ids) output = self.unfold_kmer_embeddings(output, attention_mask) return super().forward(output, labels, **kwargs) diff --git a/multimolecule/module/heads/utils.py b/multimolecule/module/heads/utils.py index c5937c6d..cc1f3654 100644 --- a/multimolecule/module/heads/utils.py +++ b/multimolecule/module/heads/utils.py @@ -119,32 +119,3 @@ def unfold_kmer_embeddings( embedding = torch.cat([embedding, tensor[seq_len - 1][None, :]]) output[index, : seq_len + nmers - 1] = embedding return output - - -def rotate_half(x): - x1, x2 = x.chunk(2, dim=-1) - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(x, cos, sin): - cos = cos[:, :, : x.shape[-2], :] - sin = sin[:, :, : x.shape[-2], :] - - return (x * cos) + (rotate_half(x) * sin) - - -def symmetrize(x): - "Make layer symmetric in final two dimensions, used for contact prediction." - return x + x.transpose(-1, -2) - - -def average_product_correct(x): - "Perform average product correct, used for contact prediction." - a1 = x.sum(-1, keepdims=True) - a2 = x.sum(-2, keepdims=True) - a12 = x.sum((-1, -2), keepdims=True) - - avg = a1 * a2 - avg.div_(a12) # in-place to reduce memory - normalized = x - avg - return normalized diff --git a/multimolecule/module/model.py b/multimolecule/module/model.py new file mode 100644 index 00000000..256783be --- /dev/null +++ b/multimolecule/module/model.py @@ -0,0 +1,89 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from __future__ import annotations + +from chanfig import FlatDict +from danling import NestedTensor +from torch import Tensor, nn + +from .backbones import BackboneRegistry +from .heads import HeadRegistry +from .necks import NeckRegistry +from .registry import ModelRegistry + + +@ModelRegistry.register(default=True) +class MultiMoleculeModel(nn.Module): + def __init__( + self, + backbone: dict, + heads: dict, + neck: dict | None = None, + max_length: int = 1024, + truncation: bool = False, + ): + super().__init__() + + # Backbone + self.backbone = BackboneRegistry.build(**backbone) + backbone = self.backbone.config + out_channels = self.backbone.out_channels + + # Neck + if neck: + num_discrete = self.backbone.num_discrete + num_continuous = self.backbone.num_continuous + embed_dim = self.backbone.sequence.config.hidden_size + attention_heads = self.backbone.sequence.config.num_attention_heads + neck.update( + { + "num_discrete": num_discrete, + "num_continuous": num_continuous, + "embed_dim": embed_dim, + "attention_heads": attention_heads, + "max_length": max_length, + "truncation": truncation, + } + ) + self.neck = NeckRegistry.build(**neck) + out_channels = self.neck.out_channels + else: + self.neck = None + + # Heads + for head in heads.values(): + if "hidden_size" not in head or head["hidden_size"] is None: + head["hidden_size"] = out_channels + self.heads = nn.ModuleDict({name: HeadRegistry.build(backbone, head) for name, head in heads.items()}) + if any(getattr(h, "requires_attention", False) for h in self.heads.values()): + self.backbone.sequence.config.output_attentions = True + + def forward( + self, + sequence: NestedTensor | Tensor, + discrete: Tensor | None = None, + continuous: Tensor | None = None, + dataset: str | None = None, + **labels: NestedTensor | Tensor, + ) -> FlatDict: + ret = FlatDict() + output, _ = self.backbone(sequence, discrete, continuous) + if self.neck is not None: + output = self.neck(**output) + for task, label in labels.items(): + ret[task] = self.heads[task](output, input_ids=sequence, labels=label) + return ret diff --git a/multimolecule/module/necks/__init__.py b/multimolecule/module/necks/__init__.py new file mode 100644 index 00000000..e8f1f7e2 --- /dev/null +++ b/multimolecule/module/necks/__init__.py @@ -0,0 +1,21 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from .bert import BERTNeck +from .cat import CatNeck +from .registry import NeckRegistry + +__all__ = ["NeckRegistry", "CatNeck", "BERTNeck"] diff --git a/multimolecule/module/necks/bert.py b/multimolecule/module/necks/bert.py new file mode 100644 index 00000000..1360f0dd --- /dev/null +++ b/multimolecule/module/necks/bert.py @@ -0,0 +1,102 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from __future__ import annotations + +import torch +from chanfig import FlatDict +from danling.modules import TransformerEncoder, TransformerEncoderLayer +from torch import Tensor, nn + +from .registry import NeckRegistry + +MAX_LENGTH = 1024 + + +@NeckRegistry.register("bert") +class BERTNeck(nn.Module): + def __init__( # pylint: disable=keyword-arg-before-vararg + self, + num_discrete: int, + num_continuous: int, + embed_dim: int, + attention_heads: int, + num_layers: int = 6, + max_length: int | None = None, + truncation: bool = False, + dropout: float = 0.1, + *args, + **kwargs, + ): + super().__init__() + self.cls_token_dis = nn.Parameter(torch.zeros(embed_dim)) + self.cls_token_con = nn.Parameter(torch.zeros(embed_dim)) + if max_length is None: + if truncation: + max_length = MAX_LENGTH + 1 + num_discrete + 1 + num_continuous + else: + max_length = MAX_LENGTH * 4 + 1 + num_discrete + 1 + num_continuous + self.max_length = max_length + self.pos_embed = nn.Parameter(torch.zeros(1, self.max_length, embed_dim)) + bert_layer = TransformerEncoderLayer( + embed_dim, attention_heads, *args, dropout=dropout, attn_dropout=dropout, ffn_dropout=dropout, **kwargs + ) + self.bert = TransformerEncoder(bert_layer, num_layers) + self.out_channels = embed_dim + nn.init.normal_(self.pos_embed, std=0.02) + nn.init.trunc_normal_(self.cls_token_dis, std=0.2) + nn.init.trunc_normal_(self.cls_token_con, std=0.2) + + def forward( + self, + cls_token: Tensor | None = None, + all_tokens: Tensor | None = None, + discrete: Tensor | None = None, + continuous: Tensor | None = None, + ) -> FlatDict: + ret = FlatDict() + if cls_token is not None: + ret["cls_token"] = self._forward(cls_token, discrete, continuous) + if all_tokens is not None: + ret["all_tokens"] = self._forward(all_tokens, discrete, continuous) + return ret + + def _forward( + self, + sequence: Tensor, + discrete: Tensor | None = None, + continuous: Tensor | None = None, + ) -> Tensor: + if sequence is None: + raise ValueError("sequence should not be None.") + if sequence.dim() == 2: + sequence = sequence[:, None] + batch_size, seq_len, _ = sequence.shape + output = sequence + if discrete is not None: + cls_token_dis = self.cls_token_dis.expand(batch_size, 1, -1) + output = torch.cat((output, cls_token_dis, discrete), dim=1) + if continuous is not None: + cls_token_con = self.cls_token_con.expand(batch_size, -1)[:, None] + output = torch.cat((output, cls_token_con, continuous), dim=1) + all_len = output.shape[1] + if all_len > self.pos_embed.shape[1]: + raise ValueError("sequence length is out of range.") + output = output + self.pos_embed[:, 0:all_len, :] + output = self.bert(output)[0][:, 0:seq_len, :] + if seq_len == 1: + output = output.squeeze(1) + return output diff --git a/multimolecule/module/necks/cat.py b/multimolecule/module/necks/cat.py new file mode 100644 index 00000000..d5165a92 --- /dev/null +++ b/multimolecule/module/necks/cat.py @@ -0,0 +1,43 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from __future__ import annotations + +import torch +from chanfig import FlatDict +from torch import Tensor + +from .registry import NeckRegistry + + +@NeckRegistry.register("cat") +class CatNeck: # pylint: disable=too-few-public-methods + def __init__(self, embed_dim: int): + self.out_channels = embed_dim * 2 + + def __call__( + self, + cls_token: Tensor | None = None, + all_tokens: Tensor | None = None, + discrete: Tensor | None = None, + continuous: Tensor | None = None, + ) -> FlatDict: + ret = FlatDict() + if cls_token is not None: + ret.cls_token = torch.cat(tuple(i for i in (cls_token, discrete, continuous) if i is not None), -1) + if all_tokens is not None: + ret.all_tokens = torch.cat(tuple(i for i in (all_tokens, discrete, continuous) if i is not None), -1) + return ret diff --git a/multimolecule/module/necks/registry.py b/multimolecule/module/necks/registry.py new file mode 100644 index 00000000..c024227c --- /dev/null +++ b/multimolecule/module/necks/registry.py @@ -0,0 +1,21 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from __future__ import annotations + +from chanfig import Registry + +NeckRegistry = Registry() diff --git a/multimolecule/module/registry.py b/multimolecule/module/registry.py new file mode 100644 index 00000000..b0332463 --- /dev/null +++ b/multimolecule/module/registry.py @@ -0,0 +1,35 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from __future__ import annotations + +from chanfig import Registry as Registry_ +from torch import nn + +from .backbones import BackboneRegistry +from .backbones.sequences import SequenceRegistry +from .heads import HeadRegistry +from .necks import NeckRegistry + + +class Registry(Registry_): # pylint: disable=too-few-public-methods + def build(self, *args, **kwargs) -> nn.Module: + return super().build(*args, **kwargs) + + +ModelRegistry = Registry() + +__all__ = ["ModelRegistry", "BackboneRegistry", "SequenceRegistry", "NeckRegistry", "HeadRegistry"] diff --git a/multimolecule/runners/README.md b/multimolecule/runners/README.md new file mode 100644 index 00000000..bb1000ad --- /dev/null +++ b/multimolecule/runners/README.md @@ -0,0 +1,9 @@ +--- +authors: + - Zhiyuan Chen +date: 2024-05-04 +--- + +# runners + +`runners` provide an easy-to-use interface for running experiments. diff --git a/multimolecule/runners/__init__.py b/multimolecule/runners/__init__.py new file mode 100644 index 00000000..70fa4076 --- /dev/null +++ b/multimolecule/runners/__init__.py @@ -0,0 +1,20 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from .config import MultiMoleculeConfig +from .runner import MultiMoleculeRunner + +__all__ = ["MultiMoleculeConfig", "MultiMoleculeRunner"] diff --git a/multimolecule/runners/base_runner.py b/multimolecule/runners/base_runner.py new file mode 100644 index 00000000..79881213 --- /dev/null +++ b/multimolecule/runners/base_runner.py @@ -0,0 +1,286 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import math +import os +from functools import cached_property, partial +from typing import Any, Tuple +from warnings import warn + +import danling as dl +import torch +from art import text2art +from chanfig import NestedDict +from danling import MultiTaskMetrics +from datasets import disable_progress_bars, get_dataset_split_names +from torch import nn, optim +from torch.nn import functional as F +from torch.utils import data +from transformers import AutoTokenizer + +from multimolecule import defaults +from multimolecule.data import Dataset, DistributedMultiTaskSampler, MultiTaskDataset, MultiTaskSampler +from multimolecule.module import HeadConfig, ModelRegistry + +from .config import MultiMoleculeConfig +from .metrics import MetricRegistry + +disable_progress_bars() + + +class BaseRunner(dl.BaseRunner): + + all_datasets: NestedDict + + def __init__(self, config: MultiMoleculeConfig): + if config.art: + print(text2art("MultiMolecule", "rand-large")) + super().__init__(config) + self.name = config.name + self.tokenizer = AutoTokenizer.from_pretrained(self.config.pretrained) + self.build_datasets() + self.build_dataloaders() + self.model = ModelRegistry.build(**self.network) + if self.config.get("checkpoint"): + ckpt = dl.load(self.config.checkpoint) + model = ckpt.get("model", ckpt) + parameters = self.model.load_state_dict(model, strict=False) + if parameters.missing_keys: + raise ValueError(f"Missing keys in model: {parameters.missing_keys}") + if parameters.unexpected_keys: + warn(f"Unexpected keys in model: {parameters.unexpected_keys}") + self.model = self.model.to(self.device) + if self.distributed: + self.model = nn.parallel.DistributedDataParallel( + self.model, find_unused_parameters=True, bucket_cap_mb=32, gradient_as_bucket_view=True + ) + if self.config.training: + if self.config.optim and not (self.config.platform == "deepspeed" and self.config.deepspeed.optimizer): + self.optimizer = getattr(optim, self.config.optim.pop("name"))( + params=self.model.parameters(), **self.config.optim + ) + if self.config.sched and not (self.config.platform == "deepspeed" and self.config.deepspeed.scheduler): + self.scheduler = dl.optim.LRScheduler(self.optimizer, total_steps=self.total_steps, **self.config.sched) + self.metrics = self.build_metrics() + + def __post_init__(self): + super().__post_init__() + self.yaml(os.path.join(self.dir, "trainer.yaml")) + print(self) + print(self.get_dataset_lengths()) + + def train_step(self, data) -> Tuple[Any, torch.Tensor]: + with self.autocast(), self.accumulate(): + pred = self.model(**data) + loss = self.loss_fn(pred, data) + self.advance(loss) + self.metric_fn(pred, data) + return pred, loss + + def evaluate_step(self, data) -> Tuple[Any, torch.Tensor]: + pred = self.model(**data) + loss = self.loss_fn(pred, data) + self.metric_fn(pred, data) + return pred, loss + + def loss_fn(self, pred, data): + if self.balance == "rlw": + loss = torch.stack([p["loss"] for p in pred.values()]) + weight = F.softmax(torch.randn(len(pred), device=loss.device, dtype=loss.dtype), dim=-1) + return loss.T @ weight + if self.balance == "gls": + return math.prod(p["loss"] for p in pred.values()) ** (1 / len(pred)) + if self.balance != "ew": + warn(f"Unknown balance method {self.balance}, using equal weighting.") + return sum(p["loss"] for p in pred.values()) / len(pred) + + def metric_fn(self, pred, data): + metric = self.metrics[data["dataset"]] if "dataset" in data else self.metrics + metric.update({t: (p["logits"], data[t]) for t, p in pred.items()}) + + @cached_property + def tasks(self): + if not self.datasets: + raise ValueError("No datasets found") + if "train" in self.datasets: + return self.datasets.train.tasks + return next(iter(self.datasets.values())).tasks + + @cached_property + def dataset_tasks(self): + if not self.datasets: + raise ValueError("No datasets found") + dataset = self.datasets.train if "train" in self.datasets else next(iter(self.datasets.values())) + tasks = self.tasks + dataset_tasks = dataset.dataset_tasks if isinstance(dataset, MultiTaskDataset) else dataset.tasks + for dataset in self.datasets.values(): + if isinstance(dataset, MultiTaskDataset): + for dataset_name, tasks_ in dataset.dataset_tasks.items(): + for task_name, task_ in tasks_.items(): + if task_name not in tasks: + raise ValueError(f"Task {task_name} of dataset {dataset_name} is not defined") + task = tasks[task_name] + if task != task_: + warn( + f"Task {task_name} of dataset {dataset_name} has different configurations " + "compared to training data, using training configuration.\n" + "This may lead to unexpected behavior.", + ) + if dataset_name not in dataset_tasks: + dataset_tasks[dataset_name] = NestedDict() + if task_name not in dataset_tasks[dataset_name]: + dataset_tasks[dataset_name][task_name] = task + else: + for task_name, task_ in dataset.tasks.items(): + if task_name not in tasks: + raise ValueError(f"Task {task_name} is not defined") + task = tasks[task_name] + if task != task_: + warn( + f"Task {task_name} has different configurations " + "compared to training data, using training configuration.\n" + "This may lead to unexpected behavior.", + ) + if task_name not in dataset_tasks: + dataset_tasks[task_name] = task + return dataset_tasks + + @cached_property + def network(self): + heads = { + name: HeadConfig(num_labels=task.num_labels, problem_type=task.type, type=task.level) + for name, task in self.tasks.items() + } + if "heads" not in self.config.network: + self.config.network.heads = NestedDict(heads) + else: + self.config.network.heads.merge(heads, overwrite=False) + return self.config.network + + def build_datasets(self): + if "data" in self.config: + self.datasets = self.all_datasets = self._build_dataset(self.config.data) + return + if "datas" in self.config: + self.all_datasets = NestedDict( + {name: self._build_dataset(config, name) for name, config in self.config.datas.items()} + ) + datasets = { + subkey: {key: subdict[subkey] for key, subdict in self.all_datasets.items() if subkey in subdict} + for subkey in {k for v in self.all_datasets.values() for k in v} + } + self.datasets = NestedDict({split: MultiTaskDataset(datas) for split, datas in datasets.items()}) + return + raise ValueError("No data configuration found") + + def _build_dataset(self, config: NestedDict, name: str | None = None) -> NestedDict: + name = name or config.root + print(f"Building dataset {name}") + dataset = NestedDict() + ignored_key = defaults.DATASET_SPLITS + ["root"] + dataset_factory = partial( + Dataset, + tokenizer=self.tokenizer, + **{k: v for k, v in config.items() if k not in ignored_key}, + ) + if os.path.isdir(config.root): + if "train" in config: + dataset.train = dataset_factory(os.path.join(config.root, config.train), split="train") + if "validation" in config: + dataset.validation = dataset_factory(os.path.join(config.root, config.validation), split="validation") + if "test" in config: + dataset.test = dataset_factory(os.path.join(config.root, config.test), split="test") + if "evaluation" in config: + dataset.evaluation = dataset_factory(os.path.join(config.root, config.evaluation), split="evaluation") + if "inference" in config: + dataset.inference = dataset_factory(os.path.join(config.root, config.inference), split="inference") + else: + splits = get_dataset_split_names(config.root) + existing_splits = {k for k in defaults.DATASET_SPLITS if config.get(k) is not None} + if not existing_splits: + if "train" in splits: + config.train = "train" + if "validation" in splits: + config.validation = "validation" + if "test" in splits: + config.test = "test" + if config.get("train") is not None: + dataset.train = dataset_factory(config.root, split="train") + if config.get("validation") is not None: + dataset.validation = dataset_factory(config.root, split="validation") + if config.get("test") is not None: + dataset.test = dataset_factory(config.root, split="test") + if config.get("evaluation") is not None: + dataset.evaluation = dataset_factory(config.root, split=config.get("evaluation")) + if config.get("inference") is not None: + dataset.inference = dataset_factory(config.root, split=config.get("inference")) + if not dataset: + raise ValueError(f"No datasets built. This is likely due to missing data paths in {config}.") + return dataset + + def build_dataloaders(self): + datasets = {k: d for k, d in self.datasets.items() if k not in self.dataloaders} + default_kwargs = self.config.get("dataloader", NestedDict()) + dataloader_kwargs = NestedDict({k: default_kwargs.pop(k) for k in self.datasets if k in default_kwargs}) + for k, d in datasets.items(): + dataloader_kwargs.setdefault(k, NestedDict()) + dataloader_kwargs[k].merge(default_kwargs, overwrite=False) + batch_size = dataloader_kwargs[k].pop("batch_size") + shuffle = dataloader_kwargs[k].pop("shuffle", getattr(d, "train", True)) + drop_last = dataloader_kwargs[k].pop("drop_last", not getattr(d, "train", True)) + if isinstance(d, MultiTaskDataset): + batch_sampler = ( + DistributedMultiTaskSampler(d, batch_size, shuffle=shuffle, drop_last=drop_last) + if self.distributed + else MultiTaskSampler(d, batch_size, shuffle=shuffle, drop_last=drop_last) + ) + else: + sampler = ( + data.distributed.DistributedSampler(d, shuffle=shuffle) + if self.distributed + else data.RandomSampler(d) if shuffle else data.SequentialSampler(d) + ) + batch_sampler = data.BatchSampler(sampler, batch_size, drop_last=drop_last) + self.dataloaders[k] = data.DataLoader( + d, batch_sampler=batch_sampler, collate_fn=self.collate_fn, **dataloader_kwargs[k] + ) + + def build_metrics(self) -> MultiTaskMetrics: + return MultiTaskMetrics( + { + name: MetricRegistry.build(type=task.type, num_labels=task.num_labels) + for name, task in self.dataset_tasks.all_items() + } + ) + + def collate_fn(self, batch): + return {k: v.to(self.device) if hasattr(v, "to") else v for k, v in batch.items()} + + def get_dataset_lengths(self) -> str: + repr = "dataset lengths:\n" + longest_name = max(len(name) for name in self.all_datasets.keys()) + for name, dataset in self.all_datasets.items(): + if isinstance(dataset, NestedDict): + repr += f"{name}:" + if len(name) < longest_name: + repr += " " * (longest_name - len(name)) + repr += "\t\t" + for split, d in dataset.items(): + repr += f" {split}: {len(d)}\t" + else: + repr += f"{name}: {len(dataset)}\t" + repr += "\n" + return repr diff --git a/multimolecule/runners/config.py b/multimolecule/runners/config.py new file mode 100644 index 00000000..08b992d1 --- /dev/null +++ b/multimolecule/runners/config.py @@ -0,0 +1,83 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from __future__ import annotations + +import os +from pathlib import Path +from typing import List + +from chanfig import Config +from transformers import PretrainedConfig + + +class DataConfig(Config): + root: str = "." + train: str | None + validation: str | None + test: str | None + feature_cols: List | None = None + label_cols: List | None = None + truncation: bool = True + + +class OptimConfig(Config): + name: str = "AdamW" + lr: float = 1e-3 + weight_decay: float = 1e-2 + + +class MultiMoleculeConfig(Config): + + balance: str = "ew" + platform: str = "torch" + training: bool = True + + pretrained: str + use_pretrained: bool = True + transformers: PretrainedConfig + epoch_end: int = 20 + + data: DataConfig + + tensorboard: bool = True + save_interval: int = 10 + seed: int = 1016 + art: bool = True + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.datas = Config(default_factory=DataConfig) + self.dataloader.batch_size = 32 + self.optim = OptimConfig() + self.sched.final_lr = 0 + + def post(self): + if "data" in self: + if self.datas: + raise ValueError("Only one of `data` or `datas` can be specified, but not both") + del self.datas + self["network.backbone.sequence.name"] = self.pretrained + self["network.backbone.sequence.use_pretrained"] = self.use_pretrained + pretrained = self.pretrained + if os.path.exists(self.pretrained): + path = Path(pretrained) + if os.path.isfile(pretrained): + pretrained = str(path.relative_to(path.parents[1]).with_suffix("")) + else: + pretrained = path.stem + + self.name = f"{pretrained}-{self.optim.lr}@{self.optim.name}-{self.seed}" diff --git a/multimolecule/runners/metrics.py b/multimolecule/runners/metrics.py new file mode 100644 index 00000000..da584cbc --- /dev/null +++ b/multimolecule/runners/metrics.py @@ -0,0 +1,37 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from chanfig import Registry as Registry_ +from danling.metrics import binary_metrics, multiclass_metrics, multilabel_metrics, regression_metrics + + +class Registry(Registry_): + + def build(self, type, num_labels: int | None = None, **kwargs): + if type == "multilabel": + return self.init(self.lookup(type), num_labels=num_labels, **kwargs) + if type == "multiclass": + return self.init(self.lookup(type), num_classes=num_labels, **kwargs) + if type == "regression": + return self.init(self.lookup(type), num_outputs=num_labels, **kwargs) + return self.init(self.lookup(type), **kwargs) + + +MetricRegistry = Registry(key="type") +MetricRegistry.register(binary_metrics, "binary") +MetricRegistry.register(multiclass_metrics, "multiclass") +MetricRegistry.register(multilabel_metrics, "multilabel") +MetricRegistry.register(regression_metrics, "regression") diff --git a/multimolecule/runners/runner.py b/multimolecule/runners/runner.py new file mode 100644 index 00000000..97cab4e7 --- /dev/null +++ b/multimolecule/runners/runner.py @@ -0,0 +1,42 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import danling as dl + +from .base_runner import BaseRunner + + +class MultiMoleculeRunner(type): + def __new__(cls, config): + if config.get("platform", "torch") == "torch": + return TorchRunner(config) + if config.platform == "deepspeed": + return DeepSpeedRunner(config) + if config.platform == "accelerate": + return AccelerateRunner(config) + raise ValueError(f"Unsupported platform: {config.platform}") + + +class TorchRunner(BaseRunner, dl.TorchRunner): + pass + + +class DeepSpeedRunner(BaseRunner, dl.DeepSpeedRunner): + pass + + +class AccelerateRunner(BaseRunner, dl.AccelerateRunner): + pass diff --git a/multimolecule/tasks/task.py b/multimolecule/tasks/task.py index e2473ab0..5d435f83 100644 --- a/multimolecule/tasks/task.py +++ b/multimolecule/tasks/task.py @@ -34,9 +34,8 @@ class TaskType(StrEnum): class TaskLevel(StrEnum): Sequence = auto() - Nucleotide = auto() + Token = auto() Contact = auto() - # Token = auto() @dataclass diff --git a/multimolecule/train.py b/multimolecule/train.py new file mode 100644 index 00000000..f146bc97 --- /dev/null +++ b/multimolecule/train.py @@ -0,0 +1,20 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from .apis import train + +if __name__ == "__main__": + train() diff --git a/pyproject.toml b/pyproject.toml index 972b1d8a..ee4c525d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,9 +45,11 @@ dynamic = [ ] dependencies = [ "accelerate", + "art", "chanfig>=0.0.105", - "danling>=0.3.11", + "danling>=0.4.0b1", "datasets", + 'StrEnum; python_version < "3.11"', "torch", "transformers", ] diff --git a/tests/data/test_dataset.py b/tests/data/test_dataset.py index 5ddfb3c4..318e289e 100644 --- a/tests/data/test_dataset.py +++ b/tests/data/test_dataset.py @@ -126,7 +126,7 @@ def test_spliceai(self, preprocess: bool): feature_cols=feature_cols, label_cols=label_cols, ) - task = Task(type=TaskType.Binary, level=TaskLevel.Nucleotide, num_labels=1) + task = Task(type=TaskType.Binary, level=TaskLevel.Token, num_labels=1) elem = dataset[0] assert isinstance(elem["sequence"], torch.LongTensor) assert isinstance(elem["splice_ai"], torch.LongTensor) @@ -175,20 +175,18 @@ def test_rna_task_recognition_json(self): assert dataset.tasks["sequence_regression"] == Task( type=TaskType.Regression, level=TaskLevel.Sequence, num_labels=1 ) - assert dataset.tasks["nucleotide_binary"] == Task( - type=TaskType.Binary, level=TaskLevel.Nucleotide, num_labels=1 - ) + assert dataset.tasks["nucleotide_binary"] == Task(type=TaskType.Binary, level=TaskLevel.Token, num_labels=1) assert dataset.tasks["nucleotide_multiclass"] == Task( - type=TaskType.MultiClass, level=TaskLevel.Nucleotide, num_labels=5 + type=TaskType.MultiClass, level=TaskLevel.Token, num_labels=5 ) assert dataset.tasks["nucleotide_multilabel"] == Task( - type=TaskType.MultiLabel, level=TaskLevel.Nucleotide, num_labels=5 + type=TaskType.MultiLabel, level=TaskLevel.Token, num_labels=5 ) assert dataset.tasks["nucleotide_multireg"] == Task( - type=TaskType.Regression, level=TaskLevel.Nucleotide, num_labels=5 + type=TaskType.Regression, level=TaskLevel.Token, num_labels=5 ) assert dataset.tasks["nucleotide_regression"] == Task( - type=TaskType.Regression, level=TaskLevel.Nucleotide, num_labels=1 + type=TaskType.Regression, level=TaskLevel.Token, num_labels=1 ) assert dataset.tasks["contact_binary"] == Task(type=TaskType.Binary, level=TaskLevel.Contact, num_labels=1) assert dataset.tasks["contact_multiclass"] == Task(