From 94a884755761b3a4e60618dd1b237d3143663bed Mon Sep 17 00:00:00 2001
From: Zhiyuan Chen <chenzhiyuan@pjlab.org.cn>
Date: Sat, 11 May 2024 17:04:38 +0800
Subject: [PATCH] add runner

Signed-off-by: Zhiyuan Chen <chenzhiyuan@pjlab.org.cn>
---
 .codespell-whitelist.txt                      |   1 +
 docs/docs/data/multitask.md                   |   9 +
 docs/docs/runners/config.md                   |   9 +
 docs/docs/runners/index.md                    |   9 +
 docs/docs/runners/runner.md                   |   9 +
 docs/mkdocs.yml                               |   5 +
 multimolecule/__init__.py                     |  16 +-
 multimolecule/apis/__init__.py                |  19 ++
 multimolecule/apis/run.py                     | 115 +++++++
 multimolecule/apis/stat.py                    |  99 ++++++
 multimolecule/data/__init__.py                |  10 +-
 multimolecule/data/dataset.py                 |  52 ++--
 multimolecule/data/multitask.py               | 246 +++++++++++++++
 multimolecule/data/utils.py                   |   6 +-
 multimolecule/defaults.py                     |   1 +
 multimolecule/models/__init__.py              |   2 +
 .../models/calm/configuration_calm.py         |   4 +-
 multimolecule/models/calm/modeling_calm.py    |  12 +-
 multimolecule/models/configuration_utils.py   |  14 +-
 .../models/ernierna/configuration_ernierna.py |   4 +-
 .../models/ernierna/modeling_ernierna.py      |  12 +-
 .../models/rinalmo/configuration_rinalmo.py   |   4 +-
 .../models/rinalmo/modeling_rinalmo.py        |  12 +-
 .../models/rnabert/configuration_rnabert.py   |   4 +-
 .../models/rnabert/modeling_rnabert.py        |  22 +-
 .../models/rnaernie/configuration_rnaernie.py |   4 +-
 .../models/rnaernie/modeling_rnaernie.py      |  12 +-
 .../models/rnafm/configuration_rnafm.py       |   4 +-
 multimolecule/models/rnafm/modeling_rnafm.py  |  14 +-
 .../models/rnamsm/configuration_rnamsm.py     |   4 +-
 .../models/rnamsm/modeling_rnamsm.py          |  14 +-
 .../splicebert/configuration_splicebert.py    |   4 +-
 .../models/splicebert/modeling_splicebert.py  |  12 +-
 .../models/utrbert/configuration_utrbert.py   |   4 +-
 .../models/utrbert/modeling_utrbert.py        |  12 +-
 .../models/utrlm/configuration_utrlm.py       |   4 +-
 multimolecule/models/utrlm/modeling_utrlm.py  |  14 +-
 multimolecule/module/__init__.py              |  14 +-
 multimolecule/module/backbones/__init__.py    |  21 ++
 multimolecule/module/backbones/registry.py    |  21 ++
 multimolecule/module/backbones/sequence.py    |  46 +++
 .../module/backbones/sequences/__init__.py    |  20 ++
 .../module/backbones/sequences/onehot.py      |  39 +++
 .../module/backbones/sequences/registry.py    |  66 ++++
 multimolecule/module/criterions/__init__.py   |  14 +-
 multimolecule/module/criterions/binary.py     |  44 +++
 multimolecule/module/criterions/generic.py    |  33 +-
 multimolecule/module/criterions/multiclass.py |  44 +++
 multimolecule/module/criterions/multilabel.py |  44 +++
 multimolecule/module/criterions/registry.py   |  29 ++
 multimolecule/module/criterions/regression.py |  44 +++
 multimolecule/module/heads/__init__.py        |  12 +-
 multimolecule/module/heads/config.py          |  24 +-
 multimolecule/module/heads/contact.py         | 290 +++++++++++++++++-
 multimolecule/module/heads/generic.py         |  18 +-
 multimolecule/module/heads/pretrain.py        |   6 +-
 multimolecule/module/heads/registry.py        |  12 +-
 multimolecule/module/heads/utils.py           |  29 --
 multimolecule/module/model.py                 |  89 ++++++
 multimolecule/module/necks/__init__.py        |  21 ++
 multimolecule/module/necks/bert.py            | 102 ++++++
 multimolecule/module/necks/cat.py             |  43 +++
 multimolecule/module/necks/registry.py        |  21 ++
 multimolecule/module/registry.py              |  35 +++
 multimolecule/runners/README.md               |   9 +
 multimolecule/runners/__init__.py             |  20 ++
 multimolecule/runners/base_runner.py          | 286 +++++++++++++++++
 multimolecule/runners/config.py               |  83 +++++
 multimolecule/runners/metrics.py              |  37 +++
 multimolecule/runners/runner.py               |  42 +++
 multimolecule/tasks/task.py                   |   3 +-
 multimolecule/train.py                        |  20 ++
 pyproject.toml                                |   4 +-
 tests/data/test_dataset.py                    |  14 +-
 74 files changed, 2278 insertions(+), 219 deletions(-)
 create mode 100644 docs/docs/data/multitask.md
 create mode 100644 docs/docs/runners/config.md
 create mode 100644 docs/docs/runners/index.md
 create mode 100644 docs/docs/runners/runner.md
 create mode 100644 multimolecule/apis/__init__.py
 create mode 100644 multimolecule/apis/run.py
 create mode 100644 multimolecule/apis/stat.py
 create mode 100644 multimolecule/data/multitask.py
 create mode 100644 multimolecule/module/backbones/__init__.py
 create mode 100644 multimolecule/module/backbones/registry.py
 create mode 100644 multimolecule/module/backbones/sequence.py
 create mode 100644 multimolecule/module/backbones/sequences/__init__.py
 create mode 100644 multimolecule/module/backbones/sequences/onehot.py
 create mode 100644 multimolecule/module/backbones/sequences/registry.py
 create mode 100644 multimolecule/module/criterions/binary.py
 create mode 100644 multimolecule/module/criterions/multiclass.py
 create mode 100644 multimolecule/module/criterions/multilabel.py
 create mode 100644 multimolecule/module/criterions/registry.py
 create mode 100644 multimolecule/module/criterions/regression.py
 create mode 100644 multimolecule/module/model.py
 create mode 100644 multimolecule/module/necks/__init__.py
 create mode 100644 multimolecule/module/necks/bert.py
 create mode 100644 multimolecule/module/necks/cat.py
 create mode 100644 multimolecule/module/necks/registry.py
 create mode 100644 multimolecule/module/registry.py
 create mode 100644 multimolecule/runners/README.md
 create mode 100644 multimolecule/runners/__init__.py
 create mode 100644 multimolecule/runners/base_runner.py
 create mode 100644 multimolecule/runners/config.py
 create mode 100644 multimolecule/runners/metrics.py
 create mode 100644 multimolecule/runners/runner.py
 create mode 100644 multimolecule/train.py

diff --git a/.codespell-whitelist.txt b/.codespell-whitelist.txt
index 44c7e9f5..467e5c38 100644
--- a/.codespell-whitelist.txt
+++ b/.codespell-whitelist.txt
@@ -1,3 +1,4 @@
+datas
 ser
 marz
 manuel
diff --git a/docs/docs/data/multitask.md b/docs/docs/data/multitask.md
new file mode 100644
index 00000000..054c6205
--- /dev/null
+++ b/docs/docs/data/multitask.md
@@ -0,0 +1,9 @@
+---
+authors:
+  - Zhiyuan Chen
+date: 2024-05-04
+---
+
+# MultiTask
+
+::: multimolecule.data.multitask
diff --git a/docs/docs/runners/config.md b/docs/docs/runners/config.md
new file mode 100644
index 00000000..8e188199
--- /dev/null
+++ b/docs/docs/runners/config.md
@@ -0,0 +1,9 @@
+---
+authors:
+  - Zhiyuan Chen
+date: 2024-05-04
+---
+
+# MultiMoleculeConfig
+
+::: multimolecule.runners.MultiMoleculeConfig
diff --git a/docs/docs/runners/index.md b/docs/docs/runners/index.md
new file mode 100644
index 00000000..75a0f528
--- /dev/null
+++ b/docs/docs/runners/index.md
@@ -0,0 +1,9 @@
+---
+authors:
+  - Zhiyuan Chen
+date: 2024-05-04
+---
+
+# runners
+
+--8<-- "multimolecule/runners/README.md:8:"
diff --git a/docs/docs/runners/runner.md b/docs/docs/runners/runner.md
new file mode 100644
index 00000000..2c93ce1b
--- /dev/null
+++ b/docs/docs/runners/runner.md
@@ -0,0 +1,9 @@
+---
+authors:
+  - Zhiyuan Chen
+date: 2024-05-04
+---
+
+# MultiMoleculeRunner
+
+::: multimolecule.runners.MultiMoleculeRunner
diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml
index 93d53a45..9324d9c3 100644
--- a/docs/mkdocs.yml
+++ b/docs/mkdocs.yml
@@ -9,9 +9,14 @@ repo_url: https://github.com/DLS5-Omics/multimolecule
 
 nav:
   - index.md
+  - runners:
+      - runners/index.md
+      - MultiMoleculeRunner: runners/runner.md
+      - MultiMoleculeConfig: runners/config.md
   - data:
       - data/index.md
       - Dataset: data/dataset.md
+      - multitask: data/multitask.md
   - datasets:
       - datasets/index.md
       - DNA:
diff --git a/multimolecule/__init__.py b/multimolecule/__init__.py
index 240e9fcc..c9ee3a87 100644
--- a/multimolecule/__init__.py
+++ b/multimolecule/__init__.py
@@ -14,6 +14,7 @@
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
+from .apis import evaluate, infer, train
 from .data import Dataset
 from .models import (
     AutoModelForContactPrediction,
@@ -111,30 +112,33 @@
     HeadConfig,
     HeadRegistry,
     HeadTransformRegistry,
-    HeadTransformRegistryHF,
     IdentityTransform,
     LinearTransform,
     MaskedLMHead,
     MaskedLMHeadConfig,
     NonLinearTransform,
     PositionEmbeddingRegistry,
-    PositionEmbeddingRegistryHF,
     PredictionHead,
     RotaryEmbedding,
     SequencePredictionHead,
     SinusoidalEmbedding,
-    TokenHeadRegistryHF,
     TokenKMerHead,
     TokenPredictionHead,
 )
