diff --git a/.codespell-whitelist.txt b/.codespell-whitelist.txt
index 44c7e9f5..467e5c38 100644
--- a/.codespell-whitelist.txt
+++ b/.codespell-whitelist.txt
@@ -1,3 +1,4 @@
+datas
ser
marz
manuel
diff --git a/docs/docs/data/multitask.md b/docs/docs/data/multitask.md
new file mode 100644
index 00000000..054c6205
--- /dev/null
+++ b/docs/docs/data/multitask.md
@@ -0,0 +1,9 @@
+---
+authors:
+ - Zhiyuan Chen
+date: 2024-05-04
+---
+
+# MultiTask
+
+::: multimolecule.data.multitask
diff --git a/docs/docs/runners/config.md b/docs/docs/runners/config.md
new file mode 100644
index 00000000..8e188199
--- /dev/null
+++ b/docs/docs/runners/config.md
@@ -0,0 +1,9 @@
+---
+authors:
+ - Zhiyuan Chen
+date: 2024-05-04
+---
+
+# MultiMoleculeConfig
+
+::: multimolecule.runners.MultiMoleculeConfig
diff --git a/docs/docs/runners/index.md b/docs/docs/runners/index.md
new file mode 100644
index 00000000..75a0f528
--- /dev/null
+++ b/docs/docs/runners/index.md
@@ -0,0 +1,9 @@
+---
+authors:
+ - Zhiyuan Chen
+date: 2024-05-04
+---
+
+# runners
+
+--8<-- "multimolecule/runners/README.md:8:"
diff --git a/docs/docs/runners/runner.md b/docs/docs/runners/runner.md
new file mode 100644
index 00000000..2c93ce1b
--- /dev/null
+++ b/docs/docs/runners/runner.md
@@ -0,0 +1,9 @@
+---
+authors:
+ - Zhiyuan Chen
+date: 2024-05-04
+---
+
+# MultiMoleculeRunner
+
+::: multimolecule.runners.MultiMoleculeRunner
diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml
index 93d53a45..9324d9c3 100644
--- a/docs/mkdocs.yml
+++ b/docs/mkdocs.yml
@@ -9,9 +9,14 @@ repo_url: https://github.com/DLS5-Omics/multimolecule
nav:
- index.md
+ - runners:
+ - runners/index.md
+ - MultiMoleculeRunner: runners/runner.md
+ - MultiMoleculeConfig: runners/config.md
- data:
- data/index.md
- Dataset: data/dataset.md
+ - multitask: data/multitask.md
- datasets:
- datasets/index.md
- DNA:
diff --git a/multimolecule/__init__.py b/multimolecule/__init__.py
index 240e9fcc..c9ee3a87 100644
--- a/multimolecule/__init__.py
+++ b/multimolecule/__init__.py
@@ -14,6 +14,7 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+from .apis import evaluate, infer, train
from .data import Dataset
from .models import (
AutoModelForContactPrediction,
@@ -111,30 +112,33 @@
HeadConfig,
HeadRegistry,
HeadTransformRegistry,
- HeadTransformRegistryHF,
IdentityTransform,
LinearTransform,
MaskedLMHead,
MaskedLMHeadConfig,
NonLinearTransform,
PositionEmbeddingRegistry,
- PositionEmbeddingRegistryHF,
PredictionHead,
RotaryEmbedding,
SequencePredictionHead,
SinusoidalEmbedding,
- TokenHeadRegistryHF,
TokenKMerHead,
TokenPredictionHead,
)
+from .runners import MultiMoleculeConfig, MultiMoleculeRunner
from .tasks import Task, TaskLevel, TaskType
from .tokenisers import Alphabet, DnaTokenizer, DotBracketTokenizer, ProteinTokenizer, RnaTokenizer, Tokenizer
from .utils import count_parameters
__all__ = [
+ "train",
+ "evaluate",
+ "infer",
"modeling_auto",
"modeling_outputs",
"Dataset",
+ "MultiMoleculeConfig",
+ "MultiMoleculeRunner",
"PreTrainedConfig",
"HeadConfig",
"BaseHeadConfig",
@@ -233,21 +237,15 @@
"HeadRegistry",
"PredictionHead",
"SequencePredictionHead",
- "TokenHeadRegistryHF",
"TokenPredictionHead",
"TokenKMerHead",
- "NucleotideHeadRegistryHF",
- "NucleotidePredictionHead",
- "NucleotideKMerHead",
"ContactPredictionHead",
"MaskedLMHead",
"HeadTransformRegistry",
- "HeadTransformRegistryHF",
"LinearTransform",
"NonLinearTransform",
"IdentityTransform",
"PositionEmbeddingRegistry",
- "PositionEmbeddingRegistryHF",
"RotaryEmbedding",
"SinusoidalEmbedding",
"Criterion",
diff --git a/multimolecule/apis/__init__.py b/multimolecule/apis/__init__.py
new file mode 100644
index 00000000..8e3e5b3c
--- /dev/null
+++ b/multimolecule/apis/__init__.py
@@ -0,0 +1,19 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from .run import evaluate, infer, train
+
+__all__ = ["train", "evaluate", "infer"]
diff --git a/multimolecule/apis/run.py b/multimolecule/apis/run.py
new file mode 100644
index 00000000..1fdb7666
--- /dev/null
+++ b/multimolecule/apis/run.py
@@ -0,0 +1,115 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+# mypy: disable-error-code="attr-defined"
+
+import atexit
+import os
+import warnings
+from typing import Type
+
+import danling as dl
+import torch
+
+from multimolecule.runners import MultiMoleculeConfig, MultiMoleculeRunner
+
+try:
+ import nni
+except ImportError:
+ nni = None
+
+
+def train(
+ config: MultiMoleculeConfig = None, # type: ignore
+ runner_cls: Type[MultiMoleculeRunner] = MultiMoleculeRunner,
+):
+ if config is None:
+ config = MultiMoleculeConfig()
+ config = config.parse(default_config="config", no_default_config_action="warn")
+ config.interpolate(unsafe_eval=True)
+ config.training = True
+ if config.allow_tf32:
+ torch.backends.cudnn.allow_tf32 = True
+ torch.backends.cuda.matmul.allow_tf32 = True
+ if config.reduced_precision_reduction:
+ torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
+ torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True
+ if config.get("nni", False):
+ if nni is None:
+ raise ValueError("Unable to retrieve nni parameters, since nni is not installed.")
+ config.merge(nni.get_next_parameter())
+ with dl.debug(config.get("debug", False)):
+ runner = runner_cls(config)
+ atexit.register(runner.print_result)
+ atexit.register(runner.save_result)
+ atexit.register(runner.save_checkpoint)
+ result = runner.train()
+ return result
+
+
+def evaluate(
+ config: MultiMoleculeConfig = None, # type: ignore
+ runner_cls: Type[MultiMoleculeRunner] = MultiMoleculeRunner,
+):
+ if config is None:
+ config = MultiMoleculeConfig.empty()
+ config = config.parse(default_config="config", no_default_config_action="warn")
+ config.interpolate(unsafe_eval=True)
+ config.training = False
+ if config.allow_tf32:
+ torch.backends.cudnn.allow_tf32 = True
+ torch.backends.cuda.matmul.allow_tf32 = True
+ if config.reduced_precision_reduction:
+ torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
+ torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True
+ if "checkpoint" not in config or not isinstance(config.checkpoint, str):
+ raise RuntimeError("Please specify `checkpoint` to run evaluate")
+ for name, data in config.datas.items():
+ if "evaluation" not in data or not isinstance(data.evaluate, str):
+ raise RuntimeError(f"Please specify `evaluation` to run evaluate in datas.{name}")
+ runner = runner_cls(config)
+ result = runner.evaluate_epoch("evaluation")
+ print(result)
+ return result
+
+
+def infer(
+ config: MultiMoleculeConfig = None, # type: ignore
+ runner_cls: Type[MultiMoleculeRunner] = MultiMoleculeRunner,
+):
+ if config is None:
+ config = MultiMoleculeConfig.empty()
+ config = config.parse(default_config="config", no_default_config_action="warn")
+ config.interpolate(unsafe_eval=True)
+ config.training = False
+ if config.allow_tf32:
+ torch.backends.cudnn.allow_tf32 = True
+ torch.backends.cuda.matmul.allow_tf32 = True
+ if config.reduced_precision_reduction:
+ torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
+ torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True
+ if "checkpoint" not in config or not isinstance(config.checkpoint, str):
+ raise RuntimeError("Please specify `checkpoint` to run infer.")
+ for name, data in config.datas.items():
+ if "inference" not in data or not isinstance(data.inference, str):
+ raise RuntimeError(f"Please specify `inference` to run infer in datas.{name}")
+ if "result_path" not in config or not isinstance(config.result_path, str):
+ config.result_path = os.path.join(os.getcwd(), "result.json")
+ warnings.warn("`result_path` is not specified, default to `result.json`.", RuntimeWarning, stacklevel=2)
+ runner = runner_cls(config)
+ result = runner.infer()
+ runner.save(result, config.result_path)
+ return result
diff --git a/multimolecule/apis/stat.py b/multimolecule/apis/stat.py
new file mode 100644
index 00000000..5e525d55
--- /dev/null
+++ b/multimolecule/apis/stat.py
@@ -0,0 +1,99 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+import os
+import shutil
+from statistics import mean
+from typing import List
+
+import chanfig
+import pandas as pd
+from chanfig import NestedDict
+from tqdm import tqdm
+
+
+class Result(NestedDict):
+ pretrained: str
+ id: str
+ seed: int
+ epoch: int
+ validation: NestedDict
+ test: NestedDict
+
+
+def get_result_stat(experiment_root: str, remove_empty: bool = True) -> List[Result]:
+ results = []
+ for root, _, files in tqdm(os.walk(experiment_root)):
+ if "run.log" in files:
+ if "best.json" not in files:
+ if remove_empty:
+ shutil.rmtree(root)
+ continue
+ best = NestedDict.from_json(os.path.join(root, "best.json"))
+ if "index" not in best:
+ if remove_empty:
+ shutil.rmtree(root)
+ continue
+ config = NestedDict.from_yaml(os.path.join(root, "trainer.yaml"))
+ pretrained = config.pretrained.split("/")[-1]
+ seed = config.seed
+ pretrained, seed = "", 1
+ result = Result(id=best.id, pretrained=pretrained, seed=seed)
+ result.validation = NestedDict(
+ {k: format(mean(v) if isinstance(v, list) else v, ".8f") for k, v in best.validation.items()}
+ )
+ result.test = NestedDict(
+ {k: format(mean(v) if isinstance(v, list) else v, ".8f") for k, v in best.test.items()}
+ )
+ result.epoch = best.index
+ result.pop("validation.time", None)
+ result.pop("test.time", None)
+ result.pop("validation.loss", None)
+ result.pop("test.loss", None)
+ result.pop("validation.lr", None)
+ result.pop("test.lr", None)
+ results.append(result)
+ # Remove empty directories, perform twice to remove all empty directories
+ if remove_empty:
+ for root, dirs, files in os.walk(experiment_root):
+ if not files and not dirs:
+ os.rmdir(root)
+ for root, dirs, files in os.walk(experiment_root):
+ if not files and not dirs:
+ os.rmdir(root)
+ results.sort(key=lambda x: (x.pretrained, x.seed, x.id))
+ return results
+
+
+def write_result_stat(results: List[Result], path: str):
+ results = [dict(result.all_items()) for result in results] # type: ignore[misc]
+ df = pd.DataFrame.from_dict(results)
+ df.insert(len(df.keys()) - 1, "comment", "")
+ df.fillna("")
+ df.to_csv(path, index=False)
+
+
+class Config(chanfig.Config):
+ experiment_root: str = "experiments"
+ out_path: str = "result.csv"
+
+
+if __name__ == "__main__":
+ config = Config().parse()
+ result_stat = get_result_stat(config.experiment_root)
+ if not len(result_stat) > 0:
+ raise ValueError("No results found")
+ write_result_stat(result_stat, config.out_path)
diff --git a/multimolecule/data/__init__.py b/multimolecule/data/__init__.py
index 62196c10..f2366d77 100644
--- a/multimolecule/data/__init__.py
+++ b/multimolecule/data/__init__.py
@@ -15,6 +15,14 @@
# along with this program. If not, see .
from .dataset import Dataset
+from .multitask import DistributedMultiTaskSampler, MultiTaskDataset, MultiTaskSampler
from .utils import no_collate
-__all__ = ["Dataset", "no_collate"]
+__all__ = [
+ "Dataset",
+ "PandasDataset",
+ "MultiTaskDataset",
+ "MultiTaskSampler",
+ "DistributedMultiTaskSampler",
+ "no_collate",
+]
diff --git a/multimolecule/data/dataset.py b/multimolecule/data/dataset.py
index 54565349..2b6fae49 100644
--- a/multimolecule/data/dataset.py
+++ b/multimolecule/data/dataset.py
@@ -149,8 +149,9 @@ def __init__(
fingerprint: str | None = None,
ignored_cols: List[str] | None = None,
):
+ self._tasks = NestedDict()
if tasks is not None:
- self._tasks = NestedDict(tasks)
+ self.tasks = tasks
if discrete_map is not None:
self._discrete_map = discrete_map
arrow_table = self.build_table(
@@ -187,13 +188,13 @@ def build_table(
data = dl.load_pandas(data)
if isinstance(data, DataFrame):
data = data.loc[:, ~data.columns.str.contains("^Unnamed")]
- data = pa.Table.from_pandas(data)
+ data = pa.Table.from_pandas(data, preserve_index=False)
elif isinstance(data, dict):
data = pa.Table.from_pydict(data)
elif isinstance(data, list):
data = pa.Table.from_pylist(data)
elif isinstance(data, DataFrame):
- data = pa.Table.from_pandas(data)
+ data = pa.Table.from_pandas(data, preserve_index=False)
if feature_cols is not None and label_cols is not None:
data = data.select(feature_cols + label_cols)
data = self.process_nan(data, nan_process=nan_process, fill_value=fill_value)
@@ -250,6 +251,7 @@ def post(
self.column_names_map = column_names_map
if self.column_names_map:
self.rename_columns(self.column_names_map)
+ self.infer_tasks()
if self.preprocess:
self.update(self.map(self.tokenization))
@@ -258,7 +260,7 @@ def post(
if self.discrete_map:
self.update(self.map(self.map_discrete))
fn_kwargs = {
- "columns": [name for name, task in self.tasks.items() if task.level in ["nucleotide", "contact"]],
+ "columns": [name for name, task in self.tasks.items() if task.level in ["token", "contact"]],
"max_seq_length": self.max_seq_length - self.seq_length_offset,
}
if self.truncation and 0 < self.max_seq_length < 2**32:
@@ -297,20 +299,23 @@ def collate(self, col: str, data: Any) -> Tensor | NestedTensor | None:
except ValueError:
return NestedTensor(data)
- def infer_tasks(self, tasks: Mapping | None = None, sequence_col: str | None = None) -> NestedDict:
- self._tasks = tasks or NestedDict()
+ def infer_tasks(self, sequence_col: str | None = None) -> NestedDict:
for col in self.label_cols:
- if col not in self.tasks:
- if col in self.secondary_structure_cols:
- task = Task(TaskType.Binary, level=TaskLevel.Contact, num_labels=1)
- self._tasks[col] = task # type: ignore[index]
- warn(
- f"Secondary structure columns are assumed to be {task}."
- " Please explicitly specify the task if this is not the case."
- )
- else:
- self._tasks[col] = self.infer_task(col, sequence_col) # type: ignore[index]
- return self._tasks
+ if col in self.tasks:
+ continue
+ if col in self.secondary_structure_cols:
+ task = Task(TaskType.Binary, level=TaskLevel.Contact, num_labels=1)
+ self.tasks[col] = task # type: ignore[index]
+ warn(
+ f"Secondary structure columns are assumed to be {task}. "
+ "Please explicitly specify the task if this is not the case."
+ )
+ else:
+ try:
+ self.tasks[col] = self.infer_task(col, sequence_col) # type: ignore[index]
+ except ValueError:
+ raise ValueError(f"Unable to infer task for column {col}.")
+ return self.tasks
def infer_task(self, label_col: str, sequence_col: str | None = None) -> Task:
if sequence_col is None:
@@ -404,7 +409,7 @@ def rename_columns(self, column_mapping: Mapping[str, str], new_fingerprint: str
self._label_cols = [column_mapping.get(i, i) for i in self.label_cols]
self._sequence_cols = [column_mapping.get(i, i) for i in self.sequence_cols]
self._secondary_structure_cols = [column_mapping.get(i, i) for i in self.secondary_structure_cols]
- self._tasks = {column_mapping.get(k, k): v for k, v in self.tasks.items()}
+ self.tasks = {column_mapping.get(k, k): v for k, v in self.tasks.items()}
return self
def rename_column(
@@ -418,7 +423,7 @@ def rename_column(
self._secondary_structure_cols = [
new_column_name if i == original_column_name else i for i in self.secondary_structure_cols
]
- self._tasks = {new_column_name if k == original_column_name else k: v for k, v in self.tasks.items()}
+ self.tasks = {new_column_name if k == original_column_name else k: v for k, v in self.tasks.items()}
return self
def process_nan(self, data: Table, nan_process: str | None, fill_value: str | int | float = 0) -> Table:
@@ -470,9 +475,18 @@ def secondary_structure_cols(self) -> List:
@property
def tasks(self) -> NestedDict:
if not hasattr(self, "_tasks"):
+ self._tasks = NestedDict()
return self.infer_tasks()
return self._tasks
+ @tasks.setter
+ def tasks(self, tasks: Mapping):
+ self._tasks = NestedDict()
+ for name, task in tasks.items():
+ if not isinstance(task, Task):
+ task = Task(**task)
+ self._tasks[name] = task
+
@property
def discrete_map(self) -> Mapping:
if not hasattr(self, "_discrete_map"):
diff --git a/multimolecule/data/multitask.py b/multimolecule/data/multitask.py
new file mode 100644
index 00000000..7c20e829
--- /dev/null
+++ b/multimolecule/data/multitask.py
@@ -0,0 +1,246 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from __future__ import annotations
+
+from bisect import bisect_right
+from collections.abc import Iterator, Mapping, Sequence
+from copy import deepcopy
+from random import choices
+
+import torch
+from chanfig import NestedDict
+from torch import distributed as dist
+from torch.utils import data
+
+from .dataset import Dataset
+
+
+class MultiTaskDataset(data.ConcatDataset):
+
+ datasets: Mapping
+ dataset_keys: Sequence[str]
+ dataset_values: Sequence[Dataset]
+
+ def __init__(self, datasets: Mapping) -> None:
+ for key, dataset in datasets.items():
+ if not isinstance(dataset, Dataset):
+ raise TypeError(f"Dataset {key} should be an instance of Dataset")
+ self.datasets = datasets
+ if not len(self.datasets) > 0:
+ raise ValueError("MultiTaskDataset should contain at least one dataset")
+ self.dataset_keys, self.dataset_values = zip(*self.datasets.items())
+ self.cumulative_sizes = self.cumsum(self.dataset_values)
+
+ def __getitems__(self, key: Sequence[int]) -> Mapping:
+ dataset_idx = bisect_right(self.cumulative_sizes, key[0])
+ if dataset_idx == 0:
+ sample_idx = key
+ else:
+ sample_idx = [i - self.cumulative_sizes[dataset_idx - 1] for i in key]
+ batch = self.dataset_values[dataset_idx][sample_idx]
+ batch["dataset"] = self.dataset_keys[dataset_idx]
+ return batch
+
+ @property
+ def tasks(self) -> NestedDict:
+ tasks = NestedDict()
+ for dataset in self.dataset_values:
+ for n, t in dataset.tasks.items():
+ if n not in tasks:
+ tasks[n] = t
+ elif tasks[n] != t:
+ raise ValueError(f"Task {n} has different configurations across datasets")
+ return tasks
+
+ @property
+ def dataset_tasks(self) -> NestedDict:
+ return NestedDict({k: v.tasks for k, v in self.datasets.items()})
+
+ def __repr__(self) -> str:
+ return f"MultiTaskDataset({', '.join([str(d) for d in self.datasets])})"
+
+
+class MultiTaskSampler(data.BatchSampler):
+ r"""
+ Ensure all items in a batch comes from the same dataset.
+
+ Arguments:
+ sampler (Sampler): Base sampler.
+ batch_size (int): Size of mini-batch.
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
+ its size would be less than ``batch_size``
+ """
+
+ datasets: Sequence[Dataset]
+
+ def __init__( # pylint: disable=super-init-not-called
+ self,
+ dataset: MultiTaskDataset,
+ batch_size: int,
+ shuffle: bool = True,
+ drop_last: bool = False,
+ sampler_cls: type[data.Sampler] | None = None,
+ weights: list[int] | None = None,
+ ) -> None:
+ self.datasets = dataset.dataset_values
+ self.batch_size = batch_size
+ self.drop_last = drop_last
+ self.shuffle = shuffle
+ if sampler_cls is None:
+ sampler_cls = data.RandomSampler if shuffle else data.SequentialSampler
+ self.samplers = [sampler_cls(d) for d in self.datasets] # type: ignore
+ self.dataset_sizes = [len(d) for d in self.datasets] # type: ignore
+ self.cumulative_sizes = dataset.cumulative_sizes
+ self.num_datasets = len(self.datasets)
+ self.weights = weights if weights is not None else self.dataset_sizes
+
+ def __iter__(self):
+ sampler_iters = [(i, iter(s)) for i, s in enumerate(self.samplers)]
+ sampler_weights = deepcopy(self.weights)
+ sampler_idx = 0
+ # Implemented based on the benchmarking in https://github.com/pytorch/pytorch/pull/76951
+ if self.drop_last:
+ while sampler_iters:
+ if self.shuffle:
+ sampler_idx = choices(range(len(sampler_iters)), weights=sampler_weights)[0]
+ sampler_id, sampler_iter = sampler_iters[sampler_idx]
+ cumulative_size = self.cumulative_sizes[sampler_id - 1] if sampler_id > 0 else 0
+ try:
+ batch = [next(sampler_iter) + cumulative_size for _ in range(self.batch_size)]
+ yield batch
+ except StopIteration:
+ sampler_iters.pop(sampler_idx)
+ sampler_weights.pop(sampler_idx)
+ else:
+ while sampler_iters:
+ if self.shuffle:
+ sampler_idx = choices(range(len(sampler_iters)), weights=sampler_weights)[0]
+ sampler_id, sampler_iter = sampler_iters[sampler_idx]
+ cumulative_size = self.cumulative_sizes[sampler_id - 1] if sampler_id > 0 else 0
+ batch = [0] * self.batch_size
+ idx_in_batch = 0
+ try:
+ for _ in range(self.batch_size):
+ batch[idx_in_batch] = next(sampler_iter) + cumulative_size
+ idx_in_batch += 1
+ yield batch
+ idx_in_batch = 0 # noqa: SIM113
+ batch = [0] * self.batch_size
+ except StopIteration:
+ sampler_iters.pop(sampler_idx)
+ sampler_weights.pop(sampler_idx)
+ if idx_in_batch > 0:
+ yield batch[:idx_in_batch]
+
+ def __len__(self):
+ batch_size = self.batch_size
+ if self.drop_last:
+ return sum(len(d) // batch_size for d in self.datasets)
+ return sum((len(d) + batch_size - 1) // batch_size for d in self.datasets)
+
+
+class DistributedMultiTaskSampler(MultiTaskSampler): # pylint: disable=too-few-public-methods
+ r"""
+ Distributed version of MultiTaskSampler, which ensures that all GPUs sample data from the
+ same sub-dataset in each step without requiring additional communication.
+ The dataset selection is based on a random seed mechanism that is synchronized across epochs.
+
+ See Also:
+ [MultiTaskSampler][MultiTaskSampler]
+ """
+
+ def __init__(
+ self,
+ dataset: MultiTaskDataset,
+ batch_size: int,
+ shuffle: bool = True,
+ drop_last: bool = False,
+ sampler_cls: type[data.Sampler] = data.RandomSampler,
+ weights: list[int] | None = None,
+ seed: int = 0,
+ ) -> None:
+ super().__init__(dataset, batch_size, shuffle, drop_last, sampler_cls, weights)
+ self.samplers = [data.DistributedSampler(d, shuffle=shuffle, drop_last=drop_last) for d in self.datasets]
+ self.seed = seed
+ self.epoch = 0
+
+ def set_epoch(self, epoch: int):
+ """
+ Sets the epoch for deterministic shuffling.
+ """
+ self.epoch = epoch
+ for sampler in self.samplers:
+ sampler.set_epoch(epoch)
+
+ def _get_sampler_idx(self, high: int) -> int:
+ """
+ Determines which sampler (i.e., sub-dataset) to use based on the seed and epoch.
+ """
+ g = torch.Generator()
+ g.manual_seed(self.seed + self.epoch)
+ sampler_idx = torch.randint(low=0, high=high, size=(1,), generator=g).item()
+ return sampler_idx
+
+ def __iter__(self) -> Iterator:
+ sampler_iters = [(i, iter(s)) for i, s in enumerate(self.samplers)]
+ sampler_weights = deepcopy(self.weights)
+
+ if self.drop_last:
+ while sampler_iters:
+ # Sample the same sub-dataset across all GPUs using the seeded index
+ sampler_idx = self._get_sampler_idx(len(sampler_iters))
+ sampler_id, sampler_iter = sampler_iters[sampler_idx]
+ cumulative_size = self.cumulative_sizes[sampler_id - 1] if sampler_id > 0 else 0
+ try:
+ batch = [next(sampler_iter) + cumulative_size for _ in range(self.batch_size)]
+ yield batch
+ except StopIteration:
+ sampler_iters.pop(sampler_idx)
+ sampler_weights.pop(sampler_idx)
+ else:
+ while sampler_iters:
+ # Sample the same sub-dataset across all GPUs using the seeded index
+ sampler_idx = self._get_sampler_idx(len(sampler_iters))
+ sampler_id, sampler_iter = sampler_iters[sampler_idx]
+ cumulative_size = self.cumulative_sizes[sampler_id - 1] if sampler_id > 0 else 0
+ batch = [0] * self.batch_size
+ idx_in_batch = 0
+ try:
+ for _ in range(self.batch_size):
+ batch[idx_in_batch] = next(sampler_iter) + cumulative_size
+ idx_in_batch += 1
+ yield batch
+ idx_in_batch = 0 # noqa: SIM113
+ batch = [0] * self.batch_size
+ except StopIteration:
+ sampler_iters.pop(sampler_idx)
+ sampler_weights.pop(sampler_idx)
+ if idx_in_batch > 0:
+ yield batch[:idx_in_batch]
+
+ def __len__(self) -> int:
+ batch_size = self.batch_size * self.world_size
+ if self.drop_last:
+ return sum(len(d) // batch_size for d in self.datasets)
+ return sum((len(d) + batch_size - 1) // batch_size for d in self.datasets)
+
+ @property
+ def world_size(self) -> int:
+ r"""Return the number of processes in the current process group."""
+ if dist.is_available() and dist.is_initialized():
+ return dist.get_world_size()
+ return 1
diff --git a/multimolecule/data/utils.py b/multimolecule/data/utils.py
index 85bc4423..1afddd1b 100644
--- a/multimolecule/data/utils.py
+++ b/multimolecule/data/utils.py
@@ -60,7 +60,7 @@ def infer_task(
level = TaskLevel.Contact
num_labels = len(flattened) // num_contacts
elif len(flattened) % num_tokens == 0:
- level = TaskLevel.Nucleotide
+ level = TaskLevel.Token
num_labels = len(flattened) // num_tokens
elif len(flattened) % num_elem == 0:
level = TaskLevel.Sequence
@@ -86,7 +86,7 @@ def infer_task(
task_type = TaskType.MultiClass if num_labels > 2 else TaskType.Binary
num_labels = 1 if task_type == TaskType.Binary else num_labels
if num_tokens_flattened == num_tokens:
- return Task(task_type, level=TaskLevel.Nucleotide, num_labels=num_labels)
+ return Task(task_type, level=TaskLevel.Token, num_labels=num_labels)
if num_contacts_flattened == num_contacts:
return Task(task_type, level=TaskLevel.Contact, num_labels=num_labels)
return Task(task_type, level=TaskLevel.Sequence, num_labels=num_labels)
@@ -122,7 +122,7 @@ def map_value(value: Any, mapping: dict[str, int] | None) -> Any:
def truncate_value(value: Any, max_seq_length: int, level: int | None = None) -> Any:
- if level == TaskLevel.Nucleotide:
+ if level == TaskLevel.Token:
return value[:max_seq_length]
if level == TaskLevel.Contact:
return [i[:max_seq_length] for i in value[:max_seq_length]]
diff --git a/multimolecule/defaults.py b/multimolecule/defaults.py
index c299ea1a..0a718a60 100644
--- a/multimolecule/defaults.py
+++ b/multimolecule/defaults.py
@@ -14,6 +14,7 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+DATASET_SPLITS = ["train", "validation", "test", "evaluation", "inference"]
ID_COL_NAMES = ["id", "idx", "index"]
SEQUENCE_COL_NAMES = ["input_ids", "sequence", "seq"]
SECONDARY_STRUCTURE_COL_NAMES = ["secondary_structure", "ss"]
diff --git a/multimolecule/models/__init__.py b/multimolecule/models/__init__.py
index 66147616..29d99436 100644
--- a/multimolecule/models/__init__.py
+++ b/multimolecule/models/__init__.py
@@ -14,6 +14,7 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+from multimolecule.module import HeadConfig
from multimolecule.tokenisers import DnaTokenizer, ProteinTokenizer, RnaTokenizer
from .calm import (
@@ -127,6 +128,7 @@
__all__ = [
"PreTrainedConfig",
+ "HeadConfig",
"DnaTokenizer",
"RnaTokenizer",
"ProteinTokenizer",
diff --git a/multimolecule/models/calm/configuration_calm.py b/multimolecule/models/calm/configuration_calm.py
index c5d73c03..032bda8e 100644
--- a/multimolecule/models/calm/configuration_calm.py
+++ b/multimolecule/models/calm/configuration_calm.py
@@ -127,5 +127,5 @@ def __init__(
self.use_cache = use_cache
self.emb_layer_norm_before = emb_layer_norm_before
self.token_dropout = token_dropout
- self.head = HeadConfig(**head if head is not None else {})
- self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {})
+ self.head = HeadConfig(**head) if head is not None else None
+ self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None
diff --git a/multimolecule/models/calm/modeling_calm.py b/multimolecule/models/calm/modeling_calm.py
index 25c1eba4..c8abdffe 100644
--- a/multimolecule/models/calm/modeling_calm.py
+++ b/multimolecule/models/calm/modeling_calm.py
@@ -270,9 +270,9 @@ class CaLmForSequencePrediction(CaLmPreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.tensor([[1]]))
>>> output["logits"].shape
- torch.Size([1, 2])
+ torch.Size([1, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: CaLmConfig):
@@ -334,9 +334,9 @@ class CaLmForTokenPrediction(CaLmPreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.randint(2, (1, 5)))
>>> output["logits"].shape
- torch.Size([1, 5, 2])
+ torch.Size([1, 5, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: CaLmConfig):
@@ -398,9 +398,9 @@ class CaLmForContactPrediction(CaLmPreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.randint(2, (1, 5, 5)))
>>> output["logits"].shape
- torch.Size([1, 5, 5, 2])
+ torch.Size([1, 5, 5, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: CaLmConfig):
diff --git a/multimolecule/models/configuration_utils.py b/multimolecule/models/configuration_utils.py
index 2047d671..ce6f10ea 100644
--- a/multimolecule/models/configuration_utils.py
+++ b/multimolecule/models/configuration_utils.py
@@ -30,7 +30,8 @@ class PreTrainedConfig(PretrainedConfig):
Base class for all model configuration classes.
"""
- head: HeadConfig
+ head: HeadConfig | None
+ num_labels: int = 1
hidden_size: int
@@ -42,7 +43,15 @@ class PreTrainedConfig(PretrainedConfig):
null_token_id: int = 5
def __init__(
- self, pad_token_id=0, bos_token_id=1, eos_token_id=2, unk_token_id=3, mask_token_id=4, null_token_id=5, **kwargs
+ self,
+ pad_token_id: int = 0,
+ bos_token_id: int = 1,
+ eos_token_id: int = 2,
+ unk_token_id: int = 3,
+ mask_token_id: int = 4,
+ null_token_id: int = 5,
+ num_labels: int = 1,
+ **kwargs,
):
super().__init__(
pad_token_id=pad_token_id,
@@ -51,6 +60,7 @@ def __init__(
unk_token_id=unk_token_id,
mask_token_id=mask_token_id,
null_token_id=null_token_id,
+ num_labels=num_labels,
**kwargs,
)
diff --git a/multimolecule/models/ernierna/configuration_ernierna.py b/multimolecule/models/ernierna/configuration_ernierna.py
index 0648bb2d..bfd11d51 100644
--- a/multimolecule/models/ernierna/configuration_ernierna.py
+++ b/multimolecule/models/ernierna/configuration_ernierna.py
@@ -110,5 +110,5 @@ def __init__(
self.pairwise_alpha = pairwise_alpha
self.is_decoder = is_decoder
self.use_cache = use_cache
- self.head = HeadConfig(**head if head is not None else {})
- self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {})
+ self.head = HeadConfig(**head) if head is not None else None
+ self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None
diff --git a/multimolecule/models/ernierna/modeling_ernierna.py b/multimolecule/models/ernierna/modeling_ernierna.py
index 563249e7..3f2c96a0 100644
--- a/multimolecule/models/ernierna/modeling_ernierna.py
+++ b/multimolecule/models/ernierna/modeling_ernierna.py
@@ -321,7 +321,7 @@ class ErnieRnaForSequencePrediction(ErnieRnaPreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input)
>>> output["logits"].shape
- torch.Size([1, 2])
+ torch.Size([1, 1])
"""
def __init__(self, config: ErnieRnaConfig):
@@ -385,9 +385,9 @@ class ErnieRnaForTokenPrediction(ErnieRnaPreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.randint(2, (1, 5)))
>>> output["logits"].shape
- torch.Size([1, 5, 2])
+ torch.Size([1, 5, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: ErnieRnaConfig):
@@ -452,9 +452,9 @@ class ErnieRnaForContactPrediction(ErnieRnaPreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.randint(2, (1, 5, 5)))
>>> output["logits"].shape
- torch.Size([1, 5, 5, 2])
+ torch.Size([1, 5, 5, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: ErnieRnaConfig):
@@ -1183,7 +1183,7 @@ class ErnieRnaContactClassificationHead(nn.Module):
def __init__(self, config: ErnieRnaConfig, head_config: HeadConfig | None = None):
super().__init__()
if head_config is None:
- head_config = config.head
+ head_config = config.head or HeadConfig()
self.config = head_config
self.bos_token_id = config.bos_token_id
self.eos_token_id = config.eos_token_id
diff --git a/multimolecule/models/rinalmo/configuration_rinalmo.py b/multimolecule/models/rinalmo/configuration_rinalmo.py
index 5e21725d..1cc963b2 100644
--- a/multimolecule/models/rinalmo/configuration_rinalmo.py
+++ b/multimolecule/models/rinalmo/configuration_rinalmo.py
@@ -125,6 +125,6 @@ def __init__(
self.use_cache = use_cache
self.learnable_beta = learnable_beta
self.token_dropout = token_dropout
- self.head = HeadConfig(**head if head is not None else {})
- self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {})
+ self.head = HeadConfig(**head) if head is not None else None
+ self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None
self.emb_layer_norm_before = emb_layer_norm_before
diff --git a/multimolecule/models/rinalmo/modeling_rinalmo.py b/multimolecule/models/rinalmo/modeling_rinalmo.py
index b45d2823..d0ac6e8c 100644
--- a/multimolecule/models/rinalmo/modeling_rinalmo.py
+++ b/multimolecule/models/rinalmo/modeling_rinalmo.py
@@ -269,9 +269,9 @@ class RiNALMoForSequencePrediction(RiNALMoPreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.tensor([[1]]))
>>> output["logits"].shape
- torch.Size([1, 2])
+ torch.Size([1, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: RiNALMoConfig):
@@ -333,9 +333,9 @@ class RiNALMoForTokenPrediction(RiNALMoPreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.randint(2, (1, 5)))
>>> output["logits"].shape
- torch.Size([1, 5, 2])
+ torch.Size([1, 5, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: RiNALMoConfig):
@@ -397,9 +397,9 @@ class RiNALMoForContactPrediction(RiNALMoPreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.randint(2, (1, 5, 5)))
>>> output["logits"].shape
- torch.Size([1, 5, 5, 2])
+ torch.Size([1, 5, 5, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: RiNALMoConfig):
diff --git a/multimolecule/models/rnabert/configuration_rnabert.py b/multimolecule/models/rnabert/configuration_rnabert.py
index f044ecc7..97632d2e 100644
--- a/multimolecule/models/rnabert/configuration_rnabert.py
+++ b/multimolecule/models/rnabert/configuration_rnabert.py
@@ -112,5 +112,5 @@ def __init__(
self.position_embedding_type = position_embedding_type
self.is_decoder = is_decoder
self.use_cache = use_cache
- self.head = HeadConfig(**head if head is not None else {})
- self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {})
+ self.head = HeadConfig(**head) if head is not None else None
+ self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None
diff --git a/multimolecule/models/rnabert/modeling_rnabert.py b/multimolecule/models/rnabert/modeling_rnabert.py
index 74b06cf1..32f7bf01 100644
--- a/multimolecule/models/rnabert/modeling_rnabert.py
+++ b/multimolecule/models/rnabert/modeling_rnabert.py
@@ -37,7 +37,13 @@
from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from transformers.utils import logging
-from multimolecule.module import ContactPredictionHead, MaskedLMHead, SequencePredictionHead, TokenPredictionHead
+from multimolecule.module import (
+ ContactPredictionHead,
+ HeadConfig,
+ MaskedLMHead,
+ SequencePredictionHead,
+ TokenPredictionHead,
+)
from ..modeling_outputs import ContactPredictorOutput, SequencePredictorOutput, TokenPredictorOutput
from .configuration_rnabert import RnaBertConfig
@@ -266,9 +272,9 @@ class RnaBertForSequencePrediction(RnaBertPreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.tensor([[1]]))
>>> output["logits"].shape
- torch.Size([1, 2])
+ torch.Size([1, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: RnaBertConfig):
@@ -330,9 +336,9 @@ class RnaBertForTokenPrediction(RnaBertPreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.randint(2, (1, 5)))
>>> output["logits"].shape
- torch.Size([1, 5, 2])
+ torch.Size([1, 5, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: RnaBertConfig):
@@ -394,9 +400,9 @@ class RnaBertForContactPrediction(RnaBertPreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.randint(2, (1, 5, 5)))
>>> output["logits"].shape
- torch.Size([1, 5, 5, 2])
+ torch.Size([1, 5, 5, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: RnaBertConfig):
@@ -1065,7 +1071,7 @@ def __init__(self, config: RnaBertConfig):
vocab_size, config.vocab_size = config.vocab_size, config.ss_vocab_size
self.predictions_ss = MaskedLMHead(config)
config.vocab_size = vocab_size
- self.seq_relationship = SequencePredictionHead(config)
+ self.seq_relationship = SequencePredictionHead(config, HeadConfig(num_labels=2))
def forward(
self,
diff --git a/multimolecule/models/rnaernie/configuration_rnaernie.py b/multimolecule/models/rnaernie/configuration_rnaernie.py
index 2d540c9d..7a788297 100644
--- a/multimolecule/models/rnaernie/configuration_rnaernie.py
+++ b/multimolecule/models/rnaernie/configuration_rnaernie.py
@@ -108,5 +108,5 @@ def __init__(
self.position_embedding_type = position_embedding_type
self.is_decoder = is_decoder
self.use_cache = use_cache
- self.head = HeadConfig(**head if head is not None else {})
- self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {})
+ self.head = HeadConfig(**head) if head is not None else None
+ self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None
diff --git a/multimolecule/models/rnaernie/modeling_rnaernie.py b/multimolecule/models/rnaernie/modeling_rnaernie.py
index 8d79a760..8136f673 100644
--- a/multimolecule/models/rnaernie/modeling_rnaernie.py
+++ b/multimolecule/models/rnaernie/modeling_rnaernie.py
@@ -270,9 +270,9 @@ class RnaErnieForSequencePrediction(RnaErniePreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.tensor([[1]]))
>>> output["logits"].shape
- torch.Size([1, 2])
+ torch.Size([1, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config):
@@ -334,9 +334,9 @@ class RnaErnieForTokenPrediction(RnaErniePreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.randint(2, (1, 5)))
>>> output["logits"].shape
- torch.Size([1, 5, 2])
+ torch.Size([1, 5, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: RnaErnieConfig):
@@ -398,9 +398,9 @@ class RnaErnieForContactPrediction(RnaErniePreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.randint(2, (1, 5, 5)))
>>> output["logits"].shape
- torch.Size([1, 5, 5, 2])
+ torch.Size([1, 5, 5, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: RnaErnieConfig):
diff --git a/multimolecule/models/rnafm/configuration_rnafm.py b/multimolecule/models/rnafm/configuration_rnafm.py
index 8fdb7f49..ef1f0c18 100644
--- a/multimolecule/models/rnafm/configuration_rnafm.py
+++ b/multimolecule/models/rnafm/configuration_rnafm.py
@@ -131,5 +131,5 @@ def __init__(
self.use_cache = use_cache
self.emb_layer_norm_before = emb_layer_norm_before
self.token_dropout = token_dropout
- self.head = HeadConfig(**head if head is not None else {})
- self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {})
+ self.head = HeadConfig(**head) if head is not None else None
+ self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None
diff --git a/multimolecule/models/rnafm/modeling_rnafm.py b/multimolecule/models/rnafm/modeling_rnafm.py
index 99f553da..6898da9c 100644
--- a/multimolecule/models/rnafm/modeling_rnafm.py
+++ b/multimolecule/models/rnafm/modeling_rnafm.py
@@ -272,9 +272,9 @@ class RnaFmForSequencePrediction(RnaFmPreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.tensor([[1]]))
>>> output["logits"].shape
- torch.Size([1, 2])
+ torch.Size([1, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: RnaFmConfig):
@@ -336,9 +336,9 @@ class RnaFmForTokenPrediction(RnaFmPreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.randint(2, (1, 5)))
>>> output["logits"].shape
- torch.Size([1, 5, 2])
+ torch.Size([1, 5, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: RnaFmConfig):
@@ -400,9 +400,9 @@ class RnaFmForContactPrediction(RnaFmPreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.randint(2, (1, 5, 5)))
>>> output["logits"].shape
- torch.Size([1, 5, 5, 2])
+ torch.Size([1, 5, 5, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: RnaFmConfig):
@@ -555,7 +555,7 @@ class RnaFmForPreTraining(RnaFmPreTrainedModel):
>>> output["logits"].shape
torch.Size([1, 7, 26])
>>> output["contact_map"].shape
- torch.Size([1, 5, 5, 2])
+ torch.Size([1, 5, 5, 1])
"""
_tied_weights_keys = [
diff --git a/multimolecule/models/rnamsm/configuration_rnamsm.py b/multimolecule/models/rnamsm/configuration_rnamsm.py
index ae914c82..2e8150ba 100644
--- a/multimolecule/models/rnamsm/configuration_rnamsm.py
+++ b/multimolecule/models/rnamsm/configuration_rnamsm.py
@@ -116,5 +116,5 @@ def __init__(
self.attention_type = attention_type
self.embed_positions_msa = embed_positions_msa
self.attention_bias = attention_bias
- self.head = HeadConfig(**head if head is not None else {})
- self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {})
+ self.head = HeadConfig(**head) if head is not None else None
+ self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None
diff --git a/multimolecule/models/rnamsm/modeling_rnamsm.py b/multimolecule/models/rnamsm/modeling_rnamsm.py
index 5ed6bf87..0390a129 100644
--- a/multimolecule/models/rnamsm/modeling_rnamsm.py
+++ b/multimolecule/models/rnamsm/modeling_rnamsm.py
@@ -176,9 +176,9 @@ class RnaMsmForSequencePrediction(RnaMsmPreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.tensor([[1]]))
>>> output["logits"].shape
- torch.Size([1, 2])
+ torch.Size([1, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: RnaMsmConfig):
@@ -239,9 +239,9 @@ class RnaMsmForTokenPrediction(RnaMsmPreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.randint(2, (1, 5)))
>>> output["logits"].shape
- torch.Size([1, 5, 2])
+ torch.Size([1, 5, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: RnaMsmConfig):
@@ -302,9 +302,9 @@ class RnaMsmForContactPrediction(RnaMsmPreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.randint(2, (1, 5, 5)))
>>> output["logits"].shape
- torch.Size([1, 5, 5, 2])
+ torch.Size([1, 5, 5, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: RnaMsmConfig):
@@ -449,7 +449,7 @@ class RnaMsmForPreTraining(RnaMsmPreTrainedModel):
>>> output["logits"].shape
torch.Size([1, 7, 26])
>>> output["contact_map"].shape
- torch.Size([1, 5, 5, 2])
+ torch.Size([1, 5, 5, 1])
"""
_tied_weights_keys = [
diff --git a/multimolecule/models/splicebert/configuration_splicebert.py b/multimolecule/models/splicebert/configuration_splicebert.py
index 66b46a88..f789516d 100644
--- a/multimolecule/models/splicebert/configuration_splicebert.py
+++ b/multimolecule/models/splicebert/configuration_splicebert.py
@@ -108,5 +108,5 @@ def __init__(
self.position_embedding_type = position_embedding_type
self.is_decoder = is_decoder
self.use_cache = use_cache
- self.head = HeadConfig(**head if head is not None else {})
- self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {})
+ self.head = HeadConfig(**head) if head is not None else None
+ self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None
diff --git a/multimolecule/models/splicebert/modeling_splicebert.py b/multimolecule/models/splicebert/modeling_splicebert.py
index 1b0fd072..9d129d74 100644
--- a/multimolecule/models/splicebert/modeling_splicebert.py
+++ b/multimolecule/models/splicebert/modeling_splicebert.py
@@ -274,9 +274,9 @@ class SpliceBertForSequencePrediction(SpliceBertPreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.tensor([[1]]))
>>> output["logits"].shape
- torch.Size([1, 2])
+ torch.Size([1, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: SpliceBertConfig):
@@ -338,9 +338,9 @@ class SpliceBertForTokenPrediction(SpliceBertPreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.randint(2, (1, 5)))
>>> output["logits"].shape
- torch.Size([1, 5, 2])
+ torch.Size([1, 5, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: SpliceBertConfig):
@@ -402,9 +402,9 @@ class SpliceBertForContactPrediction(SpliceBertPreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.randint(2, (1, 5, 5)))
>>> output["logits"].shape
- torch.Size([1, 5, 5, 2])
+ torch.Size([1, 5, 5, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: SpliceBertConfig):
diff --git a/multimolecule/models/utrbert/configuration_utrbert.py b/multimolecule/models/utrbert/configuration_utrbert.py
index d032c5ee..5230c04f 100644
--- a/multimolecule/models/utrbert/configuration_utrbert.py
+++ b/multimolecule/models/utrbert/configuration_utrbert.py
@@ -125,5 +125,5 @@ def __init__(
self.position_embedding_type = position_embedding_type
self.is_decoder = is_decoder
self.use_cache = use_cache
- self.head = HeadConfig(**head if head is not None else {})
- self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {})
+ self.head = HeadConfig(**head) if head is not None else None
+ self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None
diff --git a/multimolecule/models/utrbert/modeling_utrbert.py b/multimolecule/models/utrbert/modeling_utrbert.py
index 1a5b47f9..688bedbe 100644
--- a/multimolecule/models/utrbert/modeling_utrbert.py
+++ b/multimolecule/models/utrbert/modeling_utrbert.py
@@ -264,9 +264,9 @@ class UtrBertForSequencePrediction(UtrBertPreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.tensor([[1]]))
>>> output["logits"].shape
- torch.Size([1, 2])
+ torch.Size([1, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: UtrBertConfig):
@@ -328,9 +328,9 @@ class UtrBertForTokenPrediction(UtrBertPreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.randint(2, (1, 5)))
>>> output["logits"].shape
- torch.Size([1, 5, 2])
+ torch.Size([1, 5, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: UtrBertConfig):
@@ -393,9 +393,9 @@ class UtrBertForContactPrediction(UtrBertPreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.randint(2, (1, 5, 5)))
>>> output["logits"].shape
- torch.Size([1, 5, 5, 2])
+ torch.Size([1, 5, 5, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: UtrBertConfig):
diff --git a/multimolecule/models/utrlm/configuration_utrlm.py b/multimolecule/models/utrlm/configuration_utrlm.py
index f0f705de..a4f930d7 100644
--- a/multimolecule/models/utrlm/configuration_utrlm.py
+++ b/multimolecule/models/utrlm/configuration_utrlm.py
@@ -127,7 +127,7 @@ def __init__(
self.use_cache = use_cache
self.emb_layer_norm_before = emb_layer_norm_before
self.token_dropout = token_dropout
- self.head = HeadConfig(**head if head is not None else {})
- self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {})
+ self.head = HeadConfig(**head) if head is not None else None
+ self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None
self.ss_head = HeadConfig(**ss_head) if ss_head is not None else None
self.mfe_head = HeadConfig(**mfe_head) if mfe_head is not None else None
diff --git a/multimolecule/models/utrlm/modeling_utrlm.py b/multimolecule/models/utrlm/modeling_utrlm.py
index 535f99f0..aae1b593 100644
--- a/multimolecule/models/utrlm/modeling_utrlm.py
+++ b/multimolecule/models/utrlm/modeling_utrlm.py
@@ -272,9 +272,9 @@ class UtrLmForSequencePrediction(UtrLmPreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.tensor([[1]]))
>>> output["logits"].shape
- torch.Size([1, 2])
+ torch.Size([1, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: UtrLmConfig):
@@ -336,9 +336,9 @@ class UtrLmForTokenPrediction(UtrLmPreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.randint(2, (1, 5)))
>>> output["logits"].shape
- torch.Size([1, 5, 2])
+ torch.Size([1, 5, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: UtrLmConfig):
@@ -400,9 +400,9 @@ class UtrLmForContactPrediction(UtrLmPreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.randint(2, (1, 5, 5)))
>>> output["logits"].shape
- torch.Size([1, 5, 5, 2])
+ torch.Size([1, 5, 5, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: UtrLmConfig):
@@ -555,7 +555,7 @@ class UtrLmForPreTraining(UtrLmPreTrainedModel):
>>> output["logits"].shape
torch.Size([1, 7, 26])
>>> output["contact_map"].shape
- torch.Size([1, 5, 5, 2])
+ torch.Size([1, 5, 5, 1])
"""
_tied_weights_keys = [
diff --git a/multimolecule/module/__init__.py b/multimolecule/module/__init__.py
index 0128fe9b..dbba900b 100644
--- a/multimolecule/module/__init__.py
+++ b/multimolecule/module/__init__.py
@@ -14,8 +14,8 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from .criterions import Criterion
-from .embeddings import PositionEmbeddingRegistry, PositionEmbeddingRegistryHF, RotaryEmbedding, SinusoidalEmbedding
+from .criterions import Criterion, CriterionRegistry
+from .embeddings import PositionEmbeddingRegistry, RotaryEmbedding, SinusoidalEmbedding
from .heads import (
BaseHeadConfig,
ContactPredictionHead,
@@ -23,7 +23,6 @@
HeadOutput,
HeadRegistry,
HeadTransformRegistry,
- HeadTransformRegistryHF,
IdentityTransform,
LinearTransform,
MaskedLMHead,
@@ -31,15 +30,18 @@
NonLinearTransform,
PredictionHead,
SequencePredictionHead,
- TokenHeadRegistryHF,
TokenKMerHead,
TokenPredictionHead,
)
+from .model import MultiMoleculeModel
+from .registry import ModelRegistry
__all__ = [
+ "ModelRegistry",
+ "MultiMoleculeModel",
+ "CriterionRegistry",
"Criterion",
"PositionEmbeddingRegistry",
- "PositionEmbeddingRegistryHF",
"RotaryEmbedding",
"SinusoidalEmbedding",
"BaseHeadConfig",
@@ -48,14 +50,12 @@
"HeadRegistry",
"PredictionHead",
"SequencePredictionHead",
- "TokenHeadRegistryHF",
"TokenPredictionHead",
"TokenKMerHead",
"ContactPredictionHead",
"MaskedLMHead",
"HeadOutput",
"HeadTransformRegistry",
- "HeadTransformRegistryHF",
"LinearTransform",
"NonLinearTransform",
"IdentityTransform",
diff --git a/multimolecule/module/backbones/__init__.py b/multimolecule/module/backbones/__init__.py
new file mode 100644
index 00000000..d69e6292
--- /dev/null
+++ b/multimolecule/module/backbones/__init__.py
@@ -0,0 +1,21 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from .registry import BackboneRegistry
+from .sequence import SequenceBackbone
+from .sequences import SequenceRegistry
+
+__all__ = ["BackboneRegistry", "SequenceRegistry", "SequenceBackbone"]
diff --git a/multimolecule/module/backbones/registry.py b/multimolecule/module/backbones/registry.py
new file mode 100644
index 00000000..47be122d
--- /dev/null
+++ b/multimolecule/module/backbones/registry.py
@@ -0,0 +1,21 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from __future__ import annotations
+
+from chanfig import Registry
+
+BackboneRegistry = Registry()
diff --git a/multimolecule/module/backbones/sequence.py b/multimolecule/module/backbones/sequence.py
new file mode 100644
index 00000000..a30cbf83
--- /dev/null
+++ b/multimolecule/module/backbones/sequence.py
@@ -0,0 +1,46 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from __future__ import annotations
+
+import torch
+from chanfig import FlatDict
+from danling import NestedTensor
+from torch import Tensor, nn
+
+from .registry import BackboneRegistry
+from .sequences import SequenceRegistry
+
+
+@BackboneRegistry.register("sequence", default=True)
+class SequenceBackbone(nn.Module):
+ def __init__(self, sequence) -> None:
+ super().__init__()
+ self.sequence = SequenceRegistry.build(**sequence)
+ self.sequence_dropout = nn.Dropout(sequence.pop("dropout", 0), inplace=True)
+ self.config = self.sequence.config
+ self.out_channels = self.config.hidden_size
+
+ def forward(self, sequence: NestedTensor | Tensor, *args, **kwargs) -> tuple[FlatDict, FlatDict]:
+ attentions = None
+ input_ids, attention_mask = sequence.tensor, sequence.mask
+ sequence_output = self.sequence(input_ids.int(), attention_mask)
+ sequence_output["pooler_output"] = self.sequence_dropout(sequence_output["pooler_output"])
+ sequence_output["last_hidden_state"] = self.sequence_dropout(sequence_output["last_hidden_state"])
+ if "attentions" in sequence_output:
+ attentions = torch.stack(sequence_output["attentions"], dim=1).detach()
+
+ return sequence_output, attentions
diff --git a/multimolecule/module/backbones/sequences/__init__.py b/multimolecule/module/backbones/sequences/__init__.py
new file mode 100644
index 00000000..e6e5cd08
--- /dev/null
+++ b/multimolecule/module/backbones/sequences/__init__.py
@@ -0,0 +1,20 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from .onehot import OneHot
+from .registry import SequenceRegistry
+
+__all__ = ["SequenceRegistry", "OneHot"]
diff --git a/multimolecule/module/backbones/sequences/onehot.py b/multimolecule/module/backbones/sequences/onehot.py
new file mode 100644
index 00000000..bc4c979f
--- /dev/null
+++ b/multimolecule/module/backbones/sequences/onehot.py
@@ -0,0 +1,39 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+import torch
+from chanfig import FlatDict
+from torch import nn
+from transformers import AutoConfig
+
+from .registry import SequenceRegistry
+
+
+@SequenceRegistry.register("onehot")
+class OneHot(nn.Module):
+ def __init__(self, pretrained: str) -> None:
+ super().__init__()
+ self.config = AutoConfig.from_pretrained(str(pretrained))
+ self.module = nn.Embedding(self.config.vocab_size, self.config.hidden_size)
+
+ def forward(self, input_ids, attn_mask) -> FlatDict:
+ output = FlatDict()
+ output["last_hidden_state"] = self.module(input_ids)
+ valid_length = attn_mask.sum(dim=1)
+ output["pooler_output"] = torch.stack(
+ [t[: valid_length[i]].sum(0) for i, t in enumerate(output["last_hidden_state"])]
+ )
+ return output
diff --git a/multimolecule/module/backbones/sequences/registry.py b/multimolecule/module/backbones/sequences/registry.py
new file mode 100644
index 00000000..c9178231
--- /dev/null
+++ b/multimolecule/module/backbones/sequences/registry.py
@@ -0,0 +1,66 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from __future__ import annotations
+
+import danling as dl
+import transformers
+from chanfig import Registry as Registry_
+from torch import nn
+from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
+
+
+class Registry(Registry_): # pylint: disable=too-few-public-methods
+ def build(
+ self,
+ type: str | None = None,
+ name: str | None = None,
+ use_pretrained: bool = True,
+ gradient_checkpoint: bool = False,
+ checkpoint: str | None = None,
+ *args,
+ **kwargs,
+ ) -> nn.Module:
+ if type is not None:
+ if type in self:
+ sequence_cls = self.lookup(type)
+ sequence = self.init(sequence_cls, *args, **kwargs)
+ if checkpoint is not None:
+ sequence.load_state_dict(dl.load(checkpoint))
+ elif hasattr(transformers, type + "Model"):
+ if use_pretrained:
+ sequence_cls: PreTrainedModel = getattr(transformers, type + "Model") # type: ignore[no-redef]
+ sequence = sequence_cls.from_pretrained(name, *args, **kwargs)
+ else:
+ config_cls: PretrainedConfig = getattr(transformers, type + "Config")
+ config, kwargs = config_cls.from_pretrained(name, return_unused_kwargs=True, **kwargs)
+ sequence_cls: PreTrainedModel = getattr(transformers, type + "Model") # type: ignore[no-redef]
+ sequence = sequence_cls.from_config(config, *args, **kwargs)
+ else:
+ raise ValueError(f"Sequence {type} not found in registry or transformers")
+ else:
+ if use_pretrained:
+ sequence = AutoModel.from_pretrained(name, *args, **kwargs)
+ else:
+ config, kwargs = AutoConfig.from_pretrained(name, return_unused_kwargs=True, **kwargs)
+ sequence = AutoModel.from_config(config, *args, **kwargs)
+
+ if gradient_checkpoint:
+ sequence.gradient_checkpointing_enable()
+ return sequence
+
+
+SequenceRegistry = Registry()
diff --git a/multimolecule/module/criterions/__init__.py b/multimolecule/module/criterions/__init__.py
index 104334b5..4b9adf7e 100644
--- a/multimolecule/module/criterions/__init__.py
+++ b/multimolecule/module/criterions/__init__.py
@@ -14,6 +14,18 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+from .binary import BCEWithLogitsLoss
from .generic import Criterion
+from .multiclass import CrossEntropyLoss
+from .multilabel import MultiLabelSoftMarginLoss
+from .registry import CriterionRegistry
+from .regression import MSELoss
-__all__ = ["Criterion"]
+__all__ = [
+ "CriterionRegistry",
+ "Criterion",
+ "MSELoss",
+ "BCEWithLogitsLoss",
+ "CrossEntropyLoss",
+ "MultiLabelSoftMarginLoss",
+]
diff --git a/multimolecule/module/criterions/binary.py b/multimolecule/module/criterions/binary.py
new file mode 100644
index 00000000..0bf53e59
--- /dev/null
+++ b/multimolecule/module/criterions/binary.py
@@ -0,0 +1,44 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+import torch
+from danling import NestedTensor
+from torch import Tensor, nn
+
+from .registry import CriterionRegistry
+
+if TYPE_CHECKING:
+ from ..heads.config import HeadConfig
+
+
+@CriterionRegistry.register("binary")
+class BCEWithLogitsLoss(nn.BCEWithLogitsLoss):
+ def __init__(self, config: HeadConfig) -> None:
+ super().__init__(**config.get("loss", {}))
+ self.config = config
+
+ def forward(self, input: NestedTensor | Tensor, target: NestedTensor | Tensor) -> Tensor:
+ if isinstance(input, NestedTensor):
+ input = torch.cat(input.flatten().storage())
+ if isinstance(target, NestedTensor):
+ target = torch.cat(target.flatten().storage())
+ if input.ndim == target.ndim + 1:
+ input = input.squeeze(-1)
+ return super().forward(input, target.float())
diff --git a/multimolecule/module/criterions/generic.py b/multimolecule/module/criterions/generic.py
index b003c81d..a6731933 100644
--- a/multimolecule/module/criterions/generic.py
+++ b/multimolecule/module/criterions/generic.py
@@ -17,8 +17,8 @@
from __future__ import annotations
from typing import TYPE_CHECKING
+from warnings import warn
-import torch
from danling import NestedTensor
from torch import Tensor, nn
from torch.nn import functional as F
@@ -26,10 +26,13 @@
if TYPE_CHECKING:
from ..heads.config import HeadConfig
+from .registry import CriterionRegistry
+
+@CriterionRegistry.register(default=True)
class Criterion(nn.Module):
- problem_types = ["regression", "single_label_classification", "multi_label_classification"]
+ problem_types = ["regression", "binary", "multiclass", "multilabel"]
def __init__(self, config: HeadConfig) -> None:
super().__init__()
@@ -41,21 +44,31 @@ def forward(self, logits: Tensor | NestedTensor, labels: Tensor | NestedTensor)
if labels is None:
return None
if self.problem_type is None:
- if self.num_labels == 1:
+ if labels.is_floating_point():
self.problem_type = "regression"
- elif self.num_labels > 1 and labels.dtype in (torch.long, torch.int):
- self.problem_type = "single_label_classification"
+ elif self.num_labels == 1:
+ self.problem_type = "binary"
+ elif labels.unique().numel() == 2:
+ self.problem_type = "multilabel"
else:
- self.problem_type = "multi_label_classification"
+ self.problem_type = "multiclass"
+ warn(
+ f"`problem_type` is not set. Assuming {self.problem_type}. \n"
+ "This can lead to unexpected behavior. Please set `problem_type` explicitly."
+ )
self.config.problem_type = self.problem_type
if self.problem_type == "regression":
labels = labels.to(logits.dtype)
if self.num_labels == 1:
return F.mse_loss(logits.squeeze(), labels.squeeze())
logits, labels = logits.view(-1, self.num_labels), labels.view(-1, self.num_labels)
- return sum(F.mse_loss(logits[:, i], labels[:, i]).sqrt() for i in range(self.num_labels))
- if self.problem_type == "single_label_classification":
+ return sum(F.mse_loss(logits[:, i], labels[:, i]).sqrt() for i in range(self.num_labels)) # type: ignore
+ if self.problem_type == "multiclass":
return F.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1))
- if self.problem_type == "multi_label_classification":
- return F.binary_cross_entropy_with_logits(logits, labels)
+ if self.problem_type == "binary":
+ if logits.ndim == labels.ndim + 1:
+ logits = logits.squeeze(-1)
+ return F.binary_cross_entropy_with_logits(logits, labels.to(logits.dtype))
+ if self.problem_type == "multilabel":
+ return F.multilabel_soft_margin_loss(logits, labels.to(logits.dtype))
raise ValueError(f"problem_type should be one of {self.problem_types}, but got {self.problem_type}")
diff --git a/multimolecule/module/criterions/multiclass.py b/multimolecule/module/criterions/multiclass.py
new file mode 100644
index 00000000..f7070e94
--- /dev/null
+++ b/multimolecule/module/criterions/multiclass.py
@@ -0,0 +1,44 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+import torch
+from danling import NestedTensor
+from torch import Tensor, nn
+
+if TYPE_CHECKING:
+ from ..heads.config import HeadConfig
+
+from .registry import CriterionRegistry
+
+
+@CriterionRegistry.register("multiclass")
+class CrossEntropyLoss(nn.CrossEntropyLoss):
+ def __init__(self, config: HeadConfig) -> None:
+ super().__init__(**config.get("loss", {}))
+ self.config = config
+
+ def forward(self, input: NestedTensor | Tensor, target: NestedTensor | Tensor) -> Tensor:
+ if isinstance(input, NestedTensor):
+ input = torch.cat(input.storage())
+ if isinstance(target, NestedTensor):
+ target = torch.cat(target.storage())
+ if input.ndim > 2:
+ input, target = input.view(-1, input.size(-1)), target.view(-1)
+ return super().forward(input, target.long())
diff --git a/multimolecule/module/criterions/multilabel.py b/multimolecule/module/criterions/multilabel.py
new file mode 100644
index 00000000..c72bb9f9
--- /dev/null
+++ b/multimolecule/module/criterions/multilabel.py
@@ -0,0 +1,44 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+import torch
+from danling import NestedTensor
+from torch import Tensor, nn
+
+if TYPE_CHECKING:
+ from ..heads.config import HeadConfig
+
+from .registry import CriterionRegistry
+
+
+@CriterionRegistry.register("multilabel")
+class MultiLabelSoftMarginLoss(nn.MultiLabelSoftMarginLoss):
+ def __init__(self, config: HeadConfig) -> None:
+ super().__init__(**config.get("loss", {}))
+ self.config = config
+
+ def forward(self, input: NestedTensor | Tensor, target: NestedTensor | Tensor) -> Tensor:
+ if isinstance(target, NestedTensor) and target.ndim > 2:
+ input, target = input.view(-1, input.size(-1)), target.view(-1, target.size(-1))
+ if isinstance(input, NestedTensor):
+ input = torch.cat(input.storage())
+ if isinstance(target, NestedTensor):
+ target = torch.cat(target.storage())
+ return super().forward(input, target.float())
diff --git a/multimolecule/module/criterions/registry.py b/multimolecule/module/criterions/registry.py
new file mode 100644
index 00000000..856341f7
--- /dev/null
+++ b/multimolecule/module/criterions/registry.py
@@ -0,0 +1,29 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from chanfig import ConfigRegistry as Registry_
+from torch import nn
+
+
+class Registry(Registry_): # pylint: disable=too-few-public-methods
+ key = "problem_type"
+
+ def build(self, config) -> nn.Module: # type: ignore[override]
+ name = getattr(config, self.getattr("key"))
+ return self.init(self.lookup(name), config) # type: ignore[arg-type]
+
+
+CriterionRegistry = Registry(fallback=True)
diff --git a/multimolecule/module/criterions/regression.py b/multimolecule/module/criterions/regression.py
new file mode 100644
index 00000000..4f39e0eb
--- /dev/null
+++ b/multimolecule/module/criterions/regression.py
@@ -0,0 +1,44 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+import torch
+from danling import NestedTensor
+from torch import Tensor, nn
+
+if TYPE_CHECKING:
+ from ..heads.config import HeadConfig
+
+from .registry import CriterionRegistry
+
+
+@CriterionRegistry.register("regression")
+class MSELoss(nn.MSELoss):
+ def __init__(self, config: HeadConfig) -> None:
+ super().__init__(**config.get("loss", {}))
+ self.config = config
+
+ def forward(self, input: NestedTensor | Tensor, target: NestedTensor | Tensor) -> Tensor:
+ if isinstance(input, NestedTensor):
+ input = torch.cat(input.flatten().storage())
+ if isinstance(target, NestedTensor):
+ target = torch.cat(target.flatten().storage())
+ if input.ndim == target.ndim + 1:
+ target = target.unsqueeze(-1)
+ return super().forward(input, target.to(input.dtype))
diff --git a/multimolecule/module/heads/__init__.py b/multimolecule/module/heads/__init__.py
index 8cc91f29..0e857c5e 100644
--- a/multimolecule/module/heads/__init__.py
+++ b/multimolecule/module/heads/__init__.py
@@ -21,14 +21,8 @@
from .pretrain import MaskedLMHead
from .registry import HeadRegistry
from .sequence import SequencePredictionHead
-from .token import TokenHeadRegistryHF, TokenKMerHead, TokenPredictionHead
-from .transform import (
- HeadTransformRegistry,
- HeadTransformRegistryHF,
- IdentityTransform,
- LinearTransform,
- NonLinearTransform,
-)
+from .token import TokenKMerHead, TokenPredictionHead
+from .transform import HeadTransformRegistry, IdentityTransform, LinearTransform, NonLinearTransform
__all__ = [
"BaseHeadConfig",
@@ -37,14 +31,12 @@
"HeadRegistry",
"PredictionHead",
"SequencePredictionHead",
- "TokenHeadRegistryHF",
"TokenPredictionHead",
"TokenKMerHead",
"ContactPredictionHead",
"MaskedLMHead",
"HeadOutput",
"HeadTransformRegistry",
- "HeadTransformRegistryHF",
"LinearTransform",
"NonLinearTransform",
"IdentityTransform",
diff --git a/multimolecule/module/heads/config.py b/multimolecule/module/heads/config.py
index bb0dbba6..3b9ee64b 100644
--- a/multimolecule/module/heads/config.py
+++ b/multimolecule/module/heads/config.py
@@ -16,15 +16,13 @@
from __future__ import annotations
-from collections import OrderedDict
-from dataclasses import dataclass
+from chanfig import FlatDict
-class BaseHeadConfig(OrderedDict):
+class BaseHeadConfig(FlatDict):
pass
-@dataclass
class HeadConfig(BaseHeadConfig):
r"""
Configuration class for a prediction head.
@@ -35,8 +33,8 @@ class HeadConfig(BaseHeadConfig):
Head should look for [`Config.num_labels`][multimolecule.PreTrainedConfig] if is `None`.
problem_type:
- Problem type for `XxxForYyyPrediction` models. Can be one of `"regression"`,
- `"single_label_classification"` or `"multi_label_classification"`.
+ Problem type for `XxxForYyyPrediction` models. Can be one of `"binary"`, `"regression"`,
+ `"multiclass"` or `"multilabel"`.
Head should look for [`Config.problem_type`][multimolecule.PreTrainedConfig] if is `None`.
hidden_size:
@@ -55,14 +53,18 @@ class HeadConfig(BaseHeadConfig):
The activation function of the final prediction output.
layer_norm_eps:
The epsilon used by the layer normalization layers.
- output_name (`str`, *optional*):
+ output_name:
The name of the tensor required in model outputs.
If is `None`, will use the default output name of the corresponding head.
+ type:
+ The type of the head in the model.
+
+ This is used by [`MultiMoleculeModel`][multimolecule.MultiMoleculeModel] to construct heads.
"""
- num_labels: int = None # type: ignore[assignment]
- problem_type: str = None # type: ignore[assignment]
+ num_labels: int | None = None
+ problem_type: str | None = None
hidden_size: int | None = None
dropout: float = 0.0
transform: str | None = None
@@ -71,9 +73,9 @@ class HeadConfig(BaseHeadConfig):
act: str | None = None
layer_norm_eps: float = 1e-12
output_name: str | None = None
+ type: str | None = None
-@dataclass
class MaskedLMHeadConfig(BaseHeadConfig):
r"""
Configuration class for a Masked Language Modeling head.
@@ -95,7 +97,7 @@ class MaskedLMHeadConfig(BaseHeadConfig):
The activation function of the final prediction output.
layer_norm_eps:
The epsilon used by the layer normalization layers.
- output_name (`str`, *optional*):
+ output_name:
The name of the tensor required in model outputs.
If is `None`, will use the default output name of the corresponding head.
diff --git a/multimolecule/module/heads/contact.py b/multimolecule/module/heads/contact.py
index 50ec4fbb..866e2250 100644
--- a/multimolecule/module/heads/contact.py
+++ b/multimolecule/module/heads/contact.py
@@ -16,10 +16,12 @@
from __future__ import annotations
-from typing import Mapping, Tuple
+from typing import Callable, Mapping, Tuple, Type
import torch
from danling import NestedTensor
+from danling.modules import MLP
+from lazy_imports import try_import
from torch import Tensor, nn
from transformers.modeling_outputs import ModelOutput
from typing_extensions import TYPE_CHECKING
@@ -28,13 +30,86 @@
from .generic import PredictionHead
from .output import HeadOutput
from .registry import HeadRegistry
-from .utils import average_product_correct, symmetrize
+
+with try_import() as tv:
+ from torchvision.models.resnet import BasicBlock, Bottleneck
if TYPE_CHECKING:
from multimolecule.models import PreTrainedConfig
-@HeadRegistry.register("contact")
+@HeadRegistry.contact.register("default", default=True)
+class ContactHead(PredictionHead):
+
+ output_name: str = "last_hidden_state"
+
+ def __init__(self, config: PreTrainedConfig, head_config: HeadConfig | None = None):
+ super().__init__(config, head_config)
+ self._bos_token_id = config.bos_token_id
+ self._eos_token_id = config.eos_token_id
+ self.pad_token_id = config.pad_token_id
+ out_channels: int = self.config.hidden_size # type: ignore[assignment]
+ self.qk_proj = nn.Linear(out_channels, 2 * out_channels)
+ self.ffn = MLP(1, out_channels, residual=False)
+
+ def forward( # type: ignore[override] # pylint: disable=arguments-renamed
+ self,
+ outputs: ModelOutput | Mapping | Tuple[Tensor, ...],
+ attention_mask: Tensor | None = None,
+ input_ids: NestedTensor | Tensor | None = None,
+ labels: Tensor | None = None,
+ output_name: str | None = None,
+ **kwargs,
+ ) -> HeadOutput:
+ if attention_mask is None:
+ if isinstance(input_ids, NestedTensor):
+ input_ids, attention_mask = input_ids.tensor, input_ids.mask
+ else:
+ if input_ids is None:
+ raise ValueError(
+ f"Either attention_mask or input_ids must be provided for {self.__class__.__name__} to work."
+ )
+ if self.pad_token_id is None:
+ raise ValueError(
+ f"pad_token_id must be provided when attention_mask is not passed to {self.__class__.__name__}."
+ )
+ attention_mask = input_ids.ne(self.pad_token_id)
+
+ if isinstance(outputs, (Mapping, ModelOutput)):
+ output = outputs[output_name or self.output_name]
+ elif isinstance(outputs, tuple):
+ output = outputs[0]
+ else:
+ raise ValueError(f"Unsupported type for outputs: {type(outputs)}")
+ output = output * attention_mask.unsqueeze(-1)
+
+ # remove cls token embeddings
+ if self.bos_token_id is not None:
+ output = output[..., 1:, :]
+ # process attention_mask and input_ids to make removal of eos token happy
+ attention_mask = attention_mask[..., 1:]
+ if input_ids is not None:
+ input_ids = input_ids[..., 1:]
+ # remove eos token embeddings
+ if self.eos_token_id is not None:
+ if input_ids is not None:
+ eos_mask = input_ids.ne(self.eos_token_id).to(output)
+ else:
+ last_valid_indices = attention_mask.sum(dim=-1)
+ seq_length = attention_mask.size(-1)
+ eos_mask = torch.arange(seq_length, device=output.device) == last_valid_indices.unsqueeze(1)
+ output = output * eos_mask[:, :, None]
+ output = output[..., :-1, :]
+
+ q, k = self.qk_proj(output).chunk(2, dim=-1)
+ contact_map = (q @ k.transpose(-2, -1)).unsqueeze(-1)
+ contact_map = contact_map + self.ffn(contact_map)
+ if "continuous" in outputs:
+ contact_map = contact_map * (1 + outputs["continuous"].unsqueeze(dim=-1)) # type: ignore[call-overload]
+ return super().forward(contact_map, labels)
+
+
+@HeadRegistry.contact.register("attention")
class ContactPredictionHead(PredictionHead):
r"""
Head for tasks in contact-level.
@@ -50,13 +125,23 @@ class ContactPredictionHead(PredictionHead):
output_name: str = "attentions"
r"""The default output to use for the head."""
+ requires_attention: bool = True
+
def __init__(self, config: PreTrainedConfig, head_config: HeadConfig | None = None):
super().__init__(config, head_config)
self.bos_token_id = config.bos_token_id
self.eos_token_id = config.eos_token_id
self.pad_token_id = config.pad_token_id
- self.decoder = nn.Linear(
- config.num_hidden_layers * config.num_attention_heads, self.num_labels, bias=self.config.bias
+ self.config.hidden_size = config.num_hidden_layers * config.num_attention_heads
+ num_layers = self.config.get("num_layers", 16)
+ num_channels = self.config.get("num_channels", self.config.hidden_size // 10) # type: ignore[operator]
+ block = self.config.get("block", "auto")
+ self.decoder = ResNet(
+ num_layers=num_layers,
+ hidden_size=self.config.hidden_size, # type: ignore[arg-type]
+ block=block,
+ num_channels=num_channels,
+ num_labels=self.num_labels,
)
if head_config is not None and head_config.output_name is not None:
self.output_name = head_config.output_name
@@ -106,7 +191,7 @@ def forward( # type: ignore[override] # pylint: disable=arguments-renamed
# but it does for the contact prediction task, which takes attentions as input,
# so we have to mimic that here.
attention_mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2)
- attentions *= attention_mask[:, None, None, :, :]
+ attentions = attentions * attention_mask[:, None, None, :, :]
# remove cls token attentions
if self.bos_token_id is not None:
@@ -124,14 +209,203 @@ def forward( # type: ignore[override] # pylint: disable=arguments-renamed
seq_length = attention_mask.size(-1)
eos_mask = torch.arange(seq_length, device=attentions.device).unsqueeze(0) == last_valid_indices
eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2)
- attentions *= eos_mask[:, None, None, :, :]
+ attentions = attentions * eos_mask[:, None, None, :, :]
attentions = attentions[..., :-1, :-1]
# features: batch x channels x input_ids x input_ids (symmetric)
batch_size, layers, heads, seqlen, _ = attentions.size()
attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen)
- attentions = attentions.to(self.decoder.weight.device)
+ attentions = attentions.to(self.decoder.proj.weight.device)
attentions = average_product_correct(symmetrize(attentions))
attentions = attentions.permute(0, 2, 3, 1).squeeze(3)
return super().forward(attentions, labels, **kwargs)
+
+
+@HeadRegistry.contact.register("logits")
+class ContactLogitsHead(PredictionHead):
+ r"""
+ Head for tasks in contact-level.
+
+ Performs symmetrization, and average product correct.
+
+ Args:
+ config: The configuration object for the model.
+ head_config: The configuration object for the head.
+ If None, will use configuration from the `config`.
+ """
+
+ output_name: str = "last_hidden_state"
+ r"""The default output to use for the head."""
+
+ requires_attention: bool = False
+
+ def __init__(self, config: PreTrainedConfig, head_config: HeadConfig | None = None):
+ super().__init__(config, head_config)
+ self.bos_token_id = config.bos_token_id
+ self.eos_token_id = config.eos_token_id
+ self.pad_token_id = config.pad_token_id
+ num_layers = self.config.get("num_layers", 16)
+ num_channels = self.config.get("num_channels", self.config.hidden_size // 10) # type: ignore[operator]
+ block = self.config.get("block", "auto")
+ self.decoder = ResNet(
+ num_layers=num_layers,
+ hidden_size=self.config.hidden_size, # type: ignore[arg-type]
+ block=block,
+ num_channels=num_channels,
+ num_labels=self.num_labels,
+ )
+ if head_config is not None and head_config.output_name is not None:
+ self.output_name = head_config.output_name
+
+ def forward( # type: ignore[override] # pylint: disable=arguments-renamed
+ self,
+ outputs: ModelOutput | Mapping | Tuple[Tensor, ...],
+ attention_mask: Tensor | None = None,
+ input_ids: NestedTensor | Tensor | None = None,
+ labels: Tensor | None = None,
+ output_name: str | None = None,
+ **kwargs,
+ ) -> HeadOutput:
+ r"""
+ Forward pass of the ContactPredictionHead.
+
+ Args:
+ outputs: The outputs of the model.
+ attention_mask: The attention mask for the inputs.
+ input_ids: The input ids for the inputs.
+ labels: The labels for the head.
+ output_name: The name of the output to use.
+ Defaults to `self.output_name`.
+ """
+ if attention_mask is None:
+ if isinstance(input_ids, NestedTensor):
+ input_ids, attention_mask = input_ids.tensor, input_ids.mask
+ else:
+ if input_ids is None:
+ raise ValueError(
+ f"Either attention_mask or input_ids must be provided for {self.__class__.__name__} to work."
+ )
+ if self.pad_token_id is None:
+ raise ValueError(
+ f"pad_token_id must be provided when attention_mask is not passed to {self.__class__.__name__}."
+ )
+ attention_mask = input_ids.ne(self.pad_token_id)
+
+ if isinstance(outputs, (Mapping, ModelOutput)):
+ output = outputs[output_name or self.output_name]
+ elif isinstance(outputs, tuple):
+ output = outputs[0]
+ else:
+ raise ValueError(f"Unsupported type for outputs: {type(outputs)}")
+ output = output * attention_mask.unsqueeze(-1)
+
+ # remove cls token embeddings
+ if self.bos_token_id is not None:
+ output = output[..., 1:, :]
+ # process attention_mask and input_ids to make removal of eos token happy
+ attention_mask = attention_mask[..., 1:]
+ if input_ids is not None:
+ input_ids = input_ids[..., 1:]
+ # remove eos token embeddings
+ if self.eos_token_id is not None:
+ if input_ids is not None:
+ eos_mask = input_ids.ne(self.eos_token_id).to(output)
+ else:
+ last_valid_indices = attention_mask.sum(dim=-1)
+ seq_length = attention_mask.size(-1)
+ eos_mask = torch.arange(seq_length, device=output.device) == last_valid_indices.unsqueeze(1)
+ output = output * eos_mask[:, :, None]
+ output = output[..., :-1, :]
+
+ # make symmetric contact map
+ contact_map = output.unsqueeze(1) * output.unsqueeze(2)
+
+ return super().forward(contact_map, labels, **kwargs)
+
+
+class ResNet(nn.Module):
+ def __init__(
+ self,
+ num_layers: int,
+ hidden_size: int,
+ block: Type[BasicBlock | Bottleneck] | str = "auto",
+ num_channels: int | None = None,
+ num_labels: int = 1,
+ norm_layer: Callable[..., nn.Module] | None = None,
+ zero_init_residual: bool = True,
+ ) -> None:
+ tv.check()
+ super().__init__()
+
+ if block == "auto":
+ block = BasicBlock if num_layers < 50 else Bottleneck
+ elif block in ("basic", "BasicBlock"):
+ block = BasicBlock
+ elif block in ("bottleneck", "Bottleneck"):
+ block = Bottleneck
+ else:
+ raise ValueError(f"Unknown block type: {block}")
+ if num_channels is None:
+ num_channels = hidden_size // 10
+ if norm_layer is None:
+ norm_layer = LayerNorm2D
+
+ self.proj = nn.Conv2d(hidden_size, num_channels, kernel_size=3, stride=1, padding=1, bias=False)
+ self.norm = norm_layer(num_channels)
+ self.relu = nn.ReLU(inplace=True)
+ self.layers = nn.Sequential(
+ *[block(num_channels, num_channels, norm_layer=norm_layer) for _ in range(num_layers)] # type: ignore
+ )
+ self.output = nn.Linear(num_channels, num_labels)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ # Zero-initialize the last BN in each residual branch,
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
+ if zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, Bottleneck) and m.bn3.weight is not None:
+ nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
+ elif isinstance(m, BasicBlock) and m.bn2.weight is not None:
+ nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
+
+ def forward(self, x: Tensor) -> Tensor:
+ x = self.proj(x.transpose(1, 3))
+ x = self.norm(x)
+ x = self.relu(x)
+ x = self.layers(x)
+ x = self.output(x.transpose(1, 3))
+ return x
+
+
+class LayerNorm2D(nn.GroupNorm):
+ def __init__(self, num_features: int, eps: float = 1e-5, elementwise_affine: bool = True) -> None:
+ super().__init__(num_channels=num_features, eps=eps, affine=elementwise_affine, num_groups=1)
+ self.num_channels = num_features
+
+ def __repr__(self):
+ return f"{self.__class__.__name__}(num_channels={self.num_channels}, eps={self.eps}, affine={self.affine})"
+
+
+def symmetrize(x):
+ "Make layer symmetric in final two dimensions, used for contact prediction."
+ return x + x.transpose(-1, -2)
+
+
+def average_product_correct(x):
+ "Perform average product correct, used for contact prediction."
+ a1 = x.sum(-1, keepdims=True)
+ a2 = x.sum(-2, keepdims=True)
+ a12 = x.sum((-1, -2), keepdims=True)
+
+ avg = a1 * a2
+ avg.div_(a12) # in-place to reduce memory
+ normalized = x - avg
+ return normalized
diff --git a/multimolecule/module/heads/generic.py b/multimolecule/module/heads/generic.py
index d97950a2..ef0fa520 100644
--- a/multimolecule/module/heads/generic.py
+++ b/multimolecule/module/heads/generic.py
@@ -19,12 +19,11 @@
from typing import TYPE_CHECKING
from warnings import warn
-import torch
from danling import NestedTensor
from torch import Tensor, nn
from transformers.activations import ACT2FN
-from ..criterions import Criterion
+from ..criterions import CriterionRegistry
from .config import HeadConfig
from .output import HeadOutput
from .transform import HeadTransformRegistryHF
@@ -44,24 +43,25 @@ class PredictionHead(nn.Module):
"""
num_labels: int
+ requires_attention: bool = False
def __init__(self, config: PreTrainedConfig, head_config: HeadConfig | None = None):
super().__init__()
if head_config is None:
- head_config = config.head
+ head_config = config.head or HeadConfig(num_labels=config.num_labels)
+ elif head_config.num_labels is None:
+ head_config.num_labels = config.num_labels
self.config = head_config
if self.config.hidden_size is None:
self.config.hidden_size = config.hidden_size
- if self.config.num_labels is None:
- self.config.num_labels = config.num_labels
if self.config.problem_type is None:
self.config.problem_type = config.problem_type
- self.num_labels = self.config.num_labels
+ self.num_labels = self.config.num_labels # type: ignore[assignment]
self.dropout = nn.Dropout(self.config.dropout)
self.transform = HeadTransformRegistryHF.build(self.config)
- self.decoder = nn.Linear(config.hidden_size, self.num_labels, bias=self.config.bias)
+ self.decoder = nn.Linear(self.config.hidden_size, self.num_labels, bias=self.config.bias)
self.activation = ACT2FN[self.config.act] if self.config.act is not None else None
- self.criterion = Criterion(self.config)
+ self.criterion = CriterionRegistry.build(self.config)
def forward(self, embeddings: Tensor, labels: Tensor | None, **kwargs) -> HeadOutput:
r"""
@@ -85,6 +85,6 @@ def forward(self, embeddings: Tensor, labels: Tensor | None, **kwargs) -> HeadOu
if isinstance(labels, NestedTensor):
if isinstance(output, Tensor):
output = labels.nested_like(output, strict=False)
- return HeadOutput(output, self.criterion(torch.cat(output.storage()), torch.cat(labels.storage())))
+ return HeadOutput(output, self.criterion(output.concat, labels.concat))
return HeadOutput(output, self.criterion(output, labels))
return HeadOutput(output)
diff --git a/multimolecule/module/heads/pretrain.py b/multimolecule/module/heads/pretrain.py
index 994cb8ca..c6968c4b 100644
--- a/multimolecule/module/heads/pretrain.py
+++ b/multimolecule/module/heads/pretrain.py
@@ -53,8 +53,8 @@ def __init__(
):
super().__init__()
if head_config is None:
- head_config = config.lm_head if hasattr(config, "lm_head") else config.head # type: ignore[assignment]
- self.config: MaskedLMHeadConfig = head_config # type: ignore[assignment]
+ head_config = (config.lm_head if hasattr(config, "lm_head") else config.head) or MaskedLMHeadConfig()
+ self.config: MaskedLMHeadConfig = head_config
if self.config.hidden_size is None:
self.config.hidden_size = config.hidden_size
self.num_labels = config.vocab_size
@@ -97,6 +97,6 @@ def forward(
if isinstance(labels, NestedTensor):
if isinstance(output, Tensor):
output = labels.nested_like(output, strict=False)
- return HeadOutput(output, F.cross_entropy(torch.cat(output.storage()), torch.cat(labels.storage())))
+ return HeadOutput(output, F.cross_entropy(output.concat, labels.concat))
return HeadOutput(output, F.cross_entropy(output.view(-1, self.num_labels), labels.view(-1)))
return HeadOutput(output)
diff --git a/multimolecule/module/heads/registry.py b/multimolecule/module/heads/registry.py
index e5393e4e..6db3b680 100644
--- a/multimolecule/module/heads/registry.py
+++ b/multimolecule/module/heads/registry.py
@@ -14,6 +14,16 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from chanfig import Registry
+from chanfig import ConfigRegistry as Registry_
+from torch import nn
+
+
+class Registry(Registry_): # pylint: disable=too-few-public-methods
+ key = "type"
+
+ def build(self, config, head_config) -> nn.Module: # type: ignore[override]
+ name = getattr(head_config, self.getattr("key"))
+ return self.init(self.lookup(name), config, head_config) # type: ignore[arg-type]
+
HeadRegistry = Registry(default_factory=Registry, fallback=True)
diff --git a/multimolecule/module/heads/utils.py b/multimolecule/module/heads/utils.py
index c5937c6d..cc1f3654 100644
--- a/multimolecule/module/heads/utils.py
+++ b/multimolecule/module/heads/utils.py
@@ -119,32 +119,3 @@ def unfold_kmer_embeddings(
embedding = torch.cat([embedding, tensor[seq_len - 1][None, :]])
output[index, : seq_len + nmers - 1] = embedding
return output
-
-
-def rotate_half(x):
- x1, x2 = x.chunk(2, dim=-1)
- return torch.cat((-x2, x1), dim=-1)
-
-
-def apply_rotary_pos_emb(x, cos, sin):
- cos = cos[:, :, : x.shape[-2], :]
- sin = sin[:, :, : x.shape[-2], :]
-
- return (x * cos) + (rotate_half(x) * sin)
-
-
-def symmetrize(x):
- "Make layer symmetric in final two dimensions, used for contact prediction."
- return x + x.transpose(-1, -2)
-
-
-def average_product_correct(x):
- "Perform average product correct, used for contact prediction."
- a1 = x.sum(-1, keepdims=True)
- a2 = x.sum(-2, keepdims=True)
- a12 = x.sum((-1, -2), keepdims=True)
-
- avg = a1 * a2
- avg.div_(a12) # in-place to reduce memory
- normalized = x - avg
- return normalized
diff --git a/multimolecule/module/model.py b/multimolecule/module/model.py
new file mode 100644
index 00000000..256783be
--- /dev/null
+++ b/multimolecule/module/model.py
@@ -0,0 +1,89 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from __future__ import annotations
+
+from chanfig import FlatDict
+from danling import NestedTensor
+from torch import Tensor, nn
+
+from .backbones import BackboneRegistry
+from .heads import HeadRegistry
+from .necks import NeckRegistry
+from .registry import ModelRegistry
+
+
+@ModelRegistry.register(default=True)
+class MultiMoleculeModel(nn.Module):
+ def __init__(
+ self,
+ backbone: dict,
+ heads: dict,
+ neck: dict | None = None,
+ max_length: int = 1024,
+ truncation: bool = False,
+ ):
+ super().__init__()
+
+ # Backbone
+ self.backbone = BackboneRegistry.build(**backbone)
+ backbone = self.backbone.config
+ out_channels = self.backbone.out_channels
+
+ # Neck
+ if neck:
+ num_discrete = self.backbone.num_discrete
+ num_continuous = self.backbone.num_continuous
+ embed_dim = self.backbone.sequence.config.hidden_size
+ attention_heads = self.backbone.sequence.config.num_attention_heads
+ neck.update(
+ {
+ "num_discrete": num_discrete,
+ "num_continuous": num_continuous,
+ "embed_dim": embed_dim,
+ "attention_heads": attention_heads,
+ "max_length": max_length,
+ "truncation": truncation,
+ }
+ )
+ self.neck = NeckRegistry.build(**neck)
+ out_channels = self.neck.out_channels
+ else:
+ self.neck = None
+
+ # Heads
+ for head in heads.values():
+ if "hidden_size" not in head or head["hidden_size"] is None:
+ head["hidden_size"] = out_channels
+ self.heads = nn.ModuleDict({name: HeadRegistry.build(backbone, head) for name, head in heads.items()})
+ if any(getattr(h, "requires_attention", False) for h in self.heads.values()):
+ self.backbone.sequence.config.output_attentions = True
+
+ def forward(
+ self,
+ sequence: NestedTensor | Tensor,
+ discrete: Tensor | None = None,
+ continuous: Tensor | None = None,
+ dataset: str | None = None,
+ **labels: NestedTensor | Tensor,
+ ) -> FlatDict:
+ ret = FlatDict()
+ output, _ = self.backbone(sequence, discrete, continuous)
+ if self.neck is not None:
+ output = self.neck(**output)
+ for task, label in labels.items():
+ ret[task] = self.heads[task](output, input_ids=sequence, labels=label)
+ return ret
diff --git a/multimolecule/module/necks/__init__.py b/multimolecule/module/necks/__init__.py
new file mode 100644
index 00000000..e8f1f7e2
--- /dev/null
+++ b/multimolecule/module/necks/__init__.py
@@ -0,0 +1,21 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from .bert import BERTNeck
+from .cat import CatNeck
+from .registry import NeckRegistry
+
+__all__ = ["NeckRegistry", "CatNeck", "BERTNeck"]
diff --git a/multimolecule/module/necks/bert.py b/multimolecule/module/necks/bert.py
new file mode 100644
index 00000000..1360f0dd
--- /dev/null
+++ b/multimolecule/module/necks/bert.py
@@ -0,0 +1,102 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from __future__ import annotations
+
+import torch
+from chanfig import FlatDict
+from danling.modules import TransformerEncoder, TransformerEncoderLayer
+from torch import Tensor, nn
+
+from .registry import NeckRegistry
+
+MAX_LENGTH = 1024
+
+
+@NeckRegistry.register("bert")
+class BERTNeck(nn.Module):
+ def __init__( # pylint: disable=keyword-arg-before-vararg
+ self,
+ num_discrete: int,
+ num_continuous: int,
+ embed_dim: int,
+ attention_heads: int,
+ num_layers: int = 6,
+ max_length: int | None = None,
+ truncation: bool = False,
+ dropout: float = 0.1,
+ *args,
+ **kwargs,
+ ):
+ super().__init__()
+ self.cls_token_dis = nn.Parameter(torch.zeros(embed_dim))
+ self.cls_token_con = nn.Parameter(torch.zeros(embed_dim))
+ if max_length is None:
+ if truncation:
+ max_length = MAX_LENGTH + 1 + num_discrete + 1 + num_continuous
+ else:
+ max_length = MAX_LENGTH * 4 + 1 + num_discrete + 1 + num_continuous
+ self.max_length = max_length
+ self.pos_embed = nn.Parameter(torch.zeros(1, self.max_length, embed_dim))
+ bert_layer = TransformerEncoderLayer(
+ embed_dim, attention_heads, *args, dropout=dropout, attn_dropout=dropout, ffn_dropout=dropout, **kwargs
+ )
+ self.bert = TransformerEncoder(bert_layer, num_layers)
+ self.out_channels = embed_dim
+ nn.init.normal_(self.pos_embed, std=0.02)
+ nn.init.trunc_normal_(self.cls_token_dis, std=0.2)
+ nn.init.trunc_normal_(self.cls_token_con, std=0.2)
+
+ def forward(
+ self,
+ cls_token: Tensor | None = None,
+ all_tokens: Tensor | None = None,
+ discrete: Tensor | None = None,
+ continuous: Tensor | None = None,
+ ) -> FlatDict:
+ ret = FlatDict()
+ if cls_token is not None:
+ ret["cls_token"] = self._forward(cls_token, discrete, continuous)
+ if all_tokens is not None:
+ ret["all_tokens"] = self._forward(all_tokens, discrete, continuous)
+ return ret
+
+ def _forward(
+ self,
+ sequence: Tensor,
+ discrete: Tensor | None = None,
+ continuous: Tensor | None = None,
+ ) -> Tensor:
+ if sequence is None:
+ raise ValueError("sequence should not be None.")
+ if sequence.dim() == 2:
+ sequence = sequence[:, None]
+ batch_size, seq_len, _ = sequence.shape
+ output = sequence
+ if discrete is not None:
+ cls_token_dis = self.cls_token_dis.expand(batch_size, 1, -1)
+ output = torch.cat((output, cls_token_dis, discrete), dim=1)
+ if continuous is not None:
+ cls_token_con = self.cls_token_con.expand(batch_size, -1)[:, None]
+ output = torch.cat((output, cls_token_con, continuous), dim=1)
+ all_len = output.shape[1]
+ if all_len > self.pos_embed.shape[1]:
+ raise ValueError("sequence length is out of range.")
+ output = output + self.pos_embed[:, 0:all_len, :]
+ output = self.bert(output)[0][:, 0:seq_len, :]
+ if seq_len == 1:
+ output = output.squeeze(1)
+ return output
diff --git a/multimolecule/module/necks/cat.py b/multimolecule/module/necks/cat.py
new file mode 100644
index 00000000..d5165a92
--- /dev/null
+++ b/multimolecule/module/necks/cat.py
@@ -0,0 +1,43 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from __future__ import annotations
+
+import torch
+from chanfig import FlatDict
+from torch import Tensor
+
+from .registry import NeckRegistry
+
+
+@NeckRegistry.register("cat")
+class CatNeck: # pylint: disable=too-few-public-methods
+ def __init__(self, embed_dim: int):
+ self.out_channels = embed_dim * 2
+
+ def __call__(
+ self,
+ cls_token: Tensor | None = None,
+ all_tokens: Tensor | None = None,
+ discrete: Tensor | None = None,
+ continuous: Tensor | None = None,
+ ) -> FlatDict:
+ ret = FlatDict()
+ if cls_token is not None:
+ ret.cls_token = torch.cat(tuple(i for i in (cls_token, discrete, continuous) if i is not None), -1)
+ if all_tokens is not None:
+ ret.all_tokens = torch.cat(tuple(i for i in (all_tokens, discrete, continuous) if i is not None), -1)
+ return ret
diff --git a/multimolecule/module/necks/registry.py b/multimolecule/module/necks/registry.py
new file mode 100644
index 00000000..c024227c
--- /dev/null
+++ b/multimolecule/module/necks/registry.py
@@ -0,0 +1,21 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from __future__ import annotations
+
+from chanfig import Registry
+
+NeckRegistry = Registry()
diff --git a/multimolecule/module/registry.py b/multimolecule/module/registry.py
new file mode 100644
index 00000000..b0332463
--- /dev/null
+++ b/multimolecule/module/registry.py
@@ -0,0 +1,35 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from __future__ import annotations
+
+from chanfig import Registry as Registry_
+from torch import nn
+
+from .backbones import BackboneRegistry
+from .backbones.sequences import SequenceRegistry
+from .heads import HeadRegistry
+from .necks import NeckRegistry
+
+
+class Registry(Registry_): # pylint: disable=too-few-public-methods
+ def build(self, *args, **kwargs) -> nn.Module:
+ return super().build(*args, **kwargs)
+
+
+ModelRegistry = Registry()
+
+__all__ = ["ModelRegistry", "BackboneRegistry", "SequenceRegistry", "NeckRegistry", "HeadRegistry"]
diff --git a/multimolecule/runners/README.md b/multimolecule/runners/README.md
new file mode 100644
index 00000000..bb1000ad
--- /dev/null
+++ b/multimolecule/runners/README.md
@@ -0,0 +1,9 @@
+---
+authors:
+ - Zhiyuan Chen
+date: 2024-05-04
+---
+
+# runners
+
+`runners` provide an easy-to-use interface for running experiments.
diff --git a/multimolecule/runners/__init__.py b/multimolecule/runners/__init__.py
new file mode 100644
index 00000000..70fa4076
--- /dev/null
+++ b/multimolecule/runners/__init__.py
@@ -0,0 +1,20 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from .config import MultiMoleculeConfig
+from .runner import MultiMoleculeRunner
+
+__all__ = ["MultiMoleculeConfig", "MultiMoleculeRunner"]
diff --git a/multimolecule/runners/base_runner.py b/multimolecule/runners/base_runner.py
new file mode 100644
index 00000000..1e02d96e
--- /dev/null
+++ b/multimolecule/runners/base_runner.py
@@ -0,0 +1,286 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+import math
+import os
+from functools import cached_property, partial
+from typing import Any, Tuple
+from warnings import warn
+
+import danling as dl
+import torch
+from art import text2art
+from chanfig import NestedDict
+from danling import MultiTaskMetrics
+from datasets import disable_progress_bars, get_dataset_split_names
+from torch import nn, optim
+from torch.nn import functional as F
+from torch.utils import data
+from transformers import AutoTokenizer
+
+from multimolecule import defaults
+from multimolecule.data import Dataset, DistributedMultiTaskSampler, MultiTaskDataset, MultiTaskSampler
+from multimolecule.module import HeadConfig, ModelRegistry
+
+from .config import MultiMoleculeConfig
+from .metrics import MetricRegistry
+
+disable_progress_bars()
+
+
+class BaseRunner(dl.BaseRunner):
+
+ all_datasets: NestedDict
+
+ def __init__(self, config: MultiMoleculeConfig):
+ if config.art:
+ print(text2art("MultiMolecule", "rand-large"))
+ super().__init__(config)
+ self.name = config.name
+ self.tokenizer = AutoTokenizer.from_pretrained(self.config.pretrained)
+ self.build_datasets()
+ self.build_dataloaders()
+ self.model = ModelRegistry.build(**self.network)
+ if self.config.get("checkpoint"):
+ ckpt = dl.load(self.config.checkpoint)
+ model = ckpt.get("model", ckpt)
+ parameters = self.model.load_state_dict(model, strict=False)
+ if parameters.missing_keys:
+ raise ValueError(f"Missing keys in model: {parameters.missing_keys}")
+ if parameters.unexpected_keys:
+ warn(f"Unexpected keys in model: {parameters.unexpected_keys}")
+ self.model = self.model.to(self.device)
+ if self.distributed:
+ self.model = nn.parallel.DistributedDataParallel(
+ self.model, find_unused_parameters=True, bucket_cap_mb=32, gradient_as_bucket_view=True
+ )
+ if self.config.training:
+ if self.config.optim and not (self.config.platform == "deepspeed" and self.config.deepspeed.optimizer):
+ self.optimizer = getattr(optim, self.config.optim.pop("name"))(
+ params=self.model.parameters(), **self.config.optim
+ )
+ if self.config.sched and not (self.config.platform == "deepspeed" and self.config.deepspeed.scheduler):
+ self.scheduler = dl.optim.LRScheduler(self.optimizer, total_steps=self.total_steps, **self.config.sched)
+ self.metrics = self.build_metrics()
+
+ def __post_init__(self):
+ super().__post_init__()
+ self.yaml(os.path.join(self.dir, "trainer.yaml"))
+ print(self)
+ print(self.get_dataset_lengths())
+
+ def train_step(self, data) -> Tuple[Any, torch.Tensor]:
+ with self.autocast(), self.accumulate():
+ pred = self.model(**data)
+ loss = self.loss_fn(pred, data)
+ self.advance(loss)
+ self.metric_fn(pred, data)
+ return pred, loss
+
+ def evaluate_step(self, data) -> Tuple[Any, torch.Tensor]:
+ pred = self.model(**data)
+ loss = self.loss_fn(pred, data)
+ self.metric_fn(pred, data)
+ return pred, loss
+
+ def loss_fn(self, pred, data):
+ if self.balance == "rlw":
+ loss = torch.stack([p["loss"] for p in pred.values()])
+ weight = F.softmax(torch.randn(len(pred), device=loss.device, dtype=loss.dtype), dim=-1)
+ return loss.T @ weight
+ if self.balance == "gls":
+ return math.prod(p["loss"] for p in pred.values()) ** (1 / len(pred))
+ if self.balance != "ew":
+ warn(f"Unknown balance method {self.balance}, using equal weighting.")
+ return sum(p["loss"] for p in pred.values()) / len(pred)
+
+ def metric_fn(self, pred, data):
+ metric = self.metrics[data["dataset"]] if "dataset" in data else self.metrics
+ metric.update({t: (p["logits"], data[t]) for t, p in pred.items()})
+
+ @cached_property
+ def tasks(self):
+ if not self.datasets:
+ raise ValueError("No datasets found")
+ if "train" in self.datasets:
+ return self.datasets.train.tasks
+ return next(iter(self.datasets.values())).tasks
+
+ @cached_property
+ def dataset_tasks(self):
+ if not self.datasets:
+ raise ValueError("No datasets found")
+ dataset = self.datasets.train if "train" in self.datasets else next(iter(self.datasets.values()))
+ tasks = self.tasks
+ dataset_tasks = dataset.dataset_tasks if isinstance(dataset, MultiTaskDataset) else dataset.tasks
+ for dataset in self.datasets.values():
+ if isinstance(dataset, MultiTaskDataset):
+ for dataset_name, tasks_ in dataset.dataset_tasks.items():
+ for task_name, task_ in tasks_.items():
+ if task_name not in tasks:
+ raise ValueError(f"Task {task_name} of dataset {dataset_name} is not defined")
+ task = tasks[task_name]
+ if task != task_:
+ warn(
+ f"Task {task_name} of dataset {dataset_name} has different configurations "
+ "compared to training data, using training configuration.\n"
+ "This may lead to unexpected behavior.",
+ )
+ if dataset_name not in dataset_tasks:
+ dataset_tasks[dataset_name] = NestedDict()
+ if task_name not in dataset_tasks[dataset_name]:
+ dataset_tasks[dataset_name][task_name] = task
+ else:
+ for task_name, task_ in dataset.tasks.items():
+ if task_name not in tasks:
+ raise ValueError(f"Task {task_name} is not defined")
+ task = tasks[task_name]
+ if task != task_:
+ warn(
+ f"Task {task_name} has different configurations "
+ "compared to training data, using training configuration.\n"
+ "This may lead to unexpected behavior.",
+ )
+ if task_name not in dataset_tasks:
+ dataset_tasks[task_name] = task
+ return dataset_tasks
+
+ @cached_property
+ def network(self):
+ heads = {
+ name: HeadConfig(num_labels=task.num_labels, problem_type=task.type, type=task.level)
+ for name, task in self.tasks.items()
+ }
+ if "heads" not in self.config.network:
+ self.config.network.heads = NestedDict(heads)
+ else:
+ self.config.network.heads.merge(heads, overwrite=False)
+ return self.config.network
+
+ def build_datasets(self):
+ if "data" in self.config:
+ self.datasets = self.all_datasets = self._build_dataset(self.config.data)
+ return
+ if "datas" in self.config:
+ self.all_datasets = NestedDict(
+ {name: self._build_dataset(config, name) for name, config in self.config.datas.items()}
+ )
+ datasets = {
+ subkey: {key: subdict[subkey] for key, subdict in self.all_datasets.items() if subkey in subdict}
+ for subkey in {k for v in self.all_datasets.values() for k in v}
+ }
+ self.datasets = NestedDict({split: MultiTaskDataset(datas) for split, datas in datasets.items()})
+ return
+ raise ValueError("No data configuration found")
+
+ def _build_dataset(self, config: NestedDict, name: str | None = None) -> NestedDict:
+ name = name or config.root
+ print(f"Building dataset {name}")
+ dataset = NestedDict()
+ ignored_key = defaults.DATASET_SPLITS + ["root"]
+ dataset_factory = partial(
+ Dataset,
+ tokenizer=self.tokenizer,
+ **{k: v for k, v in config.items() if k not in ignored_key},
+ )
+ if os.path.isdir(config.root):
+ if "train" in config:
+ dataset.train = dataset_factory(os.path.join(config.root, config.train), split="train")
+ if "validation" in config:
+ dataset.validation = dataset_factory(os.path.join(config.root, config.validation), split="validation")
+ if "test" in config:
+ dataset.test = dataset_factory(os.path.join(config.root, config.test), split="test")
+ if "evaluation" in config:
+ dataset.evaluation = dataset_factory(os.path.join(config.root, config.evaluation), split="evaluation")
+ if "inference" in config:
+ dataset.inference = dataset_factory(os.path.join(config.root, config.inference), split="inference")
+ else:
+ splits = get_dataset_split_names(config.root)
+ existing_splits = {k for k in defaults.DATASET_SPLITS if config.get(k) is not None}
+ if not existing_splits:
+ if "train" in splits:
+ config.train = "train"
+ if "validation" in splits:
+ config.validation = "validation"
+ if "test" in splits:
+ config.test = "test"
+ if config.get("train") is not None:
+ dataset.train = dataset_factory(config.root, split="train")
+ if config.get("validation") is not None:
+ dataset.validation = dataset_factory(config.root, split="validation")
+ if config.get("test") is not None:
+ dataset.test = dataset_factory(config.root, split="test")
+ if config.get("evaluation") is not None:
+ dataset.evaluation = dataset_factory(config.root, split=config.get("evaluation"))
+ if config.get("inference") is not None:
+ dataset.inference = dataset_factory(config.root, split=config.get("inference"))
+ if not dataset:
+ raise ValueError("No datasets built. This is likely due to missing data paths in Config.")
+ return dataset
+
+ def build_dataloaders(self):
+ datasets = {k: d for k, d in self.datasets.items() if k not in self.dataloaders}
+ default_kwargs = self.config.get("dataloader", NestedDict())
+ dataloader_kwargs = NestedDict({k: default_kwargs.pop(k) for k in self.datasets if k in default_kwargs})
+ for k, d in datasets.items():
+ dataloader_kwargs.setdefault(k, NestedDict())
+ dataloader_kwargs[k].merge(default_kwargs, overwrite=False)
+ batch_size = dataloader_kwargs[k].pop("batch_size")
+ shuffle = dataloader_kwargs[k].pop("shuffle", getattr(d, "train", True))
+ drop_last = dataloader_kwargs[k].pop("drop_last", not getattr(d, "train", True))
+ if isinstance(d, MultiTaskDataset):
+ batch_sampler = (
+ DistributedMultiTaskSampler(d, batch_size, shuffle=shuffle, drop_last=drop_last)
+ if self.distributed
+ else MultiTaskSampler(d, batch_size, shuffle=shuffle, drop_last=drop_last)
+ )
+ else:
+ sampler = (
+ data.distributed.DistributedSampler(d, shuffle=shuffle)
+ if self.distributed
+ else data.RandomSampler(d) if shuffle else data.SequentialSampler(d)
+ )
+ batch_sampler = data.BatchSampler(sampler, batch_size, drop_last=drop_last)
+ self.dataloaders[k] = data.DataLoader(
+ d, batch_sampler=batch_sampler, collate_fn=self.collate_fn, **dataloader_kwargs[k]
+ )
+
+ def build_metrics(self) -> MultiTaskMetrics:
+ return MultiTaskMetrics(
+ {
+ name: MetricRegistry.build(type=task.type, num_labels=task.num_labels)
+ for name, task in self.dataset_tasks.all_items()
+ }
+ )
+
+ def collate_fn(self, batch):
+ return {k: v.to(self.device) if hasattr(v, "to") else v for k, v in batch.items()}
+
+ def get_dataset_lengths(self) -> str:
+ repr = "dataset lengths:\n"
+ longest_name = max(len(name) for name in self.all_datasets.keys())
+ for name, dataset in self.all_datasets.items():
+ if isinstance(dataset, NestedDict):
+ repr += f"{name}:"
+ if len(name) < longest_name:
+ repr += " " * (longest_name - len(name))
+ repr += "\t\t"
+ for split, d in dataset.items():
+ repr += f" {split}: {len(d)}\t"
+ else:
+ repr += f"{name}: {len(dataset)}\t"
+ repr += "\n"
+ return repr
diff --git a/multimolecule/runners/config.py b/multimolecule/runners/config.py
new file mode 100644
index 00000000..0edf83a5
--- /dev/null
+++ b/multimolecule/runners/config.py
@@ -0,0 +1,83 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from __future__ import annotations
+
+import os
+from pathlib import Path
+from typing import List
+
+from chanfig import Config
+from transformers import PretrainedConfig
+
+
+class DataConfig(Config):
+ root: str = "."
+ train: str | None
+ validation: str | None
+ test: str | None
+ feature_cols: List | None = None
+ label_cols: List | None = None
+ truncation: bool = True
+
+
+class OptimConfig(Config):
+ name: str = "AdamW"
+ lr: float = 1e-3
+ weight_decay: float = 1e-2
+
+
+class MultiMoleculeConfig(Config):
+
+ balance: str = "ew"
+ platform: str = "torch"
+ training: bool = True
+
+ pretrained: str
+ use_pretrained: bool = True
+ transformers: PretrainedConfig
+ epoch_end: int = 20
+
+ data: DataConfig
+
+ tensorboard: bool = True
+ save_interval: int = 10
+ seed: int = 1016
+ art: bool = True
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.datas = Config(default_factory=DataConfig)
+ self.dataloader.batch_size = 32
+ self.optim = OptimConfig()
+ self.sched.final_lr = 0
+
+ def post(self):
+ if "data" in self:
+ if self.datas:
+ raise ValueError("Only one of `data` or `datas` can be specified, but not both")
+ del self.datas
+ self["network.backbone.sequence.name"] = self.pretrained
+ self["network.backbone.sequence.use_pretrained"] = self.use_pretrained
+ pretrained = self.pretrained
+ if os.path.exists(self.pretrained):
+ path = Path(pretrained)
+ if os.path.isfile(pretrained):
+ pretrained = str(path.relative_to(path.parents[1]).with_suffix(""))
+ else:
+ pretrained = path.stem
+
+ self.name = f"{self.pretrained}-{self.optim.lr}@{self.optim.name}-{self.seed}"
diff --git a/multimolecule/runners/metrics.py b/multimolecule/runners/metrics.py
new file mode 100644
index 00000000..da584cbc
--- /dev/null
+++ b/multimolecule/runners/metrics.py
@@ -0,0 +1,37 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from chanfig import Registry as Registry_
+from danling.metrics import binary_metrics, multiclass_metrics, multilabel_metrics, regression_metrics
+
+
+class Registry(Registry_):
+
+ def build(self, type, num_labels: int | None = None, **kwargs):
+ if type == "multilabel":
+ return self.init(self.lookup(type), num_labels=num_labels, **kwargs)
+ if type == "multiclass":
+ return self.init(self.lookup(type), num_classes=num_labels, **kwargs)
+ if type == "regression":
+ return self.init(self.lookup(type), num_outputs=num_labels, **kwargs)
+ return self.init(self.lookup(type), **kwargs)
+
+
+MetricRegistry = Registry(key="type")
+MetricRegistry.register(binary_metrics, "binary")
+MetricRegistry.register(multiclass_metrics, "multiclass")
+MetricRegistry.register(multilabel_metrics, "multilabel")
+MetricRegistry.register(regression_metrics, "regression")
diff --git a/multimolecule/runners/runner.py b/multimolecule/runners/runner.py
new file mode 100644
index 00000000..97cab4e7
--- /dev/null
+++ b/multimolecule/runners/runner.py
@@ -0,0 +1,42 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+import danling as dl
+
+from .base_runner import BaseRunner
+
+
+class MultiMoleculeRunner(type):
+ def __new__(cls, config):
+ if config.get("platform", "torch") == "torch":
+ return TorchRunner(config)
+ if config.platform == "deepspeed":
+ return DeepSpeedRunner(config)
+ if config.platform == "accelerate":
+ return AccelerateRunner(config)
+ raise ValueError(f"Unsupported platform: {config.platform}")
+
+
+class TorchRunner(BaseRunner, dl.TorchRunner):
+ pass
+
+
+class DeepSpeedRunner(BaseRunner, dl.DeepSpeedRunner):
+ pass
+
+
+class AccelerateRunner(BaseRunner, dl.AccelerateRunner):
+ pass
diff --git a/multimolecule/tasks/task.py b/multimolecule/tasks/task.py
index e2473ab0..5d435f83 100644
--- a/multimolecule/tasks/task.py
+++ b/multimolecule/tasks/task.py
@@ -34,9 +34,8 @@ class TaskType(StrEnum):
class TaskLevel(StrEnum):
Sequence = auto()
- Nucleotide = auto()
+ Token = auto()
Contact = auto()
- # Token = auto()
@dataclass
diff --git a/multimolecule/train.py b/multimolecule/train.py
new file mode 100644
index 00000000..f146bc97
--- /dev/null
+++ b/multimolecule/train.py
@@ -0,0 +1,20 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from .apis import train
+
+if __name__ == "__main__":
+ train()
diff --git a/pyproject.toml b/pyproject.toml
index 972b1d8a..ee4c525d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -45,9 +45,11 @@ dynamic = [
]
dependencies = [
"accelerate",
+ "art",
"chanfig>=0.0.105",
- "danling>=0.3.11",
+ "danling>=0.4.0b1",
"datasets",
+ 'StrEnum; python_version < "3.11"',
"torch",
"transformers",
]
diff --git a/tests/data/test_dataset.py b/tests/data/test_dataset.py
index 5ddfb3c4..318e289e 100644
--- a/tests/data/test_dataset.py
+++ b/tests/data/test_dataset.py
@@ -126,7 +126,7 @@ def test_spliceai(self, preprocess: bool):
feature_cols=feature_cols,
label_cols=label_cols,
)
- task = Task(type=TaskType.Binary, level=TaskLevel.Nucleotide, num_labels=1)
+ task = Task(type=TaskType.Binary, level=TaskLevel.Token, num_labels=1)
elem = dataset[0]
assert isinstance(elem["sequence"], torch.LongTensor)
assert isinstance(elem["splice_ai"], torch.LongTensor)
@@ -175,20 +175,18 @@ def test_rna_task_recognition_json(self):
assert dataset.tasks["sequence_regression"] == Task(
type=TaskType.Regression, level=TaskLevel.Sequence, num_labels=1
)
- assert dataset.tasks["nucleotide_binary"] == Task(
- type=TaskType.Binary, level=TaskLevel.Nucleotide, num_labels=1
- )
+ assert dataset.tasks["nucleotide_binary"] == Task(type=TaskType.Binary, level=TaskLevel.Token, num_labels=1)
assert dataset.tasks["nucleotide_multiclass"] == Task(
- type=TaskType.MultiClass, level=TaskLevel.Nucleotide, num_labels=5
+ type=TaskType.MultiClass, level=TaskLevel.Token, num_labels=5
)
assert dataset.tasks["nucleotide_multilabel"] == Task(
- type=TaskType.MultiLabel, level=TaskLevel.Nucleotide, num_labels=5
+ type=TaskType.MultiLabel, level=TaskLevel.Token, num_labels=5
)
assert dataset.tasks["nucleotide_multireg"] == Task(
- type=TaskType.Regression, level=TaskLevel.Nucleotide, num_labels=5
+ type=TaskType.Regression, level=TaskLevel.Token, num_labels=5
)
assert dataset.tasks["nucleotide_regression"] == Task(
- type=TaskType.Regression, level=TaskLevel.Nucleotide, num_labels=1
+ type=TaskType.Regression, level=TaskLevel.Token, num_labels=1
)
assert dataset.tasks["contact_binary"] == Task(type=TaskType.Binary, level=TaskLevel.Contact, num_labels=1)
assert dataset.tasks["contact_multiclass"] == Task(