From 3d559911d49bc62ad49d8afae3881baa975c0863 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 12 Feb 2024 16:05:52 +0000 Subject: [PATCH] init --- examples/rlhf/train_rlhf.py | 4 +--- torchrl/data/rlhf/utils.py | 12 +++++------- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/examples/rlhf/train_rlhf.py b/examples/rlhf/train_rlhf.py index 04eab54b4fa..94d9234db2a 100644 --- a/examples/rlhf/train_rlhf.py +++ b/examples/rlhf/train_rlhf.py @@ -100,9 +100,7 @@ def main(cfg): # using a Gym-like API (querying steps etc) introduces some # extra code that we can spare. # - kl_scheduler = AdaptiveKLController( - model=model, init_kl_coef=0.1, target=6, horizon=10000 - ) + kl_scheduler = AdaptiveKLController(init_kl_coef=0.1, target=6, horizon=10000) rollout_from_model = RolloutFromModel( model, ref_model, diff --git a/torchrl/data/rlhf/utils.py b/torchrl/data/rlhf/utils.py index 1daf8d22e79..a4ccbfd8a1b 100644 --- a/torchrl/data/rlhf/utils.py +++ b/torchrl/data/rlhf/utils.py @@ -7,7 +7,7 @@ import abc import collections import importlib -from typing import Sequence, Tuple +from typing import List, Tuple import numpy as np import torch @@ -30,7 +30,7 @@ class KLControllerBase(abc.ABC): """ @abc.abstractmethod - def update(self, kl_values: float) -> float: + def update(self, kl_values: List[float]) -> float: ... @@ -63,7 +63,7 @@ def __init__( if model is not None: self.model.kl_coef = self.coef - def update(self, kl_values: Sequence[float] = None) -> float: + def update(self, kl_values: List[float] = None) -> float: if self.model is not None: self.model.kl_coef = self.coef return self.coef @@ -104,7 +104,7 @@ def __init__( if model is not None: self.model.kl_coef = self.coef - def update(self, kl_values: Sequence[float]): + def update(self, kl_values: List[float]): """Update ``self.coef`` adaptively. Arguments: @@ -256,8 +256,6 @@ def create_rollout_td(self, batch, generated, log_probs, log_ratio): log_ratio (torch.Tensor): The log ratio of the probabilities of the generated tokens according to the generative model and the reference model. Can be obtained by calling the ``generate`` method. - kl_coef (float, optional): Coefficient with which to multiply the KL term before subtracting - from the reward. Defaults to 0.1. Returns: A :class:`~tensordict.TensorDict` with the following keys: @@ -537,7 +535,7 @@ def generate(self, batch: PromptData, generation_config=None): def step_scheduler(self): # recover true kl - self.kl_scheduler.update(self._kl_queue) + self.kl_coef = self.kl_scheduler.update(self._kl_queue) if isinstance(self._kl_queue, (list, collections.deque)): # remove all values while len(self._kl_queue):