diff --git a/docs/docs/data/multitask.md b/docs/docs/data/multitask.md new file mode 100644 index 00000000..054c6205 --- /dev/null +++ b/docs/docs/data/multitask.md @@ -0,0 +1,9 @@ +--- +authors: + - Zhiyuan Chen +date: 2024-05-04 +--- + +# MultiTask + +::: multimolecule.data.multitask diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 53875f32..036d8c96 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -12,6 +12,7 @@ nav: - data: - data/index.md - Dataset: data/dataset.md + - multitask: data/multitask.md - datasets: - datasets/index.md - DNA: diff --git a/multimolecule/data/__init__.py b/multimolecule/data/__init__.py index 62196c10..f2366d77 100644 --- a/multimolecule/data/__init__.py +++ b/multimolecule/data/__init__.py @@ -15,6 +15,14 @@ # along with this program. If not, see . from .dataset import Dataset +from .multitask import DistributedMultiTaskSampler, MultiTaskDataset, MultiTaskSampler from .utils import no_collate -__all__ = ["Dataset", "no_collate"] +__all__ = [ + "Dataset", + "PandasDataset", + "MultiTaskDataset", + "MultiTaskSampler", + "DistributedMultiTaskSampler", + "no_collate", +] diff --git a/multimolecule/data/multitask.py b/multimolecule/data/multitask.py new file mode 100644 index 00000000..70a93313 --- /dev/null +++ b/multimolecule/data/multitask.py @@ -0,0 +1,191 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from __future__ import annotations + +from bisect import bisect_right +from collections.abc import Mapping, Sequence +from copy import deepcopy +from random import choices + +from chanfig import NestedDict +from torch import distributed as dist +from torch.utils import data + +from .dataset import Dataset + + +class MultiTaskDataset(data.ConcatDataset): + + datasets: Mapping + dataset_keys: Sequence[str] + dataset_values: Sequence[Dataset] + + def __init__(self, datasets: Mapping) -> None: + for key, dataset in datasets.items(): + if not isinstance(dataset, Dataset): + raise TypeError(f"Dataset {key} should be an instance of Dataset") + self.datasets = datasets + if not len(self.datasets) > 0: + raise ValueError("MultiTaskDataset should contain at least one dataset") + self.dataset_keys, self.dataset_values = zip(*self.datasets.items()) + self.cumulative_sizes = self.cumsum(self.dataset_values) + + def __getitems__(self, key: Sequence[int]) -> Mapping: + dataset_idx = bisect_right(self.cumulative_sizes, key[0]) + if dataset_idx == 0: + sample_idx = key + else: + sample_idx = [i - self.cumulative_sizes[dataset_idx - 1] for i in key] + batch = self.dataset_values[dataset_idx][sample_idx] + batch["dataset"] = self.dataset_keys[dataset_idx] + return batch + + @property + def tasks(self) -> NestedDict: + tasks = self.dataset_values[0].tasks + for dataset in self.dataset_values[1:]: + for n, t in dataset.tasks.items(): + if n not in tasks: + tasks[n] = t + elif tasks[n] != t: + raise ValueError(f"Task {n} has different configurations across datasets") + return tasks + + @property + def dataset_tasks(self) -> NestedDict: + return NestedDict({k: v.tasks for k, v in self.datasets.items()}) + + def __repr__(self) -> str: + return f"MultiTaskDataset({', '.join([str(d) for d in self.datasets])})" + + +class MultiTaskSampler(data.BatchSampler): + r""" + Ensure all items in a batch comes from the same dataset. + + Arguments: + sampler (Sampler): Base sampler. + batch_size (int): Size of mini-batch. + drop_last (bool): If ``True``, the sampler will drop the last batch if + its size would be less than ``batch_size`` + """ + + datasets: Sequence[Dataset] + + def __init__( # pylint: disable=super-init-not-called + self, + dataset: MultiTaskDataset, + batch_size: int, + shuffle: bool = True, + drop_last: bool = False, + sampler_cls: type[data.Sampler] | None = None, + weights: list[int] | None = None, + ) -> None: + self.datasets = dataset.dataset_values + self.batch_size = batch_size + self.drop_last = drop_last + self.shuffle = shuffle + if sampler_cls is None: + sampler_cls = data.RandomSampler if shuffle else data.SequentialSampler + self.samplers = [sampler_cls(d) for d in self.datasets] # type: ignore + self.dataset_sizes = [len(d) for d in self.datasets] # type: ignore + self.cumulative_sizes = dataset.cumulative_sizes + self.num_datasets = len(self.datasets) + self.weights = weights if weights is not None else self.dataset_sizes + + def __iter__(self): + sampler_iters = [(i, iter(s)) for i, s in enumerate(self.samplers)] + sampler_weights = deepcopy(self.weights) + sampler_idx = 0 + # Implemented based on the benchmarking in https://github.com/pytorch/pytorch/pull/76951 + if self.drop_last: + while sampler_iters: + if self.shuffle: + sampler_idx = choices(range(len(sampler_iters)), weights=sampler_weights)[0] + sampler_id, sampler_iter = sampler_iters[sampler_idx] + cumulative_size = self.cumulative_sizes[sampler_id - 1] if sampler_id > 0 else 0 + try: + batch = [next(sampler_iter) + cumulative_size for _ in range(self.batch_size)] + yield batch + except StopIteration: + sampler_iters.pop(sampler_idx) + sampler_weights.pop(sampler_idx) + else: + while sampler_iters: + if self.shuffle: + sampler_idx = choices(range(len(sampler_iters)), weights=sampler_weights)[0] + sampler_id, sampler_iter = sampler_iters[sampler_idx] + cumulative_size = self.cumulative_sizes[sampler_id - 1] if sampler_id > 0 else 0 + batch = [0] * self.batch_size + idx_in_batch = 0 + try: + for _ in range(self.batch_size): + batch[idx_in_batch] = next(sampler_iter) + cumulative_size + idx_in_batch += 1 + yield batch + idx_in_batch = 0 # noqa: SIM113 + batch = [0] * self.batch_size + except StopIteration: + sampler_iters.pop(sampler_idx) + sampler_weights.pop(sampler_idx) + if idx_in_batch > 0: + yield batch[:idx_in_batch] + + def __len__(self): + batch_size = self.batch_size + if self.drop_last: + return sum(len(d) // batch_size for d in self.datasets) + return sum((len(d) + batch_size - 1) // batch_size for d in self.datasets) + + +class DistributedMultiTaskSampler(MultiTaskSampler): # pylint: disable=too-few-public-methods + r""" + Distributed version of MultiTaskSampler, which ensures that each batch contains + data from only one dataset. + + See Also: + [MultiTaskSampler][MultiTaskSampler] + """ + + def __init__( + self, + dataset: MultiTaskDataset, + batch_size: int, + shuffle: bool = True, + drop_last: bool = False, + sampler_cls: type[data.Sampler] = data.RandomSampler, + weights: list[int] | None = None, + ) -> None: + super().__init__(dataset, batch_size, shuffle, drop_last, sampler_cls, weights) + self.samplers = [data.DistributedSampler(d, shuffle=shuffle, drop_last=drop_last) for d in self.datasets] + + def set_epoch(self, epoch): + for s in self.samplers: + s.set_epoch(epoch) + + def __len__(self): + batch_size = self.batch_size * self.world_size + if self.drop_last: + return sum(len(d) // batch_size for d in self.datasets) + return sum((len(d) + batch_size - 1) // batch_size for d in self.datasets) + + @property + def world_size(self) -> int: + r"""Return the number of processes in the current process group.""" + if dist.is_available() and dist.is_initialized(): + return dist.get_world_size() + return 1 diff --git a/multimolecule/runners/config.py b/multimolecule/runners/config.py index 34675789..784c5d86 100644 --- a/multimolecule/runners/config.py +++ b/multimolecule/runners/config.py @@ -30,6 +30,7 @@ class DataConfig(Config): test: str | None feature_cols: List | None = None label_cols: List | None = None + truncation: bool = True def post(self): if "train" in self: @@ -40,6 +41,12 @@ def post(self): self.test = os.path.join(self.root, self.test) +class OptimConfig(Config): + name: str = "AdamW" + lr: float = 1e-3 + weight_decay: float = 1e-2 + + class MultiMoleculeConfig(Config): pretrained: str @@ -50,16 +57,19 @@ class MultiMoleculeConfig(Config): save_interval: int = 10 seed: int = 1013 + data: DataConfig def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.data = DataConfig() + self.datas = Config(default_factory=DataConfig) self.dataloader.batch_size = 32 - self.optim.name = "AdamW" - self.optim.lr = 1e-3 - self.optim.weight_decay = 1e-2 + self.optim = OptimConfig() self.sched.final_lr = 0 def post(self): + if "data" in self: + if self.datas: + raise ValueError("Only one of `data` or `datas` can be specified, but not both") + del self.datas self.network.backbone.sequence.name = self.pretrained self.name = f"{self.pretrained}-{self.optim.lr}@{self.optim.name}-{self.seed}" diff --git a/multimolecule/runners/runner.py b/multimolecule/runners/runner.py index dcceb77f..a192c5b5 100644 --- a/multimolecule/runners/runner.py +++ b/multimolecule/runners/runner.py @@ -21,10 +21,11 @@ import torch from chanfig import NestedDict from danling import MultiTaskMetrics, TorchRunner -from torch import optim +from torch import nn, optim +from torch.utils import data from transformers import AutoTokenizer -from multimolecule.data import Dataset +from multimolecule.data import Dataset, DistributedMultiTaskSampler, MultiTaskDataset, MultiTaskSampler from multimolecule.module import HeadConfig, ModelRegistry from .config import MultiMoleculeConfig @@ -33,24 +34,30 @@ class MultiMoleculeRunner(TorchRunner): + all_datasets: NestedDict + def __init__(self, config: MultiMoleculeConfig): super().__init__(config) self.name = config.name self.tokenizer = AutoTokenizer.from_pretrained(self.config.pretrained) - self.datasets = self.build_datasets() - self.model = ModelRegistry.build(**self.network) + self.build_datasets() + self.build_dataloaders() + self.model = ModelRegistry.build(**self.network).to(self.device) + if self.distributed: + self.model = nn.parallel.DistributedDataParallel( + self.model, find_unused_parameters=True, bucket_cap_mb=32, gradient_as_bucket_view=True + ) self.optimizer = getattr(optim, self.config.optim.pop("name"))( params=self.model.parameters(), **self.config.optim ) + self.scheduler = dl.optim.LRScheduler(self.optimizer, total_steps=self.total_steps, **self.config.sched) self.metrics = self.build_metrics() def __post_init__(self): - if self.datasets: - self.build_dataloaders() - self.scheduler = dl.optim.LRScheduler(self.optimizer, total_steps=self.total_steps, **self.config.sched) super().__post_init__() self.yaml(os.path.join(self.dir, "trainer.yaml")) print(self) + print(self.get_dataset_lengths()) def train_step(self, data) -> torch.Tensor: with self.autocast(), self.accumulate(): @@ -70,7 +77,8 @@ def loss_fn(self, pred, data): return sum(p["loss"] for p in pred.values()) def metric_fn(self, pred, data): - self.metrics.update({task: (pred["logits"], data[task]) for task, pred in pred.items()}) + metric = self.metrics[data["dataset"]] if "dataset" in data else self.metrics + metric.update({task: (pred["logits"], data[task]) for task, pred in pred.items()}) @cached_property def tasks(self): @@ -89,30 +97,100 @@ def network(self): self.config.network.heads.merge(heads, overwrite=False) return self.config.network - def build_datasets(self) -> NestedDict: - datasets = NestedDict() + def build_datasets(self): + if "data" in self.config: + self.datasets = self.all_datasets = self._build_dataset(self.config.data) + return + if "datas" in self.config: + self.all_datasets = NestedDict( + {name: self._build_dataset(config, name) for name, config in self.config.datas.items()} + ) + datasets = { + subkey: {key: subdict[subkey] for key, subdict in self.all_datasets.items() if subkey in subdict} + for subkey in {k for v in self.all_datasets.values() for k in v} + } + self.datasets = NestedDict({split: MultiTaskDataset(datas) for split, datas in datasets.items()}) + return + raise ValueError("No data configuration found") + + def _build_dataset(self, config: NestedDict, name: str | None = None) -> NestedDict: + name = name or config.root + print(f"Building dataset {name}") + dataset = NestedDict() dataset_factory = partial( Dataset, tokenizer=self.tokenizer, - **{k: v for k, v in self.config.data.items() if k not in ("train", "val", "test", "root")}, + **{k: v for k, v in config.items() if k not in ("train", "val", "test", "root")}, ) - if self.config.data.train: - datasets.train = dataset_factory(self.config.data.train, split="train") - if self.config.data.val: - datasets.val = dataset_factory(self.config.data.val, split="val") - if self.config.data.test: - datasets.test = dataset_factory(self.config.data.test, split="test") - if not datasets: + if "train" in config: + dataset.train = dataset_factory(config.train, split="train") + if "val" in config: + dataset.val = dataset_factory(config.val, split="val") + if "test" in config: + dataset.test = dataset_factory(config.test, split="test") + if not dataset: raise ValueError("No datasets built. This is likely due to missing data paths in Config.") - return datasets + return dataset + + def build_dataloaders(self): + datasets = {k: d for k, d in self.datasets.items() if k not in self.dataloaders} + default_kwargs = self.config.get("dataloader", NestedDict()) + dataloader_kwargs = NestedDict({k: default_kwargs.pop(k) for k in self.datasets if k in default_kwargs}) + for k, d in datasets.items(): + dataloader_kwargs.setdefault(k, NestedDict()) + dataloader_kwargs[k].merge(default_kwargs, overwrite=False) + batch_size = dataloader_kwargs[k].pop("batch_size") + shuffle = dataloader_kwargs[k].pop("shuffle", getattr(d, "train", True)) + drop_last = dataloader_kwargs[k].pop("drop_last", not getattr(d, "train", True)) + if isinstance(d, MultiTaskDataset): + batch_sampler = ( + DistributedMultiTaskSampler(d, batch_size, shuffle=shuffle, drop_last=drop_last) + if self.distributed + else MultiTaskSampler(d, batch_size, shuffle=shuffle, drop_last=drop_last) + ) + else: + sampler = ( + data.distributed.DistributedSampler(d, shuffle=shuffle) + if self.distributed + else data.RandomSampler(d) if shuffle else data.SequentialSampler(d) + ) + batch_sampler = data.BatchSampler(sampler, batch_size, drop_last=drop_last) + self.dataloaders[k] = data.DataLoader( + d, batch_sampler=batch_sampler, collate_fn=self.collate_fn, **dataloader_kwargs[k] + ) def build_metrics(self) -> MultiTaskMetrics: return MultiTaskMetrics( { name: MetricRegistry.build(type=task.type, num_labels=task.num_labels) - for name, task in self.tasks.items() + for name, task in self.dataset_tasks.all_items() } ) def collate_fn(self, batch): - return {k: v.to(self.device) for k, v in batch.items()} + return {k: v.to(self.device) if hasattr(v, "to") else v for k, v in batch.items()} + + @cached_property + def dataset_tasks(self): + if not self.datasets: + raise ValueError("No datasets found") + dataset = self.datasets.train if self.datasets.train else next(iter(self.datasets)) + if isinstance(dataset, MultiTaskDataset): + return dataset.dataset_tasks + return dataset.tasks + + def get_dataset_lengths(self) -> str: + repr = "dataset lengths:\n" + longest_name = max(len(name) for name in self.all_datasets.keys()) + for name, dataset in self.all_datasets.items(): + if isinstance(dataset, NestedDict): + repr += f"{name}:" + if len(name) < longest_name: + repr += " " * (longest_name - len(name)) + repr += "\t\t" + for split, d in dataset.items(): + repr += f" {split}: {len(d)}\t" + else: + repr += f"{name}: {len(dataset)}\t" + repr += "\n" + return repr