diff --git a/.codespell-whitelist.txt b/.codespell-whitelist.txt
index d67fd62d..aeca612a 100644
--- a/.codespell-whitelist.txt
+++ b/.codespell-whitelist.txt
@@ -1,3 +1,4 @@
+datas
ser
marz
manuel
diff --git a/docs/docs/data/multitask.md b/docs/docs/data/multitask.md
new file mode 100644
index 00000000..054c6205
--- /dev/null
+++ b/docs/docs/data/multitask.md
@@ -0,0 +1,9 @@
+---
+authors:
+ - Zhiyuan Chen
+date: 2024-05-04
+---
+
+# MultiTask
+
+::: multimolecule.data.multitask
diff --git a/docs/docs/runners/config.md b/docs/docs/runners/config.md
new file mode 100644
index 00000000..8e188199
--- /dev/null
+++ b/docs/docs/runners/config.md
@@ -0,0 +1,9 @@
+---
+authors:
+ - Zhiyuan Chen
+date: 2024-05-04
+---
+
+# MultiMoleculeConfig
+
+::: multimolecule.runners.MultiMoleculeConfig
diff --git a/docs/docs/runners/index.md b/docs/docs/runners/index.md
new file mode 100644
index 00000000..75a0f528
--- /dev/null
+++ b/docs/docs/runners/index.md
@@ -0,0 +1,9 @@
+---
+authors:
+ - Zhiyuan Chen
+date: 2024-05-04
+---
+
+# runners
+
+--8<-- "multimolecule/runners/README.md:8:"
diff --git a/docs/docs/runners/runner.md b/docs/docs/runners/runner.md
new file mode 100644
index 00000000..2c93ce1b
--- /dev/null
+++ b/docs/docs/runners/runner.md
@@ -0,0 +1,9 @@
+---
+authors:
+ - Zhiyuan Chen
+date: 2024-05-04
+---
+
+# MultiMoleculeRunner
+
+::: multimolecule.runners.MultiMoleculeRunner
diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml
index 53875f32..5722355e 100644
--- a/docs/mkdocs.yml
+++ b/docs/mkdocs.yml
@@ -9,9 +9,14 @@ repo_url: https://github.com/DLS5-Omics/multimolecule
nav:
- index.md
+ - runners:
+ - runners/index.md
+ - MultiMoleculeRunner: runners/runner.md
+ - MultiMoleculeConfig: runners/config.md
- data:
- data/index.md
- Dataset: data/dataset.md
+ - multitask: data/multitask.md
- datasets:
- datasets/index.md
- DNA:
diff --git a/multimolecule/__init__.py b/multimolecule/__init__.py
index c168ce48..31319225 100644
--- a/multimolecule/__init__.py
+++ b/multimolecule/__init__.py
@@ -14,6 +14,7 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+from .apis import evaluate, inference, train
from .data import Dataset
from .models import (
AutoModelForContactPrediction,
@@ -130,14 +131,20 @@
TokenKMerHead,
TokenPredictionHead,
)
+from .runners import MultiMoleculeConfig, MultiMoleculeRunner
from .tasks import Task, TaskLevel, TaskType
from .tokenisers import Alphabet, DnaTokenizer, DotBracketTokenizer, ProteinTokenizer, RnaTokenizer, Tokenizer
from .utils import count_parameters
__all__ = [
+ "train",
+ "evaluate",
+ "inference",
"modeling_auto",
"modeling_outputs",
"Dataset",
+ "MultiMoleculeConfig",
+ "MultiMoleculeRunner",
"PreTrainedConfig",
"HeadConfig",
"BaseHeadConfig",
diff --git a/multimolecule/apis/__init__.py b/multimolecule/apis/__init__.py
new file mode 100644
index 00000000..3da34480
--- /dev/null
+++ b/multimolecule/apis/__init__.py
@@ -0,0 +1,19 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from .run import evaluate, inference, train
+
+__all__ = ["train", "evaluate", "inference"]
diff --git a/multimolecule/apis/run.py b/multimolecule/apis/run.py
new file mode 100644
index 00000000..e81da5bf
--- /dev/null
+++ b/multimolecule/apis/run.py
@@ -0,0 +1,110 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+import atexit
+import os
+import warnings
+from typing import Type
+
+import danling as dl
+import torch
+
+from multimolecule.runners import MultiMoleculeConfig, MultiMoleculeRunner
+
+try:
+ import nni
+except ImportError:
+ nni = None
+
+
+def train(
+ config: MultiMoleculeConfig = None, # type: ignore
+ runner_cls: Type[MultiMoleculeRunner] = MultiMoleculeRunner,
+):
+ if config is None:
+ config = MultiMoleculeConfig()
+ config = config.parse(default_config="config", no_default_config_action="warn")
+ config.interpolate(unsafe_eval=True)
+ if config.allow_tf32:
+ torch.backends.cudnn.allow_tf32 = True
+ torch.backends.cuda.matmul.allow_tf32 = True
+ if config.reduced_precision_reduction:
+ torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
+ torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True
+ if config.get("nni", False):
+ if nni is None:
+ raise ValueError("Unable to retrieve nni parameters, since nni is not installed.")
+ config.merge(nni.get_next_parameter())
+ with dl.debug(config.get("debug", False)):
+ runner = runner_cls(config)
+ atexit.register(runner.print_result)
+ atexit.register(runner.save_result)
+ atexit.register(runner.save_checkpoint)
+ result = runner.train()
+ return result
+
+
+def evaluate(
+ config: MultiMoleculeConfig = None, # type: ignore
+ runner_cls: Type[MultiMoleculeRunner] = MultiMoleculeRunner,
+):
+ if config is None:
+ config = MultiMoleculeConfig.empty()
+ config = config.parse(default_config="config", no_default_config_action="warn")
+ config.interpolate(unsafe_eval=True)
+ if config.allow_tf32:
+ torch.backends.cudnn.allow_tf32 = True
+ torch.backends.cuda.matmul.allow_tf32 = True
+ if config.reduced_precision_reduction:
+ torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
+ torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True
+ if "checkpoint" not in config or not isinstance(config.checkpoint, str):
+ raise RuntimeError("Please specify `checkpoint` to run evaluate")
+ for name, data in config.datas.items():
+ if "eval" not in data or not isinstance(data.eval, str):
+ raise RuntimeError(f"Please specify `eval` to run evaluate in datas.{name}")
+ runner = runner_cls(config)
+ result = runner.evaluate_epoch("eval")
+ print(result)
+ return result
+
+
+def inference(
+ config: MultiMoleculeConfig = None, # type: ignore
+ runner_cls: Type[MultiMoleculeRunner] = MultiMoleculeRunner,
+):
+ if config is None:
+ config = MultiMoleculeConfig.empty()
+ config = config.parse(default_config="config", no_default_config_action="warn")
+ config.interpolate(unsafe_eval=True)
+ if config.allow_tf32:
+ torch.backends.cudnn.allow_tf32 = True
+ torch.backends.cuda.matmul.allow_tf32 = True
+ if config.reduced_precision_reduction:
+ torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
+ torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True
+ if "checkpoint" not in config or not isinstance(config.checkpoint, str):
+ raise RuntimeError("Please specify `checkpoint` to run infer.")
+ for name, data in config.datas.items():
+ if "inf" not in data or not isinstance(data.inf, str):
+ raise RuntimeError(f"Please specify `inf` to run infer in datas.{name}")
+ if "result_path" not in config or not isinstance(config.result_path, str):
+ config.result_path = os.path.join(os.getcwd(), "result.json")
+ warnings.warn("`result_path` is not specified, default to `result.json`.", RuntimeWarning, stacklevel=2)
+ runner = runner_cls(config)
+ result = runner.inference()
+ runner.save(result, config.result_path)
+ return result
diff --git a/multimolecule/apis/stat.py b/multimolecule/apis/stat.py
new file mode 100644
index 00000000..5e525d55
--- /dev/null
+++ b/multimolecule/apis/stat.py
@@ -0,0 +1,99 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+import os
+import shutil
+from statistics import mean
+from typing import List
+
+import chanfig
+import pandas as pd
+from chanfig import NestedDict
+from tqdm import tqdm
+
+
+class Result(NestedDict):
+ pretrained: str
+ id: str
+ seed: int
+ epoch: int
+ validation: NestedDict
+ test: NestedDict
+
+
+def get_result_stat(experiment_root: str, remove_empty: bool = True) -> List[Result]:
+ results = []
+ for root, _, files in tqdm(os.walk(experiment_root)):
+ if "run.log" in files:
+ if "best.json" not in files:
+ if remove_empty:
+ shutil.rmtree(root)
+ continue
+ best = NestedDict.from_json(os.path.join(root, "best.json"))
+ if "index" not in best:
+ if remove_empty:
+ shutil.rmtree(root)
+ continue
+ config = NestedDict.from_yaml(os.path.join(root, "trainer.yaml"))
+ pretrained = config.pretrained.split("/")[-1]
+ seed = config.seed
+ pretrained, seed = "", 1
+ result = Result(id=best.id, pretrained=pretrained, seed=seed)
+ result.validation = NestedDict(
+ {k: format(mean(v) if isinstance(v, list) else v, ".8f") for k, v in best.validation.items()}
+ )
+ result.test = NestedDict(
+ {k: format(mean(v) if isinstance(v, list) else v, ".8f") for k, v in best.test.items()}
+ )
+ result.epoch = best.index
+ result.pop("validation.time", None)
+ result.pop("test.time", None)
+ result.pop("validation.loss", None)
+ result.pop("test.loss", None)
+ result.pop("validation.lr", None)
+ result.pop("test.lr", None)
+ results.append(result)
+ # Remove empty directories, perform twice to remove all empty directories
+ if remove_empty:
+ for root, dirs, files in os.walk(experiment_root):
+ if not files and not dirs:
+ os.rmdir(root)
+ for root, dirs, files in os.walk(experiment_root):
+ if not files and not dirs:
+ os.rmdir(root)
+ results.sort(key=lambda x: (x.pretrained, x.seed, x.id))
+ return results
+
+
+def write_result_stat(results: List[Result], path: str):
+ results = [dict(result.all_items()) for result in results] # type: ignore[misc]
+ df = pd.DataFrame.from_dict(results)
+ df.insert(len(df.keys()) - 1, "comment", "")
+ df.fillna("")
+ df.to_csv(path, index=False)
+
+
+class Config(chanfig.Config):
+ experiment_root: str = "experiments"
+ out_path: str = "result.csv"
+
+
+if __name__ == "__main__":
+ config = Config().parse()
+ result_stat = get_result_stat(config.experiment_root)
+ if not len(result_stat) > 0:
+ raise ValueError("No results found")
+ write_result_stat(result_stat, config.out_path)
diff --git a/multimolecule/data/__init__.py b/multimolecule/data/__init__.py
index 62196c10..f2366d77 100644
--- a/multimolecule/data/__init__.py
+++ b/multimolecule/data/__init__.py
@@ -15,6 +15,14 @@
# along with this program. If not, see .
from .dataset import Dataset
+from .multitask import DistributedMultiTaskSampler, MultiTaskDataset, MultiTaskSampler
from .utils import no_collate
-__all__ = ["Dataset", "no_collate"]
+__all__ = [
+ "Dataset",
+ "PandasDataset",
+ "MultiTaskDataset",
+ "MultiTaskSampler",
+ "DistributedMultiTaskSampler",
+ "no_collate",
+]
diff --git a/multimolecule/data/dataset.py b/multimolecule/data/dataset.py
index 54565349..5023d6bb 100644
--- a/multimolecule/data/dataset.py
+++ b/multimolecule/data/dataset.py
@@ -150,7 +150,7 @@ def __init__(
ignored_cols: List[str] | None = None,
):
if tasks is not None:
- self._tasks = NestedDict(tasks)
+ self.tasks = tasks
if discrete_map is not None:
self._discrete_map = discrete_map
arrow_table = self.build_table(
@@ -297,20 +297,20 @@ def collate(self, col: str, data: Any) -> Tensor | NestedTensor | None:
except ValueError:
return NestedTensor(data)
- def infer_tasks(self, tasks: Mapping | None = None, sequence_col: str | None = None) -> NestedDict:
- self._tasks = tasks or NestedDict()
+ def infer_tasks(self, sequence_col: str | None = None) -> NestedDict:
for col in self.label_cols:
- if col not in self.tasks:
- if col in self.secondary_structure_cols:
- task = Task(TaskType.Binary, level=TaskLevel.Contact, num_labels=1)
- self._tasks[col] = task # type: ignore[index]
- warn(
- f"Secondary structure columns are assumed to be {task}."
- " Please explicitly specify the task if this is not the case."
- )
- else:
- self._tasks[col] = self.infer_task(col, sequence_col) # type: ignore[index]
- return self._tasks
+ if col in self.tasks:
+ continue
+ if col in self.secondary_structure_cols:
+ task = Task(TaskType.Binary, level=TaskLevel.Contact, num_labels=1)
+ self.tasks[col] = task # type: ignore[index]
+ warn(
+ f"Secondary structure columns are assumed to be {task}. "
+ "Please explicitly specify the task if this is not the case."
+ )
+ else:
+ self.tasks[col] = self.infer_task(col, sequence_col) # type: ignore[index]
+ return self.tasks
def infer_task(self, label_col: str, sequence_col: str | None = None) -> Task:
if sequence_col is None:
@@ -404,7 +404,7 @@ def rename_columns(self, column_mapping: Mapping[str, str], new_fingerprint: str
self._label_cols = [column_mapping.get(i, i) for i in self.label_cols]
self._sequence_cols = [column_mapping.get(i, i) for i in self.sequence_cols]
self._secondary_structure_cols = [column_mapping.get(i, i) for i in self.secondary_structure_cols]
- self._tasks = {column_mapping.get(k, k): v for k, v in self.tasks.items()}
+ self.tasks = {column_mapping.get(k, k): v for k, v in self.tasks.items()}
return self
def rename_column(
@@ -418,7 +418,7 @@ def rename_column(
self._secondary_structure_cols = [
new_column_name if i == original_column_name else i for i in self.secondary_structure_cols
]
- self._tasks = {new_column_name if k == original_column_name else k: v for k, v in self.tasks.items()}
+ self.tasks = {new_column_name if k == original_column_name else k: v for k, v in self.tasks.items()}
return self
def process_nan(self, data: Table, nan_process: str | None, fill_value: str | int | float = 0) -> Table:
@@ -473,6 +473,14 @@ def tasks(self) -> NestedDict:
return self.infer_tasks()
return self._tasks
+ @tasks.setter
+ def tasks(self, tasks: Mapping):
+ self._tasks = NestedDict()
+ for name, task in tasks.items():
+ if not isinstance(task, Task):
+ task = Task(**task)
+ self._tasks[name] = task
+
@property
def discrete_map(self) -> Mapping:
if not hasattr(self, "_discrete_map"):
diff --git a/multimolecule/data/multitask.py b/multimolecule/data/multitask.py
new file mode 100644
index 00000000..0d812e4b
--- /dev/null
+++ b/multimolecule/data/multitask.py
@@ -0,0 +1,191 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from __future__ import annotations
+
+from bisect import bisect_right
+from collections.abc import Mapping, Sequence
+from copy import deepcopy
+from random import choices
+
+from chanfig import NestedDict
+from torch import distributed as dist
+from torch.utils import data
+
+from .dataset import Dataset
+
+
+class MultiTaskDataset(data.ConcatDataset):
+
+ datasets: Mapping
+ dataset_keys: Sequence[str]
+ dataset_values: Sequence[Dataset]
+
+ def __init__(self, datasets: Mapping) -> None:
+ for key, dataset in datasets.items():
+ if not isinstance(dataset, Dataset):
+ raise TypeError(f"Dataset {key} should be an instance of Dataset")
+ self.datasets = datasets
+ if not len(self.datasets) > 0:
+ raise ValueError("MultiTaskDataset should contain at least one dataset")
+ self.dataset_keys, self.dataset_values = zip(*self.datasets.items())
+ self.cumulative_sizes = self.cumsum(self.dataset_values)
+
+ def __getitems__(self, key: Sequence[int]) -> Mapping:
+ dataset_idx = bisect_right(self.cumulative_sizes, key[0])
+ if dataset_idx == 0:
+ sample_idx = key
+ else:
+ sample_idx = [i - self.cumulative_sizes[dataset_idx - 1] for i in key]
+ batch = self.dataset_values[dataset_idx][sample_idx]
+ batch["dataset"] = self.dataset_keys[dataset_idx]
+ return batch
+
+ @property
+ def tasks(self) -> NestedDict:
+ tasks = NestedDict()
+ for dataset in self.dataset_values:
+ for n, t in dataset.tasks.items():
+ if n not in tasks:
+ tasks[n] = t
+ elif tasks[n] != t:
+ raise ValueError(f"Task {n} has different configurations across datasets")
+ return tasks
+
+ @property
+ def dataset_tasks(self) -> NestedDict:
+ return NestedDict({k: v.tasks for k, v in self.datasets.items()})
+
+ def __repr__(self) -> str:
+ return f"MultiTaskDataset({', '.join([str(d) for d in self.datasets])})"
+
+
+class MultiTaskSampler(data.BatchSampler):
+ r"""
+ Ensure all items in a batch comes from the same dataset.
+
+ Arguments:
+ sampler (Sampler): Base sampler.
+ batch_size (int): Size of mini-batch.
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
+ its size would be less than ``batch_size``
+ """
+
+ datasets: Sequence[Dataset]
+
+ def __init__( # pylint: disable=super-init-not-called
+ self,
+ dataset: MultiTaskDataset,
+ batch_size: int,
+ shuffle: bool = True,
+ drop_last: bool = False,
+ sampler_cls: type[data.Sampler] | None = None,
+ weights: list[int] | None = None,
+ ) -> None:
+ self.datasets = dataset.dataset_values
+ self.batch_size = batch_size
+ self.drop_last = drop_last
+ self.shuffle = shuffle
+ if sampler_cls is None:
+ sampler_cls = data.RandomSampler if shuffle else data.SequentialSampler
+ self.samplers = [sampler_cls(d) for d in self.datasets] # type: ignore
+ self.dataset_sizes = [len(d) for d in self.datasets] # type: ignore
+ self.cumulative_sizes = dataset.cumulative_sizes
+ self.num_datasets = len(self.datasets)
+ self.weights = weights if weights is not None else self.dataset_sizes
+
+ def __iter__(self):
+ sampler_iters = [(i, iter(s)) for i, s in enumerate(self.samplers)]
+ sampler_weights = deepcopy(self.weights)
+ sampler_idx = 0
+ # Implemented based on the benchmarking in https://github.com/pytorch/pytorch/pull/76951
+ if self.drop_last:
+ while sampler_iters:
+ if self.shuffle:
+ sampler_idx = choices(range(len(sampler_iters)), weights=sampler_weights)[0]
+ sampler_id, sampler_iter = sampler_iters[sampler_idx]
+ cumulative_size = self.cumulative_sizes[sampler_id - 1] if sampler_id > 0 else 0
+ try:
+ batch = [next(sampler_iter) + cumulative_size for _ in range(self.batch_size)]
+ yield batch
+ except StopIteration:
+ sampler_iters.pop(sampler_idx)
+ sampler_weights.pop(sampler_idx)
+ else:
+ while sampler_iters:
+ if self.shuffle:
+ sampler_idx = choices(range(len(sampler_iters)), weights=sampler_weights)[0]
+ sampler_id, sampler_iter = sampler_iters[sampler_idx]
+ cumulative_size = self.cumulative_sizes[sampler_id - 1] if sampler_id > 0 else 0
+ batch = [0] * self.batch_size
+ idx_in_batch = 0
+ try:
+ for _ in range(self.batch_size):
+ batch[idx_in_batch] = next(sampler_iter) + cumulative_size
+ idx_in_batch += 1
+ yield batch
+ idx_in_batch = 0 # noqa: SIM113
+ batch = [0] * self.batch_size
+ except StopIteration:
+ sampler_iters.pop(sampler_idx)
+ sampler_weights.pop(sampler_idx)
+ if idx_in_batch > 0:
+ yield batch[:idx_in_batch]
+
+ def __len__(self):
+ batch_size = self.batch_size
+ if self.drop_last:
+ return sum(len(d) // batch_size for d in self.datasets)
+ return sum((len(d) + batch_size - 1) // batch_size for d in self.datasets)
+
+
+class DistributedMultiTaskSampler(MultiTaskSampler): # pylint: disable=too-few-public-methods
+ r"""
+ Distributed version of MultiTaskSampler, which ensures that each batch contains
+ data from only one dataset.
+
+ See Also:
+ [MultiTaskSampler][MultiTaskSampler]
+ """
+
+ def __init__(
+ self,
+ dataset: MultiTaskDataset,
+ batch_size: int,
+ shuffle: bool = True,
+ drop_last: bool = False,
+ sampler_cls: type[data.Sampler] = data.RandomSampler,
+ weights: list[int] | None = None,
+ ) -> None:
+ super().__init__(dataset, batch_size, shuffle, drop_last, sampler_cls, weights)
+ self.samplers = [data.DistributedSampler(d, shuffle=shuffle, drop_last=drop_last) for d in self.datasets]
+
+ def set_epoch(self, epoch):
+ for s in self.samplers:
+ s.set_epoch(epoch)
+
+ def __len__(self):
+ batch_size = self.batch_size * self.world_size
+ if self.drop_last:
+ return sum(len(d) // batch_size for d in self.datasets)
+ return sum((len(d) + batch_size - 1) // batch_size for d in self.datasets)
+
+ @property
+ def world_size(self) -> int:
+ r"""Return the number of processes in the current process group."""
+ if dist.is_available() and dist.is_initialized():
+ return dist.get_world_size()
+ return 1
diff --git a/multimolecule/models/__init__.py b/multimolecule/models/__init__.py
index 66147616..29d99436 100644
--- a/multimolecule/models/__init__.py
+++ b/multimolecule/models/__init__.py
@@ -14,6 +14,7 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+from multimolecule.module import HeadConfig
from multimolecule.tokenisers import DnaTokenizer, ProteinTokenizer, RnaTokenizer
from .calm import (
@@ -127,6 +128,7 @@
__all__ = [
"PreTrainedConfig",
+ "HeadConfig",
"DnaTokenizer",
"RnaTokenizer",
"ProteinTokenizer",
diff --git a/multimolecule/models/calm/configuration_calm.py b/multimolecule/models/calm/configuration_calm.py
index 66110176..a9540346 100644
--- a/multimolecule/models/calm/configuration_calm.py
+++ b/multimolecule/models/calm/configuration_calm.py
@@ -130,5 +130,5 @@ def __init__(
self.use_cache = use_cache
self.emb_layer_norm_before = emb_layer_norm_before
self.token_dropout = token_dropout
- self.head = HeadConfig(**head if head is not None else {})
- self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {})
+ self.head = HeadConfig(**head) if head is not None else None
+ self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None
diff --git a/multimolecule/models/calm/modeling_calm.py b/multimolecule/models/calm/modeling_calm.py
index 2494a3b5..7db38f19 100644
--- a/multimolecule/models/calm/modeling_calm.py
+++ b/multimolecule/models/calm/modeling_calm.py
@@ -268,9 +268,9 @@ class CaLmForContactPrediction(CaLmPreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.randint(2, (1, 5, 5)))
>>> output["logits"].shape
- torch.Size([1, 5, 5, 2])
+ torch.Size([1, 5, 5, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: CaLmConfig):
@@ -332,11 +332,11 @@ class CaLmForNucleotidePrediction(CaLmPreTrainedModel):
>>> model = CaLmForNucleotidePrediction(config)
>>> tokenizer = RnaTokenizer.from_pretrained("multimolecule/rna")
>>> input = tokenizer("ACGUN", return_tensors="pt")
- >>> output = model(**input, labels=torch.randn(1, 5, 2))
+ >>> output = model(**input, labels=torch.randn(1, 5))
>>> output["logits"].shape
- torch.Size([1, 5, 2])
+ torch.Size([1, 5, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: CaLmConfig):
@@ -398,9 +398,9 @@ class CaLmForSequencePrediction(CaLmPreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.tensor([[1]]))
>>> output["logits"].shape
- torch.Size([1, 2])
+ torch.Size([1, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: CaLmConfig):
@@ -462,9 +462,9 @@ class CaLmForTokenPrediction(CaLmPreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.randint(2, (1, 7)))
>>> output["logits"].shape
- torch.Size([1, 7, 2])
+ torch.Size([1, 7, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: CaLmConfig):
diff --git a/multimolecule/models/configuration_utils.py b/multimolecule/models/configuration_utils.py
index 2047d671..ce6f10ea 100644
--- a/multimolecule/models/configuration_utils.py
+++ b/multimolecule/models/configuration_utils.py
@@ -30,7 +30,8 @@ class PreTrainedConfig(PretrainedConfig):
Base class for all model configuration classes.
"""
- head: HeadConfig
+ head: HeadConfig | None
+ num_labels: int = 1
hidden_size: int
@@ -42,7 +43,15 @@ class PreTrainedConfig(PretrainedConfig):
null_token_id: int = 5
def __init__(
- self, pad_token_id=0, bos_token_id=1, eos_token_id=2, unk_token_id=3, mask_token_id=4, null_token_id=5, **kwargs
+ self,
+ pad_token_id: int = 0,
+ bos_token_id: int = 1,
+ eos_token_id: int = 2,
+ unk_token_id: int = 3,
+ mask_token_id: int = 4,
+ null_token_id: int = 5,
+ num_labels: int = 1,
+ **kwargs,
):
super().__init__(
pad_token_id=pad_token_id,
@@ -51,6 +60,7 @@ def __init__(
unk_token_id=unk_token_id,
mask_token_id=mask_token_id,
null_token_id=null_token_id,
+ num_labels=num_labels,
**kwargs,
)
diff --git a/multimolecule/models/ernierna/configuration_ernierna.py b/multimolecule/models/ernierna/configuration_ernierna.py
index c2b25364..5bda885b 100644
--- a/multimolecule/models/ernierna/configuration_ernierna.py
+++ b/multimolecule/models/ernierna/configuration_ernierna.py
@@ -113,5 +113,5 @@ def __init__(
self.pairwise_alpha = pairwise_alpha
self.is_decoder = is_decoder
self.use_cache = use_cache
- self.head = HeadConfig(**head if head is not None else {})
- self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {})
+ self.head = HeadConfig(**head) if head is not None else None
+ self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None
diff --git a/multimolecule/models/ernierna/modeling_ernierna.py b/multimolecule/models/ernierna/modeling_ernierna.py
index 34828e2d..61162632 100644
--- a/multimolecule/models/ernierna/modeling_ernierna.py
+++ b/multimolecule/models/ernierna/modeling_ernierna.py
@@ -314,9 +314,9 @@ class ErnieRnaForContactPrediction(ErnieRnaPreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.randint(2, (1, 5, 5)))
>>> output["logits"].shape
- torch.Size([1, 5, 5, 2])
+ torch.Size([1, 5, 5, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: ErnieRnaConfig):
@@ -378,11 +378,11 @@ class ErnieRnaForNucleotidePrediction(ErnieRnaPreTrainedModel):
>>> model = ErnieRnaForNucleotidePrediction(config)
>>> tokenizer = RnaTokenizer()
>>> input = tokenizer("ACGUN", return_tensors="pt")
- >>> output = model(**input, labels=torch.randn(1, 5, 2))
+ >>> output = model(**input, labels=torch.randn(1, 5))
>>> output["logits"].shape
- torch.Size([1, 5, 2])
+ torch.Size([1, 5, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: ErnieRnaConfig):
@@ -446,7 +446,7 @@ class ErnieRnaForSequencePrediction(ErnieRnaPreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input)
>>> output["logits"].shape
- torch.Size([1, 2])
+ torch.Size([1, 1])
"""
def __init__(self, config: ErnieRnaConfig):
@@ -510,9 +510,9 @@ class ErnieRnaForTokenPrediction(ErnieRnaPreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.randint(2, (1, 7)))
>>> output["logits"].shape
- torch.Size([1, 7, 2])
+ torch.Size([1, 7, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: ErnieRnaConfig):
@@ -1231,7 +1231,7 @@ class ErnieRnaContactClassificationHead(nn.Module):
def __init__(self, config: ErnieRnaConfig, head_config: HeadConfig | None = None):
super().__init__()
if head_config is None:
- head_config = config.head
+ head_config = config.head or HeadConfig()
self.config = head_config
self.bos_token_id = config.bos_token_id
self.eos_token_id = config.eos_token_id
diff --git a/multimolecule/models/rinalmo/configuration_rinalmo.py b/multimolecule/models/rinalmo/configuration_rinalmo.py
index fddfff85..486cd9b9 100644
--- a/multimolecule/models/rinalmo/configuration_rinalmo.py
+++ b/multimolecule/models/rinalmo/configuration_rinalmo.py
@@ -128,6 +128,6 @@ def __init__(
self.use_cache = use_cache
self.learnable_beta = learnable_beta
self.token_dropout = token_dropout
- self.head = HeadConfig(**head if head is not None else {})
- self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {})
+ self.head = HeadConfig(**head) if head is not None else None
+ self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None
self.emb_layer_norm_before = emb_layer_norm_before
diff --git a/multimolecule/models/rinalmo/modeling_rinalmo.py b/multimolecule/models/rinalmo/modeling_rinalmo.py
index c262a0cc..947cdbce 100644
--- a/multimolecule/models/rinalmo/modeling_rinalmo.py
+++ b/multimolecule/models/rinalmo/modeling_rinalmo.py
@@ -267,9 +267,9 @@ class RiNALMoForContactPrediction(RiNALMoPreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.randint(2, (1, 5, 5)))
>>> output["logits"].shape
- torch.Size([1, 5, 5, 2])
+ torch.Size([1, 5, 5, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: RiNALMoConfig):
@@ -331,11 +331,11 @@ class RiNALMoForNucleotidePrediction(RiNALMoPreTrainedModel):
>>> model = RiNALMoForNucleotidePrediction(config)
>>> tokenizer = RnaTokenizer.from_pretrained("multimolecule/rna")
>>> input = tokenizer("ACGUN", return_tensors="pt")
- >>> output = model(**input, labels=torch.randn(1, 5, 2))
+ >>> output = model(**input, labels=torch.randn(1, 5))
>>> output["logits"].shape
- torch.Size([1, 5, 2])
+ torch.Size([1, 5, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: RiNALMoConfig):
@@ -397,9 +397,9 @@ class RiNALMoForSequencePrediction(RiNALMoPreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.tensor([[1]]))
>>> output["logits"].shape
- torch.Size([1, 2])
+ torch.Size([1, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: RiNALMoConfig):
@@ -461,9 +461,9 @@ class RiNALMoForTokenPrediction(RiNALMoPreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.randint(2, (1, 7)))
>>> output["logits"].shape
- torch.Size([1, 7, 2])
+ torch.Size([1, 7, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: RiNALMoConfig):
diff --git a/multimolecule/models/rnabert/configuration_rnabert.py b/multimolecule/models/rnabert/configuration_rnabert.py
index 473c285e..e16013c1 100644
--- a/multimolecule/models/rnabert/configuration_rnabert.py
+++ b/multimolecule/models/rnabert/configuration_rnabert.py
@@ -115,5 +115,5 @@ def __init__(
self.position_embedding_type = position_embedding_type
self.is_decoder = is_decoder
self.use_cache = use_cache
- self.head = HeadConfig(**head if head is not None else {})
- self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {})
+ self.head = HeadConfig(**head) if head is not None else None
+ self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None
diff --git a/multimolecule/models/rnabert/modeling_rnabert.py b/multimolecule/models/rnabert/modeling_rnabert.py
index c68b97ac..71bfb673 100644
--- a/multimolecule/models/rnabert/modeling_rnabert.py
+++ b/multimolecule/models/rnabert/modeling_rnabert.py
@@ -39,6 +39,7 @@
from multimolecule.module import (
ContactPredictionHead,
+ HeadConfig,
MaskedLMHead,
NucleotidePredictionHead,
SequencePredictionHead,
@@ -269,9 +270,9 @@ class RnaBertForContactPrediction(RnaBertPreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.randint(2, (1, 5, 5)))
>>> output["logits"].shape
- torch.Size([1, 5, 5, 2])
+ torch.Size([1, 5, 5, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: RnaBertConfig):
@@ -333,11 +334,11 @@ class RnaBertForNucleotidePrediction(RnaBertPreTrainedModel):
>>> model = RnaBertForNucleotidePrediction(config)
>>> tokenizer = RnaTokenizer.from_pretrained("multimolecule/rna")
>>> input = tokenizer("ACGUN", return_tensors="pt")
- >>> output = model(**input, labels=torch.randn(1, 5, 2))
+ >>> output = model(**input, labels=torch.randn(1, 5))
>>> output["logits"].shape
- torch.Size([1, 5, 2])
+ torch.Size([1, 5, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: RnaBertConfig):
@@ -399,9 +400,9 @@ class RnaBertForSequencePrediction(RnaBertPreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.tensor([[1]]))
>>> output["logits"].shape
- torch.Size([1, 2])
+ torch.Size([1, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: RnaBertConfig):
@@ -463,9 +464,9 @@ class RnaBertForTokenPrediction(RnaBertPreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.randint(2, (1, 7)))
>>> output["logits"].shape
- torch.Size([1, 7, 2])
+ torch.Size([1, 7, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: RnaBertConfig):
@@ -1121,7 +1122,7 @@ def __init__(self, config: RnaBertConfig):
vocab_size, config.vocab_size = config.vocab_size, config.ss_vocab_size
self.predictions_ss = MaskedLMHead(config)
config.vocab_size = vocab_size
- self.seq_relationship = SequencePredictionHead(config)
+ self.seq_relationship = SequencePredictionHead(config, HeadConfig(num_labels=2))
def forward(
self,
diff --git a/multimolecule/models/rnaernie/configuration_rnaernie.py b/multimolecule/models/rnaernie/configuration_rnaernie.py
index 96069bae..b7076190 100644
--- a/multimolecule/models/rnaernie/configuration_rnaernie.py
+++ b/multimolecule/models/rnaernie/configuration_rnaernie.py
@@ -111,5 +111,5 @@ def __init__(
self.position_embedding_type = position_embedding_type
self.is_decoder = is_decoder
self.use_cache = use_cache
- self.head = HeadConfig(**head if head is not None else {})
- self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {})
+ self.head = HeadConfig(**head) if head is not None else None
+ self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None
diff --git a/multimolecule/models/rnaernie/modeling_rnaernie.py b/multimolecule/models/rnaernie/modeling_rnaernie.py
index 8e8c4336..a69b9fcf 100644
--- a/multimolecule/models/rnaernie/modeling_rnaernie.py
+++ b/multimolecule/models/rnaernie/modeling_rnaernie.py
@@ -273,9 +273,9 @@ class RnaErnieForContactPrediction(RnaErniePreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.randint(2, (1, 5, 5)))
>>> output["logits"].shape
- torch.Size([1, 5, 5, 2])
+ torch.Size([1, 5, 5, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: RnaErnieConfig):
@@ -337,11 +337,11 @@ class RnaErnieForNucleotidePrediction(RnaErniePreTrainedModel):
>>> model = RnaErnieForNucleotidePrediction(config)
>>> tokenizer = RnaTokenizer.from_pretrained("multimolecule/rna")
>>> input = tokenizer("ACGUN", return_tensors="pt")
- >>> output = model(**input, labels=torch.randn(1, 5, 2))
+ >>> output = model(**input, labels=torch.randn(1, 5))
>>> output["logits"].shape
- torch.Size([1, 5, 2])
+ torch.Size([1, 5, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: RnaErnieConfig):
@@ -403,9 +403,9 @@ class RnaErnieForSequencePrediction(RnaErniePreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.tensor([[1]]))
>>> output["logits"].shape
- torch.Size([1, 2])
+ torch.Size([1, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config):
@@ -467,9 +467,9 @@ class RnaErnieForTokenPrediction(RnaErniePreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.randint(2, (1, 7)))
>>> output["logits"].shape
- torch.Size([1, 7, 2])
+ torch.Size([1, 7, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: RnaErnieConfig):
diff --git a/multimolecule/models/rnafm/configuration_rnafm.py b/multimolecule/models/rnafm/configuration_rnafm.py
index cf6cff48..1a92451d 100644
--- a/multimolecule/models/rnafm/configuration_rnafm.py
+++ b/multimolecule/models/rnafm/configuration_rnafm.py
@@ -134,5 +134,5 @@ def __init__(
self.use_cache = use_cache
self.emb_layer_norm_before = emb_layer_norm_before
self.token_dropout = token_dropout
- self.head = HeadConfig(**head if head is not None else {})
- self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {})
+ self.head = HeadConfig(**head) if head is not None else None
+ self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None
diff --git a/multimolecule/models/rnafm/modeling_rnafm.py b/multimolecule/models/rnafm/modeling_rnafm.py
index ebfb0201..9e0763e4 100644
--- a/multimolecule/models/rnafm/modeling_rnafm.py
+++ b/multimolecule/models/rnafm/modeling_rnafm.py
@@ -270,9 +270,9 @@ class RnaFmForContactPrediction(RnaFmPreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.randint(2, (1, 5, 5)))
>>> output["logits"].shape
- torch.Size([1, 5, 5, 2])
+ torch.Size([1, 5, 5, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: RnaFmConfig):
@@ -334,11 +334,11 @@ class RnaFmForNucleotidePrediction(RnaFmPreTrainedModel):
>>> model = RnaFmForNucleotidePrediction(config)
>>> tokenizer = RnaTokenizer.from_pretrained("multimolecule/rna")
>>> input = tokenizer("ACGUN", return_tensors="pt")
- >>> output = model(**input, labels=torch.randn(1, 5, 2))
+ >>> output = model(**input, labels=torch.randn(1, 5))
>>> output["logits"].shape
- torch.Size([1, 5, 2])
+ torch.Size([1, 5, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: RnaFmConfig):
@@ -400,9 +400,9 @@ class RnaFmForSequencePrediction(RnaFmPreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.tensor([[1]]))
>>> output["logits"].shape
- torch.Size([1, 2])
+ torch.Size([1, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: RnaFmConfig):
@@ -464,9 +464,9 @@ class RnaFmForTokenPrediction(RnaFmPreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.randint(2, (1, 7)))
>>> output["logits"].shape
- torch.Size([1, 7, 2])
+ torch.Size([1, 7, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: RnaFmConfig):
@@ -606,7 +606,7 @@ class RnaFmForPreTraining(RnaFmPreTrainedModel):
>>> output["logits"].shape
torch.Size([1, 7, 26])
>>> output["contact_map"].shape
- torch.Size([1, 5, 5, 2])
+ torch.Size([1, 5, 5, 1])
"""
_tied_weights_keys = [
diff --git a/multimolecule/models/rnamsm/configuration_rnamsm.py b/multimolecule/models/rnamsm/configuration_rnamsm.py
index 678e1b5a..4ae62a95 100644
--- a/multimolecule/models/rnamsm/configuration_rnamsm.py
+++ b/multimolecule/models/rnamsm/configuration_rnamsm.py
@@ -119,5 +119,5 @@ def __init__(
self.attention_type = attention_type
self.embed_positions_msa = embed_positions_msa
self.attention_bias = attention_bias
- self.head = HeadConfig(**head if head is not None else {})
- self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {})
+ self.head = HeadConfig(**head) if head is not None else None
+ self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None
diff --git a/multimolecule/models/rnamsm/modeling_rnamsm.py b/multimolecule/models/rnamsm/modeling_rnamsm.py
index a1636d9b..b1adcfca 100644
--- a/multimolecule/models/rnamsm/modeling_rnamsm.py
+++ b/multimolecule/models/rnamsm/modeling_rnamsm.py
@@ -182,9 +182,9 @@ class RnaMsmForContactPrediction(RnaMsmPreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.randint(2, (1, 5, 5)))
>>> output["logits"].shape
- torch.Size([1, 5, 5, 2])
+ torch.Size([1, 5, 5, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: RnaMsmConfig):
@@ -248,11 +248,11 @@ class RnaMsmForNucleotidePrediction(RnaMsmPreTrainedModel):
>>> model = RnaMsmForNucleotidePrediction(config)
>>> tokenizer = RnaTokenizer.from_pretrained("multimolecule/rna")
>>> input = tokenizer("ACGUN", return_tensors="pt")
- >>> output = model(**input, labels=torch.randn(1, 5, 2))
+ >>> output = model(**input, labels=torch.randn(1, 5))
>>> output["logits"].shape
- torch.Size([1, 5, 2])
+ torch.Size([1, 5, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: RnaMsmConfig):
@@ -313,9 +313,9 @@ class RnaMsmForSequencePrediction(RnaMsmPreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.tensor([[1]]))
>>> output["logits"].shape
- torch.Size([1, 2])
+ torch.Size([1, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: RnaMsmConfig):
@@ -376,9 +376,9 @@ class RnaMsmForTokenPrediction(RnaMsmPreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.randint(2, (1, 7)))
>>> output["logits"].shape
- torch.Size([1, 7, 2])
+ torch.Size([1, 7, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: RnaMsmConfig):
@@ -507,7 +507,7 @@ class RnaMsmForPreTraining(RnaMsmPreTrainedModel):
>>> output["logits"].shape
torch.Size([1, 7, 26])
>>> output["contact_map"].shape
- torch.Size([1, 5, 5, 2])
+ torch.Size([1, 5, 5, 1])
"""
_tied_weights_keys = [
diff --git a/multimolecule/models/splicebert/configuration_splicebert.py b/multimolecule/models/splicebert/configuration_splicebert.py
index 1f98de4c..9a61df3e 100644
--- a/multimolecule/models/splicebert/configuration_splicebert.py
+++ b/multimolecule/models/splicebert/configuration_splicebert.py
@@ -111,5 +111,5 @@ def __init__(
self.position_embedding_type = position_embedding_type
self.is_decoder = is_decoder
self.use_cache = use_cache
- self.head = HeadConfig(**head if head is not None else {})
- self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {})
+ self.head = HeadConfig(**head) if head is not None else None
+ self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None
diff --git a/multimolecule/models/splicebert/modeling_splicebert.py b/multimolecule/models/splicebert/modeling_splicebert.py
index 7a99a2ff..3b5f2ddd 100644
--- a/multimolecule/models/splicebert/modeling_splicebert.py
+++ b/multimolecule/models/splicebert/modeling_splicebert.py
@@ -277,9 +277,9 @@ class SpliceBertForContactPrediction(SpliceBertPreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.randint(2, (1, 5, 5)))
>>> output["logits"].shape
- torch.Size([1, 5, 5, 2])
+ torch.Size([1, 5, 5, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: SpliceBertConfig):
@@ -341,11 +341,11 @@ class SpliceBertForNucleotidePrediction(SpliceBertPreTrainedModel):
>>> model = SpliceBertForNucleotidePrediction(config)
>>> tokenizer = RnaTokenizer.from_pretrained("multimolecule/rna")
>>> input = tokenizer("ACGUN", return_tensors="pt")
- >>> output = model(**input, labels=torch.randn(1, 5, 2))
+ >>> output = model(**input, labels=torch.randn(1, 5))
>>> output["logits"].shape
- torch.Size([1, 5, 2])
+ torch.Size([1, 5, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: SpliceBertConfig):
@@ -407,9 +407,9 @@ class SpliceBertForSequencePrediction(SpliceBertPreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.tensor([[1]]))
>>> output["logits"].shape
- torch.Size([1, 2])
+ torch.Size([1, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: SpliceBertConfig):
@@ -471,9 +471,9 @@ class SpliceBertForTokenPrediction(SpliceBertPreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.randint(2, (1, 7)))
>>> output["logits"].shape
- torch.Size([1, 7, 2])
+ torch.Size([1, 7, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: SpliceBertConfig):
diff --git a/multimolecule/models/utrbert/configuration_utrbert.py b/multimolecule/models/utrbert/configuration_utrbert.py
index fa62a579..cfc5dd53 100644
--- a/multimolecule/models/utrbert/configuration_utrbert.py
+++ b/multimolecule/models/utrbert/configuration_utrbert.py
@@ -128,5 +128,5 @@ def __init__(
self.position_embedding_type = position_embedding_type
self.is_decoder = is_decoder
self.use_cache = use_cache
- self.head = HeadConfig(**head if head is not None else {})
- self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {})
+ self.head = HeadConfig(**head) if head is not None else None
+ self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None
diff --git a/multimolecule/models/utrbert/modeling_utrbert.py b/multimolecule/models/utrbert/modeling_utrbert.py
index 6f9b79a3..c28230e7 100644
--- a/multimolecule/models/utrbert/modeling_utrbert.py
+++ b/multimolecule/models/utrbert/modeling_utrbert.py
@@ -267,9 +267,9 @@ class UtrBertForContactPrediction(UtrBertPreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.randint(2, (1, 5, 5)))
>>> output["logits"].shape
- torch.Size([1, 5, 5, 2])
+ torch.Size([1, 5, 5, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: UtrBertConfig):
@@ -331,11 +331,11 @@ class UtrBertForNucleotidePrediction(UtrBertPreTrainedModel):
>>> config = UtrBertConfig(vocab_size=tokenizer.vocab_size, nmers=2)
>>> model = UtrBertForNucleotidePrediction(config)
>>> input = tokenizer("ACGUN", return_tensors="pt")
- >>> output = model(**input, labels=torch.randn(1, 5, 2))
+ >>> output = model(**input, labels=torch.randn(1, 5))
>>> output["logits"].shape
- torch.Size([1, 5, 2])
+ torch.Size([1, 5, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: UtrBertConfig):
@@ -397,9 +397,9 @@ class UtrBertForSequencePrediction(UtrBertPreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.tensor([[1]]))
>>> output["logits"].shape
- torch.Size([1, 2])
+ torch.Size([1, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: UtrBertConfig):
@@ -461,9 +461,9 @@ class UtrBertForTokenPrediction(UtrBertPreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.randint(2, (1, 7)))
>>> output["logits"].shape
- torch.Size([1, 7, 2])
+ torch.Size([1, 7, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: UtrBertConfig):
diff --git a/multimolecule/models/utrlm/configuration_utrlm.py b/multimolecule/models/utrlm/configuration_utrlm.py
index 0e6babb1..6f7765d4 100644
--- a/multimolecule/models/utrlm/configuration_utrlm.py
+++ b/multimolecule/models/utrlm/configuration_utrlm.py
@@ -130,7 +130,7 @@ def __init__(
self.use_cache = use_cache
self.emb_layer_norm_before = emb_layer_norm_before
self.token_dropout = token_dropout
- self.head = HeadConfig(**head if head is not None else {})
- self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {})
+ self.head = HeadConfig(**head) if head is not None else None
+ self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None
self.ss_head = HeadConfig(**ss_head) if ss_head is not None else None
self.mfe_head = HeadConfig(**mfe_head) if mfe_head is not None else None
diff --git a/multimolecule/models/utrlm/modeling_utrlm.py b/multimolecule/models/utrlm/modeling_utrlm.py
index e6ada5b3..697fed8c 100644
--- a/multimolecule/models/utrlm/modeling_utrlm.py
+++ b/multimolecule/models/utrlm/modeling_utrlm.py
@@ -270,9 +270,9 @@ class UtrLmForContactPrediction(UtrLmPreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.randint(2, (1, 5, 5)))
>>> output["logits"].shape
- torch.Size([1, 5, 5, 2])
+ torch.Size([1, 5, 5, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: UtrLmConfig):
@@ -334,11 +334,11 @@ class UtrLmForNucleotidePrediction(UtrLmPreTrainedModel):
>>> model = UtrLmForNucleotidePrediction(config)
>>> tokenizer = RnaTokenizer.from_pretrained("multimolecule/rna")
>>> input = tokenizer("ACGUN", return_tensors="pt")
- >>> output = model(**input, labels=torch.randn(1, 5, 2))
+ >>> output = model(**input, labels=torch.randn(1, 5))
>>> output["logits"].shape
- torch.Size([1, 5, 2])
+ torch.Size([1, 5, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: UtrLmConfig):
@@ -400,9 +400,9 @@ class UtrLmForSequencePrediction(UtrLmPreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.tensor([[1]]))
>>> output["logits"].shape
- torch.Size([1, 2])
+ torch.Size([1, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: UtrLmConfig):
@@ -464,9 +464,9 @@ class UtrLmForTokenPrediction(UtrLmPreTrainedModel):
>>> input = tokenizer("ACGUN", return_tensors="pt")
>>> output = model(**input, labels=torch.randint(2, (1, 7)))
>>> output["logits"].shape
- torch.Size([1, 7, 2])
+ torch.Size([1, 7, 1])
>>> output["loss"] # doctest:+ELLIPSIS
- tensor(..., grad_fn=)
+ tensor(..., grad_fn=)
"""
def __init__(self, config: UtrLmConfig):
@@ -606,7 +606,7 @@ class UtrLmForPreTraining(UtrLmPreTrainedModel):
>>> output["logits"].shape
torch.Size([1, 7, 26])
>>> output["contact_map"].shape
- torch.Size([1, 5, 5, 2])
+ torch.Size([1, 5, 5, 1])
"""
_tied_weights_keys = [
diff --git a/multimolecule/module/__init__.py b/multimolecule/module/__init__.py
index bc657940..b7d54ad5 100644
--- a/multimolecule/module/__init__.py
+++ b/multimolecule/module/__init__.py
@@ -14,7 +14,7 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from .criterions import Criterion
+from .criterions import Criterion, CriterionRegistry
from .embeddings import PositionEmbeddingRegistry, PositionEmbeddingRegistryHF, RotaryEmbedding, SinusoidalEmbedding
from .heads import (
BaseHeadConfig,
@@ -38,8 +38,13 @@
TokenKMerHead,
TokenPredictionHead,
)
+from .model import MultiMoleculeModel
+from .registry import ModelRegistry
__all__ = [
+ "ModelRegistry",
+ "MultiMoleculeModel",
+ "CriterionRegistry",
"Criterion",
"PositionEmbeddingRegistry",
"PositionEmbeddingRegistryHF",
diff --git a/multimolecule/module/backbones/__init__.py b/multimolecule/module/backbones/__init__.py
new file mode 100644
index 00000000..d69e6292
--- /dev/null
+++ b/multimolecule/module/backbones/__init__.py
@@ -0,0 +1,21 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from .registry import BackboneRegistry
+from .sequence import SequenceBackbone
+from .sequences import SequenceRegistry
+
+__all__ = ["BackboneRegistry", "SequenceRegistry", "SequenceBackbone"]
diff --git a/multimolecule/module/backbones/registry.py b/multimolecule/module/backbones/registry.py
new file mode 100644
index 00000000..47be122d
--- /dev/null
+++ b/multimolecule/module/backbones/registry.py
@@ -0,0 +1,21 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from __future__ import annotations
+
+from chanfig import Registry
+
+BackboneRegistry = Registry()
diff --git a/multimolecule/module/backbones/sequence.py b/multimolecule/module/backbones/sequence.py
new file mode 100644
index 00000000..a30cbf83
--- /dev/null
+++ b/multimolecule/module/backbones/sequence.py
@@ -0,0 +1,46 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from __future__ import annotations
+
+import torch
+from chanfig import FlatDict
+from danling import NestedTensor
+from torch import Tensor, nn
+
+from .registry import BackboneRegistry
+from .sequences import SequenceRegistry
+
+
+@BackboneRegistry.register("sequence", default=True)
+class SequenceBackbone(nn.Module):
+ def __init__(self, sequence) -> None:
+ super().__init__()
+ self.sequence = SequenceRegistry.build(**sequence)
+ self.sequence_dropout = nn.Dropout(sequence.pop("dropout", 0), inplace=True)
+ self.config = self.sequence.config
+ self.out_channels = self.config.hidden_size
+
+ def forward(self, sequence: NestedTensor | Tensor, *args, **kwargs) -> tuple[FlatDict, FlatDict]:
+ attentions = None
+ input_ids, attention_mask = sequence.tensor, sequence.mask
+ sequence_output = self.sequence(input_ids.int(), attention_mask)
+ sequence_output["pooler_output"] = self.sequence_dropout(sequence_output["pooler_output"])
+ sequence_output["last_hidden_state"] = self.sequence_dropout(sequence_output["last_hidden_state"])
+ if "attentions" in sequence_output:
+ attentions = torch.stack(sequence_output["attentions"], dim=1).detach()
+
+ return sequence_output, attentions
diff --git a/multimolecule/module/backbones/sequences/__init__.py b/multimolecule/module/backbones/sequences/__init__.py
new file mode 100644
index 00000000..e6e5cd08
--- /dev/null
+++ b/multimolecule/module/backbones/sequences/__init__.py
@@ -0,0 +1,20 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from .onehot import OneHot
+from .registry import SequenceRegistry
+
+__all__ = ["SequenceRegistry", "OneHot"]
diff --git a/multimolecule/module/backbones/sequences/onehot.py b/multimolecule/module/backbones/sequences/onehot.py
new file mode 100644
index 00000000..bc4c979f
--- /dev/null
+++ b/multimolecule/module/backbones/sequences/onehot.py
@@ -0,0 +1,39 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+import torch
+from chanfig import FlatDict
+from torch import nn
+from transformers import AutoConfig
+
+from .registry import SequenceRegistry
+
+
+@SequenceRegistry.register("onehot")
+class OneHot(nn.Module):
+ def __init__(self, pretrained: str) -> None:
+ super().__init__()
+ self.config = AutoConfig.from_pretrained(str(pretrained))
+ self.module = nn.Embedding(self.config.vocab_size, self.config.hidden_size)
+
+ def forward(self, input_ids, attn_mask) -> FlatDict:
+ output = FlatDict()
+ output["last_hidden_state"] = self.module(input_ids)
+ valid_length = attn_mask.sum(dim=1)
+ output["pooler_output"] = torch.stack(
+ [t[: valid_length[i]].sum(0) for i, t in enumerate(output["last_hidden_state"])]
+ )
+ return output
diff --git a/multimolecule/module/backbones/sequences/registry.py b/multimolecule/module/backbones/sequences/registry.py
new file mode 100644
index 00000000..c9178231
--- /dev/null
+++ b/multimolecule/module/backbones/sequences/registry.py
@@ -0,0 +1,66 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from __future__ import annotations
+
+import danling as dl
+import transformers
+from chanfig import Registry as Registry_
+from torch import nn
+from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
+
+
+class Registry(Registry_): # pylint: disable=too-few-public-methods
+ def build(
+ self,
+ type: str | None = None,
+ name: str | None = None,
+ use_pretrained: bool = True,
+ gradient_checkpoint: bool = False,
+ checkpoint: str | None = None,
+ *args,
+ **kwargs,
+ ) -> nn.Module:
+ if type is not None:
+ if type in self:
+ sequence_cls = self.lookup(type)
+ sequence = self.init(sequence_cls, *args, **kwargs)
+ if checkpoint is not None:
+ sequence.load_state_dict(dl.load(checkpoint))
+ elif hasattr(transformers, type + "Model"):
+ if use_pretrained:
+ sequence_cls: PreTrainedModel = getattr(transformers, type + "Model") # type: ignore[no-redef]
+ sequence = sequence_cls.from_pretrained(name, *args, **kwargs)
+ else:
+ config_cls: PretrainedConfig = getattr(transformers, type + "Config")
+ config, kwargs = config_cls.from_pretrained(name, return_unused_kwargs=True, **kwargs)
+ sequence_cls: PreTrainedModel = getattr(transformers, type + "Model") # type: ignore[no-redef]
+ sequence = sequence_cls.from_config(config, *args, **kwargs)
+ else:
+ raise ValueError(f"Sequence {type} not found in registry or transformers")
+ else:
+ if use_pretrained:
+ sequence = AutoModel.from_pretrained(name, *args, **kwargs)
+ else:
+ config, kwargs = AutoConfig.from_pretrained(name, return_unused_kwargs=True, **kwargs)
+ sequence = AutoModel.from_config(config, *args, **kwargs)
+
+ if gradient_checkpoint:
+ sequence.gradient_checkpointing_enable()
+ return sequence
+
+
+SequenceRegistry = Registry()
diff --git a/multimolecule/module/criterions/__init__.py b/multimolecule/module/criterions/__init__.py
index 104334b5..4b9adf7e 100644
--- a/multimolecule/module/criterions/__init__.py
+++ b/multimolecule/module/criterions/__init__.py
@@ -14,6 +14,18 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
+from .binary import BCEWithLogitsLoss
from .generic import Criterion
+from .multiclass import CrossEntropyLoss
+from .multilabel import MultiLabelSoftMarginLoss
+from .registry import CriterionRegistry
+from .regression import MSELoss
-__all__ = ["Criterion"]
+__all__ = [
+ "CriterionRegistry",
+ "Criterion",
+ "MSELoss",
+ "BCEWithLogitsLoss",
+ "CrossEntropyLoss",
+ "MultiLabelSoftMarginLoss",
+]
diff --git a/multimolecule/module/criterions/binary.py b/multimolecule/module/criterions/binary.py
new file mode 100644
index 00000000..0bf53e59
--- /dev/null
+++ b/multimolecule/module/criterions/binary.py
@@ -0,0 +1,44 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+import torch
+from danling import NestedTensor
+from torch import Tensor, nn
+
+from .registry import CriterionRegistry
+
+if TYPE_CHECKING:
+ from ..heads.config import HeadConfig
+
+
+@CriterionRegistry.register("binary")
+class BCEWithLogitsLoss(nn.BCEWithLogitsLoss):
+ def __init__(self, config: HeadConfig) -> None:
+ super().__init__(**config.get("loss", {}))
+ self.config = config
+
+ def forward(self, input: NestedTensor | Tensor, target: NestedTensor | Tensor) -> Tensor:
+ if isinstance(input, NestedTensor):
+ input = torch.cat(input.flatten().storage())
+ if isinstance(target, NestedTensor):
+ target = torch.cat(target.flatten().storage())
+ if input.ndim == target.ndim + 1:
+ input = input.squeeze(-1)
+ return super().forward(input, target.float())
diff --git a/multimolecule/module/criterions/generic.py b/multimolecule/module/criterions/generic.py
index b003c81d..a6731933 100644
--- a/multimolecule/module/criterions/generic.py
+++ b/multimolecule/module/criterions/generic.py
@@ -17,8 +17,8 @@
from __future__ import annotations
from typing import TYPE_CHECKING
+from warnings import warn
-import torch
from danling import NestedTensor
from torch import Tensor, nn
from torch.nn import functional as F
@@ -26,10 +26,13 @@
if TYPE_CHECKING:
from ..heads.config import HeadConfig
+from .registry import CriterionRegistry
+
+@CriterionRegistry.register(default=True)
class Criterion(nn.Module):
- problem_types = ["regression", "single_label_classification", "multi_label_classification"]
+ problem_types = ["regression", "binary", "multiclass", "multilabel"]
def __init__(self, config: HeadConfig) -> None:
super().__init__()
@@ -41,21 +44,31 @@ def forward(self, logits: Tensor | NestedTensor, labels: Tensor | NestedTensor)
if labels is None:
return None
if self.problem_type is None:
- if self.num_labels == 1:
+ if labels.is_floating_point():
self.problem_type = "regression"
- elif self.num_labels > 1 and labels.dtype in (torch.long, torch.int):
- self.problem_type = "single_label_classification"
+ elif self.num_labels == 1:
+ self.problem_type = "binary"
+ elif labels.unique().numel() == 2:
+ self.problem_type = "multilabel"
else:
- self.problem_type = "multi_label_classification"
+ self.problem_type = "multiclass"
+ warn(
+ f"`problem_type` is not set. Assuming {self.problem_type}. \n"
+ "This can lead to unexpected behavior. Please set `problem_type` explicitly."
+ )
self.config.problem_type = self.problem_type
if self.problem_type == "regression":
labels = labels.to(logits.dtype)
if self.num_labels == 1:
return F.mse_loss(logits.squeeze(), labels.squeeze())
logits, labels = logits.view(-1, self.num_labels), labels.view(-1, self.num_labels)
- return sum(F.mse_loss(logits[:, i], labels[:, i]).sqrt() for i in range(self.num_labels))
- if self.problem_type == "single_label_classification":
+ return sum(F.mse_loss(logits[:, i], labels[:, i]).sqrt() for i in range(self.num_labels)) # type: ignore
+ if self.problem_type == "multiclass":
return F.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1))
- if self.problem_type == "multi_label_classification":
- return F.binary_cross_entropy_with_logits(logits, labels)
+ if self.problem_type == "binary":
+ if logits.ndim == labels.ndim + 1:
+ logits = logits.squeeze(-1)
+ return F.binary_cross_entropy_with_logits(logits, labels.to(logits.dtype))
+ if self.problem_type == "multilabel":
+ return F.multilabel_soft_margin_loss(logits, labels.to(logits.dtype))
raise ValueError(f"problem_type should be one of {self.problem_types}, but got {self.problem_type}")
diff --git a/multimolecule/module/criterions/multiclass.py b/multimolecule/module/criterions/multiclass.py
new file mode 100644
index 00000000..f7070e94
--- /dev/null
+++ b/multimolecule/module/criterions/multiclass.py
@@ -0,0 +1,44 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+import torch
+from danling import NestedTensor
+from torch import Tensor, nn
+
+if TYPE_CHECKING:
+ from ..heads.config import HeadConfig
+
+from .registry import CriterionRegistry
+
+
+@CriterionRegistry.register("multiclass")
+class CrossEntropyLoss(nn.CrossEntropyLoss):
+ def __init__(self, config: HeadConfig) -> None:
+ super().__init__(**config.get("loss", {}))
+ self.config = config
+
+ def forward(self, input: NestedTensor | Tensor, target: NestedTensor | Tensor) -> Tensor:
+ if isinstance(input, NestedTensor):
+ input = torch.cat(input.storage())
+ if isinstance(target, NestedTensor):
+ target = torch.cat(target.storage())
+ if input.ndim > 2:
+ input, target = input.view(-1, input.size(-1)), target.view(-1)
+ return super().forward(input, target.long())
diff --git a/multimolecule/module/criterions/multilabel.py b/multimolecule/module/criterions/multilabel.py
new file mode 100644
index 00000000..c72bb9f9
--- /dev/null
+++ b/multimolecule/module/criterions/multilabel.py
@@ -0,0 +1,44 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+import torch
+from danling import NestedTensor
+from torch import Tensor, nn
+
+if TYPE_CHECKING:
+ from ..heads.config import HeadConfig
+
+from .registry import CriterionRegistry
+
+
+@CriterionRegistry.register("multilabel")
+class MultiLabelSoftMarginLoss(nn.MultiLabelSoftMarginLoss):
+ def __init__(self, config: HeadConfig) -> None:
+ super().__init__(**config.get("loss", {}))
+ self.config = config
+
+ def forward(self, input: NestedTensor | Tensor, target: NestedTensor | Tensor) -> Tensor:
+ if isinstance(target, NestedTensor) and target.ndim > 2:
+ input, target = input.view(-1, input.size(-1)), target.view(-1, target.size(-1))
+ if isinstance(input, NestedTensor):
+ input = torch.cat(input.storage())
+ if isinstance(target, NestedTensor):
+ target = torch.cat(target.storage())
+ return super().forward(input, target.float())
diff --git a/multimolecule/module/criterions/registry.py b/multimolecule/module/criterions/registry.py
new file mode 100644
index 00000000..856341f7
--- /dev/null
+++ b/multimolecule/module/criterions/registry.py
@@ -0,0 +1,29 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from chanfig import ConfigRegistry as Registry_
+from torch import nn
+
+
+class Registry(Registry_): # pylint: disable=too-few-public-methods
+ key = "problem_type"
+
+ def build(self, config) -> nn.Module: # type: ignore[override]
+ name = getattr(config, self.getattr("key"))
+ return self.init(self.lookup(name), config) # type: ignore[arg-type]
+
+
+CriterionRegistry = Registry(fallback=True)
diff --git a/multimolecule/module/criterions/regression.py b/multimolecule/module/criterions/regression.py
new file mode 100644
index 00000000..4f39e0eb
--- /dev/null
+++ b/multimolecule/module/criterions/regression.py
@@ -0,0 +1,44 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+import torch
+from danling import NestedTensor
+from torch import Tensor, nn
+
+if TYPE_CHECKING:
+ from ..heads.config import HeadConfig
+
+from .registry import CriterionRegistry
+
+
+@CriterionRegistry.register("regression")
+class MSELoss(nn.MSELoss):
+ def __init__(self, config: HeadConfig) -> None:
+ super().__init__(**config.get("loss", {}))
+ self.config = config
+
+ def forward(self, input: NestedTensor | Tensor, target: NestedTensor | Tensor) -> Tensor:
+ if isinstance(input, NestedTensor):
+ input = torch.cat(input.flatten().storage())
+ if isinstance(target, NestedTensor):
+ target = torch.cat(target.flatten().storage())
+ if input.ndim == target.ndim + 1:
+ target = target.unsqueeze(-1)
+ return super().forward(input, target.to(input.dtype))
diff --git a/multimolecule/module/heads/config.py b/multimolecule/module/heads/config.py
index bb0dbba6..3b9ee64b 100644
--- a/multimolecule/module/heads/config.py
+++ b/multimolecule/module/heads/config.py
@@ -16,15 +16,13 @@
from __future__ import annotations
-from collections import OrderedDict
-from dataclasses import dataclass
+from chanfig import FlatDict
-class BaseHeadConfig(OrderedDict):
+class BaseHeadConfig(FlatDict):
pass
-@dataclass
class HeadConfig(BaseHeadConfig):
r"""
Configuration class for a prediction head.
@@ -35,8 +33,8 @@ class HeadConfig(BaseHeadConfig):
Head should look for [`Config.num_labels`][multimolecule.PreTrainedConfig] if is `None`.
problem_type:
- Problem type for `XxxForYyyPrediction` models. Can be one of `"regression"`,
- `"single_label_classification"` or `"multi_label_classification"`.
+ Problem type for `XxxForYyyPrediction` models. Can be one of `"binary"`, `"regression"`,
+ `"multiclass"` or `"multilabel"`.
Head should look for [`Config.problem_type`][multimolecule.PreTrainedConfig] if is `None`.
hidden_size:
@@ -55,14 +53,18 @@ class HeadConfig(BaseHeadConfig):
The activation function of the final prediction output.
layer_norm_eps:
The epsilon used by the layer normalization layers.
- output_name (`str`, *optional*):
+ output_name:
The name of the tensor required in model outputs.
If is `None`, will use the default output name of the corresponding head.
+ type:
+ The type of the head in the model.
+
+ This is used by [`MultiMoleculeModel`][multimolecule.MultiMoleculeModel] to construct heads.
"""
- num_labels: int = None # type: ignore[assignment]
- problem_type: str = None # type: ignore[assignment]
+ num_labels: int | None = None
+ problem_type: str | None = None
hidden_size: int | None = None
dropout: float = 0.0
transform: str | None = None
@@ -71,9 +73,9 @@ class HeadConfig(BaseHeadConfig):
act: str | None = None
layer_norm_eps: float = 1e-12
output_name: str | None = None
+ type: str | None = None
-@dataclass
class MaskedLMHeadConfig(BaseHeadConfig):
r"""
Configuration class for a Masked Language Modeling head.
@@ -95,7 +97,7 @@ class MaskedLMHeadConfig(BaseHeadConfig):
The activation function of the final prediction output.
layer_norm_eps:
The epsilon used by the layer normalization layers.
- output_name (`str`, *optional*):
+ output_name:
The name of the tensor required in model outputs.
If is `None`, will use the default output name of the corresponding head.
diff --git a/multimolecule/module/heads/contact.py b/multimolecule/module/heads/contact.py
index 50ec4fbb..3a276d8c 100644
--- a/multimolecule/module/heads/contact.py
+++ b/multimolecule/module/heads/contact.py
@@ -50,6 +50,8 @@ class ContactPredictionHead(PredictionHead):
output_name: str = "attentions"
r"""The default output to use for the head."""
+ requires_attention: bool = True
+
def __init__(self, config: PreTrainedConfig, head_config: HeadConfig | None = None):
super().__init__(config, head_config)
self.bos_token_id = config.bos_token_id
diff --git a/multimolecule/module/heads/generic.py b/multimolecule/module/heads/generic.py
index d97950a2..3423926e 100644
--- a/multimolecule/module/heads/generic.py
+++ b/multimolecule/module/heads/generic.py
@@ -19,12 +19,11 @@
from typing import TYPE_CHECKING
from warnings import warn
-import torch
from danling import NestedTensor
from torch import Tensor, nn
from transformers.activations import ACT2FN
-from ..criterions import Criterion
+from ..criterions import CriterionRegistry
from .config import HeadConfig
from .output import HeadOutput
from .transform import HeadTransformRegistryHF
@@ -44,24 +43,23 @@ class PredictionHead(nn.Module):
"""
num_labels: int
+ requires_attention: bool = False
def __init__(self, config: PreTrainedConfig, head_config: HeadConfig | None = None):
super().__init__()
if head_config is None:
- head_config = config.head
+ head_config = config.head or HeadConfig(num_labels=config.num_labels)
self.config = head_config
if self.config.hidden_size is None:
self.config.hidden_size = config.hidden_size
- if self.config.num_labels is None:
- self.config.num_labels = config.num_labels
if self.config.problem_type is None:
self.config.problem_type = config.problem_type
- self.num_labels = self.config.num_labels
+ self.num_labels = self.config.num_labels # type: ignore[assignment]
self.dropout = nn.Dropout(self.config.dropout)
self.transform = HeadTransformRegistryHF.build(self.config)
- self.decoder = nn.Linear(config.hidden_size, self.num_labels, bias=self.config.bias)
+ self.decoder = nn.Linear(self.config.hidden_size, self.num_labels, bias=self.config.bias)
self.activation = ACT2FN[self.config.act] if self.config.act is not None else None
- self.criterion = Criterion(self.config)
+ self.criterion = CriterionRegistry.build(self.config)
def forward(self, embeddings: Tensor, labels: Tensor | None, **kwargs) -> HeadOutput:
r"""
@@ -85,6 +83,6 @@ def forward(self, embeddings: Tensor, labels: Tensor | None, **kwargs) -> HeadOu
if isinstance(labels, NestedTensor):
if isinstance(output, Tensor):
output = labels.nested_like(output, strict=False)
- return HeadOutput(output, self.criterion(torch.cat(output.storage()), torch.cat(labels.storage())))
+ return HeadOutput(output, self.criterion(output.concat, labels.concat))
return HeadOutput(output, self.criterion(output, labels))
return HeadOutput(output)
diff --git a/multimolecule/module/heads/pretrain.py b/multimolecule/module/heads/pretrain.py
index 329d92e1..37a3e5a8 100644
--- a/multimolecule/module/heads/pretrain.py
+++ b/multimolecule/module/heads/pretrain.py
@@ -53,8 +53,8 @@ def __init__(
):
super().__init__()
if head_config is None:
- head_config = config.lm_head if hasattr(config, "lm_head") else config.head # type: ignore[assignment]
- self.config: MaskedLMHeadConfig = head_config # type: ignore[assignment]
+ head_config = (config.lm_head if hasattr(config, "lm_head") else config.head) or MaskedLMHeadConfig()
+ self.config: MaskedLMHeadConfig = head_config
if self.config.hidden_size is None:
self.config.hidden_size = config.hidden_size
self.num_labels = config.vocab_size
@@ -95,6 +95,6 @@ def forward(
if isinstance(labels, NestedTensor):
if isinstance(output, Tensor):
output = labels.nested_like(output, strict=False)
- return HeadOutput(output, F.cross_entropy(torch.cat(output.storage()), torch.cat(labels.storage())))
+ return HeadOutput(output, F.cross_entropy(output.concat, labels.concat))
return HeadOutput(output, F.cross_entropy(output.view(-1, self.num_labels), labels.view(-1)))
return HeadOutput(output)
diff --git a/multimolecule/module/heads/registry.py b/multimolecule/module/heads/registry.py
index e5393e4e..6db3b680 100644
--- a/multimolecule/module/heads/registry.py
+++ b/multimolecule/module/heads/registry.py
@@ -14,6 +14,16 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from chanfig import Registry
+from chanfig import ConfigRegistry as Registry_
+from torch import nn
+
+
+class Registry(Registry_): # pylint: disable=too-few-public-methods
+ key = "type"
+
+ def build(self, config, head_config) -> nn.Module: # type: ignore[override]
+ name = getattr(head_config, self.getattr("key"))
+ return self.init(self.lookup(name), config, head_config) # type: ignore[arg-type]
+
HeadRegistry = Registry(default_factory=Registry, fallback=True)
diff --git a/multimolecule/module/model.py b/multimolecule/module/model.py
new file mode 100644
index 00000000..256783be
--- /dev/null
+++ b/multimolecule/module/model.py
@@ -0,0 +1,89 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from __future__ import annotations
+
+from chanfig import FlatDict
+from danling import NestedTensor
+from torch import Tensor, nn
+
+from .backbones import BackboneRegistry
+from .heads import HeadRegistry
+from .necks import NeckRegistry
+from .registry import ModelRegistry
+
+
+@ModelRegistry.register(default=True)
+class MultiMoleculeModel(nn.Module):
+ def __init__(
+ self,
+ backbone: dict,
+ heads: dict,
+ neck: dict | None = None,
+ max_length: int = 1024,
+ truncation: bool = False,
+ ):
+ super().__init__()
+
+ # Backbone
+ self.backbone = BackboneRegistry.build(**backbone)
+ backbone = self.backbone.config
+ out_channels = self.backbone.out_channels
+
+ # Neck
+ if neck:
+ num_discrete = self.backbone.num_discrete
+ num_continuous = self.backbone.num_continuous
+ embed_dim = self.backbone.sequence.config.hidden_size
+ attention_heads = self.backbone.sequence.config.num_attention_heads
+ neck.update(
+ {
+ "num_discrete": num_discrete,
+ "num_continuous": num_continuous,
+ "embed_dim": embed_dim,
+ "attention_heads": attention_heads,
+ "max_length": max_length,
+ "truncation": truncation,
+ }
+ )
+ self.neck = NeckRegistry.build(**neck)
+ out_channels = self.neck.out_channels
+ else:
+ self.neck = None
+
+ # Heads
+ for head in heads.values():
+ if "hidden_size" not in head or head["hidden_size"] is None:
+ head["hidden_size"] = out_channels
+ self.heads = nn.ModuleDict({name: HeadRegistry.build(backbone, head) for name, head in heads.items()})
+ if any(getattr(h, "requires_attention", False) for h in self.heads.values()):
+ self.backbone.sequence.config.output_attentions = True
+
+ def forward(
+ self,
+ sequence: NestedTensor | Tensor,
+ discrete: Tensor | None = None,
+ continuous: Tensor | None = None,
+ dataset: str | None = None,
+ **labels: NestedTensor | Tensor,
+ ) -> FlatDict:
+ ret = FlatDict()
+ output, _ = self.backbone(sequence, discrete, continuous)
+ if self.neck is not None:
+ output = self.neck(**output)
+ for task, label in labels.items():
+ ret[task] = self.heads[task](output, input_ids=sequence, labels=label)
+ return ret
diff --git a/multimolecule/module/necks/__init__.py b/multimolecule/module/necks/__init__.py
new file mode 100644
index 00000000..e8f1f7e2
--- /dev/null
+++ b/multimolecule/module/necks/__init__.py
@@ -0,0 +1,21 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from .bert import BERTNeck
+from .cat import CatNeck
+from .registry import NeckRegistry
+
+__all__ = ["NeckRegistry", "CatNeck", "BERTNeck"]
diff --git a/multimolecule/module/necks/bert.py b/multimolecule/module/necks/bert.py
new file mode 100644
index 00000000..1360f0dd
--- /dev/null
+++ b/multimolecule/module/necks/bert.py
@@ -0,0 +1,102 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from __future__ import annotations
+
+import torch
+from chanfig import FlatDict
+from danling.modules import TransformerEncoder, TransformerEncoderLayer
+from torch import Tensor, nn
+
+from .registry import NeckRegistry
+
+MAX_LENGTH = 1024
+
+
+@NeckRegistry.register("bert")
+class BERTNeck(nn.Module):
+ def __init__( # pylint: disable=keyword-arg-before-vararg
+ self,
+ num_discrete: int,
+ num_continuous: int,
+ embed_dim: int,
+ attention_heads: int,
+ num_layers: int = 6,
+ max_length: int | None = None,
+ truncation: bool = False,
+ dropout: float = 0.1,
+ *args,
+ **kwargs,
+ ):
+ super().__init__()
+ self.cls_token_dis = nn.Parameter(torch.zeros(embed_dim))
+ self.cls_token_con = nn.Parameter(torch.zeros(embed_dim))
+ if max_length is None:
+ if truncation:
+ max_length = MAX_LENGTH + 1 + num_discrete + 1 + num_continuous
+ else:
+ max_length = MAX_LENGTH * 4 + 1 + num_discrete + 1 + num_continuous
+ self.max_length = max_length
+ self.pos_embed = nn.Parameter(torch.zeros(1, self.max_length, embed_dim))
+ bert_layer = TransformerEncoderLayer(
+ embed_dim, attention_heads, *args, dropout=dropout, attn_dropout=dropout, ffn_dropout=dropout, **kwargs
+ )
+ self.bert = TransformerEncoder(bert_layer, num_layers)
+ self.out_channels = embed_dim
+ nn.init.normal_(self.pos_embed, std=0.02)
+ nn.init.trunc_normal_(self.cls_token_dis, std=0.2)
+ nn.init.trunc_normal_(self.cls_token_con, std=0.2)
+
+ def forward(
+ self,
+ cls_token: Tensor | None = None,
+ all_tokens: Tensor | None = None,
+ discrete: Tensor | None = None,
+ continuous: Tensor | None = None,
+ ) -> FlatDict:
+ ret = FlatDict()
+ if cls_token is not None:
+ ret["cls_token"] = self._forward(cls_token, discrete, continuous)
+ if all_tokens is not None:
+ ret["all_tokens"] = self._forward(all_tokens, discrete, continuous)
+ return ret
+
+ def _forward(
+ self,
+ sequence: Tensor,
+ discrete: Tensor | None = None,
+ continuous: Tensor | None = None,
+ ) -> Tensor:
+ if sequence is None:
+ raise ValueError("sequence should not be None.")
+ if sequence.dim() == 2:
+ sequence = sequence[:, None]
+ batch_size, seq_len, _ = sequence.shape
+ output = sequence
+ if discrete is not None:
+ cls_token_dis = self.cls_token_dis.expand(batch_size, 1, -1)
+ output = torch.cat((output, cls_token_dis, discrete), dim=1)
+ if continuous is not None:
+ cls_token_con = self.cls_token_con.expand(batch_size, -1)[:, None]
+ output = torch.cat((output, cls_token_con, continuous), dim=1)
+ all_len = output.shape[1]
+ if all_len > self.pos_embed.shape[1]:
+ raise ValueError("sequence length is out of range.")
+ output = output + self.pos_embed[:, 0:all_len, :]
+ output = self.bert(output)[0][:, 0:seq_len, :]
+ if seq_len == 1:
+ output = output.squeeze(1)
+ return output
diff --git a/multimolecule/module/necks/cat.py b/multimolecule/module/necks/cat.py
new file mode 100644
index 00000000..d5165a92
--- /dev/null
+++ b/multimolecule/module/necks/cat.py
@@ -0,0 +1,43 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from __future__ import annotations
+
+import torch
+from chanfig import FlatDict
+from torch import Tensor
+
+from .registry import NeckRegistry
+
+
+@NeckRegistry.register("cat")
+class CatNeck: # pylint: disable=too-few-public-methods
+ def __init__(self, embed_dim: int):
+ self.out_channels = embed_dim * 2
+
+ def __call__(
+ self,
+ cls_token: Tensor | None = None,
+ all_tokens: Tensor | None = None,
+ discrete: Tensor | None = None,
+ continuous: Tensor | None = None,
+ ) -> FlatDict:
+ ret = FlatDict()
+ if cls_token is not None:
+ ret.cls_token = torch.cat(tuple(i for i in (cls_token, discrete, continuous) if i is not None), -1)
+ if all_tokens is not None:
+ ret.all_tokens = torch.cat(tuple(i for i in (all_tokens, discrete, continuous) if i is not None), -1)
+ return ret
diff --git a/multimolecule/module/necks/registry.py b/multimolecule/module/necks/registry.py
new file mode 100644
index 00000000..c024227c
--- /dev/null
+++ b/multimolecule/module/necks/registry.py
@@ -0,0 +1,21 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from __future__ import annotations
+
+from chanfig import Registry
+
+NeckRegistry = Registry()
diff --git a/multimolecule/module/registry.py b/multimolecule/module/registry.py
new file mode 100644
index 00000000..b0332463
--- /dev/null
+++ b/multimolecule/module/registry.py
@@ -0,0 +1,35 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from __future__ import annotations
+
+from chanfig import Registry as Registry_
+from torch import nn
+
+from .backbones import BackboneRegistry
+from .backbones.sequences import SequenceRegistry
+from .heads import HeadRegistry
+from .necks import NeckRegistry
+
+
+class Registry(Registry_): # pylint: disable=too-few-public-methods
+ def build(self, *args, **kwargs) -> nn.Module:
+ return super().build(*args, **kwargs)
+
+
+ModelRegistry = Registry()
+
+__all__ = ["ModelRegistry", "BackboneRegistry", "SequenceRegistry", "NeckRegistry", "HeadRegistry"]
diff --git a/multimolecule/runners/README.md b/multimolecule/runners/README.md
new file mode 100644
index 00000000..bb1000ad
--- /dev/null
+++ b/multimolecule/runners/README.md
@@ -0,0 +1,9 @@
+---
+authors:
+ - Zhiyuan Chen
+date: 2024-05-04
+---
+
+# runners
+
+`runners` provide an easy-to-use interface for running experiments.
diff --git a/multimolecule/runners/__init__.py b/multimolecule/runners/__init__.py
new file mode 100644
index 00000000..70fa4076
--- /dev/null
+++ b/multimolecule/runners/__init__.py
@@ -0,0 +1,20 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from .config import MultiMoleculeConfig
+from .runner import MultiMoleculeRunner
+
+__all__ = ["MultiMoleculeConfig", "MultiMoleculeRunner"]
diff --git a/multimolecule/runners/config.py b/multimolecule/runners/config.py
new file mode 100644
index 00000000..f4c42592
--- /dev/null
+++ b/multimolecule/runners/config.py
@@ -0,0 +1,70 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from __future__ import annotations
+
+from typing import List
+
+from chanfig import Config
+from transformers import PretrainedConfig
+
+
+class DataConfig(Config):
+ root: str = "."
+ train: str | None
+ validation: str | None
+ test: str | None
+ feature_cols: List | None = None
+ label_cols: List | None = None
+ truncation: bool = True
+
+
+class OptimConfig(Config):
+ name: str = "AdamW"
+ lr: float = 1e-3
+ weight_decay: float = 1e-2
+
+
+class MultiMoleculeConfig(Config):
+
+ art: bool = True
+
+ pretrained: str
+ use_pretrained: bool = True
+ transformers: PretrainedConfig
+ epoch_end: int = 20
+
+ tensorboard: bool = True
+ save_interval: int = 10
+
+ seed: int = 1013
+ data: DataConfig
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.datas = Config(default_factory=DataConfig)
+ self.dataloader.batch_size = 32
+ self.optim = OptimConfig()
+ self.sched.final_lr = 0
+
+ def post(self):
+ if "data" in self:
+ if self.datas:
+ raise ValueError("Only one of `data` or `datas` can be specified, but not both")
+ del self.datas
+ self["network.backbone.sequence.name"] = self.pretrained
+ self["network.backbone.sequence.use_pretrained"] = self.use_pretrained
+ self.name = f"{self.pretrained}-{self.optim.lr}@{self.optim.name}-{self.seed}"
diff --git a/multimolecule/runners/metrics.py b/multimolecule/runners/metrics.py
new file mode 100644
index 00000000..da584cbc
--- /dev/null
+++ b/multimolecule/runners/metrics.py
@@ -0,0 +1,37 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from chanfig import Registry as Registry_
+from danling.metrics import binary_metrics, multiclass_metrics, multilabel_metrics, regression_metrics
+
+
+class Registry(Registry_):
+
+ def build(self, type, num_labels: int | None = None, **kwargs):
+ if type == "multilabel":
+ return self.init(self.lookup(type), num_labels=num_labels, **kwargs)
+ if type == "multiclass":
+ return self.init(self.lookup(type), num_classes=num_labels, **kwargs)
+ if type == "regression":
+ return self.init(self.lookup(type), num_outputs=num_labels, **kwargs)
+ return self.init(self.lookup(type), **kwargs)
+
+
+MetricRegistry = Registry(key="type")
+MetricRegistry.register(binary_metrics, "binary")
+MetricRegistry.register(multiclass_metrics, "multiclass")
+MetricRegistry.register(multilabel_metrics, "multilabel")
+MetricRegistry.register(regression_metrics, "regression")
diff --git a/multimolecule/runners/runner.py b/multimolecule/runners/runner.py
new file mode 100644
index 00000000..c15f11ae
--- /dev/null
+++ b/multimolecule/runners/runner.py
@@ -0,0 +1,236 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+import os
+from functools import cached_property, partial
+from typing import Any, Tuple
+from warnings import warn
+
+import danling as dl
+import torch
+from art import text2art
+from chanfig import NestedDict
+from danling import MultiTaskMetrics, TorchRunner
+from datasets import disable_progress_bars, get_dataset_split_names
+from torch import nn, optim
+from torch.utils import data
+from transformers import AutoTokenizer
+
+from multimolecule.data import Dataset, DistributedMultiTaskSampler, MultiTaskDataset, MultiTaskSampler
+from multimolecule.module import HeadConfig, ModelRegistry
+
+from .config import MultiMoleculeConfig
+from .metrics import MetricRegistry
+
+disable_progress_bars()
+
+
+class MultiMoleculeRunner(TorchRunner):
+
+ all_datasets: NestedDict
+
+ def __init__(self, config: MultiMoleculeConfig):
+ if config.art:
+ print(text2art("MultiMolecule", "rand-large"))
+ super().__init__(config)
+ self.name = config.name
+ self.tokenizer = AutoTokenizer.from_pretrained(self.config.pretrained)
+ self.build_datasets()
+ self.build_dataloaders()
+ self.model = ModelRegistry.build(**self.network).to(self.device)
+ if self.distributed:
+ self.model = nn.parallel.DistributedDataParallel(
+ self.model, find_unused_parameters=True, bucket_cap_mb=32, gradient_as_bucket_view=True
+ )
+ self.optimizer = getattr(optim, self.config.optim.pop("name"))(
+ params=self.model.parameters(), **self.config.optim
+ )
+ self.scheduler = dl.optim.LRScheduler(self.optimizer, total_steps=self.total_steps, **self.config.sched)
+ self.metrics = self.build_metrics()
+
+ def __post_init__(self):
+ super().__post_init__()
+ self.yaml(os.path.join(self.dir, "trainer.yaml"))
+ print(self)
+ print(self.get_dataset_lengths())
+
+ def train_step(self, data) -> Tuple[Any, torch.Tensor]:
+ with self.autocast(), self.accumulate():
+ pred = self.model(**data)
+ loss = self.loss_fn(pred, data)
+ self.advance(loss)
+ self.metric_fn(pred, data)
+ return pred, loss
+
+ def evaluate_step(self, data) -> Tuple[Any, torch.Tensor]:
+ pred = self.model(**data)
+ loss = self.loss_fn(pred, data)
+ self.metric_fn(pred, data)
+ return pred, loss
+
+ def loss_fn(self, pred, data):
+ return sum(p["loss"] for p in pred.values())
+
+ def metric_fn(self, pred, data):
+ metric = self.metrics[data["dataset"]] if "dataset" in data else self.metrics
+ metric.update({t: (p["logits"], data[t]) for t, p in pred.items()})
+
+ @cached_property
+ def tasks(self):
+ if not self.datasets:
+ raise ValueError("No datasets found")
+ if self.datasets.train:
+ return self.datasets.train.tasks
+ return next(iter(self.datasets)).tasks
+
+ @cached_property
+ def dataset_tasks(self):
+ if not self.datasets:
+ raise ValueError("No datasets found")
+ dataset = self.datasets.train if self.datasets.train else next(iter(self.datasets))
+ tasks = self.tasks
+ dataset_tasks = dataset.dataset_tasks if isinstance(dataset, MultiTaskDataset) else dataset.tasks
+ for dataset in self.datasets.values():
+ dataset_tasks_ = dataset.dataset_tasks if isinstance(dataset, MultiTaskDataset) else dataset.tasks
+ for dataset_name, tasks_ in dataset_tasks_.items():
+ for task_name, task_ in tasks_.items():
+ if task_name not in tasks:
+ raise ValueError(f"Task {task_name} of dataset {dataset_name} is not defined in training data")
+ task = tasks[task_name]
+ if task != task_:
+ warn(
+ f"Task {task_name} of dataset {dataset_name} has different configurations "
+ "compared to training data, using training configuration.\n"
+ "This may lead to unexpected behavior.",
+ )
+ if dataset_name not in dataset_tasks:
+ dataset_tasks[dataset_name] = NestedDict()
+ if task_name not in dataset_tasks[dataset_name]:
+ dataset_tasks[dataset_name][task_name] = task
+ return dataset_tasks
+
+ @cached_property
+ def network(self):
+ heads = {
+ name: HeadConfig(num_labels=task.num_labels, problem_type=task.type, type=task.level)
+ for name, task in self.tasks.items()
+ }
+ self.config.network.heads.merge(heads, overwrite=False)
+ return self.config.network
+
+ def build_datasets(self):
+ if "data" in self.config:
+ self.datasets = self.all_datasets = self._build_dataset(self.config.data)
+ return
+ if "datas" in self.config:
+ self.all_datasets = NestedDict(
+ {name: self._build_dataset(config, name) for name, config in self.config.datas.items()}
+ )
+ datasets = {
+ subkey: {key: subdict[subkey] for key, subdict in self.all_datasets.items() if subkey in subdict}
+ for subkey in {k for v in self.all_datasets.values() for k in v}
+ }
+ self.datasets = NestedDict({split: MultiTaskDataset(datas) for split, datas in datasets.items()})
+ return
+ raise ValueError("No data configuration found")
+
+ def _build_dataset(self, config: NestedDict, name: str | None = None) -> NestedDict:
+ name = name or config.root
+ print(f"Building dataset {name}")
+ dataset = NestedDict()
+ dataset_factory = partial(
+ Dataset,
+ tokenizer=self.tokenizer,
+ **{k: v for k, v in config.items() if k not in ("train", "validation", "test", "root")},
+ )
+ if os.path.isdir(config.root):
+ if "train" in config:
+ dataset.train = dataset_factory(os.path.join(config.root, config.train), split="train")
+ if "validation" in config:
+ dataset.validation = dataset_factory(os.path.join(config.root, config.validation), split="validation")
+ if "test" in config:
+ dataset.test = dataset_factory(os.path.join(config.root, config.test), split="test")
+ else:
+ splits = get_dataset_split_names(config.root)
+ if "train" not in config:
+ config.train = "train" if "train" in splits else None
+ if config.train is not None:
+ dataset.train = dataset_factory(config.root, split="train")
+ if "validation" not in config:
+ config.validation = "validation" if "validation" in splits else None
+ if config.validation is not None:
+ dataset.validation = dataset_factory(config.root, split="validation")
+ if "test" not in config:
+ config.test = "test" if "test" in splits else None
+ if config.test is not None:
+ dataset.test = dataset_factory(config.root, split="test")
+ if not dataset:
+ raise ValueError("No datasets built. This is likely due to missing data paths in Config.")
+ return dataset
+
+ def build_dataloaders(self):
+ datasets = {k: d for k, d in self.datasets.items() if k not in self.dataloaders}
+ default_kwargs = self.config.get("dataloader", NestedDict())
+ dataloader_kwargs = NestedDict({k: default_kwargs.pop(k) for k in self.datasets if k in default_kwargs})
+ for k, d in datasets.items():
+ dataloader_kwargs.setdefault(k, NestedDict())
+ dataloader_kwargs[k].merge(default_kwargs, overwrite=False)
+ batch_size = dataloader_kwargs[k].pop("batch_size")
+ shuffle = dataloader_kwargs[k].pop("shuffle", getattr(d, "train", True))
+ drop_last = dataloader_kwargs[k].pop("drop_last", not getattr(d, "train", True))
+ if isinstance(d, MultiTaskDataset):
+ batch_sampler = (
+ DistributedMultiTaskSampler(d, batch_size, shuffle=shuffle, drop_last=drop_last)
+ if self.distributed
+ else MultiTaskSampler(d, batch_size, shuffle=shuffle, drop_last=drop_last)
+ )
+ else:
+ sampler = (
+ data.distributed.DistributedSampler(d, shuffle=shuffle)
+ if self.distributed
+ else data.RandomSampler(d) if shuffle else data.SequentialSampler(d)
+ )
+ batch_sampler = data.BatchSampler(sampler, batch_size, drop_last=drop_last)
+ self.dataloaders[k] = data.DataLoader(
+ d, batch_sampler=batch_sampler, collate_fn=self.collate_fn, **dataloader_kwargs[k]
+ )
+
+ def build_metrics(self) -> MultiTaskMetrics:
+ return MultiTaskMetrics(
+ {
+ name: MetricRegistry.build(type=task.type, num_labels=task.num_labels)
+ for name, task in self.dataset_tasks.all_items()
+ }
+ )
+
+ def collate_fn(self, batch):
+ return {k: v.to(self.device) if hasattr(v, "to") else v for k, v in batch.items()}
+
+ def get_dataset_lengths(self) -> str:
+ repr = "dataset lengths:\n"
+ longest_name = max(len(name) for name in self.all_datasets.keys())
+ for name, dataset in self.all_datasets.items():
+ if isinstance(dataset, NestedDict):
+ repr += f"{name}:"
+ if len(name) < longest_name:
+ repr += " " * (longest_name - len(name))
+ repr += "\t\t"
+ for split, d in dataset.items():
+ repr += f" {split}: {len(d)}\t"
+ else:
+ repr += f"{name}: {len(dataset)}\t"
+ repr += "\n"
+ return repr
diff --git a/multimolecule/train.py b/multimolecule/train.py
new file mode 100644
index 00000000..f146bc97
--- /dev/null
+++ b/multimolecule/train.py
@@ -0,0 +1,20 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from .apis import train
+
+if __name__ == "__main__":
+ train()
diff --git a/pyproject.toml b/pyproject.toml
index 355f13ff..e9fa30f2 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -45,9 +45,11 @@ dynamic = [
]
dependencies = [
"accelerate",
+ "art",
"chanfig>=0.0.105",
- "danling>=0.3.11",
+ "danling>=0.4.0b1",
"datasets",
+ 'StrEnum; python_version < "3.11"',
"torch",
"transformers",
]