diff --git a/docs/docs/data/multitask.md b/docs/docs/data/multitask.md
new file mode 100644
index 00000000..054c6205
--- /dev/null
+++ b/docs/docs/data/multitask.md
@@ -0,0 +1,9 @@
+---
+authors:
+ - Zhiyuan Chen
+date: 2024-05-04
+---
+
+# MultiTask
+
+::: multimolecule.data.multitask
diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml
index 53875f32..036d8c96 100644
--- a/docs/mkdocs.yml
+++ b/docs/mkdocs.yml
@@ -12,6 +12,7 @@ nav:
- data:
- data/index.md
- Dataset: data/dataset.md
+ - multitask: data/multitask.md
- datasets:
- datasets/index.md
- DNA:
diff --git a/multimolecule/data/__init__.py b/multimolecule/data/__init__.py
index 62196c10..f2366d77 100644
--- a/multimolecule/data/__init__.py
+++ b/multimolecule/data/__init__.py
@@ -15,6 +15,14 @@
# along with this program. If not, see .
from .dataset import Dataset
+from .multitask import DistributedMultiTaskSampler, MultiTaskDataset, MultiTaskSampler
from .utils import no_collate
-__all__ = ["Dataset", "no_collate"]
+__all__ = [
+ "Dataset",
+ "PandasDataset",
+ "MultiTaskDataset",
+ "MultiTaskSampler",
+ "DistributedMultiTaskSampler",
+ "no_collate",
+]
diff --git a/multimolecule/data/multitask.py b/multimolecule/data/multitask.py
new file mode 100644
index 00000000..70a93313
--- /dev/null
+++ b/multimolecule/data/multitask.py
@@ -0,0 +1,191 @@
+# MultiMolecule
+# Copyright (C) 2024-Present MultiMolecule
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# any later version.
+
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+
+from __future__ import annotations
+
+from bisect import bisect_right
+from collections.abc import Mapping, Sequence
+from copy import deepcopy
+from random import choices
+
+from chanfig import NestedDict
+from torch import distributed as dist
+from torch.utils import data
+
+from .dataset import Dataset
+
+
+class MultiTaskDataset(data.ConcatDataset):
+
+ datasets: Mapping
+ dataset_keys: Sequence[str]
+ dataset_values: Sequence[Dataset]
+
+ def __init__(self, datasets: Mapping) -> None:
+ for key, dataset in datasets.items():
+ if not isinstance(dataset, Dataset):
+ raise TypeError(f"Dataset {key} should be an instance of Dataset")
+ self.datasets = datasets
+ if not len(self.datasets) > 0:
+ raise ValueError("MultiTaskDataset should contain at least one dataset")
+ self.dataset_keys, self.dataset_values = zip(*self.datasets.items())
+ self.cumulative_sizes = self.cumsum(self.dataset_values)
+
+ def __getitems__(self, key: Sequence[int]) -> Mapping:
+ dataset_idx = bisect_right(self.cumulative_sizes, key[0])
+ if dataset_idx == 0:
+ sample_idx = key
+ else:
+ sample_idx = [i - self.cumulative_sizes[dataset_idx - 1] for i in key]
+ batch = self.dataset_values[dataset_idx][sample_idx]
+ batch["dataset"] = self.dataset_keys[dataset_idx]
+ return batch
+
+ @property
+ def tasks(self) -> NestedDict:
+ tasks = self.dataset_values[0].tasks
+ for dataset in self.dataset_values[1:]:
+ for n, t in dataset.tasks.items():
+ if n not in tasks:
+ tasks[n] = t
+ elif tasks[n] != t:
+ raise ValueError(f"Task {n} has different configurations across datasets")
+ return tasks
+
+ @property
+ def dataset_tasks(self) -> NestedDict:
+ return NestedDict({k: v.tasks for k, v in self.datasets.items()})
+
+ def __repr__(self) -> str:
+ return f"MultiTaskDataset({', '.join([str(d) for d in self.datasets])})"
+
+
+class MultiTaskSampler(data.BatchSampler):
+ r"""
+ Ensure all items in a batch comes from the same dataset.
+
+ Arguments:
+ sampler (Sampler): Base sampler.
+ batch_size (int): Size of mini-batch.
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
+ its size would be less than ``batch_size``
+ """
+
+ datasets: Sequence[Dataset]
+
+ def __init__( # pylint: disable=super-init-not-called
+ self,
+ dataset: MultiTaskDataset,
+ batch_size: int,
+ shuffle: bool = True,
+ drop_last: bool = False,
+ sampler_cls: type[data.Sampler] | None = None,
+ weights: list[int] | None = None,
+ ) -> None:
+ self.datasets = dataset.dataset_values
+ self.batch_size = batch_size
+ self.drop_last = drop_last
+ self.shuffle = shuffle
+ if sampler_cls is None:
+ sampler_cls = data.RandomSampler if shuffle else data.SequentialSampler
+ self.samplers = [sampler_cls(d) for d in self.datasets] # type: ignore
+ self.dataset_sizes = [len(d) for d in self.datasets] # type: ignore
+ self.cumulative_sizes = dataset.cumulative_sizes
+ self.num_datasets = len(self.datasets)
+ self.weights = weights if weights is not None else self.dataset_sizes
+
+ def __iter__(self):
+ sampler_iters = [(i, iter(s)) for i, s in enumerate(self.samplers)]
+ sampler_weights = deepcopy(self.weights)
+ sampler_idx = 0
+ # Implemented based on the benchmarking in https://github.com/pytorch/pytorch/pull/76951
+ if self.drop_last:
+ while sampler_iters:
+ if self.shuffle:
+ sampler_idx = choices(range(len(sampler_iters)), weights=sampler_weights)[0]
+ sampler_id, sampler_iter = sampler_iters[sampler_idx]
+ cumulative_size = self.cumulative_sizes[sampler_id - 1] if sampler_id > 0 else 0
+ try:
+ batch = [next(sampler_iter) + cumulative_size for _ in range(self.batch_size)]
+ yield batch
+ except StopIteration:
+ sampler_iters.pop(sampler_idx)
+ sampler_weights.pop(sampler_idx)
+ else:
+ while sampler_iters:
+ if self.shuffle:
+ sampler_idx = choices(range(len(sampler_iters)), weights=sampler_weights)[0]
+ sampler_id, sampler_iter = sampler_iters[sampler_idx]
+ cumulative_size = self.cumulative_sizes[sampler_id - 1] if sampler_id > 0 else 0
+ batch = [0] * self.batch_size
+ idx_in_batch = 0
+ try:
+ for _ in range(self.batch_size):
+ batch[idx_in_batch] = next(sampler_iter) + cumulative_size
+ idx_in_batch += 1
+ yield batch
+ idx_in_batch = 0 # noqa: SIM113
+ batch = [0] * self.batch_size
+ except StopIteration:
+ sampler_iters.pop(sampler_idx)
+ sampler_weights.pop(sampler_idx)
+ if idx_in_batch > 0:
+ yield batch[:idx_in_batch]
+
+ def __len__(self):
+ batch_size = self.batch_size
+ if self.drop_last:
+ return sum(len(d) // batch_size for d in self.datasets)
+ return sum((len(d) + batch_size - 1) // batch_size for d in self.datasets)
+
+
+class DistributedMultiTaskSampler(MultiTaskSampler): # pylint: disable=too-few-public-methods
+ r"""
+ Distributed version of MultiTaskSampler, which ensures that each batch contains
+ data from only one dataset.
+
+ See Also:
+ [MultiTaskSampler][MultiTaskSampler]
+ """
+
+ def __init__(
+ self,
+ dataset: MultiTaskDataset,
+ batch_size: int,
+ shuffle: bool = True,
+ drop_last: bool = False,
+ sampler_cls: type[data.Sampler] = data.RandomSampler,
+ weights: list[int] | None = None,
+ ) -> None:
+ super().__init__(dataset, batch_size, shuffle, drop_last, sampler_cls, weights)
+ self.samplers = [data.DistributedSampler(d, shuffle=shuffle, drop_last=drop_last) for d in self.datasets]
+
+ def set_epoch(self, epoch):
+ for s in self.samplers:
+ s.set_epoch(epoch)
+
+ def __len__(self):
+ batch_size = self.batch_size * self.world_size
+ if self.drop_last:
+ return sum(len(d) // batch_size for d in self.datasets)
+ return sum((len(d) + batch_size - 1) // batch_size for d in self.datasets)
+
+ @property
+ def world_size(self) -> int:
+ r"""Return the number of processes in the current process group."""
+ if dist.is_available() and dist.is_initialized():
+ return dist.get_world_size()
+ return 1
diff --git a/multimolecule/runners/config.py b/multimolecule/runners/config.py
index 34675789..784c5d86 100644
--- a/multimolecule/runners/config.py
+++ b/multimolecule/runners/config.py
@@ -30,6 +30,7 @@ class DataConfig(Config):
test: str | None
feature_cols: List | None = None
label_cols: List | None = None
+ truncation: bool = True
def post(self):
if "train" in self:
@@ -40,6 +41,12 @@ def post(self):
self.test = os.path.join(self.root, self.test)
+class OptimConfig(Config):
+ name: str = "AdamW"
+ lr: float = 1e-3
+ weight_decay: float = 1e-2
+
+
class MultiMoleculeConfig(Config):
pretrained: str
@@ -50,16 +57,19 @@ class MultiMoleculeConfig(Config):
save_interval: int = 10
seed: int = 1013
+ data: DataConfig
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
- self.data = DataConfig()
+ self.datas = Config(default_factory=DataConfig)
self.dataloader.batch_size = 32
- self.optim.name = "AdamW"
- self.optim.lr = 1e-3
- self.optim.weight_decay = 1e-2
+ self.optim = OptimConfig()
self.sched.final_lr = 0
def post(self):
+ if "data" in self:
+ if self.datas:
+ raise ValueError("Only one of `data` or `datas` can be specified, but not both")
+ del self.datas
self.network.backbone.sequence.name = self.pretrained
self.name = f"{self.pretrained}-{self.optim.lr}@{self.optim.name}-{self.seed}"
diff --git a/multimolecule/runners/runner.py b/multimolecule/runners/runner.py
index dcceb77f..a192c5b5 100644
--- a/multimolecule/runners/runner.py
+++ b/multimolecule/runners/runner.py
@@ -21,10 +21,11 @@
import torch
from chanfig import NestedDict
from danling import MultiTaskMetrics, TorchRunner
-from torch import optim
+from torch import nn, optim
+from torch.utils import data
from transformers import AutoTokenizer
-from multimolecule.data import Dataset
+from multimolecule.data import Dataset, DistributedMultiTaskSampler, MultiTaskDataset, MultiTaskSampler
from multimolecule.module import HeadConfig, ModelRegistry
from .config import MultiMoleculeConfig
@@ -33,24 +34,30 @@
class MultiMoleculeRunner(TorchRunner):
+ all_datasets: NestedDict
+
def __init__(self, config: MultiMoleculeConfig):
super().__init__(config)
self.name = config.name
self.tokenizer = AutoTokenizer.from_pretrained(self.config.pretrained)
- self.datasets = self.build_datasets()
- self.model = ModelRegistry.build(**self.network)
+ self.build_datasets()
+ self.build_dataloaders()
+ self.model = ModelRegistry.build(**self.network).to(self.device)
+ if self.distributed:
+ self.model = nn.parallel.DistributedDataParallel(
+ self.model, find_unused_parameters=True, bucket_cap_mb=32, gradient_as_bucket_view=True
+ )
self.optimizer = getattr(optim, self.config.optim.pop("name"))(
params=self.model.parameters(), **self.config.optim
)
+ self.scheduler = dl.optim.LRScheduler(self.optimizer, total_steps=self.total_steps, **self.config.sched)
self.metrics = self.build_metrics()
def __post_init__(self):
- if self.datasets:
- self.build_dataloaders()
- self.scheduler = dl.optim.LRScheduler(self.optimizer, total_steps=self.total_steps, **self.config.sched)
super().__post_init__()
self.yaml(os.path.join(self.dir, "trainer.yaml"))
print(self)
+ print(self.get_dataset_lengths())
def train_step(self, data) -> torch.Tensor:
with self.autocast(), self.accumulate():
@@ -70,7 +77,8 @@ def loss_fn(self, pred, data):
return sum(p["loss"] for p in pred.values())
def metric_fn(self, pred, data):
- self.metrics.update({task: (pred["logits"], data[task]) for task, pred in pred.items()})
+ metric = self.metrics[data["dataset"]] if "dataset" in data else self.metrics
+ metric.update({task: (pred["logits"], data[task]) for task, pred in pred.items()})
@cached_property
def tasks(self):
@@ -89,30 +97,100 @@ def network(self):
self.config.network.heads.merge(heads, overwrite=False)
return self.config.network
- def build_datasets(self) -> NestedDict:
- datasets = NestedDict()
+ def build_datasets(self):
+ if "data" in self.config:
+ self.datasets = self.all_datasets = self._build_dataset(self.config.data)
+ return
+ if "datas" in self.config:
+ self.all_datasets = NestedDict(
+ {name: self._build_dataset(config, name) for name, config in self.config.datas.items()}
+ )
+ datasets = {
+ subkey: {key: subdict[subkey] for key, subdict in self.all_datasets.items() if subkey in subdict}
+ for subkey in {k for v in self.all_datasets.values() for k in v}
+ }
+ self.datasets = NestedDict({split: MultiTaskDataset(datas) for split, datas in datasets.items()})
+ return
+ raise ValueError("No data configuration found")
+
+ def _build_dataset(self, config: NestedDict, name: str | None = None) -> NestedDict:
+ name = name or config.root
+ print(f"Building dataset {name}")
+ dataset = NestedDict()
dataset_factory = partial(
Dataset,
tokenizer=self.tokenizer,
- **{k: v for k, v in self.config.data.items() if k not in ("train", "val", "test", "root")},
+ **{k: v for k, v in config.items() if k not in ("train", "val", "test", "root")},
)
- if self.config.data.train:
- datasets.train = dataset_factory(self.config.data.train, split="train")
- if self.config.data.val:
- datasets.val = dataset_factory(self.config.data.val, split="val")
- if self.config.data.test:
- datasets.test = dataset_factory(self.config.data.test, split="test")
- if not datasets:
+ if "train" in config:
+ dataset.train = dataset_factory(config.train, split="train")
+ if "val" in config:
+ dataset.val = dataset_factory(config.val, split="val")
+ if "test" in config:
+ dataset.test = dataset_factory(config.test, split="test")
+ if not dataset:
raise ValueError("No datasets built. This is likely due to missing data paths in Config.")
- return datasets
+ return dataset
+
+ def build_dataloaders(self):
+ datasets = {k: d for k, d in self.datasets.items() if k not in self.dataloaders}
+ default_kwargs = self.config.get("dataloader", NestedDict())
+ dataloader_kwargs = NestedDict({k: default_kwargs.pop(k) for k in self.datasets if k in default_kwargs})
+ for k, d in datasets.items():
+ dataloader_kwargs.setdefault(k, NestedDict())
+ dataloader_kwargs[k].merge(default_kwargs, overwrite=False)
+ batch_size = dataloader_kwargs[k].pop("batch_size")
+ shuffle = dataloader_kwargs[k].pop("shuffle", getattr(d, "train", True))
+ drop_last = dataloader_kwargs[k].pop("drop_last", not getattr(d, "train", True))
+ if isinstance(d, MultiTaskDataset):
+ batch_sampler = (
+ DistributedMultiTaskSampler(d, batch_size, shuffle=shuffle, drop_last=drop_last)
+ if self.distributed
+ else MultiTaskSampler(d, batch_size, shuffle=shuffle, drop_last=drop_last)
+ )
+ else:
+ sampler = (
+ data.distributed.DistributedSampler(d, shuffle=shuffle)
+ if self.distributed
+ else data.RandomSampler(d) if shuffle else data.SequentialSampler(d)
+ )
+ batch_sampler = data.BatchSampler(sampler, batch_size, drop_last=drop_last)
+ self.dataloaders[k] = data.DataLoader(
+ d, batch_sampler=batch_sampler, collate_fn=self.collate_fn, **dataloader_kwargs[k]
+ )
def build_metrics(self) -> MultiTaskMetrics:
return MultiTaskMetrics(
{
name: MetricRegistry.build(type=task.type, num_labels=task.num_labels)
- for name, task in self.tasks.items()
+ for name, task in self.dataset_tasks.all_items()
}
)
def collate_fn(self, batch):
- return {k: v.to(self.device) for k, v in batch.items()}
+ return {k: v.to(self.device) if hasattr(v, "to") else v for k, v in batch.items()}
+
+ @cached_property
+ def dataset_tasks(self):
+ if not self.datasets:
+ raise ValueError("No datasets found")
+ dataset = self.datasets.train if self.datasets.train else next(iter(self.datasets))
+ if isinstance(dataset, MultiTaskDataset):
+ return dataset.dataset_tasks
+ return dataset.tasks
+
+ def get_dataset_lengths(self) -> str:
+ repr = "dataset lengths:\n"
+ longest_name = max(len(name) for name in self.all_datasets.keys())
+ for name, dataset in self.all_datasets.items():
+ if isinstance(dataset, NestedDict):
+ repr += f"{name}:"
+ if len(name) < longest_name:
+ repr += " " * (longest_name - len(name))
+ repr += "\t\t"
+ for split, d in dataset.items():
+ repr += f" {split}: {len(d)}\t"
+ else:
+ repr += f"{name}: {len(dataset)}\t"
+ repr += "\n"
+ return repr