From 1b175fc6b82281b4f874bd7a95ca86720b025577 Mon Sep 17 00:00:00 2001 From: Zhiyuan Chen Date: Thu, 2 Jan 2025 18:10:03 +0800 Subject: [PATCH] improve MTL Signed-off-by: Zhiyuan Chen --- multimolecule/module/criterions/__init__.py | 18 ++ multimolecule/module/criterions/balancer.py | 196 ++++++++++++++++++++ multimolecule/module/model.py | 11 +- multimolecule/runners/base_runner.py | 22 +-- 4 files changed, 228 insertions(+), 19 deletions(-) create mode 100644 multimolecule/module/criterions/balancer.py diff --git a/multimolecule/module/criterions/__init__.py b/multimolecule/module/criterions/__init__.py index ddb8f37e..ca9a8648 100644 --- a/multimolecule/module/criterions/__init__.py +++ b/multimolecule/module/criterions/__init__.py @@ -20,6 +20,16 @@ # . +from .balancer import ( + DynamicWeightAverageBalancer, + EqualWeightBalancer, + GeometricLossBalancer, + GradNormBalancer, + LossBalancer, + LossBalancerRegistry, + RandomLossWeightBalancer, + UncertaintyWeightBalancer, +) from .binary import BCEWithLogitsLoss from .generic import Criterion from .multiclass import CrossEntropyLoss @@ -29,9 +39,17 @@ __all__ = [ "CriterionRegistry", + "LossBalancerRegistry", "Criterion", "MSELoss", "BCEWithLogitsLoss", "CrossEntropyLoss", "MultiLabelSoftMarginLoss", + "LossBalancer", + "EqualWeightBalancer", + "RandomLossWeightBalancer", + "GeometricLossBalancer", + "UncertaintyWeightBalancer", + "DynamicWeightAverageBalancer", + "GradNormBalancer", ] diff --git a/multimolecule/module/criterions/balancer.py b/multimolecule/module/criterions/balancer.py new file mode 100644 index 00000000..9540d432 --- /dev/null +++ b/multimolecule/module/criterions/balancer.py @@ -0,0 +1,196 @@ +# MultiMolecule +# Copyright (C) 2024-Present MultiMolecule + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# any later version. + +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. + +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +from __future__ import annotations + +import math +from typing import Dict, List + +import torch +import torch.nn as nn +import torch.nn.functional as F +from chanfig import Registry +from torch import Tensor + +LossBalancerRegistry = Registry() + + +class LossBalancer(nn.Module): + """Base class for loss balancers in multi-task learning. + + This class provides an interface for implementing various strategies + to balance the losses of different tasks in a multi-task learning setup. + """ + + def forward(self, ret: Dict[str, Tensor]) -> Tensor: + """Compute the balanced total loss. + + Args: + losses (Dict[str, Tensor]): A dictionary of task names to their respective losses. + + Returns: + Tensor: The computed balanced loss. + """ + return {k: v["loss"] for k, v in ret.items()} + + +@LossBalancerRegistry.register("equal", default=True) +class EqualWeightBalancer(LossBalancer): + """Equal Weighting Balancer. + + This method assigns equal weight to each task's loss, effectively averaging the losses across all tasks. + """ + + def forward(self, ret: Dict[str, Tensor]) -> Tensor: + losses = super().forward(ret) + return sum(losses.values()) / len(losses) + + +@LossBalancerRegistry.register("random") +class RandomLossWeightBalancer(LossBalancer): + """Random Loss Weighting Balancer. + + This method assigns random weights to each task's loss, which are sampled from a softmax distribution, + as described in the paper "Reasonable Effectiveness of Random Weighting: A Litmus Test for Multi-Task Learning" + by Liang et al. (https://openreview.net/forum?id=jjtFD8A1Wx). + """ + + def forward(self, ret: Dict[str, Tensor]) -> Tensor: + losses = super().forward(ret) + loss = torch.stack(list(losses.values())) + weight = F.softmax(torch.randn(len(losses), device=loss.device, dtype=loss.dtype), dim=-1) + return loss.T @ weight + + +@LossBalancerRegistry.register("geometric") +class GeometricLossBalancer(LossBalancer): + """Geometric Loss Strategy Balancer. + + This method computes the geometric mean of the task losses, which can be useful for balancing tasks with different + scales, as described in the paper "MultiNet++: Multi-Stream Feature Aggregation and Geometric Loss Strategy for + Multi-Task Learning" by Chennupati et al. (https://arxiv.org/abs/1904.08492). + """ + + def forward(self, losses: Dict[str, Tensor]) -> Tensor: + return math.prod(losses.values()) ** (1 / len(losses)) + + +@LossBalancerRegistry.register("uncertainty") +class UncertaintyWeightBalancer(LossBalancer): + """Uncertainty Weighting Balancer. + + This method uses task uncertainty to weight the losses, as described in the paper "Multi-Task Learning Using + Uncertainty to Weigh Losses for Scene Geometry and Semantics" by Kendall et al. (https://arxiv.org/abs/1705.07115). + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.log_vars = nn.ParameterDict() + + def forward(self, ret: Dict[str, Tensor]) -> Tensor: + losses = super().forward(ret) + for task in losses.keys(): + if task not in self.log_vars: + self.log_vars[task] = nn.Parameter(torch.zeros(1)) + + weighted_losses = [ + torch.exp(-self.log_vars[task]) * loss + self.log_vars[task] for task, loss in losses.items() + ] + return sum(weighted_losses) / len(weighted_losses) + + +@LossBalancerRegistry.register("dynamic") +class DynamicWeightAverageBalancer(LossBalancer): + """Dynamic Weight Average Balancer. + + This method dynamically adjusts the weights of task losses based on their relative changes over time, as described + in the paper "End-to-End Multi-Task Learning with Attention" by Liu et al. (https://arxiv.org/abs/1803.10704). + """ + + def __init__(self, *args, temperature: float = 2.0, **kwargs): + super().__init__(*args, **kwargs) + self.temperature = temperature + self.task_losses_history: List[List[float]] = [] + + def forward(self, ret: Dict[str, Tensor]) -> Tensor: + losses = super().forward(ret) + if len(self.task_losses_history) < 2: + self.task_losses_history.append([loss.item() for loss in losses.values()]) + return sum(losses.values()) / len(losses) + + curr_losses = [loss.item() for loss in losses.values()] + prev_losses = self.task_losses_history[-1] + loss_ratios = [c / (p + 1e-8) for c, p in zip(curr_losses, prev_losses)] + + exp_weights = torch.exp(torch.tensor(loss_ratios) / self.temperature) + weights = len(losses) * F.softmax(exp_weights, dim=-1) + + self.task_losses_history.append(curr_losses) + if len(self.task_losses_history) > 2: + self.task_losses_history.pop(0) + + return sum(w * l for w, l in zip(weights, losses.values())) / len(losses) + + +@LossBalancerRegistry.register("gradnorm") +class GradNormBalancer(LossBalancer): + """GradNorm Balancer. + + This method balances task losses by normalizing gradients, as described in the paper "GradNorm: Gradient + Normalization for Adaptive Loss Balancing in Deep Multitask Networks" by Chen et al. + (https://arxiv.org/abs/1711.02257). + """ + + def __init__(self, *args, alpha: float = 1.5, **kwargs): + super().__init__(*args, **kwargs) + self.alpha = alpha + self.task_weights = nn.ParameterDict() + self.initial_losses: Dict[str, Tensor] = {} + + def forward(self, ret: Dict[str, Tensor]) -> Tensor: + losses = super().forward(ret) + + for task in losses.keys(): + if task not in self.task_weights: + self.task_weights[task] = nn.Parameter(torch.ones(1, device=losses[task].device)) + self.initial_losses[task] = losses[task].detach() + + loss_ratios = {task: loss / (self.initial_losses[task] + 1e-8) for task, loss in losses.items()} + avg_loss_ratio = sum(loss_ratios.values()) / len(loss_ratios) + + relative_inverse_rates = { + task: (ratio / (avg_loss_ratio + 1e-8)) ** self.alpha for task, ratio in loss_ratios.items() + } + + weighted_losses = {task: self.task_weights[task] * loss for task, loss in losses.items()} + grad_norms = { + task: torch.norm(torch.autograd.grad(weighted_loss, self.task_weights[task], retain_graph=True)[0]) + for task, weighted_loss in weighted_losses.items() + } + mean_grad_norm = sum(grad_norms.values()) / len(grad_norms) + + for task in losses.keys(): + target_grad = mean_grad_norm * relative_inverse_rates[task] + grad_norm = grad_norms[task] + self.task_weights[task].data = torch.clamp( + self.task_weights[task] * (target_grad / (grad_norm + 1e-8)), min=0.0 + ) + weight_sum = sum(w.item() for w in self.task_weights.values()) + scale = len(losses) / (weight_sum + 1e-8) + for task in losses.keys(): + self.task_weights[task].data *= scale + + return sum(self.task_weights[task] * loss for task, loss in losses.items()) diff --git a/multimolecule/module/model.py b/multimolecule/module/model.py index 778eb358..db2c9cd6 100644 --- a/multimolecule/module/model.py +++ b/multimolecule/module/model.py @@ -27,6 +27,7 @@ from torch import Tensor, nn from .backbones import BackboneRegistry +from .criterions.balancer import LossBalancerRegistry from .heads import HeadRegistry from .necks import NeckRegistry from .registry import ModelRegistry @@ -42,10 +43,12 @@ def __init__( self, backbone: dict, heads: dict, + balancer: dict | None = None, neck: dict | None = None, max_length: int = 1024, truncation: bool = False, probing: bool = False, + config: dict | None = None, ): super().__init__() @@ -87,6 +90,8 @@ def __init__( for param in self.backbone.parameters(): param.requires_grad = False + self.balancer = LossBalancerRegistry.build(balancer) + def forward( self, sequence: NestedTensor | Tensor, @@ -99,9 +104,13 @@ def forward( output, _ = self.backbone(sequence, discrete, continuous) if self.neck is not None: output = self.neck(**output) + if not labels: + return output for task, label in labels.items(): ret[task] = self.heads[task](output, input_ids=sequence, labels=label) - return ret + if len(ret) == 1: + return ret, ret[task]["loss"] + return ret, self.balancer(ret) def trainable_parameters( self, diff --git a/multimolecule/runners/base_runner.py b/multimolecule/runners/base_runner.py index 16caf8e8..f49b400d 100644 --- a/multimolecule/runners/base_runner.py +++ b/multimolecule/runners/base_runner.py @@ -14,7 +14,6 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -import math import os from functools import cached_property, partial from typing import Any, Tuple @@ -28,7 +27,6 @@ from datasets import disable_progress_bars, get_dataset_split_names from lazy_imports import try_import from torch import nn -from torch.nn import functional as F from torch.utils import data from tqdm import tqdm from transformers import AutoTokenizer @@ -79,7 +77,6 @@ def __init__(self, config: MultiMoleculeConfig): self.config.optim.pretrained_ratio = pretrained_ratio if self.config.sched: self.scheduler = dl.optim.LRScheduler(self.optimizer, total_steps=self.total_steps, **self.config.sched) - self.balance = self.config.get("balance", "ew") self.train_metrics = self.build_train_metrics() self.evaluate_metrics = self.build_evaluate_metrics() @@ -99,16 +96,14 @@ def __post_init__(self): def train_step(self, data) -> Tuple[Any, torch.Tensor]: with self.autocast(), self.accumulate(): - pred = self.model(**data) - loss = self.loss_fn(pred, data) + pred, loss = self.model(**data) self.advance(loss) self.metric_fn(pred, data) return pred, loss def evaluate_step(self, data) -> Tuple[Any, torch.Tensor]: model = self.ema or self.model - pred = model(**data) - loss = self.loss_fn(pred, data) + pred, loss = model(**data) self.metric_fn(pred, data) return pred, loss @@ -143,6 +138,8 @@ def infer(self, split: str = "inf") -> NestedDict | FlatDict | list: model = self.ema or self.model for _, data in tqdm(enumerate(loader), total=len(loader)): # noqa: F402 pred = model(**data) + if isinstance(pred, tuple): + pred, loss = pred for task, p in pred.items(): preds[task].extend(p["logits"].squeeze(-1).tolist()) if task in data: @@ -162,17 +159,6 @@ def infer(self, split: str = "inf") -> NestedDict | FlatDict | list: return next(iter(preds.values())) return preds - def loss_fn(self, pred, data): - if self.balance == "rlw": - loss = torch.stack([p["loss"] for p in pred.values()]) - weight = F.softmax(torch.randn(len(pred), device=loss.device, dtype=loss.dtype), dim=-1) - return loss.T @ weight - if self.balance == "gls": - return math.prod(p["loss"] for p in pred.values()) ** (1 / len(pred)) - if self.balance != "ew": - warn(f"Unknown balance method {self.balance}, using equal weighting.") - return sum(p["loss"] for p in pred.values()) / len(pred) - def metric_fn(self, pred, data): metric = self.metrics[data["dataset"]] if "dataset" in data else self.metrics metric.update({t: (p["logits"], data[t]) for t, p in pred.items()})