From 899af07fc10538af528e30e2caa8a67c18bf8164 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 12 Feb 2024 20:14:24 +0000 Subject: [PATCH] [BugFix] Make KL-controllers independent of the model (#1903) --- docs/source/reference/data.rst | 3 ++ examples/rlhf/train_rlhf.py | 4 +-- torchrl/data/rlhf/utils.py | 59 +++++++++++++++++++++++----------- 3 files changed, 44 insertions(+), 22 deletions(-) diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index d426a112b72..47ffb64753b 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -702,6 +702,9 @@ efficient sampling. TokenizedDatasetLoader create_infinite_iterator get_dataloader + ConstantKLController + AdaptiveKLController + Utils ----- diff --git a/examples/rlhf/train_rlhf.py b/examples/rlhf/train_rlhf.py index a921e58bad6..94d9234db2a 100644 --- a/examples/rlhf/train_rlhf.py +++ b/examples/rlhf/train_rlhf.py @@ -100,9 +100,7 @@ def main(cfg): # using a Gym-like API (querying steps etc) introduces some # extra code that we can spare. # - kl_scheduler = AdaptiveKLController( - model, init_kl_coef=0.1, target=6, horizon=10000 - ) + kl_scheduler = AdaptiveKLController(init_kl_coef=0.1, target=6, horizon=10000) rollout_from_model = RolloutFromModel( model, ref_model, diff --git a/torchrl/data/rlhf/utils.py b/torchrl/data/rlhf/utils.py index 311b2584aa5..a4ccbfd8a1b 100644 --- a/torchrl/data/rlhf/utils.py +++ b/torchrl/data/rlhf/utils.py @@ -7,13 +7,13 @@ import abc import collections import importlib -from typing import Sequence, Tuple +from typing import List, Tuple import numpy as np import torch from tensordict import TensorDict -from torch import Tensor +from torch import nn, Tensor from torch.nn import functional as F from torchrl.data.rlhf.prompt import PromptData @@ -30,8 +30,8 @@ class KLControllerBase(abc.ABC): """ @abc.abstractmethod - def update(self, kl_values: float): - pass + def update(self, kl_values: List[float]) -> float: + ... class ConstantKLController(KLControllerBase): @@ -40,30 +40,39 @@ class ConstantKLController(KLControllerBase): This controller maintains a fixed coefficient no matter what values it is updated with. - Arguments: - model: wrapped model that needs to be controlled. Must have attribute 'kl_coef' + Keyword Arguments: kl_coef (float): The coefficient to multiply KL with when calculating the reward. + model (nn.Module, optional): wrapped model that needs to be controlled. + Must have an attribute ``"kl_coef"``. If provided, the ``"kl_coef"`` will + be updated in-place. """ - def __init__(self, model, kl_coef): + def __init__( + self, + *, + kl_coef: float = None, + model: nn.Module | None = None, + ): self.model = model - if not hasattr(model, "kl_coef"): + if model is not None and not hasattr(model, "kl_coef"): raise AttributeError( "Model input to ConstantKLController doesn't have attribute 'kl_coef'" ) self.coef = kl_coef - self.model.kl_coef = self.coef + if model is not None: + self.model.kl_coef = self.coef - def update(self, kl_values: Sequence[float] = None): - self.model.kl_coef = self.coef + def update(self, kl_values: List[float] = None) -> float: + if self.model is not None: + self.model.kl_coef = self.coef + return self.coef class AdaptiveKLController(KLControllerBase): """Adaptive KL Controller as described in Ziegler et al. "Fine-Tuning Language Models from Human Preferences". - Arguments: - model: wrapped model that needs to be controlled. Must have attribute 'kl_coef' + Keyword Arguments: init_kl_coef (float): The starting value of the coefficient. target (float): The target KL value. When the observed KL is smaller, the coefficient is decreased, thereby relaxing the KL penalty in the training @@ -72,19 +81,30 @@ class AdaptiveKLController(KLControllerBase): increased, thereby pulling the model back towards the reference model. horizon (int): Scaling factor to control how aggressively we update the coefficient. + model (nn.Module, optional): wrapped model that needs to be controlled. + Must have an attribute ``"kl_coef"``. If provided, the ``"kl_coef"`` will + be updated in-place. Reference: Section 2.2 https://arxiv.org/pdf/1909.08593.pdf#page=2 Source: https://github.com/openai/lm-human-preferences/blob/master/lm_human_preferences/train_policy.py """ - def __init__(self, model, init_kl_coef: float, target: float, horizon: int): + def __init__( + self, + *, + init_kl_coef: float, + target: float, + horizon: int, + model: nn.Module | None = None, + ): self.model = model self.coef = init_kl_coef self.target = target self.horizon = horizon - self.model.kl_coef = self.coef + if model is not None: + self.model.kl_coef = self.coef - def update(self, kl_values: Sequence[float]): + def update(self, kl_values: List[float]): """Update ``self.coef`` adaptively. Arguments: @@ -104,6 +124,9 @@ def update(self, kl_values: Sequence[float]): proportional_error = np.clip(kl_value / self.target - 1, -0.2, 0.2) # ϵₜ mult = 1 + proportional_error * n_steps / self.horizon self.coef *= mult # βₜ₊₁ + if self.model is not None: + self.model.kl_coef = self.coef + return self.coef class RolloutFromModel: @@ -233,8 +256,6 @@ def create_rollout_td(self, batch, generated, log_probs, log_ratio): log_ratio (torch.Tensor): The log ratio of the probabilities of the generated tokens according to the generative model and the reference model. Can be obtained by calling the ``generate`` method. - kl_coef (float, optional): Coefficient with which to multiply the KL term before subtracting - from the reward. Defaults to 0.1. Returns: A :class:`~tensordict.TensorDict` with the following keys: @@ -514,7 +535,7 @@ def generate(self, batch: PromptData, generation_config=None): def step_scheduler(self): # recover true kl - self.kl_scheduler.update(self._kl_queue) + self.kl_coef = self.kl_scheduler.update(self._kl_queue) if isinstance(self._kl_queue, (list, collections.deque)): # remove all values while len(self._kl_queue):