+from .runners import MultiMoleculeConfig, MultiMoleculeRunner
 from .tasks import Task, TaskLevel, TaskType
 from .tokenisers import Alphabet, DnaTokenizer, DotBracketTokenizer, ProteinTokenizer, RnaTokenizer, Tokenizer
 from .utils import count_parameters
 
 __all__ = [
+    "train",
+    "evaluate",
+    "infer",
     "modeling_auto",
     "modeling_outputs",
     "Dataset",
+    "MultiMoleculeConfig",
+    "MultiMoleculeRunner",
     "PreTrainedConfig",
     "HeadConfig",
     "BaseHeadConfig",
@@ -233,21 +237,15 @@
     "HeadRegistry",
     "PredictionHead",
     "SequencePredictionHead",
-    "TokenHeadRegistryHF",
     "TokenPredictionHead",
     "TokenKMerHead",
-    "NucleotideHeadRegistryHF",
-    "NucleotidePredictionHead",
-    "NucleotideKMerHead",
     "ContactPredictionHead",
     "MaskedLMHead",
     "HeadTransformRegistry",
-    "HeadTransformRegistryHF",
     "LinearTransform",
     "NonLinearTransform",
     "IdentityTransform",
     "PositionEmbeddingRegistry",
-    "PositionEmbeddingRegistryHF",
     "RotaryEmbedding",
     "SinusoidalEmbedding",
     "Criterion",
diff --git a/multimolecule/apis/__init__.py b/multimolecule/apis/__init__.py
new file mode 100644
index 00000000..8e3e5b3c
--- /dev/null
+++ b/multimolecule/apis/__init__.py
@@ -0,0 +1,19 @@
+# MultiMolecule
+# Copyright (C) 2024-Present  MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+
+from .run import evaluate, infer, train
+
+__all__ = ["train", "evaluate", "infer"]
diff --git a/multimolecule/apis/run.py b/multimolecule/apis/run.py
new file mode 100644
index 00000000..1fdb7666
--- /dev/null
+++ b/multimolecule/apis/run.py
@@ -0,0 +1,115 @@
+# MultiMolecule
+# Copyright (C) 2024-Present  MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+
+# mypy: disable-error-code="attr-defined"
+
+import atexit
+import os
+import warnings
+from typing import Type
+
+import danling as dl
+import torch
+
+from multimolecule.runners import MultiMoleculeConfig, MultiMoleculeRunner
+
+try:
+    import nni
+except ImportError:
+    nni = None
+
+
+def train(
+    config: MultiMoleculeConfig = None,  # type: ignore
+    runner_cls: Type[MultiMoleculeRunner] = MultiMoleculeRunner,
+):
+    if config is None:
+        config = MultiMoleculeConfig()
+    config = config.parse(default_config="config", no_default_config_action="warn")
+    config.interpolate(unsafe_eval=True)
+    config.training = True
+    if config.allow_tf32:
+        torch.backends.cudnn.allow_tf32 = True
+        torch.backends.cuda.matmul.allow_tf32 = True
+    if config.reduced_precision_reduction:
+        torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
+        torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True
+    if config.get("nni", False):
+        if nni is None:
+            raise ValueError("Unable to retrieve nni parameters, since nni is not installed.")
+        config.merge(nni.get_next_parameter())
+    with dl.debug(config.get("debug", False)):
+        runner = runner_cls(config)
+        atexit.register(runner.print_result)
+        atexit.register(runner.save_result)
+        atexit.register(runner.save_checkpoint)
+        result = runner.train()
+        return result
+
+
+def evaluate(
+    config: MultiMoleculeConfig = None,  # type: ignore
+    runner_cls: Type[MultiMoleculeRunner] = MultiMoleculeRunner,
+):
+    if config is None:
+        config = MultiMoleculeConfig.empty()
+    config = config.parse(default_config="config", no_default_config_action="warn")
+    config.interpolate(unsafe_eval=True)
+    config.training = False
+    if config.allow_tf32:
+        torch.backends.cudnn.allow_tf32 = True
+        torch.backends.cuda.matmul.allow_tf32 = True
+    if config.reduced_precision_reduction:
+        torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
+        torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True
+    if "checkpoint" not in config or not isinstance(config.checkpoint, str):
+        raise RuntimeError("Please specify `checkpoint` to run evaluate")
+    for name, data in config.datas.items():
+        if "evaluation" not in data or not isinstance(data.evaluate, str):
+            raise RuntimeError(f"Please specify `evaluation` to run evaluate in datas.{name}")
+    runner = runner_cls(config)
+    result = runner.evaluate_epoch("evaluation")
+    print(result)
+    return result
+
+
+def infer(
+    config: MultiMoleculeConfig = None,  # type: ignore
+    runner_cls: Type[MultiMoleculeRunner] = MultiMoleculeRunner,
+):
+    if config is None:
+        config = MultiMoleculeConfig.empty()
+    config = config.parse(default_config="config", no_default_config_action="warn")
+    config.interpolate(unsafe_eval=True)
+    config.training = False
+    if config.allow_tf32:
+        torch.backends.cudnn.allow_tf32 = True
+        torch.backends.cuda.matmul.allow_tf32 = True
+    if config.reduced_precision_reduction:
+        torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
+        torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True
+    if "checkpoint" not in config or not isinstance(config.checkpoint, str):
+        raise RuntimeError("Please specify `checkpoint` to run infer.")
+    for name, data in config.datas.items():
+        if "inference" not in data or not isinstance(data.inference, str):
+            raise RuntimeError(f"Please specify `inference` to run infer in datas.{name}")
+    if "result_path" not in config or not isinstance(config.result_path, str):
+        config.result_path = os.path.join(os.getcwd(), "result.json")
+        warnings.warn("`result_path` is not specified, default to `result.json`.", RuntimeWarning, stacklevel=2)
+    runner = runner_cls(config)
+    result = runner.infer()
+    runner.save(result, config.result_path)
+    return result
diff --git a/multimolecule/apis/stat.py b/multimolecule/apis/stat.py
new file mode 100644
index 00000000..5e525d55
--- /dev/null
+++ b/multimolecule/apis/stat.py
@@ -0,0 +1,99 @@
+# MultiMolecule
+# Copyright (C) 2024-Present  MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+
+import os
+import shutil
+from statistics import mean
+from typing import List
+
+import chanfig
+import pandas as pd
+from chanfig import NestedDict
+from tqdm import tqdm
+
+
+class Result(NestedDict):
+    pretrained: str
+    id: str
+    seed: int
+    epoch: int
+    validation: NestedDict
+    test: NestedDict
+
+
+def get_result_stat(experiment_root: str, remove_empty: bool = True) -> List[Result]:
+    results = []
+    for root, _, files in tqdm(os.walk(experiment_root)):
+        if "run.log" in files:
+            if "best.json" not in files:
+                if remove_empty:
+                    shutil.rmtree(root)
+                continue
+            best = NestedDict.from_json(os.path.join(root, "best.json"))
+            if "index" not in best:
+                if remove_empty:
+                    shutil.rmtree(root)
+                continue
+            config = NestedDict.from_yaml(os.path.join(root, "trainer.yaml"))
+            pretrained = config.pretrained.split("/")[-1]
+            seed = config.seed
+            pretrained, seed = "", 1
+            result = Result(id=best.id, pretrained=pretrained, seed=seed)
+            result.validation = NestedDict(
+                {k: format(mean(v) if isinstance(v, list) else v, ".8f") for k, v in best.validation.items()}
+            )
+            result.test = NestedDict(
+                {k: format(mean(v) if isinstance(v, list) else v, ".8f") for k, v in best.test.items()}
+            )
+            result.epoch = best.index
+            result.pop("validation.time", None)
+            result.pop("test.time", None)
+            result.pop("validation.loss", None)
+            result.pop("test.loss", None)
+            result.pop("validation.lr", None)
+            result.pop("test.lr", None)
+            results.append(result)
+    # Remove empty directories, perform twice to remove all empty directories
+    if remove_empty:
+        for root, dirs, files in os.walk(experiment_root):
+            if not files and not dirs:
+                os.rmdir(root)
+        for root, dirs, files in os.walk(experiment_root):
+            if not files and not dirs:
+                os.rmdir(root)
+    results.sort(key=lambda x: (x.pretrained, x.seed, x.id))
+    return results
+
+
+def write_result_stat(results: List[Result], path: str):
+    results = [dict(result.all_items()) for result in results]  # type: ignore[misc]
+    df = pd.DataFrame.from_dict(results)
+    df.insert(len(df.keys()) - 1, "comment", "")
+    df.fillna("")
+    df.to_csv(path, index=False)
+
+
+class Config(chanfig.Config):
+    experiment_root: str = "experiments"
+    out_path: str = "result.csv"
+
+
+if __name__ == "__main__":
+    config = Config().parse()
+    result_stat = get_result_stat(config.experiment_root)
+    if not len(result_stat) > 0:
+        raise ValueError("No results found")
+    write_result_stat(result_stat, config.out_path)
diff --git a/multimolecule/data/__init__.py b/multimolecule/data/__init__.py
index 62196c10..f2366d77 100644
--- a/multimolecule/data/__init__.py
+++ b/multimolecule/data/__init__.py
@@ -15,6 +15,14 @@
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
 from .dataset import Dataset
+from .multitask import DistributedMultiTaskSampler, MultiTaskDataset, MultiTaskSampler
 from .utils import no_collate
 
-__all__ = ["Dataset", "no_collate"]
+__all__ = [
+    "Dataset",
+    "PandasDataset",
+    "MultiTaskDataset",
+    "MultiTaskSampler",
+    "DistributedMultiTaskSampler",
+    "no_collate",
+]
diff --git a/multimolecule/data/dataset.py b/multimolecule/data/dataset.py
index 54565349..2b6fae49 100644
--- a/multimolecule/data/dataset.py
+++ b/multimolecule/data/dataset.py
@@ -149,8 +149,9 @@ def __init__(
         fingerprint: str | None = None,
         ignored_cols: List[str] | None = None,
     ):
+        self._tasks = NestedDict()
         if tasks is not None:
-            self._tasks = NestedDict(tasks)
+            self.tasks = tasks
         if discrete_map is not None:
             self._discrete_map = discrete_map
         arrow_table = self.build_table(
@@ -187,13 +188,13 @@ def build_table(
                 data = dl.load_pandas(data)
                 if isinstance(data, DataFrame):
                     data = data.loc[:, ~data.columns.str.contains("^Unnamed")]
-                    data = pa.Table.from_pandas(data)
+                    data = pa.Table.from_pandas(data, preserve_index=False)
         elif isinstance(data, dict):
             data = pa.Table.from_pydict(data)
         elif isinstance(data, list):
             data = pa.Table.from_pylist(data)
         elif isinstance(data, DataFrame):
-            data = pa.Table.from_pandas(data)
+            data = pa.Table.from_pandas(data, preserve_index=False)
         if feature_cols is not None and label_cols is not None:
             data = data.select(feature_cols + label_cols)
         data = self.process_nan(data, nan_process=nan_process, fill_value=fill_value)
@@ -250,6 +251,7 @@ def post(
         self.column_names_map = column_names_map
         if self.column_names_map:
             self.rename_columns(self.column_names_map)
+        self.infer_tasks()
 
         if self.preprocess:
             self.update(self.map(self.tokenization))
@@ -258,7 +260,7 @@ def post(
             if self.discrete_map:
                 self.update(self.map(self.map_discrete))
             fn_kwargs = {
-                "columns": [name for name, task in self.tasks.items() if task.level in ["nucleotide", "contact"]],
+                "columns": [name for name, task in self.tasks.items() if task.level in ["token", "contact"]],
                 "max_seq_length": self.max_seq_length - self.seq_length_offset,
             }
             if self.truncation and 0 < self.max_seq_length < 2**32:
@@ -297,20 +299,23 @@ def collate(self, col: str, data: Any) -> Tensor | NestedTensor | None:
         except ValueError:
             return NestedTensor(data)
 
-    def infer_tasks(self, tasks: Mapping | None = None, sequence_col: str | None = None) -> NestedDict:
-        self._tasks = tasks or NestedDict()
+    def infer_tasks(self, sequence_col: str | None = None) -> NestedDict:
         for col in self.label_cols:
-            if col not in self.tasks:
-                if col in self.secondary_structure_cols:
-                    task = Task(TaskType.Binary, level=TaskLevel.Contact, num_labels=1)
-                    self._tasks[col] = task  # type: ignore[index]
-                    warn(
-                        f"Secondary structure columns are assumed to be {task}."
-                        " Please explicitly specify the task if this is not the case."
-                    )
-                else:
-                    self._tasks[col] = self.infer_task(col, sequence_col)  # type: ignore[index]
-        return self._tasks
+            if col in self.tasks:
+                continue
+            if col in self.secondary_structure_cols:
+                task = Task(TaskType.Binary, level=TaskLevel.Contact, num_labels=1)
+                self.tasks[col] = task  # type: ignore[index]
+                warn(
+                    f"Secondary structure columns are assumed to be {task}. "
+                    "Please explicitly specify the task if this is not the case."
+                )
+            else:
+                try:
+                    self.tasks[col] = self.infer_task(col, sequence_col)  # type: ignore[index]
+                except ValueError:
+                    raise ValueError(f"Unable to infer task for column {col}.")
+        return self.tasks
 
     def infer_task(self, label_col: str, sequence_col: str | None = None) -> Task:
         if sequence_col is None:
@@ -404,7 +409,7 @@ def rename_columns(self, column_mapping: Mapping[str, str], new_fingerprint: str
         self._label_cols = [column_mapping.get(i, i) for i in self.label_cols]
         self._sequence_cols = [column_mapping.get(i, i) for i in self.sequence_cols]
         self._secondary_structure_cols = [column_mapping.get(i, i) for i in self.secondary_structure_cols]
-        self._tasks = {column_mapping.get(k, k): v for k, v in self.tasks.items()}
+        self.tasks = {column_mapping.get(k, k): v for k, v in self.tasks.items()}
         return self
 
     def rename_column(
@@ -418,7 +423,7 @@ def rename_column(
         self._secondary_structure_cols = [
             new_column_name if i == original_column_name else i for i in self.secondary_structure_cols
         ]
-        self._tasks = {new_column_name if k == original_column_name else k: v for k, v in self.tasks.items()}
+        self.tasks = {new_column_name if k == original_column_name else k: v for k, v in self.tasks.items()}
         return self
 
     def process_nan(self, data: Table, nan_process: str | None, fill_value: str | int | float = 0) -> Table:
@@ -470,9 +475,18 @@ def secondary_structure_cols(self) -> List:
     @property
     def tasks(self) -> NestedDict:
         if not hasattr(self, "_tasks"):
+            self._tasks = NestedDict()
             return self.infer_tasks()
         return self._tasks
 
+    @tasks.setter
+    def tasks(self, tasks: Mapping):
+        self._tasks = NestedDict()
+        for name, task in tasks.items():
+            if not isinstance(task, Task):
+                task = Task(**task)
+            self._tasks[name] = task
+
     @property
     def discrete_map(self) -> Mapping:
         if not hasattr(self, "_discrete_map"):
diff --git a/multimolecule/data/multitask.py b/multimolecule/data/multitask.py
new file mode 100644
index 00000000..7c20e829
--- /dev/null
+++ b/multimolecule/data/multitask.py
@@ -0,0 +1,246 @@
+# MultiMolecule
+# Copyright (C) 2024-Present  MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+
+from __future__ import annotations
+
+from bisect import bisect_right
+from collections.abc import Iterator, Mapping, Sequence
+from copy import deepcopy
+from random import choices
+
+import torch
+from chanfig import NestedDict
+from torch import distributed as dist
+from torch.utils import data
+
+from .dataset import Dataset
+
+
+class MultiTaskDataset(data.ConcatDataset):
+
+    datasets: Mapping
+    dataset_keys: Sequence[str]
+    dataset_values: Sequence[Dataset]
+
+    def __init__(self, datasets: Mapping) -> None:
+        for key, dataset in datasets.items():
+            if not isinstance(dataset, Dataset):
+                raise TypeError(f"Dataset {key} should be an instance of Dataset")
+        self.datasets = datasets
+        if not len(self.datasets) > 0:
+            raise ValueError("MultiTaskDataset should contain at least one dataset")
+        self.dataset_keys, self.dataset_values = zip(*self.datasets.items())
+        self.cumulative_sizes = self.cumsum(self.dataset_values)
+
+    def __getitems__(self, key: Sequence[int]) -> Mapping:
+        dataset_idx = bisect_right(self.cumulative_sizes, key[0])
+        if dataset_idx == 0:
+            sample_idx = key
+        else:
+            sample_idx = [i - self.cumulative_sizes[dataset_idx - 1] for i in key]
+        batch = self.dataset_values[dataset_idx][sample_idx]
+        batch["dataset"] = self.dataset_keys[dataset_idx]
+        return batch
+
+    @property
+    def tasks(self) -> NestedDict:
+        tasks = NestedDict()
+        for dataset in self.dataset_values:
+            for n, t in dataset.tasks.items():
+                if n not in tasks:
+                    tasks[n] = t
+                elif tasks[n] != t:
+                    raise ValueError(f"Task {n} has different configurations across datasets")
+        return tasks
+
+    @property
+    def dataset_tasks(self) -> NestedDict:
+        return NestedDict({k: v.tasks for k, v in self.datasets.items()})
+
+    def __repr__(self) -> str:
+        return f"MultiTaskDataset({', '.join([str(d) for d in self.datasets])})"
+
+
+class MultiTaskSampler(data.BatchSampler):
+    r"""
+    Ensure all items in a batch comes from the same dataset.
+
+    Arguments:
+        sampler (Sampler): Base sampler.
+        batch_size (int): Size of mini-batch.
+        drop_last (bool): If ``True``, the sampler will drop the last batch if
+            its size would be less than ``batch_size``
+    """
+
+    datasets: Sequence[Dataset]
+
+    def __init__(  # pylint: disable=super-init-not-called
+        self,
+        dataset: MultiTaskDataset,
+        batch_size: int,
+        shuffle: bool = True,
+        drop_last: bool = False,
+        sampler_cls: type[data.Sampler] | None = None,
+        weights: list[int] | None = None,
+    ) -> None:
+        self.datasets = dataset.dataset_values
+        self.batch_size = batch_size
+        self.drop_last = drop_last
+        self.shuffle = shuffle
+        if sampler_cls is None:
+            sampler_cls = data.RandomSampler if shuffle else data.SequentialSampler
+        self.samplers = [sampler_cls(d) for d in self.datasets]  # type: ignore
+        self.dataset_sizes = [len(d) for d in self.datasets]  # type: ignore
+        self.cumulative_sizes = dataset.cumulative_sizes
+        self.num_datasets = len(self.datasets)
+        self.weights = weights if weights is not None else self.dataset_sizes
+
+    def __iter__(self):
+        sampler_iters = [(i, iter(s)) for i, s in enumerate(self.samplers)]
+        sampler_weights = deepcopy(self.weights)
+        sampler_idx = 0
+        # Implemented based on the benchmarking in https://github.com/pytorch/pytorch/pull/76951
+        if self.drop_last:
+            while sampler_iters:
+                if self.shuffle:
+                    sampler_idx = choices(range(len(sampler_iters)), weights=sampler_weights)[0]
+                sampler_id, sampler_iter = sampler_iters[sampler_idx]
+                cumulative_size = self.cumulative_sizes[sampler_id - 1] if sampler_id > 0 else 0
+                try:
+                    batch = [next(sampler_iter) + cumulative_size for _ in range(self.batch_size)]
+                    yield batch
+                except StopIteration:
+                    sampler_iters.pop(sampler_idx)
+                    sampler_weights.pop(sampler_idx)
+        else:
+            while sampler_iters:
+                if self.shuffle:
+                    sampler_idx = choices(range(len(sampler_iters)), weights=sampler_weights)[0]
+                sampler_id, sampler_iter = sampler_iters[sampler_idx]
+                cumulative_size = self.cumulative_sizes[sampler_id - 1] if sampler_id > 0 else 0
+                batch = [0] * self.batch_size
+                idx_in_batch = 0
+                try:
+                    for _ in range(self.batch_size):
+                        batch[idx_in_batch] = next(sampler_iter) + cumulative_size
+                        idx_in_batch += 1
+                    yield batch
+                    idx_in_batch = 0  # noqa: SIM113
+                    batch = [0] * self.batch_size
+                except StopIteration:
+                    sampler_iters.pop(sampler_idx)
+                    sampler_weights.pop(sampler_idx)
+                    if idx_in_batch > 0:
+                        yield batch[:idx_in_batch]
+
+    def __len__(self):
+        batch_size = self.batch_size
+        if self.drop_last:
+            return sum(len(d) // batch_size for d in self.datasets)
+        return sum((len(d) + batch_size - 1) // batch_size for d in self.datasets)
+
+
+class DistributedMultiTaskSampler(MultiTaskSampler):  # pylint: disable=too-few-public-methods
+    r"""
+    Distributed version of MultiTaskSampler, which ensures that all GPUs sample data from the
+    same sub-dataset in each step without requiring additional communication.
+    The dataset selection is based on a random seed mechanism that is synchronized across epochs.
+
+    See Also:
+        [MultiTaskSampler][MultiTaskSampler]
+    """
+
+    def __init__(
+        self,
+        dataset: MultiTaskDataset,
+        batch_size: int,
+        shuffle: bool = True,
+        drop_last: bool = False,
+        sampler_cls: type[data.Sampler] = data.RandomSampler,
+        weights: list[int] | None = None,
+        seed: int = 0,
+    ) -> None:
+        super().__init__(dataset, batch_size, shuffle, drop_last, sampler_cls, weights)
+        self.samplers = [data.DistributedSampler(d, shuffle=shuffle, drop_last=drop_last) for d in self.datasets]
+        self.seed = seed
+        self.epoch = 0
+
+    def set_epoch(self, epoch: int):
+        """
+        Sets the epoch for deterministic shuffling.
+        """
+        self.epoch = epoch
+        for sampler in self.samplers:
+            sampler.set_epoch(epoch)
+
+    def _get_sampler_idx(self, high: int) -> int:
+        """
+        Determines which sampler (i.e., sub-dataset) to use based on the seed and epoch.
+        """
+        g = torch.Generator()
+        g.manual_seed(self.seed + self.epoch)
+        sampler_idx = torch.randint(low=0, high=high, size=(1,), generator=g).item()
+        return sampler_idx
+
+    def __iter__(self) -> Iterator:
+        sampler_iters = [(i, iter(s)) for i, s in enumerate(self.samplers)]
+        sampler_weights = deepcopy(self.weights)
+
+        if self.drop_last:
+            while sampler_iters:
+                # Sample the same sub-dataset across all GPUs using the seeded index
+                sampler_idx = self._get_sampler_idx(len(sampler_iters))
+                sampler_id, sampler_iter = sampler_iters[sampler_idx]
+                cumulative_size = self.cumulative_sizes[sampler_id - 1] if sampler_id > 0 else 0
+                try:
+                    batch = [next(sampler_iter) + cumulative_size for _ in range(self.batch_size)]
+                    yield batch
+                except StopIteration:
+                    sampler_iters.pop(sampler_idx)
+                    sampler_weights.pop(sampler_idx)
+        else:
+            while sampler_iters:
+                # Sample the same sub-dataset across all GPUs using the seeded index
+                sampler_idx = self._get_sampler_idx(len(sampler_iters))
+                sampler_id, sampler_iter = sampler_iters[sampler_idx]
+                cumulative_size = self.cumulative_sizes[sampler_id - 1] if sampler_id > 0 else 0
+                batch = [0] * self.batch_size
+                idx_in_batch = 0
+                try:
+                    for _ in range(self.batch_size):
+                        batch[idx_in_batch] = next(sampler_iter) + cumulative_size
+                        idx_in_batch += 1
+                    yield batch
+                    idx_in_batch = 0  # noqa: SIM113
+                    batch = [0] * self.batch_size
+                except StopIteration:
+                    sampler_iters.pop(sampler_idx)
+                    sampler_weights.pop(sampler_idx)
+                    if idx_in_batch > 0:
+                        yield batch[:idx_in_batch]
+
+    def __len__(self) -> int:
+        batch_size = self.batch_size * self.world_size
+        if self.drop_last:
+            return sum(len(d) // batch_size for d in self.datasets)
+        return sum((len(d) + batch_size - 1) // batch_size for d in self.datasets)
+
+    @property
+    def world_size(self) -> int:
+        r"""Return the number of processes in the current process group."""
+        if dist.is_available() and dist.is_initialized():
+            return dist.get_world_size()
+        return 1
diff --git a/multimolecule/data/utils.py b/multimolecule/data/utils.py
index 85bc4423..1afddd1b 100644
--- a/multimolecule/data/utils.py
+++ b/multimolecule/data/utils.py
@@ -60,7 +60,7 @@ def infer_task(
             level = TaskLevel.Contact
             num_labels = len(flattened) // num_contacts
         elif len(flattened) % num_tokens == 0:
-            level = TaskLevel.Nucleotide
+            level = TaskLevel.Token
             num_labels = len(flattened) // num_tokens
         elif len(flattened) % num_elem == 0:
             level = TaskLevel.Sequence
@@ -86,7 +86,7 @@ def infer_task(
         task_type = TaskType.MultiClass if num_labels > 2 else TaskType.Binary
         num_labels = 1 if task_type == TaskType.Binary else num_labels
         if num_tokens_flattened == num_tokens:
-            return Task(task_type, level=TaskLevel.Nucleotide, num_labels=num_labels)
+            return Task(task_type, level=TaskLevel.Token, num_labels=num_labels)
         if num_contacts_flattened == num_contacts:
             return Task(task_type, level=TaskLevel.Contact, num_labels=num_labels)
         return Task(task_type, level=TaskLevel.Sequence, num_labels=num_labels)
@@ -122,7 +122,7 @@ def map_value(value: Any, mapping: dict[str, int] | None) -> Any:
 
 
 def truncate_value(value: Any, max_seq_length: int, level: int | None = None) -> Any:
-    if level == TaskLevel.Nucleotide:
+    if level == TaskLevel.Token:
         return value[:max_seq_length]
     if level == TaskLevel.Contact:
         return [i[:max_seq_length] for i in value[:max_seq_length]]
diff --git a/multimolecule/defaults.py b/multimolecule/defaults.py
index c299ea1a..0a718a60 100644
--- a/multimolecule/defaults.py
+++ b/multimolecule/defaults.py
@@ -14,6 +14,7 @@
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
+DATASET_SPLITS = ["train", "validation", "test", "evaluation", "inference"]
 ID_COL_NAMES = ["id", "idx", "index"]
 SEQUENCE_COL_NAMES = ["input_ids", "sequence", "seq"]
 SECONDARY_STRUCTURE_COL_NAMES = ["secondary_structure", "ss"]
diff --git a/multimolecule/models/__init__.py b/multimolecule/models/__init__.py
index 66147616..29d99436 100644
--- a/multimolecule/models/__init__.py
+++ b/multimolecule/models/__init__.py
@@ -14,6 +14,7 @@
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
+from multimolecule.module import HeadConfig
 from multimolecule.tokenisers import DnaTokenizer, ProteinTokenizer, RnaTokenizer
 
 from .calm import (
@@ -127,6 +128,7 @@
 
 __all__ = [
     "PreTrainedConfig",
+    "HeadConfig",
     "DnaTokenizer",
     "RnaTokenizer",
     "ProteinTokenizer",
diff --git a/multimolecule/models/calm/configuration_calm.py b/multimolecule/models/calm/configuration_calm.py
index c5d73c03..032bda8e 100644
--- a/multimolecule/models/calm/configuration_calm.py
+++ b/multimolecule/models/calm/configuration_calm.py
@@ -127,5 +127,5 @@ def __init__(
         self.use_cache = use_cache
         self.emb_layer_norm_before = emb_layer_norm_before
         self.token_dropout = token_dropout
-        self.head = HeadConfig(**head if head is not None else {})
-        self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {})
+        self.head = HeadConfig(**head) if head is not None else None
+        self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None
diff --git a/multimolecule/models/calm/modeling_calm.py b/multimolecule/models/calm/modeling_calm.py
index 25c1eba4..c8abdffe 100644
--- a/multimolecule/models/calm/modeling_calm.py
+++ b/multimolecule/models/calm/modeling_calm.py
@@ -270,9 +270,9 @@ class CaLmForSequencePrediction(CaLmPreTrainedModel):
         >>> input = tokenizer("ACGUN", return_tensors="pt")
         >>> output = model(**input, labels=torch.tensor([[1]]))
         >>> output["logits"].shape
-        torch.Size([1, 2])
+        torch.Size([1, 1])
         >>> output["loss"]  # doctest:+ELLIPSIS
-        tensor(..., grad_fn=<NllLossBackward0>)
+        tensor(..., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
     """
 
     def __init__(self, config: CaLmConfig):
@@ -334,9 +334,9 @@ class CaLmForTokenPrediction(CaLmPreTrainedModel):
         >>> input = tokenizer("ACGUN", return_tensors="pt")
         >>> output = model(**input, labels=torch.randint(2, (1, 5)))
         >>> output["logits"].shape
-        torch.Size([1, 5, 2])
+        torch.Size([1, 5, 1])
         >>> output["loss"]  # doctest:+ELLIPSIS
-        tensor(..., grad_fn=<NllLossBackward0>)
+        tensor(..., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
     """
 
     def __init__(self, config: CaLmConfig):
@@ -398,9 +398,9 @@ class CaLmForContactPrediction(CaLmPreTrainedModel):
         >>> input = tokenizer("ACGUN", return_tensors="pt")
         >>> output = model(**input, labels=torch.randint(2, (1, 5, 5)))
         >>> output["logits"].shape
-        torch.Size([1, 5, 5, 2])
+        torch.Size([1, 5, 5, 1])
         >>> output["loss"]  # doctest:+ELLIPSIS
-        tensor(..., grad_fn=<NllLossBackward0>)
+        tensor(..., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
     """
 
     def __init__(self, config: CaLmConfig):
diff --git a/multimolecule/models/configuration_utils.py b/multimolecule/models/configuration_utils.py
index 2047d671..ce6f10ea 100644
--- a/multimolecule/models/configuration_utils.py
+++ b/multimolecule/models/configuration_utils.py
@@ -30,7 +30,8 @@ class PreTrainedConfig(PretrainedConfig):
     Base class for all model configuration classes.
     """
 
-    head: HeadConfig
+    head: HeadConfig | None
+    num_labels: int = 1
 
     hidden_size: int
 
@@ -42,7 +43,15 @@ class PreTrainedConfig(PretrainedConfig):
     null_token_id: int = 5
 
     def __init__(
-        self, pad_token_id=0, bos_token_id=1, eos_token_id=2, unk_token_id=3, mask_token_id=4, null_token_id=5, **kwargs
+        self,
+        pad_token_id: int = 0,
+        bos_token_id: int = 1,
+        eos_token_id: int = 2,
+        unk_token_id: int = 3,
+        mask_token_id: int = 4,
+        null_token_id: int = 5,
+        num_labels: int = 1,
+        **kwargs,
     ):
         super().__init__(
             pad_token_id=pad_token_id,
@@ -51,6 +60,7 @@ def __init__(
             unk_token_id=unk_token_id,
             mask_token_id=mask_token_id,
             null_token_id=null_token_id,
+            num_labels=num_labels,
             **kwargs,
         )
 
diff --git a/multimolecule/models/ernierna/configuration_ernierna.py b/multimolecule/models/ernierna/configuration_ernierna.py
index 0648bb2d..bfd11d51 100644
--- a/multimolecule/models/ernierna/configuration_ernierna.py
+++ b/multimolecule/models/ernierna/configuration_ernierna.py
@@ -110,5 +110,5 @@ def __init__(
         self.pairwise_alpha = pairwise_alpha
         self.is_decoder = is_decoder
         self.use_cache = use_cache
-        self.head = HeadConfig(**head if head is not None else {})
-        self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {})
+        self.head = HeadConfig(**head) if head is not None else None
+        self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None
diff --git a/multimolecule/models/ernierna/modeling_ernierna.py b/multimolecule/models/ernierna/modeling_ernierna.py
index 563249e7..3f2c96a0 100644
--- a/multimolecule/models/ernierna/modeling_ernierna.py
+++ b/multimolecule/models/ernierna/modeling_ernierna.py
@@ -321,7 +321,7 @@ class ErnieRnaForSequencePrediction(ErnieRnaPreTrainedModel):
         >>> input = tokenizer("ACGUN", return_tensors="pt")
         >>> output = model(**input)
         >>> output["logits"].shape
-        torch.Size([1, 2])
+        torch.Size([1, 1])
     """
 
     def __init__(self, config: ErnieRnaConfig):
@@ -385,9 +385,9 @@ class ErnieRnaForTokenPrediction(ErnieRnaPreTrainedModel):
         >>> input = tokenizer("ACGUN", return_tensors="pt")
         >>> output = model(**input, labels=torch.randint(2, (1, 5)))
         >>> output["logits"].shape
-        torch.Size([1, 5, 2])
+        torch.Size([1, 5, 1])
         >>> output["loss"]  # doctest:+ELLIPSIS
-        tensor(..., grad_fn=<NllLossBackward0>)
+        tensor(..., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
     """
 
     def __init__(self, config: ErnieRnaConfig):
@@ -452,9 +452,9 @@ class ErnieRnaForContactPrediction(ErnieRnaPreTrainedModel):
         >>> input = tokenizer("ACGUN", return_tensors="pt")
         >>> output = model(**input, labels=torch.randint(2, (1, 5, 5)))
         >>> output["logits"].shape
-        torch.Size([1, 5, 5, 2])
+        torch.Size([1, 5, 5, 1])
         >>> output["loss"]  # doctest:+ELLIPSIS
-        tensor(..., grad_fn=<NllLossBackward0>)
+        tensor(..., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
     """
 
     def __init__(self, config: ErnieRnaConfig):
@@ -1183,7 +1183,7 @@ class ErnieRnaContactClassificationHead(nn.Module):
     def __init__(self, config: ErnieRnaConfig, head_config: HeadConfig | None = None):
         super().__init__()
         if head_config is None:
-            head_config = config.head
+            head_config = config.head or HeadConfig()
         self.config = head_config
         self.bos_token_id = config.bos_token_id
         self.eos_token_id = config.eos_token_id
diff --git a/multimolecule/models/rinalmo/configuration_rinalmo.py b/multimolecule/models/rinalmo/configuration_rinalmo.py
index 5e21725d..1cc963b2 100644
--- a/multimolecule/models/rinalmo/configuration_rinalmo.py
+++ b/multimolecule/models/rinalmo/configuration_rinalmo.py
@@ -125,6 +125,6 @@ def __init__(
         self.use_cache = use_cache
         self.learnable_beta = learnable_beta
         self.token_dropout = token_dropout
-        self.head = HeadConfig(**head if head is not None else {})
-        self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {})
+        self.head = HeadConfig(**head) if head is not None else None
+        self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None
         self.emb_layer_norm_before = emb_layer_norm_before
diff --git a/multimolecule/models/rinalmo/modeling_rinalmo.py b/multimolecule/models/rinalmo/modeling_rinalmo.py
index b45d2823..d0ac6e8c 100644
--- a/multimolecule/models/rinalmo/modeling_rinalmo.py
+++ b/multimolecule/models/rinalmo/modeling_rinalmo.py
@@ -269,9 +269,9 @@ class RiNALMoForSequencePrediction(RiNALMoPreTrainedModel):
         >>> input = tokenizer("ACGUN", return_tensors="pt")
         >>> output = model(**input, labels=torch.tensor([[1]]))
         >>> output["logits"].shape
-        torch.Size([1, 2])
+        torch.Size([1, 1])
         >>> output["loss"]  # doctest:+ELLIPSIS
-        tensor(..., grad_fn=<NllLossBackward0>)
+        tensor(..., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
     """
 
     def __init__(self, config: RiNALMoConfig):
@@ -333,9 +333,9 @@ class RiNALMoForTokenPrediction(RiNALMoPreTrainedModel):
         >>> input = tokenizer("ACGUN", return_tensors="pt")
         >>> output = model(**input, labels=torch.randint(2, (1, 5)))
         >>> output["logits"].shape
-        torch.Size([1, 5, 2])
+        torch.Size([1, 5, 1])
         >>> output["loss"]  # doctest:+ELLIPSIS
-        tensor(..., grad_fn=<NllLossBackward0>)
+        tensor(..., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
     """
 
     def __init__(self, config: RiNALMoConfig):
@@ -397,9 +397,9 @@ class RiNALMoForContactPrediction(RiNALMoPreTrainedModel):
         >>> input = tokenizer("ACGUN", return_tensors="pt")
         >>> output = model(**input, labels=torch.randint(2, (1, 5, 5)))
         >>> output["logits"].shape
-        torch.Size([1, 5, 5, 2])
+        torch.Size([1, 5, 5, 1])
         >>> output["loss"]  # doctest:+ELLIPSIS
-        tensor(..., grad_fn=<NllLossBackward0>)
+        tensor(..., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
     """
 
     def __init__(self, config: RiNALMoConfig):
diff --git a/multimolecule/models/rnabert/configuration_rnabert.py b/multimolecule/models/rnabert/configuration_rnabert.py
index f044ecc7..97632d2e 100644
--- a/multimolecule/models/rnabert/configuration_rnabert.py
+++ b/multimolecule/models/rnabert/configuration_rnabert.py
@@ -112,5 +112,5 @@ def __init__(
         self.position_embedding_type = position_embedding_type
         self.is_decoder = is_decoder
         self.use_cache = use_cache
-        self.head = HeadConfig(**head if head is not None else {})
-        self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {})
+        self.head = HeadConfig(**head) if head is not None else None
+        self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None
diff --git a/multimolecule/models/rnabert/modeling_rnabert.py b/multimolecule/models/rnabert/modeling_rnabert.py
index 74b06cf1..32f7bf01 100644
--- a/multimolecule/models/rnabert/modeling_rnabert.py
+++ b/multimolecule/models/rnabert/modeling_rnabert.py
@@ -37,7 +37,13 @@
 from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
 from transformers.utils import logging
 
-from multimolecule.module import ContactPredictionHead, MaskedLMHead, SequencePredictionHead, TokenPredictionHead
+from multimolecule.module import (
+    ContactPredictionHead,
+    HeadConfig,
+    MaskedLMHead,
+    SequencePredictionHead,
+    TokenPredictionHead,
+)
 
 from ..modeling_outputs import ContactPredictorOutput, SequencePredictorOutput, TokenPredictorOutput
 from .configuration_rnabert import RnaBertConfig
@@ -266,9 +272,9 @@ class RnaBertForSequencePrediction(RnaBertPreTrainedModel):
         >>> input = tokenizer("ACGUN", return_tensors="pt")
         >>> output = model(**input, labels=torch.tensor([[1]]))
         >>> output["logits"].shape
-        torch.Size([1, 2])
+        torch.Size([1, 1])
         >>> output["loss"]  # doctest:+ELLIPSIS
-        tensor(..., grad_fn=<NllLossBackward0>)
+        tensor(..., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
     """
 
     def __init__(self, config: RnaBertConfig):
@@ -330,9 +336,9 @@ class RnaBertForTokenPrediction(RnaBertPreTrainedModel):
         >>> input = tokenizer("ACGUN", return_tensors="pt")
         >>> output = model(**input, labels=torch.randint(2, (1, 5)))
         >>> output["logits"].shape
-        torch.Size([1, 5, 2])
+        torch.Size([1, 5, 1])
         >>> output["loss"]  # doctest:+ELLIPSIS
-        tensor(..., grad_fn=<NllLossBackward0>)
+        tensor(..., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
     """
 
     def __init__(self, config: RnaBertConfig):
@@ -394,9 +400,9 @@ class RnaBertForContactPrediction(RnaBertPreTrainedModel):
         >>> input = tokenizer("ACGUN", return_tensors="pt")
         >>> output = model(**input, labels=torch.randint(2, (1, 5, 5)))
         >>> output["logits"].shape
-        torch.Size([1, 5, 5, 2])
+        torch.Size([1, 5, 5, 1])
         >>> output["loss"]  # doctest:+ELLIPSIS
-        tensor(..., grad_fn=<NllLossBackward0>)
+        tensor(..., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
     """
 
     def __init__(self, config: RnaBertConfig):
@@ -1065,7 +1071,7 @@ def __init__(self, config: RnaBertConfig):
         vocab_size, config.vocab_size = config.vocab_size, config.ss_vocab_size
         self.predictions_ss = MaskedLMHead(config)
         config.vocab_size = vocab_size
-        self.seq_relationship = SequencePredictionHead(config)
+        self.seq_relationship = SequencePredictionHead(config, HeadConfig(num_labels=2))
 
     def forward(
         self,
diff --git a/multimolecule/models/rnaernie/configuration_rnaernie.py b/multimolecule/models/rnaernie/configuration_rnaernie.py
index 2d540c9d..7a788297 100644
--- a/multimolecule/models/rnaernie/configuration_rnaernie.py
+++ b/multimolecule/models/rnaernie/configuration_rnaernie.py
@@ -108,5 +108,5 @@ def __init__(
         self.position_embedding_type = position_embedding_type
         self.is_decoder = is_decoder
         self.use_cache = use_cache
-        self.head = HeadConfig(**head if head is not None else {})
-        self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {})
+        self.head = HeadConfig(**head) if head is not None else None
+        self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None
diff --git a/multimolecule/models/rnaernie/modeling_rnaernie.py b/multimolecule/models/rnaernie/modeling_rnaernie.py
index 8d79a760..8136f673 100644
--- a/multimolecule/models/rnaernie/modeling_rnaernie.py
+++ b/multimolecule/models/rnaernie/modeling_rnaernie.py
@@ -270,9 +270,9 @@ class RnaErnieForSequencePrediction(RnaErniePreTrainedModel):
         >>> input = tokenizer("ACGUN", return_tensors="pt")
         >>> output = model(**input, labels=torch.tensor([[1]]))
         >>> output["logits"].shape
-        torch.Size([1, 2])
+        torch.Size([1, 1])
         >>> output["loss"]  # doctest:+ELLIPSIS
-        tensor(..., grad_fn=<NllLossBackward0>)
+        tensor(..., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
     """
 
     def __init__(self, config):
@@ -334,9 +334,9 @@ class RnaErnieForTokenPrediction(RnaErniePreTrainedModel):
         >>> input = tokenizer("ACGUN", return_tensors="pt")
         >>> output = model(**input, labels=torch.randint(2, (1, 5)))
         >>> output["logits"].shape
-        torch.Size([1, 5, 2])
+        torch.Size([1, 5, 1])
         >>> output["loss"]  # doctest:+ELLIPSIS
-        tensor(..., grad_fn=<NllLossBackward0>)
+        tensor(..., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
     """
 
     def __init__(self, config: RnaErnieConfig):
@@ -398,9 +398,9 @@ class RnaErnieForContactPrediction(RnaErniePreTrainedModel):
         >>> input = tokenizer("ACGUN", return_tensors="pt")
         >>> output = model(**input, labels=torch.randint(2, (1, 5, 5)))
         >>> output["logits"].shape
-        torch.Size([1, 5, 5, 2])
+        torch.Size([1, 5, 5, 1])
         >>> output["loss"]  # doctest:+ELLIPSIS
-        tensor(..., grad_fn=<NllLossBackward0>)
+        tensor(..., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
     """
 
     def __init__(self, config: RnaErnieConfig):
diff --git a/multimolecule/models/rnafm/configuration_rnafm.py b/multimolecule/models/rnafm/configuration_rnafm.py
index 8fdb7f49..ef1f0c18 100644
--- a/multimolecule/models/rnafm/configuration_rnafm.py
+++ b/multimolecule/models/rnafm/configuration_rnafm.py
@@ -131,5 +131,5 @@ def __init__(
         self.use_cache = use_cache
         self.emb_layer_norm_before = emb_layer_norm_before
         self.token_dropout = token_dropout
-        self.head = HeadConfig(**head if head is not None else {})
-        self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {})
+        self.head = HeadConfig(**head) if head is not None else None
+        self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None
diff --git a/multimolecule/models/rnafm/modeling_rnafm.py b/multimolecule/models/rnafm/modeling_rnafm.py
index 99f553da..6898da9c 100644
--- a/multimolecule/models/rnafm/modeling_rnafm.py
+++ b/multimolecule/models/rnafm/modeling_rnafm.py
@@ -272,9 +272,9 @@ class RnaFmForSequencePrediction(RnaFmPreTrainedModel):
         >>> input = tokenizer("ACGUN", return_tensors="pt")
         >>> output = model(**input, labels=torch.tensor([[1]]))
         >>> output["logits"].shape
-        torch.Size([1, 2])
+        torch.Size([1, 1])
         >>> output["loss"]  # doctest:+ELLIPSIS
-        tensor(..., grad_fn=<NllLossBackward0>)
+        tensor(..., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
     """
 
     def __init__(self, config: RnaFmConfig):
@@ -336,9 +336,9 @@ class RnaFmForTokenPrediction(RnaFmPreTrainedModel):
         >>> input = tokenizer("ACGUN", return_tensors="pt")
         >>> output = model(**input, labels=torch.randint(2, (1, 5)))
         >>> output["logits"].shape
-        torch.Size([1, 5, 2])
+        torch.Size([1, 5, 1])
         >>> output["loss"]  # doctest:+ELLIPSIS
-        tensor(..., grad_fn=<NllLossBackward0>)
+        tensor(..., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
     """
 
     def __init__(self, config: RnaFmConfig):
@@ -400,9 +400,9 @@ class RnaFmForContactPrediction(RnaFmPreTrainedModel):
         >>> input = tokenizer("ACGUN", return_tensors="pt")
         >>> output = model(**input, labels=torch.randint(2, (1, 5, 5)))
         >>> output["logits"].shape
-        torch.Size([1, 5, 5, 2])
+        torch.Size([1, 5, 5, 1])
         >>> output["loss"]  # doctest:+ELLIPSIS
-        tensor(..., grad_fn=<NllLossBackward0>)
+        tensor(..., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
     """
 
     def __init__(self, config: RnaFmConfig):
@@ -555,7 +555,7 @@ class RnaFmForPreTraining(RnaFmPreTrainedModel):
         >>> output["logits"].shape
         torch.Size([1, 7, 26])
         >>> output["contact_map"].shape
-        torch.Size([1, 5, 5, 2])
+        torch.Size([1, 5, 5, 1])
     """
 
     _tied_weights_keys = [
diff --git a/multimolecule/models/rnamsm/configuration_rnamsm.py b/multimolecule/models/rnamsm/configuration_rnamsm.py
index ae914c82..2e8150ba 100644
--- a/multimolecule/models/rnamsm/configuration_rnamsm.py
+++ b/multimolecule/models/rnamsm/configuration_rnamsm.py
@@ -116,5 +116,5 @@ def __init__(
         self.attention_type = attention_type
         self.embed_positions_msa = embed_positions_msa
         self.attention_bias = attention_bias
-        self.head = HeadConfig(**head if head is not None else {})
-        self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {})
+        self.head = HeadConfig(**head) if head is not None else None
+        self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None
diff --git a/multimolecule/models/rnamsm/modeling_rnamsm.py b/multimolecule/models/rnamsm/modeling_rnamsm.py
index 5ed6bf87..0390a129 100644
--- a/multimolecule/models/rnamsm/modeling_rnamsm.py
+++ b/multimolecule/models/rnamsm/modeling_rnamsm.py
@@ -176,9 +176,9 @@ class RnaMsmForSequencePrediction(RnaMsmPreTrainedModel):
         >>> input = tokenizer("ACGUN", return_tensors="pt")
         >>> output = model(**input, labels=torch.tensor([[1]]))
         >>> output["logits"].shape
-        torch.Size([1, 2])
+        torch.Size([1, 1])
         >>> output["loss"]  # doctest:+ELLIPSIS
-        tensor(..., grad_fn=<NllLossBackward0>)
+        tensor(..., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
     """
 
     def __init__(self, config: RnaMsmConfig):
@@ -239,9 +239,9 @@ class RnaMsmForTokenPrediction(RnaMsmPreTrainedModel):
         >>> input = tokenizer("ACGUN", return_tensors="pt")
         >>> output = model(**input, labels=torch.randint(2, (1, 5)))
         >>> output["logits"].shape
-        torch.Size([1, 5, 2])
+        torch.Size([1, 5, 1])
         >>> output["loss"]  # doctest:+ELLIPSIS
-        tensor(..., grad_fn=<NllLossBackward0>)
+        tensor(..., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
     """
 
     def __init__(self, config: RnaMsmConfig):
@@ -302,9 +302,9 @@ class RnaMsmForContactPrediction(RnaMsmPreTrainedModel):
         >>> input = tokenizer("ACGUN", return_tensors="pt")
         >>> output = model(**input, labels=torch.randint(2, (1, 5, 5)))
         >>> output["logits"].shape
-        torch.Size([1, 5, 5, 2])
+        torch.Size([1, 5, 5, 1])
         >>> output["loss"]  # doctest:+ELLIPSIS
-        tensor(..., grad_fn=<NllLossBackward0>)
+        tensor(..., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
     """
 
     def __init__(self, config: RnaMsmConfig):
@@ -449,7 +449,7 @@ class RnaMsmForPreTraining(RnaMsmPreTrainedModel):
         >>> output["logits"].shape
         torch.Size([1, 7, 26])
         >>> output["contact_map"].shape
-        torch.Size([1, 5, 5, 2])
+        torch.Size([1, 5, 5, 1])
     """
 
     _tied_weights_keys = [
diff --git a/multimolecule/models/splicebert/configuration_splicebert.py b/multimolecule/models/splicebert/configuration_splicebert.py
index 66b46a88..f789516d 100644
--- a/multimolecule/models/splicebert/configuration_splicebert.py
+++ b/multimolecule/models/splicebert/configuration_splicebert.py
@@ -108,5 +108,5 @@ def __init__(
         self.position_embedding_type = position_embedding_type
         self.is_decoder = is_decoder
         self.use_cache = use_cache
-        self.head = HeadConfig(**head if head is not None else {})
-        self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {})
+        self.head = HeadConfig(**head) if head is not None else None
+        self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None
diff --git a/multimolecule/models/splicebert/modeling_splicebert.py b/multimolecule/models/splicebert/modeling_splicebert.py
index 1b0fd072..9d129d74 100644
--- a/multimolecule/models/splicebert/modeling_splicebert.py
+++ b/multimolecule/models/splicebert/modeling_splicebert.py
@@ -274,9 +274,9 @@ class SpliceBertForSequencePrediction(SpliceBertPreTrainedModel):
         >>> input = tokenizer("ACGUN", return_tensors="pt")
         >>> output = model(**input, labels=torch.tensor([[1]]))
         >>> output["logits"].shape
-        torch.Size([1, 2])
+        torch.Size([1, 1])
         >>> output["loss"]  # doctest:+ELLIPSIS
-        tensor(..., grad_fn=<NllLossBackward0>)
+        tensor(..., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
     """
 
     def __init__(self, config: SpliceBertConfig):
@@ -338,9 +338,9 @@ class SpliceBertForTokenPrediction(SpliceBertPreTrainedModel):
         >>> input = tokenizer("ACGUN", return_tensors="pt")
         >>> output = model(**input, labels=torch.randint(2, (1, 5)))
         >>> output["logits"].shape
-        torch.Size([1, 5, 2])
+        torch.Size([1, 5, 1])
         >>> output["loss"]  # doctest:+ELLIPSIS
-        tensor(..., grad_fn=<NllLossBackward0>)
+        tensor(..., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
     """
 
     def __init__(self, config: SpliceBertConfig):
@@ -402,9 +402,9 @@ class SpliceBertForContactPrediction(SpliceBertPreTrainedModel):
         >>> input = tokenizer("ACGUN", return_tensors="pt")
         >>> output = model(**input, labels=torch.randint(2, (1, 5, 5)))
         >>> output["logits"].shape
-        torch.Size([1, 5, 5, 2])
+        torch.Size([1, 5, 5, 1])
         >>> output["loss"]  # doctest:+ELLIPSIS
-        tensor(..., grad_fn=<NllLossBackward0>)
+        tensor(..., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
     """
 
     def __init__(self, config: SpliceBertConfig):
diff --git a/multimolecule/models/utrbert/configuration_utrbert.py b/multimolecule/models/utrbert/configuration_utrbert.py
index d032c5ee..5230c04f 100644
--- a/multimolecule/models/utrbert/configuration_utrbert.py
+++ b/multimolecule/models/utrbert/configuration_utrbert.py
@@ -125,5 +125,5 @@ def __init__(
         self.position_embedding_type = position_embedding_type
         self.is_decoder = is_decoder
         self.use_cache = use_cache
-        self.head = HeadConfig(**head if head is not None else {})
-        self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {})
+        self.head = HeadConfig(**head) if head is not None else None
+        self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None
diff --git a/multimolecule/models/utrbert/modeling_utrbert.py b/multimolecule/models/utrbert/modeling_utrbert.py
index 1a5b47f9..688bedbe 100644
--- a/multimolecule/models/utrbert/modeling_utrbert.py
+++ b/multimolecule/models/utrbert/modeling_utrbert.py
@@ -264,9 +264,9 @@ class UtrBertForSequencePrediction(UtrBertPreTrainedModel):
         >>> input = tokenizer("ACGUN", return_tensors="pt")
         >>> output = model(**input, labels=torch.tensor([[1]]))
         >>> output["logits"].shape
-        torch.Size([1, 2])
+        torch.Size([1, 1])
         >>> output["loss"]  # doctest:+ELLIPSIS
-        tensor(..., grad_fn=<NllLossBackward0>)
+        tensor(..., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
     """
 
     def __init__(self, config: UtrBertConfig):
@@ -328,9 +328,9 @@ class UtrBertForTokenPrediction(UtrBertPreTrainedModel):
         >>> input = tokenizer("ACGUN", return_tensors="pt")
         >>> output = model(**input, labels=torch.randint(2, (1, 5)))
         >>> output["logits"].shape
-        torch.Size([1, 5, 2])
+        torch.Size([1, 5, 1])
         >>> output["loss"]  # doctest:+ELLIPSIS
-        tensor(..., grad_fn=<NllLossBackward0>)
+        tensor(..., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
     """
 
     def __init__(self, config: UtrBertConfig):
@@ -393,9 +393,9 @@ class UtrBertForContactPrediction(UtrBertPreTrainedModel):
         >>> input = tokenizer("ACGUN", return_tensors="pt")
         >>> output = model(**input, labels=torch.randint(2, (1, 5, 5)))
         >>> output["logits"].shape
-        torch.Size([1, 5, 5, 2])
+        torch.Size([1, 5, 5, 1])
         >>> output["loss"]  # doctest:+ELLIPSIS
-        tensor(..., grad_fn=<NllLossBackward0>)
+        tensor(..., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
     """
 
     def __init__(self, config: UtrBertConfig):
diff --git a/multimolecule/models/utrlm/configuration_utrlm.py b/multimolecule/models/utrlm/configuration_utrlm.py
index f0f705de..a4f930d7 100644
--- a/multimolecule/models/utrlm/configuration_utrlm.py
+++ b/multimolecule/models/utrlm/configuration_utrlm.py
@@ -127,7 +127,7 @@ def __init__(
         self.use_cache = use_cache
         self.emb_layer_norm_before = emb_layer_norm_before
         self.token_dropout = token_dropout
-        self.head = HeadConfig(**head if head is not None else {})
-        self.lm_head = MaskedLMHeadConfig(**lm_head if lm_head is not None else {})
+        self.head = HeadConfig(**head) if head is not None else None
+        self.lm_head = MaskedLMHeadConfig(**lm_head) if lm_head is not None else None
         self.ss_head = HeadConfig(**ss_head) if ss_head is not None else None
         self.mfe_head = HeadConfig(**mfe_head) if mfe_head is not None else None
diff --git a/multimolecule/models/utrlm/modeling_utrlm.py b/multimolecule/models/utrlm/modeling_utrlm.py
index 535f99f0..aae1b593 100644
--- a/multimolecule/models/utrlm/modeling_utrlm.py
+++ b/multimolecule/models/utrlm/modeling_utrlm.py
@@ -272,9 +272,9 @@ class UtrLmForSequencePrediction(UtrLmPreTrainedModel):
         >>> input = tokenizer("ACGUN", return_tensors="pt")
         >>> output = model(**input, labels=torch.tensor([[1]]))
         >>> output["logits"].shape
-        torch.Size([1, 2])
+        torch.Size([1, 1])
         >>> output["loss"]  # doctest:+ELLIPSIS
-        tensor(..., grad_fn=<NllLossBackward0>)
+        tensor(..., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
     """
 
     def __init__(self, config: UtrLmConfig):
@@ -336,9 +336,9 @@ class UtrLmForTokenPrediction(UtrLmPreTrainedModel):
         >>> input = tokenizer("ACGUN", return_tensors="pt")
         >>> output = model(**input, labels=torch.randint(2, (1, 5)))
         >>> output["logits"].shape
-        torch.Size([1, 5, 2])
+        torch.Size([1, 5, 1])
         >>> output["loss"]  # doctest:+ELLIPSIS
-        tensor(..., grad_fn=<NllLossBackward0>)
+        tensor(..., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
     """
 
     def __init__(self, config: UtrLmConfig):
@@ -400,9 +400,9 @@ class UtrLmForContactPrediction(UtrLmPreTrainedModel):
         >>> input = tokenizer("ACGUN", return_tensors="pt")
         >>> output = model(**input, labels=torch.randint(2, (1, 5, 5)))
         >>> output["logits"].shape
-        torch.Size([1, 5, 5, 2])
+        torch.Size([1, 5, 5, 1])
         >>> output["loss"]  # doctest:+ELLIPSIS
-        tensor(..., grad_fn=<NllLossBackward0>)
+        tensor(..., grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)
     """
 
     def __init__(self, config: UtrLmConfig):
@@ -555,7 +555,7 @@ class UtrLmForPreTraining(UtrLmPreTrainedModel):
         >>> output["logits"].shape
         torch.Size([1, 7, 26])
         >>> output["contact_map"].shape
-        torch.Size([1, 5, 5, 2])
+        torch.Size([1, 5, 5, 1])
     """
 
     _tied_weights_keys = [
diff --git a/multimolecule/module/__init__.py b/multimolecule/module/__init__.py
index 0128fe9b..dbba900b 100644
--- a/multimolecule/module/__init__.py
+++ b/multimolecule/module/__init__.py
@@ -14,8 +14,8 @@
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
-from .criterions import Criterion
-from .embeddings import PositionEmbeddingRegistry, PositionEmbeddingRegistryHF, RotaryEmbedding, SinusoidalEmbedding
+from .criterions import Criterion, CriterionRegistry
+from .embeddings import PositionEmbeddingRegistry, RotaryEmbedding, SinusoidalEmbedding
 from .heads import (
     BaseHeadConfig,
     ContactPredictionHead,
@@ -23,7 +23,6 @@
     HeadOutput,
     HeadRegistry,
     HeadTransformRegistry,
-    HeadTransformRegistryHF,
     IdentityTransform,
     LinearTransform,
     MaskedLMHead,
@@ -31,15 +30,18 @@
     NonLinearTransform,
     PredictionHead,
     SequencePredictionHead,
-    TokenHeadRegistryHF,
     TokenKMerHead,
     TokenPredictionHead,
 )
+from .model import MultiMoleculeModel
+from .registry import ModelRegistry
 
 __all__ = [
+    "ModelRegistry",
+    "MultiMoleculeModel",
+    "CriterionRegistry",
     "Criterion",
     "PositionEmbeddingRegistry",
-    "PositionEmbeddingRegistryHF",
     "RotaryEmbedding",
     "SinusoidalEmbedding",
     "BaseHeadConfig",
@@ -48,14 +50,12 @@
     "HeadRegistry",
     "PredictionHead",
     "SequencePredictionHead",
-    "TokenHeadRegistryHF",
     "TokenPredictionHead",
     "TokenKMerHead",
     "ContactPredictionHead",
     "MaskedLMHead",
     "HeadOutput",
     "HeadTransformRegistry",
-    "HeadTransformRegistryHF",
     "LinearTransform",
     "NonLinearTransform",
     "IdentityTransform",
diff --git a/multimolecule/module/backbones/__init__.py b/multimolecule/module/backbones/__init__.py
new file mode 100644
index 00000000..d69e6292
--- /dev/null
+++ b/multimolecule/module/backbones/__init__.py
@@ -0,0 +1,21 @@
+# MultiMolecule
+# Copyright (C) 2024-Present  MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+
+from .registry import BackboneRegistry
+from .sequence import SequenceBackbone
+from .sequences import SequenceRegistry
+
+__all__ = ["BackboneRegistry", "SequenceRegistry", "SequenceBackbone"]
diff --git a/multimolecule/module/backbones/registry.py b/multimolecule/module/backbones/registry.py
new file mode 100644
index 00000000..47be122d
--- /dev/null
+++ b/multimolecule/module/backbones/registry.py
@@ -0,0 +1,21 @@
+# MultiMolecule
+# Copyright (C) 2024-Present  MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+
+from __future__ import annotations
+
+from chanfig import Registry
+
+BackboneRegistry = Registry()
diff --git a/multimolecule/module/backbones/sequence.py b/multimolecule/module/backbones/sequence.py
new file mode 100644
index 00000000..a30cbf83
--- /dev/null
+++ b/multimolecule/module/backbones/sequence.py
@@ -0,0 +1,46 @@
+# MultiMolecule
+# Copyright (C) 2024-Present  MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+
+from __future__ import annotations
+
+import torch
+from chanfig import FlatDict
+from danling import NestedTensor
+from torch import Tensor, nn
+
+from .registry import BackboneRegistry
+from .sequences import SequenceRegistry
+
+
+@BackboneRegistry.register("sequence", default=True)
+class SequenceBackbone(nn.Module):
+    def __init__(self, sequence) -> None:
+        super().__init__()
+        self.sequence = SequenceRegistry.build(**sequence)
+        self.sequence_dropout = nn.Dropout(sequence.pop("dropout", 0), inplace=True)
+        self.config = self.sequence.config
+        self.out_channels = self.config.hidden_size
+
+    def forward(self, sequence: NestedTensor | Tensor, *args, **kwargs) -> tuple[FlatDict, FlatDict]:
+        attentions = None
+        input_ids, attention_mask = sequence.tensor, sequence.mask
+        sequence_output = self.sequence(input_ids.int(), attention_mask)
+        sequence_output["pooler_output"] = self.sequence_dropout(sequence_output["pooler_output"])
+        sequence_output["last_hidden_state"] = self.sequence_dropout(sequence_output["last_hidden_state"])
+        if "attentions" in sequence_output:
+            attentions = torch.stack(sequence_output["attentions"], dim=1).detach()
+
+        return sequence_output, attentions
diff --git a/multimolecule/module/backbones/sequences/__init__.py b/multimolecule/module/backbones/sequences/__init__.py
new file mode 100644
index 00000000..e6e5cd08
--- /dev/null
+++ b/multimolecule/module/backbones/sequences/__init__.py
@@ -0,0 +1,20 @@
+# MultiMolecule
+# Copyright (C) 2024-Present  MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+
+from .onehot import OneHot
+from .registry import SequenceRegistry
+
+__all__ = ["SequenceRegistry", "OneHot"]
diff --git a/multimolecule/module/backbones/sequences/onehot.py b/multimolecule/module/backbones/sequences/onehot.py
new file mode 100644
index 00000000..bc4c979f
--- /dev/null
+++ b/multimolecule/module/backbones/sequences/onehot.py
@@ -0,0 +1,39 @@
+# MultiMolecule
+# Copyright (C) 2024-Present  MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+
+import torch
+from chanfig import FlatDict
+from torch import nn
+from transformers import AutoConfig
+
+from .registry import SequenceRegistry
+
+
+@SequenceRegistry.register("onehot")
+class OneHot(nn.Module):
+    def __init__(self, pretrained: str) -> None:
+        super().__init__()
+        self.config = AutoConfig.from_pretrained(str(pretrained))
+        self.module = nn.Embedding(self.config.vocab_size, self.config.hidden_size)
+
+    def forward(self, input_ids, attn_mask) -> FlatDict:
+        output = FlatDict()
+        output["last_hidden_state"] = self.module(input_ids)
+        valid_length = attn_mask.sum(dim=1)
+        output["pooler_output"] = torch.stack(
+            [t[: valid_length[i]].sum(0) for i, t in enumerate(output["last_hidden_state"])]
+        )
+        return output
diff --git a/multimolecule/module/backbones/sequences/registry.py b/multimolecule/module/backbones/sequences/registry.py
new file mode 100644
index 00000000..c9178231
--- /dev/null
+++ b/multimolecule/module/backbones/sequences/registry.py
@@ -0,0 +1,66 @@
+# MultiMolecule
+# Copyright (C) 2024-Present  MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+
+from __future__ import annotations
+
+import danling as dl
+import transformers
+from chanfig import Registry as Registry_
+from torch import nn
+from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
+
+
+class Registry(Registry_):  # pylint: disable=too-few-public-methods
+    def build(
+        self,
+        type: str | None = None,
+        name: str | None = None,
+        use_pretrained: bool = True,
+        gradient_checkpoint: bool = False,
+        checkpoint: str | None = None,
+        *args,
+        **kwargs,
+    ) -> nn.Module:
+        if type is not None:
+            if type in self:
+                sequence_cls = self.lookup(type)
+                sequence = self.init(sequence_cls, *args, **kwargs)
+                if checkpoint is not None:
+                    sequence.load_state_dict(dl.load(checkpoint))
+            elif hasattr(transformers, type + "Model"):
+                if use_pretrained:
+                    sequence_cls: PreTrainedModel = getattr(transformers, type + "Model")  # type: ignore[no-redef]
+                    sequence = sequence_cls.from_pretrained(name, *args, **kwargs)
+                else:
+                    config_cls: PretrainedConfig = getattr(transformers, type + "Config")
+                    config, kwargs = config_cls.from_pretrained(name, return_unused_kwargs=True, **kwargs)
+                    sequence_cls: PreTrainedModel = getattr(transformers, type + "Model")  # type: ignore[no-redef]
+                    sequence = sequence_cls.from_config(config, *args, **kwargs)
+            else:
+                raise ValueError(f"Sequence {type} not found in registry or transformers")
+        else:
+            if use_pretrained:
+                sequence = AutoModel.from_pretrained(name, *args, **kwargs)
+            else:
+                config, kwargs = AutoConfig.from_pretrained(name, return_unused_kwargs=True, **kwargs)
+                sequence = AutoModel.from_config(config, *args, **kwargs)
+
+        if gradient_checkpoint:
+            sequence.gradient_checkpointing_enable()
+        return sequence
+
+
+SequenceRegistry = Registry()
diff --git a/multimolecule/module/criterions/__init__.py b/multimolecule/module/criterions/__init__.py
index 104334b5..4b9adf7e 100644
--- a/multimolecule/module/criterions/__init__.py
+++ b/multimolecule/module/criterions/__init__.py
@@ -14,6 +14,18 @@
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
+from .binary import BCEWithLogitsLoss
 from .generic import Criterion
+from .multiclass import CrossEntropyLoss
+from .multilabel import MultiLabelSoftMarginLoss
+from .registry import CriterionRegistry
+from .regression import MSELoss
 
-__all__ = ["Criterion"]
+__all__ = [
+    "CriterionRegistry",
+    "Criterion",
+    "MSELoss",
+    "BCEWithLogitsLoss",
+    "CrossEntropyLoss",
+    "MultiLabelSoftMarginLoss",
+]
diff --git a/multimolecule/module/criterions/binary.py b/multimolecule/module/criterions/binary.py
new file mode 100644
index 00000000..0bf53e59
--- /dev/null
+++ b/multimolecule/module/criterions/binary.py
@@ -0,0 +1,44 @@
+# MultiMolecule
+# Copyright (C) 2024-Present  MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+import torch
+from danling import NestedTensor
+from torch import Tensor, nn
+
+from .registry import CriterionRegistry
+
+if TYPE_CHECKING:
+    from ..heads.config import HeadConfig
+
+
+@CriterionRegistry.register("binary")
+class BCEWithLogitsLoss(nn.BCEWithLogitsLoss):
+    def __init__(self, config: HeadConfig) -> None:
+        super().__init__(**config.get("loss", {}))
+        self.config = config
+
+    def forward(self, input: NestedTensor | Tensor, target: NestedTensor | Tensor) -> Tensor:
+        if isinstance(input, NestedTensor):
+            input = torch.cat(input.flatten().storage())
+        if isinstance(target, NestedTensor):
+            target = torch.cat(target.flatten().storage())
+        if input.ndim == target.ndim + 1:
+            input = input.squeeze(-1)
+        return super().forward(input, target.float())
diff --git a/multimolecule/module/criterions/generic.py b/multimolecule/module/criterions/generic.py
index b003c81d..a6731933 100644
--- a/multimolecule/module/criterions/generic.py
+++ b/multimolecule/module/criterions/generic.py
@@ -17,8 +17,8 @@
 from __future__ import annotations
 
 from typing import TYPE_CHECKING
+from warnings import warn
 
-import torch
 from danling import NestedTensor
 from torch import Tensor, nn
 from torch.nn import functional as F
@@ -26,10 +26,13 @@
 if TYPE_CHECKING:
     from ..heads.config import HeadConfig
 
+from .registry import CriterionRegistry
 
+
+@CriterionRegistry.register(default=True)
 class Criterion(nn.Module):
 
-    problem_types = ["regression", "single_label_classification", "multi_label_classification"]
+    problem_types = ["regression", "binary", "multiclass", "multilabel"]
 
     def __init__(self, config: HeadConfig) -> None:
         super().__init__()
@@ -41,21 +44,31 @@ def forward(self, logits: Tensor | NestedTensor, labels: Tensor | NestedTensor)
         if labels is None:
             return None
         if self.problem_type is None:
-            if self.num_labels == 1:
+            if labels.is_floating_point():
                 self.problem_type = "regression"
-            elif self.num_labels > 1 and labels.dtype in (torch.long, torch.int):
-                self.problem_type = "single_label_classification"
+            elif self.num_labels == 1:
+                self.problem_type = "binary"
+            elif labels.unique().numel() == 2:
+                self.problem_type = "multilabel"
             else:
-                self.problem_type = "multi_label_classification"
+                self.problem_type = "multiclass"
+            warn(
+                f"`problem_type` is not set. Assuming {self.problem_type}. \n"
+                "This can lead to unexpected behavior. Please set `problem_type` explicitly."
+            )
             self.config.problem_type = self.problem_type
         if self.problem_type == "regression":
             labels = labels.to(logits.dtype)
             if self.num_labels == 1:
                 return F.mse_loss(logits.squeeze(), labels.squeeze())
             logits, labels = logits.view(-1, self.num_labels), labels.view(-1, self.num_labels)
-            return sum(F.mse_loss(logits[:, i], labels[:, i]).sqrt() for i in range(self.num_labels))
-        if self.problem_type == "single_label_classification":
+            return sum(F.mse_loss(logits[:, i], labels[:, i]).sqrt() for i in range(self.num_labels))  # type: ignore
+        if self.problem_type == "multiclass":
             return F.cross_entropy(logits.view(-1, self.num_labels), labels.view(-1))
-        if self.problem_type == "multi_label_classification":
-            return F.binary_cross_entropy_with_logits(logits, labels)
+        if self.problem_type == "binary":
+            if logits.ndim == labels.ndim + 1:
+                logits = logits.squeeze(-1)
+            return F.binary_cross_entropy_with_logits(logits, labels.to(logits.dtype))
+        if self.problem_type == "multilabel":
+            return F.multilabel_soft_margin_loss(logits, labels.to(logits.dtype))
         raise ValueError(f"problem_type should be one of {self.problem_types}, but got {self.problem_type}")
diff --git a/multimolecule/module/criterions/multiclass.py b/multimolecule/module/criterions/multiclass.py
new file mode 100644
index 00000000..f7070e94
--- /dev/null
+++ b/multimolecule/module/criterions/multiclass.py
@@ -0,0 +1,44 @@
+# MultiMolecule
+# Copyright (C) 2024-Present  MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+import torch
+from danling import NestedTensor
+from torch import Tensor, nn
+
+if TYPE_CHECKING:
+    from ..heads.config import HeadConfig
+
+from .registry import CriterionRegistry
+
+
+@CriterionRegistry.register("multiclass")
+class CrossEntropyLoss(nn.CrossEntropyLoss):
+    def __init__(self, config: HeadConfig) -> None:
+        super().__init__(**config.get("loss", {}))
+        self.config = config
+
+    def forward(self, input: NestedTensor | Tensor, target: NestedTensor | Tensor) -> Tensor:
+        if isinstance(input, NestedTensor):
+            input = torch.cat(input.storage())
+        if isinstance(target, NestedTensor):
+            target = torch.cat(target.storage())
+        if input.ndim > 2:
+            input, target = input.view(-1, input.size(-1)), target.view(-1)
+        return super().forward(input, target.long())
diff --git a/multimolecule/module/criterions/multilabel.py b/multimolecule/module/criterions/multilabel.py
new file mode 100644
index 00000000..c72bb9f9
--- /dev/null
+++ b/multimolecule/module/criterions/multilabel.py
@@ -0,0 +1,44 @@
+# MultiMolecule
+# Copyright (C) 2024-Present  MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+import torch
+from danling import NestedTensor
+from torch import Tensor, nn
+
+if TYPE_CHECKING:
+    from ..heads.config import HeadConfig
+
+from .registry import CriterionRegistry
+
+
+@CriterionRegistry.register("multilabel")
+class MultiLabelSoftMarginLoss(nn.MultiLabelSoftMarginLoss):
+    def __init__(self, config: HeadConfig) -> None:
+        super().__init__(**config.get("loss", {}))
+        self.config = config
+
+    def forward(self, input: NestedTensor | Tensor, target: NestedTensor | Tensor) -> Tensor:
+        if isinstance(target, NestedTensor) and target.ndim > 2:
+            input, target = input.view(-1, input.size(-1)), target.view(-1, target.size(-1))
+        if isinstance(input, NestedTensor):
+            input = torch.cat(input.storage())
+        if isinstance(target, NestedTensor):
+            target = torch.cat(target.storage())
+        return super().forward(input, target.float())
diff --git a/multimolecule/module/criterions/registry.py b/multimolecule/module/criterions/registry.py
new file mode 100644
index 00000000..856341f7
--- /dev/null
+++ b/multimolecule/module/criterions/registry.py
@@ -0,0 +1,29 @@
+# MultiMolecule
+# Copyright (C) 2024-Present  MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+
+from chanfig import ConfigRegistry as Registry_
+from torch import nn
+
+
+class Registry(Registry_):  # pylint: disable=too-few-public-methods
+    key = "problem_type"
+
+    def build(self, config) -> nn.Module:  # type: ignore[override]
+        name = getattr(config, self.getattr("key"))
+        return self.init(self.lookup(name), config)  # type: ignore[arg-type]
+
+
+CriterionRegistry = Registry(fallback=True)
diff --git a/multimolecule/module/criterions/regression.py b/multimolecule/module/criterions/regression.py
new file mode 100644
index 00000000..4f39e0eb
--- /dev/null
+++ b/multimolecule/module/criterions/regression.py
@@ -0,0 +1,44 @@
+# MultiMolecule
+# Copyright (C) 2024-Present  MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+import torch
+from danling import NestedTensor
+from torch import Tensor, nn
+
+if TYPE_CHECKING:
+    from ..heads.config import HeadConfig
+
+from .registry import CriterionRegistry
+
+
+@CriterionRegistry.register("regression")
+class MSELoss(nn.MSELoss):
+    def __init__(self, config: HeadConfig) -> None:
+        super().__init__(**config.get("loss", {}))
+        self.config = config
+
+    def forward(self, input: NestedTensor | Tensor, target: NestedTensor | Tensor) -> Tensor:
+        if isinstance(input, NestedTensor):
+            input = torch.cat(input.flatten().storage())
+        if isinstance(target, NestedTensor):
+            target = torch.cat(target.flatten().storage())
+        if input.ndim == target.ndim + 1:
+            target = target.unsqueeze(-1)
+        return super().forward(input, target.to(input.dtype))
diff --git a/multimolecule/module/heads/__init__.py b/multimolecule/module/heads/__init__.py
index 8cc91f29..0e857c5e 100644
--- a/multimolecule/module/heads/__init__.py
+++ b/multimolecule/module/heads/__init__.py
@@ -21,14 +21,8 @@
 from .pretrain import MaskedLMHead
 from .registry import HeadRegistry
 from .sequence import SequencePredictionHead
-from .token import TokenHeadRegistryHF, TokenKMerHead, TokenPredictionHead
-from .transform import (
-    HeadTransformRegistry,
-    HeadTransformRegistryHF,
-    IdentityTransform,
-    LinearTransform,
-    NonLinearTransform,
-)
+from .token import TokenKMerHead, TokenPredictionHead
+from .transform import HeadTransformRegistry, IdentityTransform, LinearTransform, NonLinearTransform
 
 __all__ = [
     "BaseHeadConfig",
@@ -37,14 +31,12 @@
     "HeadRegistry",
     "PredictionHead",
     "SequencePredictionHead",
-    "TokenHeadRegistryHF",
     "TokenPredictionHead",
     "TokenKMerHead",
     "ContactPredictionHead",
     "MaskedLMHead",
     "HeadOutput",
     "HeadTransformRegistry",
-    "HeadTransformRegistryHF",
     "LinearTransform",
     "NonLinearTransform",
     "IdentityTransform",
diff --git a/multimolecule/module/heads/config.py b/multimolecule/module/heads/config.py
index bb0dbba6..3b9ee64b 100644
--- a/multimolecule/module/heads/config.py
+++ b/multimolecule/module/heads/config.py
@@ -16,15 +16,13 @@
 
 from __future__ import annotations
 
-from collections import OrderedDict
-from dataclasses import dataclass
+from chanfig import FlatDict
 
 
-class BaseHeadConfig(OrderedDict):
+class BaseHeadConfig(FlatDict):
     pass
 
 
-@dataclass
 class HeadConfig(BaseHeadConfig):
     r"""
     Configuration class for a prediction head.
@@ -35,8 +33,8 @@ class HeadConfig(BaseHeadConfig):
 
             Head should look for [`Config.num_labels`][multimolecule.PreTrainedConfig] if is `None`.
         problem_type:
-            Problem type for `XxxForYyyPrediction` models. Can be one of `"regression"`,
-            `"single_label_classification"` or `"multi_label_classification"`.
+            Problem type for `XxxForYyyPrediction` models. Can be one of `"binary"`, `"regression"`,
+            `"multiclass"` or `"multilabel"`.
 
             Head should look for [`Config.problem_type`][multimolecule.PreTrainedConfig] if is `None`.
         hidden_size:
@@ -55,14 +53,18 @@ class HeadConfig(BaseHeadConfig):
             The activation function of the final prediction output.
         layer_norm_eps:
             The epsilon used by the layer normalization layers.
-        output_name (`str`, *optional*):
+        output_name:
             The name of the tensor required in model outputs.
 
             If is `None`, will use the default output name of the corresponding head.
+        type:
+            The type of the head in the model.
+
+            This is used by [`MultiMoleculeModel`][multimolecule.MultiMoleculeModel] to construct heads.
     """
 
-    num_labels: int = None  # type: ignore[assignment]
-    problem_type: str = None  # type: ignore[assignment]
+    num_labels: int | None = None
+    problem_type: str | None = None
     hidden_size: int | None = None
     dropout: float = 0.0
     transform: str | None = None
@@ -71,9 +73,9 @@ class HeadConfig(BaseHeadConfig):
     act: str | None = None
     layer_norm_eps: float = 1e-12
     output_name: str | None = None
+    type: str | None = None
 
 
-@dataclass
 class MaskedLMHeadConfig(BaseHeadConfig):
     r"""
     Configuration class for a Masked Language Modeling head.
@@ -95,7 +97,7 @@ class MaskedLMHeadConfig(BaseHeadConfig):
             The activation function of the final prediction output.
         layer_norm_eps:
             The epsilon used by the layer normalization layers.
-        output_name (`str`, *optional*):
+        output_name:
             The name of the tensor required in model outputs.
 
             If is `None`, will use the default output name of the corresponding head.
diff --git a/multimolecule/module/heads/contact.py b/multimolecule/module/heads/contact.py
index 50ec4fbb..866e2250 100644
--- a/multimolecule/module/heads/contact.py
+++ b/multimolecule/module/heads/contact.py
@@ -16,10 +16,12 @@
 
 from __future__ import annotations
 
-from typing import Mapping, Tuple
+from typing import Callable, Mapping, Tuple, Type
 
 import torch
 from danling import NestedTensor
+from danling.modules import MLP
+from lazy_imports import try_import
 from torch import Tensor, nn
 from transformers.modeling_outputs import ModelOutput
 from typing_extensions import TYPE_CHECKING
@@ -28,13 +30,86 @@
 from .generic import PredictionHead
 from .output import HeadOutput
 from .registry import HeadRegistry
-from .utils import average_product_correct, symmetrize
+
+with try_import() as tv:
+    from torchvision.models.resnet import BasicBlock, Bottleneck
 
 if TYPE_CHECKING:
     from multimolecule.models import PreTrainedConfig
 
 
-@HeadRegistry.register("contact")
+@HeadRegistry.contact.register("default", default=True)
+class ContactHead(PredictionHead):
+
+    output_name: str = "last_hidden_state"
+
+    def __init__(self, config: PreTrainedConfig, head_config: HeadConfig | None = None):
+        super().__init__(config, head_config)
+        self._bos_token_id = config.bos_token_id
+        self._eos_token_id = config.eos_token_id
+        self.pad_token_id = config.pad_token_id
+        out_channels: int = self.config.hidden_size  # type: ignore[assignment]
+        self.qk_proj = nn.Linear(out_channels, 2 * out_channels)
+        self.ffn = MLP(1, out_channels, residual=False)
+
+    def forward(  # type: ignore[override]  # pylint: disable=arguments-renamed
+        self,
+        outputs: ModelOutput | Mapping | Tuple[Tensor, ...],
+        attention_mask: Tensor | None = None,
+        input_ids: NestedTensor | Tensor | None = None,
+        labels: Tensor | None = None,
+        output_name: str | None = None,
+        **kwargs,
+    ) -> HeadOutput:
+        if attention_mask is None:
+            if isinstance(input_ids, NestedTensor):
+                input_ids, attention_mask = input_ids.tensor, input_ids.mask
+            else:
+                if input_ids is None:
+                    raise ValueError(
+                        f"Either attention_mask or input_ids must be provided for {self.__class__.__name__} to work."
+                    )
+                if self.pad_token_id is None:
+                    raise ValueError(
+                        f"pad_token_id must be provided when attention_mask is not passed to {self.__class__.__name__}."
+                    )
+                attention_mask = input_ids.ne(self.pad_token_id)
+
+        if isinstance(outputs, (Mapping, ModelOutput)):
+            output = outputs[output_name or self.output_name]
+        elif isinstance(outputs, tuple):
+            output = outputs[0]
+        else:
+            raise ValueError(f"Unsupported type for outputs: {type(outputs)}")
+        output = output * attention_mask.unsqueeze(-1)
+
+        # remove cls token embeddings
+        if self.bos_token_id is not None:
+            output = output[..., 1:, :]
+            # process attention_mask and input_ids to make removal of eos token happy
+            attention_mask = attention_mask[..., 1:]
+            if input_ids is not None:
+                input_ids = input_ids[..., 1:]
+        # remove eos token embeddings
+        if self.eos_token_id is not None:
+            if input_ids is not None:
+                eos_mask = input_ids.ne(self.eos_token_id).to(output)
+            else:
+                last_valid_indices = attention_mask.sum(dim=-1)
+                seq_length = attention_mask.size(-1)
+                eos_mask = torch.arange(seq_length, device=output.device) == last_valid_indices.unsqueeze(1)
+            output = output * eos_mask[:, :, None]
+            output = output[..., :-1, :]
+
+        q, k = self.qk_proj(output).chunk(2, dim=-1)
+        contact_map = (q @ k.transpose(-2, -1)).unsqueeze(-1)
+        contact_map = contact_map + self.ffn(contact_map)
+        if "continuous" in outputs:
+            contact_map = contact_map * (1 + outputs["continuous"].unsqueeze(dim=-1))  # type: ignore[call-overload]
+        return super().forward(contact_map, labels)
+
+
+@HeadRegistry.contact.register("attention")
 class ContactPredictionHead(PredictionHead):
     r"""
     Head for tasks in contact-level.
@@ -50,13 +125,23 @@ class ContactPredictionHead(PredictionHead):
     output_name: str = "attentions"
     r"""The default output to use for the head."""
 
+    requires_attention: bool = True
+
     def __init__(self, config: PreTrainedConfig, head_config: HeadConfig | None = None):
         super().__init__(config, head_config)
         self.bos_token_id = config.bos_token_id
         self.eos_token_id = config.eos_token_id
         self.pad_token_id = config.pad_token_id
-        self.decoder = nn.Linear(
-            config.num_hidden_layers * config.num_attention_heads, self.num_labels, bias=self.config.bias
+        self.config.hidden_size = config.num_hidden_layers * config.num_attention_heads
+        num_layers = self.config.get("num_layers", 16)
+        num_channels = self.config.get("num_channels", self.config.hidden_size // 10)  # type: ignore[operator]
+        block = self.config.get("block", "auto")
+        self.decoder = ResNet(
+            num_layers=num_layers,
+            hidden_size=self.config.hidden_size,  # type: ignore[arg-type]
+            block=block,
+            num_channels=num_channels,
+            num_labels=self.num_labels,
         )
         if head_config is not None and head_config.output_name is not None:
             self.output_name = head_config.output_name
@@ -106,7 +191,7 @@ def forward(  # type: ignore[override]  # pylint: disable=arguments-renamed
         # but it does for the contact prediction task, which takes attentions as input,
         # so we have to mimic that here.
         attention_mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2)
-        attentions *= attention_mask[:, None, None, :, :]
+        attentions = attentions * attention_mask[:, None, None, :, :]
 
         # remove cls token attentions
         if self.bos_token_id is not None:
@@ -124,14 +209,203 @@ def forward(  # type: ignore[override]  # pylint: disable=arguments-renamed
                 seq_length = attention_mask.size(-1)
                 eos_mask = torch.arange(seq_length, device=attentions.device).unsqueeze(0) == last_valid_indices
             eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2)
-            attentions *= eos_mask[:, None, None, :, :]
+            attentions = attentions * eos_mask[:, None, None, :, :]
             attentions = attentions[..., :-1, :-1]
 
         # features: batch x channels x input_ids x input_ids (symmetric)
         batch_size, layers, heads, seqlen, _ = attentions.size()
         attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen)
-        attentions = attentions.to(self.decoder.weight.device)
+        attentions = attentions.to(self.decoder.proj.weight.device)
         attentions = average_product_correct(symmetrize(attentions))
         attentions = attentions.permute(0, 2, 3, 1).squeeze(3)
 
         return super().forward(attentions, labels, **kwargs)
+
+
+@HeadRegistry.contact.register("logits")
+class ContactLogitsHead(PredictionHead):
+    r"""
+    Head for tasks in contact-level.
+
+    Performs symmetrization, and average product correct.
+
+    Args:
+        config: The configuration object for the model.
+        head_config: The configuration object for the head.
+            If None, will use configuration from the `config`.
+    """
+
+    output_name: str = "last_hidden_state"
+    r"""The default output to use for the head."""
+
+    requires_attention: bool = False
+
+    def __init__(self, config: PreTrainedConfig, head_config: HeadConfig | None = None):
+        super().__init__(config, head_config)
+        self.bos_token_id = config.bos_token_id
+        self.eos_token_id = config.eos_token_id
+        self.pad_token_id = config.pad_token_id
+        num_layers = self.config.get("num_layers", 16)
+        num_channels = self.config.get("num_channels", self.config.hidden_size // 10)  # type: ignore[operator]
+        block = self.config.get("block", "auto")
+        self.decoder = ResNet(
+            num_layers=num_layers,
+            hidden_size=self.config.hidden_size,  # type: ignore[arg-type]
+            block=block,
+            num_channels=num_channels,
+            num_labels=self.num_labels,
+        )
+        if head_config is not None and head_config.output_name is not None:
+            self.output_name = head_config.output_name
+
+    def forward(  # type: ignore[override]  # pylint: disable=arguments-renamed
+        self,
+        outputs: ModelOutput | Mapping | Tuple[Tensor, ...],
+        attention_mask: Tensor | None = None,
+        input_ids: NestedTensor | Tensor | None = None,
+        labels: Tensor | None = None,
+        output_name: str | None = None,
+        **kwargs,
+    ) -> HeadOutput:
+        r"""
+        Forward pass of the ContactPredictionHead.
+
+        Args:
+            outputs: The outputs of the model.
+            attention_mask: The attention mask for the inputs.
+            input_ids: The input ids for the inputs.
+            labels: The labels for the head.
+            output_name: The name of the output to use.
+                Defaults to `self.output_name`.
+        """
+        if attention_mask is None:
+            if isinstance(input_ids, NestedTensor):
+                input_ids, attention_mask = input_ids.tensor, input_ids.mask
+            else:
+                if input_ids is None:
+                    raise ValueError(
+                        f"Either attention_mask or input_ids must be provided for {self.__class__.__name__} to work."
+                    )
+                if self.pad_token_id is None:
+                    raise ValueError(
+                        f"pad_token_id must be provided when attention_mask is not passed to {self.__class__.__name__}."
+                    )
+                attention_mask = input_ids.ne(self.pad_token_id)
+
+        if isinstance(outputs, (Mapping, ModelOutput)):
+            output = outputs[output_name or self.output_name]
+        elif isinstance(outputs, tuple):
+            output = outputs[0]
+        else:
+            raise ValueError(f"Unsupported type for outputs: {type(outputs)}")
+        output = output * attention_mask.unsqueeze(-1)
+
+        # remove cls token embeddings
+        if self.bos_token_id is not None:
+            output = output[..., 1:, :]
+            # process attention_mask and input_ids to make removal of eos token happy
+            attention_mask = attention_mask[..., 1:]
+            if input_ids is not None:
+                input_ids = input_ids[..., 1:]
+        # remove eos token embeddings
+        if self.eos_token_id is not None:
+            if input_ids is not None:
+                eos_mask = input_ids.ne(self.eos_token_id).to(output)
+            else:
+                last_valid_indices = attention_mask.sum(dim=-1)
+                seq_length = attention_mask.size(-1)
+                eos_mask = torch.arange(seq_length, device=output.device) == last_valid_indices.unsqueeze(1)
+            output = output * eos_mask[:, :, None]
+            output = output[..., :-1, :]
+
+        # make symmetric contact map
+        contact_map = output.unsqueeze(1) * output.unsqueeze(2)
+
+        return super().forward(contact_map, labels, **kwargs)
+
+
+class ResNet(nn.Module):
+    def __init__(
+        self,
+        num_layers: int,
+        hidden_size: int,
+        block: Type[BasicBlock | Bottleneck] | str = "auto",
+        num_channels: int | None = None,
+        num_labels: int = 1,
+        norm_layer: Callable[..., nn.Module] | None = None,
+        zero_init_residual: bool = True,
+    ) -> None:
+        tv.check()
+        super().__init__()
+
+        if block == "auto":
+            block = BasicBlock if num_layers < 50 else Bottleneck
+        elif block in ("basic", "BasicBlock"):
+            block = BasicBlock
+        elif block in ("bottleneck", "Bottleneck"):
+            block = Bottleneck
+        else:
+            raise ValueError(f"Unknown block type: {block}")
+        if num_channels is None:
+            num_channels = hidden_size // 10
+        if norm_layer is None:
+            norm_layer = LayerNorm2D
+
+        self.proj = nn.Conv2d(hidden_size, num_channels, kernel_size=3, stride=1, padding=1, bias=False)
+        self.norm = norm_layer(num_channels)
+        self.relu = nn.ReLU(inplace=True)
+        self.layers = nn.Sequential(
+            *[block(num_channels, num_channels, norm_layer=norm_layer) for _ in range(num_layers)]  # type: ignore
+        )
+        self.output = nn.Linear(num_channels, num_labels)
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
+            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+                nn.init.constant_(m.weight, 1)
+                nn.init.constant_(m.bias, 0)
+
+        # Zero-initialize the last BN in each residual branch,
+        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
+        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
+        if zero_init_residual:
+            for m in self.modules():
+                if isinstance(m, Bottleneck) and m.bn3.weight is not None:
+                    nn.init.constant_(m.bn3.weight, 0)  # type: ignore[arg-type]
+                elif isinstance(m, BasicBlock) and m.bn2.weight is not None:
+                    nn.init.constant_(m.bn2.weight, 0)  # type: ignore[arg-type]
+
+    def forward(self, x: Tensor) -> Tensor:
+        x = self.proj(x.transpose(1, 3))
+        x = self.norm(x)
+        x = self.relu(x)
+        x = self.layers(x)
+        x = self.output(x.transpose(1, 3))
+        return x
+
+
+class LayerNorm2D(nn.GroupNorm):
+    def __init__(self, num_features: int, eps: float = 1e-5, elementwise_affine: bool = True) -> None:
+        super().__init__(num_channels=num_features, eps=eps, affine=elementwise_affine, num_groups=1)
+        self.num_channels = num_features
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}(num_channels={self.num_channels}, eps={self.eps}, affine={self.affine})"
+
+
+def symmetrize(x):
+    "Make layer symmetric in final two dimensions, used for contact prediction."
+    return x + x.transpose(-1, -2)
+
+
+def average_product_correct(x):
+    "Perform average product correct, used for contact prediction."
+    a1 = x.sum(-1, keepdims=True)
+    a2 = x.sum(-2, keepdims=True)
+    a12 = x.sum((-1, -2), keepdims=True)
+
+    avg = a1 * a2
+    avg.div_(a12)  # in-place to reduce memory
+    normalized = x - avg
+    return normalized
diff --git a/multimolecule/module/heads/generic.py b/multimolecule/module/heads/generic.py
index d97950a2..ef0fa520 100644
--- a/multimolecule/module/heads/generic.py
+++ b/multimolecule/module/heads/generic.py
@@ -19,12 +19,11 @@
 from typing import TYPE_CHECKING
 from warnings import warn
 
-import torch
 from danling import NestedTensor
 from torch import Tensor, nn
 from transformers.activations import ACT2FN
 
-from ..criterions import Criterion
+from ..criterions import CriterionRegistry
 from .config import HeadConfig
 from .output import HeadOutput
 from .transform import HeadTransformRegistryHF
@@ -44,24 +43,25 @@ class PredictionHead(nn.Module):
     """
 
     num_labels: int
+    requires_attention: bool = False
 
     def __init__(self, config: PreTrainedConfig, head_config: HeadConfig | None = None):
         super().__init__()
         if head_config is None:
-            head_config = config.head
+            head_config = config.head or HeadConfig(num_labels=config.num_labels)
+        elif head_config.num_labels is None:
+            head_config.num_labels = config.num_labels
         self.config = head_config
         if self.config.hidden_size is None:
             self.config.hidden_size = config.hidden_size
-        if self.config.num_labels is None:
-            self.config.num_labels = config.num_labels
         if self.config.problem_type is None:
             self.config.problem_type = config.problem_type
-        self.num_labels = self.config.num_labels
+        self.num_labels = self.config.num_labels  # type: ignore[assignment]
         self.dropout = nn.Dropout(self.config.dropout)
         self.transform = HeadTransformRegistryHF.build(self.config)
-        self.decoder = nn.Linear(config.hidden_size, self.num_labels, bias=self.config.bias)
+        self.decoder = nn.Linear(self.config.hidden_size, self.num_labels, bias=self.config.bias)
         self.activation = ACT2FN[self.config.act] if self.config.act is not None else None
-        self.criterion = Criterion(self.config)
+        self.criterion = CriterionRegistry.build(self.config)
 
     def forward(self, embeddings: Tensor, labels: Tensor | None, **kwargs) -> HeadOutput:
         r"""
@@ -85,6 +85,6 @@ def forward(self, embeddings: Tensor, labels: Tensor | None, **kwargs) -> HeadOu
             if isinstance(labels, NestedTensor):
                 if isinstance(output, Tensor):
                     output = labels.nested_like(output, strict=False)
-                return HeadOutput(output, self.criterion(torch.cat(output.storage()), torch.cat(labels.storage())))
+                return HeadOutput(output, self.criterion(output.concat, labels.concat))
             return HeadOutput(output, self.criterion(output, labels))
         return HeadOutput(output)
diff --git a/multimolecule/module/heads/pretrain.py b/multimolecule/module/heads/pretrain.py
index 994cb8ca..c6968c4b 100644
--- a/multimolecule/module/heads/pretrain.py
+++ b/multimolecule/module/heads/pretrain.py
@@ -53,8 +53,8 @@ def __init__(
     ):
         super().__init__()
         if head_config is None:
-            head_config = config.lm_head if hasattr(config, "lm_head") else config.head  # type: ignore[assignment]
-        self.config: MaskedLMHeadConfig = head_config  # type: ignore[assignment]
+            head_config = (config.lm_head if hasattr(config, "lm_head") else config.head) or MaskedLMHeadConfig()
+        self.config: MaskedLMHeadConfig = head_config
         if self.config.hidden_size is None:
             self.config.hidden_size = config.hidden_size
         self.num_labels = config.vocab_size
@@ -97,6 +97,6 @@ def forward(
             if isinstance(labels, NestedTensor):
                 if isinstance(output, Tensor):
                     output = labels.nested_like(output, strict=False)
-                return HeadOutput(output, F.cross_entropy(torch.cat(output.storage()), torch.cat(labels.storage())))
+                return HeadOutput(output, F.cross_entropy(output.concat, labels.concat))
             return HeadOutput(output, F.cross_entropy(output.view(-1, self.num_labels), labels.view(-1)))
         return HeadOutput(output)
diff --git a/multimolecule/module/heads/registry.py b/multimolecule/module/heads/registry.py
index e5393e4e..6db3b680 100644
--- a/multimolecule/module/heads/registry.py
+++ b/multimolecule/module/heads/registry.py
@@ -14,6 +14,16 @@
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
 
-from chanfig import Registry
+from chanfig import ConfigRegistry as Registry_
+from torch import nn
+
+
+class Registry(Registry_):  # pylint: disable=too-few-public-methods
+    key = "type"
+
+    def build(self, config, head_config) -> nn.Module:  # type: ignore[override]
+        name = getattr(head_config, self.getattr("key"))
+        return self.init(self.lookup(name), config, head_config)  # type: ignore[arg-type]
+
 
 HeadRegistry = Registry(default_factory=Registry, fallback=True)
diff --git a/multimolecule/module/heads/utils.py b/multimolecule/module/heads/utils.py
index c5937c6d..cc1f3654 100644
--- a/multimolecule/module/heads/utils.py
+++ b/multimolecule/module/heads/utils.py
@@ -119,32 +119,3 @@ def unfold_kmer_embeddings(
             embedding = torch.cat([embedding, tensor[seq_len - 1][None, :]])
         output[index, : seq_len + nmers - 1] = embedding
     return output
-
-
-def rotate_half(x):
-    x1, x2 = x.chunk(2, dim=-1)
-    return torch.cat((-x2, x1), dim=-1)
-
-
-def apply_rotary_pos_emb(x, cos, sin):
-    cos = cos[:, :, : x.shape[-2], :]
-    sin = sin[:, :, : x.shape[-2], :]
-
-    return (x * cos) + (rotate_half(x) * sin)
-
-
-def symmetrize(x):
-    "Make layer symmetric in final two dimensions, used for contact prediction."
-    return x + x.transpose(-1, -2)
-
-
-def average_product_correct(x):
-    "Perform average product correct, used for contact prediction."
-    a1 = x.sum(-1, keepdims=True)
-    a2 = x.sum(-2, keepdims=True)
-    a12 = x.sum((-1, -2), keepdims=True)
-
-    avg = a1 * a2
-    avg.div_(a12)  # in-place to reduce memory
-    normalized = x - avg
-    return normalized
diff --git a/multimolecule/module/model.py b/multimolecule/module/model.py
new file mode 100644
index 00000000..256783be
--- /dev/null
+++ b/multimolecule/module/model.py
@@ -0,0 +1,89 @@
+# MultiMolecule
+# Copyright (C) 2024-Present  MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+
+from __future__ import annotations
+
+from chanfig import FlatDict
+from danling import NestedTensor
+from torch import Tensor, nn
+
+from .backbones import BackboneRegistry
+from .heads import HeadRegistry
+from .necks import NeckRegistry
+from .registry import ModelRegistry
+
+
+@ModelRegistry.register(default=True)
+class MultiMoleculeModel(nn.Module):
+    def __init__(
+        self,
+        backbone: dict,
+        heads: dict,
+        neck: dict | None = None,
+        max_length: int = 1024,
+        truncation: bool = False,
+    ):
+        super().__init__()
+
+        # Backbone
+        self.backbone = BackboneRegistry.build(**backbone)
+        backbone = self.backbone.config
+        out_channels = self.backbone.out_channels
+
+        # Neck
+        if neck:
+            num_discrete = self.backbone.num_discrete
+            num_continuous = self.backbone.num_continuous
+            embed_dim = self.backbone.sequence.config.hidden_size
+            attention_heads = self.backbone.sequence.config.num_attention_heads
+            neck.update(
+                {
+                    "num_discrete": num_discrete,
+                    "num_continuous": num_continuous,
+                    "embed_dim": embed_dim,
+                    "attention_heads": attention_heads,
+                    "max_length": max_length,
+                    "truncation": truncation,
+                }
+            )
+            self.neck = NeckRegistry.build(**neck)
+            out_channels = self.neck.out_channels
+        else:
+            self.neck = None
+
+        # Heads
+        for head in heads.values():
+            if "hidden_size" not in head or head["hidden_size"] is None:
+                head["hidden_size"] = out_channels
+        self.heads = nn.ModuleDict({name: HeadRegistry.build(backbone, head) for name, head in heads.items()})
+        if any(getattr(h, "requires_attention", False) for h in self.heads.values()):
+            self.backbone.sequence.config.output_attentions = True
+
+    def forward(
+        self,
+        sequence: NestedTensor | Tensor,
+        discrete: Tensor | None = None,
+        continuous: Tensor | None = None,
+        dataset: str | None = None,
+        **labels: NestedTensor | Tensor,
+    ) -> FlatDict:
+        ret = FlatDict()
+        output, _ = self.backbone(sequence, discrete, continuous)
+        if self.neck is not None:
+            output = self.neck(**output)
+        for task, label in labels.items():
+            ret[task] = self.heads[task](output, input_ids=sequence, labels=label)
+        return ret
diff --git a/multimolecule/module/necks/__init__.py b/multimolecule/module/necks/__init__.py
new file mode 100644
index 00000000..e8f1f7e2
--- /dev/null
+++ b/multimolecule/module/necks/__init__.py
@@ -0,0 +1,21 @@
+# MultiMolecule
+# Copyright (C) 2024-Present  MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+
+from .bert import BERTNeck
+from .cat import CatNeck
+from .registry import NeckRegistry
+
+__all__ = ["NeckRegistry", "CatNeck", "BERTNeck"]
diff --git a/multimolecule/module/necks/bert.py b/multimolecule/module/necks/bert.py
new file mode 100644
index 00000000..1360f0dd
--- /dev/null
+++ b/multimolecule/module/necks/bert.py
@@ -0,0 +1,102 @@
+# MultiMolecule
+# Copyright (C) 2024-Present  MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+
+from __future__ import annotations
+
+import torch
+from chanfig import FlatDict
+from danling.modules import TransformerEncoder, TransformerEncoderLayer
+from torch import Tensor, nn
+
+from .registry import NeckRegistry
+
+MAX_LENGTH = 1024
+
+
+@NeckRegistry.register("bert")
+class BERTNeck(nn.Module):
+    def __init__(  # pylint: disable=keyword-arg-before-vararg
+        self,
+        num_discrete: int,
+        num_continuous: int,
+        embed_dim: int,
+        attention_heads: int,
+        num_layers: int = 6,
+        max_length: int | None = None,
+        truncation: bool = False,
+        dropout: float = 0.1,
+        *args,
+        **kwargs,
+    ):
+        super().__init__()
+        self.cls_token_dis = nn.Parameter(torch.zeros(embed_dim))
+        self.cls_token_con = nn.Parameter(torch.zeros(embed_dim))
+        if max_length is None:
+            if truncation:
+                max_length = MAX_LENGTH + 1 + num_discrete + 1 + num_continuous
+            else:
+                max_length = MAX_LENGTH * 4 + 1 + num_discrete + 1 + num_continuous
+        self.max_length = max_length
+        self.pos_embed = nn.Parameter(torch.zeros(1, self.max_length, embed_dim))
+        bert_layer = TransformerEncoderLayer(
+            embed_dim, attention_heads, *args, dropout=dropout, attn_dropout=dropout, ffn_dropout=dropout, **kwargs
+        )
+        self.bert = TransformerEncoder(bert_layer, num_layers)
+        self.out_channels = embed_dim
+        nn.init.normal_(self.pos_embed, std=0.02)
+        nn.init.trunc_normal_(self.cls_token_dis, std=0.2)
+        nn.init.trunc_normal_(self.cls_token_con, std=0.2)
+
+    def forward(
+        self,
+        cls_token: Tensor | None = None,
+        all_tokens: Tensor | None = None,
+        discrete: Tensor | None = None,
+        continuous: Tensor | None = None,
+    ) -> FlatDict:
+        ret = FlatDict()
+        if cls_token is not None:
+            ret["cls_token"] = self._forward(cls_token, discrete, continuous)
+        if all_tokens is not None:
+            ret["all_tokens"] = self._forward(all_tokens, discrete, continuous)
+        return ret
+
+    def _forward(
+        self,
+        sequence: Tensor,
+        discrete: Tensor | None = None,
+        continuous: Tensor | None = None,
+    ) -> Tensor:
+        if sequence is None:
+            raise ValueError("sequence should not be None.")
+        if sequence.dim() == 2:
+            sequence = sequence[:, None]
+        batch_size, seq_len, _ = sequence.shape
+        output = sequence
+        if discrete is not None:
+            cls_token_dis = self.cls_token_dis.expand(batch_size, 1, -1)
+            output = torch.cat((output, cls_token_dis, discrete), dim=1)
+        if continuous is not None:
+            cls_token_con = self.cls_token_con.expand(batch_size, -1)[:, None]
+            output = torch.cat((output, cls_token_con, continuous), dim=1)
+        all_len = output.shape[1]
+        if all_len > self.pos_embed.shape[1]:
+            raise ValueError("sequence length is out of range.")
+        output = output + self.pos_embed[:, 0:all_len, :]
+        output = self.bert(output)[0][:, 0:seq_len, :]
+        if seq_len == 1:
+            output = output.squeeze(1)
+        return output
diff --git a/multimolecule/module/necks/cat.py b/multimolecule/module/necks/cat.py
new file mode 100644
index 00000000..d5165a92
--- /dev/null
+++ b/multimolecule/module/necks/cat.py
@@ -0,0 +1,43 @@
+# MultiMolecule
+# Copyright (C) 2024-Present  MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+
+from __future__ import annotations
+
+import torch
+from chanfig import FlatDict
+from torch import Tensor
+
+from .registry import NeckRegistry
+
+
+@NeckRegistry.register("cat")
+class CatNeck:  # pylint: disable=too-few-public-methods
+    def __init__(self, embed_dim: int):
+        self.out_channels = embed_dim * 2
+
+    def __call__(
+        self,
+        cls_token: Tensor | None = None,
+        all_tokens: Tensor | None = None,
+        discrete: Tensor | None = None,
+        continuous: Tensor | None = None,
+    ) -> FlatDict:
+        ret = FlatDict()
+        if cls_token is not None:
+            ret.cls_token = torch.cat(tuple(i for i in (cls_token, discrete, continuous) if i is not None), -1)
+        if all_tokens is not None:
+            ret.all_tokens = torch.cat(tuple(i for i in (all_tokens, discrete, continuous) if i is not None), -1)
+        return ret
diff --git a/multimolecule/module/necks/registry.py b/multimolecule/module/necks/registry.py
new file mode 100644
index 00000000..c024227c
--- /dev/null
+++ b/multimolecule/module/necks/registry.py
@@ -0,0 +1,21 @@
+# MultiMolecule
+# Copyright (C) 2024-Present  MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+
+from __future__ import annotations
+
+from chanfig import Registry
+
+NeckRegistry = Registry()
diff --git a/multimolecule/module/registry.py b/multimolecule/module/registry.py
new file mode 100644
index 00000000..b0332463
--- /dev/null
+++ b/multimolecule/module/registry.py
@@ -0,0 +1,35 @@
+# MultiMolecule
+# Copyright (C) 2024-Present  MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+
+from __future__ import annotations
+
+from chanfig import Registry as Registry_
+from torch import nn
+
+from .backbones import BackboneRegistry
+from .backbones.sequences import SequenceRegistry
+from .heads import HeadRegistry
+from .necks import NeckRegistry
+
+
+class Registry(Registry_):  # pylint: disable=too-few-public-methods
+    def build(self, *args, **kwargs) -> nn.Module:
+        return super().build(*args, **kwargs)
+
+
+ModelRegistry = Registry()
+
+__all__ = ["ModelRegistry", "BackboneRegistry", "SequenceRegistry", "NeckRegistry", "HeadRegistry"]
diff --git a/multimolecule/runners/README.md b/multimolecule/runners/README.md
new file mode 100644
index 00000000..bb1000ad
--- /dev/null
+++ b/multimolecule/runners/README.md
@@ -0,0 +1,9 @@
+---
+authors:
+  - Zhiyuan Chen
+date: 2024-05-04
+---
+
+# runners
+
+`runners` provide an easy-to-use interface for running experiments.
diff --git a/multimolecule/runners/__init__.py b/multimolecule/runners/__init__.py
new file mode 100644
index 00000000..70fa4076
--- /dev/null
+++ b/multimolecule/runners/__init__.py
@@ -0,0 +1,20 @@
+# MultiMolecule
+# Copyright (C) 2024-Present  MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+
+from .config import MultiMoleculeConfig
+from .runner import MultiMoleculeRunner
+
+__all__ = ["MultiMoleculeConfig", "MultiMoleculeRunner"]
diff --git a/multimolecule/runners/base_runner.py b/multimolecule/runners/base_runner.py
new file mode 100644
index 00000000..1e02d96e
--- /dev/null
+++ b/multimolecule/runners/base_runner.py
@@ -0,0 +1,286 @@
+# MultiMolecule
+# Copyright (C) 2024-Present  MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+
+import math
+import os
+from functools import cached_property, partial
+from typing import Any, Tuple
+from warnings import warn
+
+import danling as dl
+import torch
+from art import text2art
+from chanfig import NestedDict
+from danling import MultiTaskMetrics
+from datasets import disable_progress_bars, get_dataset_split_names
+from torch import nn, optim
+from torch.nn import functional as F
+from torch.utils import data
+from transformers import AutoTokenizer
+
+from multimolecule import defaults
+from multimolecule.data import Dataset, DistributedMultiTaskSampler, MultiTaskDataset, MultiTaskSampler
+from multimolecule.module import HeadConfig, ModelRegistry
+
+from .config import MultiMoleculeConfig
+from .metrics import MetricRegistry
+
+disable_progress_bars()
+
+
+class BaseRunner(dl.BaseRunner):
+
+    all_datasets: NestedDict
+
+    def __init__(self, config: MultiMoleculeConfig):
+        if config.art:
+            print(text2art("MultiMolecule", "rand-large"))
+        super().__init__(config)
+        self.name = config.name
+        self.tokenizer = AutoTokenizer.from_pretrained(self.config.pretrained)
+        self.build_datasets()
+        self.build_dataloaders()
+        self.model = ModelRegistry.build(**self.network)
+        if self.config.get("checkpoint"):
+            ckpt = dl.load(self.config.checkpoint)
+            model = ckpt.get("model", ckpt)
+            parameters = self.model.load_state_dict(model, strict=False)
+            if parameters.missing_keys:
+                raise ValueError(f"Missing keys in model: {parameters.missing_keys}")
+            if parameters.unexpected_keys:
+                warn(f"Unexpected keys in model: {parameters.unexpected_keys}")
+        self.model = self.model.to(self.device)
+        if self.distributed:
+            self.model = nn.parallel.DistributedDataParallel(
+                self.model, find_unused_parameters=True, bucket_cap_mb=32, gradient_as_bucket_view=True
+            )
+        if self.config.training:
+            if self.config.optim and not (self.config.platform == "deepspeed" and self.config.deepspeed.optimizer):
+                self.optimizer = getattr(optim, self.config.optim.pop("name"))(
+                    params=self.model.parameters(), **self.config.optim
+                )
+            if self.config.sched and not (self.config.platform == "deepspeed" and self.config.deepspeed.scheduler):
+                self.scheduler = dl.optim.LRScheduler(self.optimizer, total_steps=self.total_steps, **self.config.sched)
+        self.metrics = self.build_metrics()
+
+    def __post_init__(self):
+        super().__post_init__()
+        self.yaml(os.path.join(self.dir, "trainer.yaml"))
+        print(self)
+        print(self.get_dataset_lengths())
+
+    def train_step(self, data) -> Tuple[Any, torch.Tensor]:
+        with self.autocast(), self.accumulate():
+            pred = self.model(**data)
+            loss = self.loss_fn(pred, data)
+            self.advance(loss)
+            self.metric_fn(pred, data)
+        return pred, loss
+
+    def evaluate_step(self, data) -> Tuple[Any, torch.Tensor]:
+        pred = self.model(**data)
+        loss = self.loss_fn(pred, data)
+        self.metric_fn(pred, data)
+        return pred, loss
+
+    def loss_fn(self, pred, data):
+        if self.balance == "rlw":
+            loss = torch.stack([p["loss"] for p in pred.values()])
+            weight = F.softmax(torch.randn(len(pred), device=loss.device, dtype=loss.dtype), dim=-1)
+            return loss.T @ weight
+        if self.balance == "gls":
+            return math.prod(p["loss"] for p in pred.values()) ** (1 / len(pred))
+        if self.balance != "ew":
+            warn(f"Unknown balance method {self.balance}, using equal weighting.")
+        return sum(p["loss"] for p in pred.values()) / len(pred)
+
+    def metric_fn(self, pred, data):
+        metric = self.metrics[data["dataset"]] if "dataset" in data else self.metrics
+        metric.update({t: (p["logits"], data[t]) for t, p in pred.items()})
+
+    @cached_property
+    def tasks(self):
+        if not self.datasets:
+            raise ValueError("No datasets found")
+        if "train" in self.datasets:
+            return self.datasets.train.tasks
+        return next(iter(self.datasets.values())).tasks
+
+    @cached_property
+    def dataset_tasks(self):
+        if not self.datasets:
+            raise ValueError("No datasets found")
+        dataset = self.datasets.train if "train" in self.datasets else next(iter(self.datasets.values()))
+        tasks = self.tasks
+        dataset_tasks = dataset.dataset_tasks if isinstance(dataset, MultiTaskDataset) else dataset.tasks
+        for dataset in self.datasets.values():
+            if isinstance(dataset, MultiTaskDataset):
+                for dataset_name, tasks_ in dataset.dataset_tasks.items():
+                    for task_name, task_ in tasks_.items():
+                        if task_name not in tasks:
+                            raise ValueError(f"Task {task_name} of dataset {dataset_name} is not defined")
+                        task = tasks[task_name]
+                        if task != task_:
+                            warn(
+                                f"Task {task_name} of dataset {dataset_name} has different configurations "
+                                "compared to training data, using training configuration.\n"
+                                "This may lead to unexpected behavior.",
+                            )
+                        if dataset_name not in dataset_tasks:
+                            dataset_tasks[dataset_name] = NestedDict()
+                        if task_name not in dataset_tasks[dataset_name]:
+                            dataset_tasks[dataset_name][task_name] = task
+            else:
+                for task_name, task_ in dataset.tasks.items():
+                    if task_name not in tasks:
+                        raise ValueError(f"Task {task_name} is not defined")
+                    task = tasks[task_name]
+                    if task != task_:
+                        warn(
+                            f"Task {task_name} has different configurations "
+                            "compared to training data, using training configuration.\n"
+                            "This may lead to unexpected behavior.",
+                        )
+                    if task_name not in dataset_tasks:
+                        dataset_tasks[task_name] = task
+        return dataset_tasks
+
+    @cached_property
+    def network(self):
+        heads = {
+            name: HeadConfig(num_labels=task.num_labels, problem_type=task.type, type=task.level)
+            for name, task in self.tasks.items()
+        }
+        if "heads" not in self.config.network:
+            self.config.network.heads = NestedDict(heads)
+        else:
+            self.config.network.heads.merge(heads, overwrite=False)
+        return self.config.network
+
+    def build_datasets(self):
+        if "data" in self.config:
+            self.datasets = self.all_datasets = self._build_dataset(self.config.data)
+            return
+        if "datas" in self.config:
+            self.all_datasets = NestedDict(
+                {name: self._build_dataset(config, name) for name, config in self.config.datas.items()}
+            )
+            datasets = {
+                subkey: {key: subdict[subkey] for key, subdict in self.all_datasets.items() if subkey in subdict}
+                for subkey in {k for v in self.all_datasets.values() for k in v}
+            }
+            self.datasets = NestedDict({split: MultiTaskDataset(datas) for split, datas in datasets.items()})
+            return
+        raise ValueError("No data configuration found")
+
+    def _build_dataset(self, config: NestedDict, name: str | None = None) -> NestedDict:
+        name = name or config.root
+        print(f"Building dataset {name}")
+        dataset = NestedDict()
+        ignored_key = defaults.DATASET_SPLITS + ["root"]
+        dataset_factory = partial(
+            Dataset,
+            tokenizer=self.tokenizer,
+            **{k: v for k, v in config.items() if k not in ignored_key},
+        )
+        if os.path.isdir(config.root):
+            if "train" in config:
+                dataset.train = dataset_factory(os.path.join(config.root, config.train), split="train")
+            if "validation" in config:
+                dataset.validation = dataset_factory(os.path.join(config.root, config.validation), split="validation")
+            if "test" in config:
+                dataset.test = dataset_factory(os.path.join(config.root, config.test), split="test")
+            if "evaluation" in config:
+                dataset.evaluation = dataset_factory(os.path.join(config.root, config.evaluation), split="evaluation")
+            if "inference" in config:
+                dataset.inference = dataset_factory(os.path.join(config.root, config.inference), split="inference")
+        else:
+            splits = get_dataset_split_names(config.root)
+            existing_splits = {k for k in defaults.DATASET_SPLITS if config.get(k) is not None}
+            if not existing_splits:
+                if "train" in splits:
+                    config.train = "train"
+                if "validation" in splits:
+                    config.validation = "validation"
+                if "test" in splits:
+                    config.test = "test"
+            if config.get("train") is not None:
+                dataset.train = dataset_factory(config.root, split="train")
+            if config.get("validation") is not None:
+                dataset.validation = dataset_factory(config.root, split="validation")
+            if config.get("test") is not None:
+                dataset.test = dataset_factory(config.root, split="test")
+            if config.get("evaluation") is not None:
+                dataset.evaluation = dataset_factory(config.root, split=config.get("evaluation"))
+            if config.get("inference") is not None:
+                dataset.inference = dataset_factory(config.root, split=config.get("inference"))
+        if not dataset:
+            raise ValueError("No datasets built. This is likely due to missing data paths in Config.")
+        return dataset
+
+    def build_dataloaders(self):
+        datasets = {k: d for k, d in self.datasets.items() if k not in self.dataloaders}
+        default_kwargs = self.config.get("dataloader", NestedDict())
+        dataloader_kwargs = NestedDict({k: default_kwargs.pop(k) for k in self.datasets if k in default_kwargs})
+        for k, d in datasets.items():
+            dataloader_kwargs.setdefault(k, NestedDict())
+            dataloader_kwargs[k].merge(default_kwargs, overwrite=False)
+            batch_size = dataloader_kwargs[k].pop("batch_size")
+            shuffle = dataloader_kwargs[k].pop("shuffle", getattr(d, "train", True))
+            drop_last = dataloader_kwargs[k].pop("drop_last", not getattr(d, "train", True))
+            if isinstance(d, MultiTaskDataset):
+                batch_sampler = (
+                    DistributedMultiTaskSampler(d, batch_size, shuffle=shuffle, drop_last=drop_last)
+                    if self.distributed
+                    else MultiTaskSampler(d, batch_size, shuffle=shuffle, drop_last=drop_last)
+                )
+            else:
+                sampler = (
+                    data.distributed.DistributedSampler(d, shuffle=shuffle)
+                    if self.distributed
+                    else data.RandomSampler(d) if shuffle else data.SequentialSampler(d)
+                )
+                batch_sampler = data.BatchSampler(sampler, batch_size, drop_last=drop_last)
+            self.dataloaders[k] = data.DataLoader(
+                d, batch_sampler=batch_sampler, collate_fn=self.collate_fn, **dataloader_kwargs[k]
+            )
+
+    def build_metrics(self) -> MultiTaskMetrics:
+        return MultiTaskMetrics(
+            {
+                name: MetricRegistry.build(type=task.type, num_labels=task.num_labels)
+                for name, task in self.dataset_tasks.all_items()
+            }
+        )
+
+    def collate_fn(self, batch):
+        return {k: v.to(self.device) if hasattr(v, "to") else v for k, v in batch.items()}
+
+    def get_dataset_lengths(self) -> str:
+        repr = "dataset lengths:\n"
+        longest_name = max(len(name) for name in self.all_datasets.keys())
+        for name, dataset in self.all_datasets.items():
+            if isinstance(dataset, NestedDict):
+                repr += f"{name}:"
+                if len(name) < longest_name:
+                    repr += " " * (longest_name - len(name))
+                repr += "\t\t"
+                for split, d in dataset.items():
+                    repr += f"  {split}: {len(d)}\t"
+            else:
+                repr += f"{name}: {len(dataset)}\t"
+            repr += "\n"
+        return repr
diff --git a/multimolecule/runners/config.py b/multimolecule/runners/config.py
new file mode 100644
index 00000000..0edf83a5
--- /dev/null
+++ b/multimolecule/runners/config.py
@@ -0,0 +1,83 @@
+# MultiMolecule
+# Copyright (C) 2024-Present  MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+
+from __future__ import annotations
+
+import os
+from pathlib import Path
+from typing import List
+
+from chanfig import Config
+from transformers import PretrainedConfig
+
+
+class DataConfig(Config):
+    root: str = "."
+    train: str | None
+    validation: str | None
+    test: str | None
+    feature_cols: List | None = None
+    label_cols: List | None = None
+    truncation: bool = True
+
+
+class OptimConfig(Config):
+    name: str = "AdamW"
+    lr: float = 1e-3
+    weight_decay: float = 1e-2
+
+
+class MultiMoleculeConfig(Config):
+
+    balance: str = "ew"
+    platform: str = "torch"
+    training: bool = True
+
+    pretrained: str
+    use_pretrained: bool = True
+    transformers: PretrainedConfig
+    epoch_end: int = 20
+
+    data: DataConfig
+
+    tensorboard: bool = True
+    save_interval: int = 10
+    seed: int = 1016
+    art: bool = True
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.datas = Config(default_factory=DataConfig)
+        self.dataloader.batch_size = 32
+        self.optim = OptimConfig()
+        self.sched.final_lr = 0
+
+    def post(self):
+        if "data" in self:
+            if self.datas:
+                raise ValueError("Only one of `data` or `datas` can be specified, but not both")
+            del self.datas
+        self["network.backbone.sequence.name"] = self.pretrained
+        self["network.backbone.sequence.use_pretrained"] = self.use_pretrained
+        pretrained = self.pretrained
+        if os.path.exists(self.pretrained):
+            path = Path(pretrained)
+            if os.path.isfile(pretrained):
+                pretrained = str(path.relative_to(path.parents[1]).with_suffix(""))
+            else:
+                pretrained = path.stem
+
+        self.name = f"{self.pretrained}-{self.optim.lr}@{self.optim.name}-{self.seed}"
diff --git a/multimolecule/runners/metrics.py b/multimolecule/runners/metrics.py
new file mode 100644
index 00000000..da584cbc
--- /dev/null
+++ b/multimolecule/runners/metrics.py
@@ -0,0 +1,37 @@
+# MultiMolecule
+# Copyright (C) 2024-Present  MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+
+from chanfig import Registry as Registry_
+from danling.metrics import binary_metrics, multiclass_metrics, multilabel_metrics, regression_metrics
+
+
+class Registry(Registry_):
+
+    def build(self, type, num_labels: int | None = None, **kwargs):
+        if type == "multilabel":
+            return self.init(self.lookup(type), num_labels=num_labels, **kwargs)
+        if type == "multiclass":
+            return self.init(self.lookup(type), num_classes=num_labels, **kwargs)
+        if type == "regression":
+            return self.init(self.lookup(type), num_outputs=num_labels, **kwargs)
+        return self.init(self.lookup(type), **kwargs)
+
+
+MetricRegistry = Registry(key="type")
+MetricRegistry.register(binary_metrics, "binary")
+MetricRegistry.register(multiclass_metrics, "multiclass")
+MetricRegistry.register(multilabel_metrics, "multilabel")
+MetricRegistry.register(regression_metrics, "regression")
diff --git a/multimolecule/runners/runner.py b/multimolecule/runners/runner.py
new file mode 100644
index 00000000..97cab4e7
--- /dev/null
+++ b/multimolecule/runners/runner.py
@@ -0,0 +1,42 @@
+# MultiMolecule
+# Copyright (C) 2024-Present  MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+
+import danling as dl
+
+from .base_runner import BaseRunner
+
+
+class MultiMoleculeRunner(type):
+    def __new__(cls, config):
+        if config.get("platform", "torch") == "torch":
+            return TorchRunner(config)
+        if config.platform == "deepspeed":
+            return DeepSpeedRunner(config)
+        if config.platform == "accelerate":
+            return AccelerateRunner(config)
+        raise ValueError(f"Unsupported platform: {config.platform}")
+
+
+class TorchRunner(BaseRunner, dl.TorchRunner):
+    pass
+
+
+class DeepSpeedRunner(BaseRunner, dl.DeepSpeedRunner):
+    pass
+
+
+class AccelerateRunner(BaseRunner, dl.AccelerateRunner):
+    pass
diff --git a/multimolecule/tasks/task.py b/multimolecule/tasks/task.py
index e2473ab0..5d435f83 100644
--- a/multimolecule/tasks/task.py
+++ b/multimolecule/tasks/task.py
@@ -34,9 +34,8 @@ class TaskType(StrEnum):
 
 class TaskLevel(StrEnum):
     Sequence = auto()
-    Nucleotide = auto()
+    Token = auto()
     Contact = auto()
-    # Token = auto()
 
 
 @dataclass
diff --git a/multimolecule/train.py b/multimolecule/train.py
new file mode 100644
index 00000000..f146bc97
--- /dev/null
+++ b/multimolecule/train.py
@@ -0,0 +1,20 @@
+# MultiMolecule
+# Copyright (C) 2024-Present  MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+
+from .apis import train
+
+if __name__ == "__main__":
+    train()
diff --git a/pyproject.toml b/pyproject.toml
index 972b1d8a..ee4c525d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -45,9 +45,11 @@ dynamic = [
 ]
 dependencies = [
   "accelerate",
+  "art",
   "chanfig>=0.0.105",
-  "danling>=0.3.11",
+  "danling>=0.4.0b1",
   "datasets",
+  'StrEnum; python_version < "3.11"',
   "torch",
   "transformers",
 ]
diff --git a/tests/data/test_dataset.py b/tests/data/test_dataset.py
index 5ddfb3c4..318e289e 100644
--- a/tests/data/test_dataset.py
+++ b/tests/data/test_dataset.py
@@ -126,7 +126,7 @@ def test_spliceai(self, preprocess: bool):
             feature_cols=feature_cols,
             label_cols=label_cols,
         )
-        task = Task(type=TaskType.Binary, level=TaskLevel.Nucleotide, num_labels=1)
+        task = Task(type=TaskType.Binary, level=TaskLevel.Token, num_labels=1)
         elem = dataset[0]
         assert isinstance(elem["sequence"], torch.LongTensor)
         assert isinstance(elem["splice_ai"], torch.LongTensor)
@@ -175,20 +175,18 @@ def test_rna_task_recognition_json(self):
         assert dataset.tasks["sequence_regression"] == Task(
             type=TaskType.Regression, level=TaskLevel.Sequence, num_labels=1
         )
-        assert dataset.tasks["nucleotide_binary"] == Task(
-            type=TaskType.Binary, level=TaskLevel.Nucleotide, num_labels=1
-        )
+        assert dataset.tasks["nucleotide_binary"] == Task(type=TaskType.Binary, level=TaskLevel.Token, num_labels=1)
         assert dataset.tasks["nucleotide_multiclass"] == Task(
-            type=TaskType.MultiClass, level=TaskLevel.Nucleotide, num_labels=5
+            type=TaskType.MultiClass, level=TaskLevel.Token, num_labels=5
         )
         assert dataset.tasks["nucleotide_multilabel"] == Task(
-            type=TaskType.MultiLabel, level=TaskLevel.Nucleotide, num_labels=5
+            type=TaskType.MultiLabel, level=TaskLevel.Token, num_labels=5
         )
         assert dataset.tasks["nucleotide_multireg"] == Task(
-            type=TaskType.Regression, level=TaskLevel.Nucleotide, num_labels=5
+            type=TaskType.Regression, level=TaskLevel.Token, num_labels=5
         )
         assert dataset.tasks["nucleotide_regression"] == Task(
-            type=TaskType.Regression, level=TaskLevel.Nucleotide, num_labels=1
+            type=TaskType.Regression, level=TaskLevel.Token, num_labels=1
         )
         assert dataset.tasks["contact_binary"] == Task(type=TaskType.Binary, level=TaskLevel.Contact, num_labels=1)
         assert dataset.tasks["contact_multiclass"] == Task(