Skip to content

Commit

Permalink
[BugFix] Make KL-controllers independent of the model (#1903)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Feb 12, 2024
1 parent 2f9e1ae commit 899af07
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 22 deletions.
3 changes: 3 additions & 0 deletions docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,9 @@ efficient sampling.
TokenizedDatasetLoader
create_infinite_iterator
get_dataloader
ConstantKLController
AdaptiveKLController


Utils
-----
Expand Down
4 changes: 1 addition & 3 deletions examples/rlhf/train_rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,7 @@ def main(cfg):
# using a Gym-like API (querying steps etc) introduces some
# extra code that we can spare.
#
kl_scheduler = AdaptiveKLController(
model, init_kl_coef=0.1, target=6, horizon=10000
)
kl_scheduler = AdaptiveKLController(init_kl_coef=0.1, target=6, horizon=10000)
rollout_from_model = RolloutFromModel(
model,
ref_model,
Expand Down
59 changes: 40 additions & 19 deletions torchrl/data/rlhf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
import abc
import collections
import importlib
from typing import Sequence, Tuple
from typing import List, Tuple

import numpy as np
import torch

from tensordict import TensorDict
from torch import Tensor
from torch import nn, Tensor
from torch.nn import functional as F

from torchrl.data.rlhf.prompt import PromptData
Expand All @@ -30,8 +30,8 @@ class KLControllerBase(abc.ABC):
"""

@abc.abstractmethod
def update(self, kl_values: float):
pass
def update(self, kl_values: List[float]) -> float:
...


class ConstantKLController(KLControllerBase):
Expand All @@ -40,30 +40,39 @@ class ConstantKLController(KLControllerBase):
This controller maintains a fixed coefficient no matter what values it is updated
with.
Arguments:
model: wrapped model that needs to be controlled. Must have attribute 'kl_coef'
Keyword Arguments:
kl_coef (float): The coefficient to multiply KL with when calculating the
reward.
model (nn.Module, optional): wrapped model that needs to be controlled.
Must have an attribute ``"kl_coef"``. If provided, the ``"kl_coef"`` will
be updated in-place.
"""

def __init__(self, model, kl_coef):
def __init__(
self,
*,
kl_coef: float = None,
model: nn.Module | None = None,
):
self.model = model
if not hasattr(model, "kl_coef"):
if model is not None and not hasattr(model, "kl_coef"):
raise AttributeError(
"Model input to ConstantKLController doesn't have attribute 'kl_coef'"
)
self.coef = kl_coef
self.model.kl_coef = self.coef
if model is not None:
self.model.kl_coef = self.coef

def update(self, kl_values: Sequence[float] = None):
self.model.kl_coef = self.coef
def update(self, kl_values: List[float] = None) -> float:
if self.model is not None:
self.model.kl_coef = self.coef
return self.coef


class AdaptiveKLController(KLControllerBase):
"""Adaptive KL Controller as described in Ziegler et al. "Fine-Tuning Language Models from Human Preferences".
Arguments:
model: wrapped model that needs to be controlled. Must have attribute 'kl_coef'
Keyword Arguments:
init_kl_coef (float): The starting value of the coefficient.
target (float): The target KL value. When the observed KL is smaller, the
coefficient is decreased, thereby relaxing the KL penalty in the training
Expand All @@ -72,19 +81,30 @@ class AdaptiveKLController(KLControllerBase):
increased, thereby pulling the model back towards the reference model.
horizon (int): Scaling factor to control how aggressively we update the
coefficient.
model (nn.Module, optional): wrapped model that needs to be controlled.
Must have an attribute ``"kl_coef"``. If provided, the ``"kl_coef"`` will
be updated in-place.
Reference: Section 2.2 https://arxiv.org/pdf/1909.08593.pdf#page=2
Source: https://github.com/openai/lm-human-preferences/blob/master/lm_human_preferences/train_policy.py
"""

def __init__(self, model, init_kl_coef: float, target: float, horizon: int):
def __init__(
self,
*,
init_kl_coef: float,
target: float,
horizon: int,
model: nn.Module | None = None,
):
self.model = model
self.coef = init_kl_coef
self.target = target
self.horizon = horizon
self.model.kl_coef = self.coef
if model is not None:
self.model.kl_coef = self.coef

def update(self, kl_values: Sequence[float]):
def update(self, kl_values: List[float]):
"""Update ``self.coef`` adaptively.
Arguments:
Expand All @@ -104,6 +124,9 @@ def update(self, kl_values: Sequence[float]):
proportional_error = np.clip(kl_value / self.target - 1, -0.2, 0.2) # ϵₜ
mult = 1 + proportional_error * n_steps / self.horizon
self.coef *= mult # βₜ₊₁
if self.model is not None:
self.model.kl_coef = self.coef
return self.coef


class RolloutFromModel:
Expand Down Expand Up @@ -233,8 +256,6 @@ def create_rollout_td(self, batch, generated, log_probs, log_ratio):
log_ratio (torch.Tensor): The log ratio of the probabilities of the generated tokens
according to the generative model and the reference model. Can be
obtained by calling the ``generate`` method.
kl_coef (float, optional): Coefficient with which to multiply the KL term before subtracting
from the reward. Defaults to 0.1.
Returns:
A :class:`~tensordict.TensorDict` with the following keys:
Expand Down Expand Up @@ -514,7 +535,7 @@ def generate(self, batch: PromptData, generation_config=None):

def step_scheduler(self):
# recover true kl
self.kl_scheduler.update(self._kl_queue)
self.kl_coef = self.kl_scheduler.update(self._kl_queue)
if isinstance(self._kl_queue, (list, collections.deque)):
# remove all values
while len(self._kl_queue):
Expand Down

0 comments on commit 899af07

Please sign in to comment.