Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Feb 12, 2024
1 parent 412bb67 commit 3d55991
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 10 deletions.
4 changes: 1 addition & 3 deletions examples/rlhf/train_rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,7 @@ def main(cfg):
# using a Gym-like API (querying steps etc) introduces some
# extra code that we can spare.
#
kl_scheduler = AdaptiveKLController(
model=model, init_kl_coef=0.1, target=6, horizon=10000
)
kl_scheduler = AdaptiveKLController(init_kl_coef=0.1, target=6, horizon=10000)
rollout_from_model = RolloutFromModel(
model,
ref_model,
Expand Down
12 changes: 5 additions & 7 deletions torchrl/data/rlhf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import abc
import collections
import importlib
from typing import Sequence, Tuple
from typing import List, Tuple

import numpy as np
import torch
Expand All @@ -30,7 +30,7 @@ class KLControllerBase(abc.ABC):
"""

@abc.abstractmethod
def update(self, kl_values: float) -> float:
def update(self, kl_values: List[float]) -> float:
...


Expand Down Expand Up @@ -63,7 +63,7 @@ def __init__(
if model is not None:
self.model.kl_coef = self.coef

def update(self, kl_values: Sequence[float] = None) -> float:
def update(self, kl_values: List[float] = None) -> float:
if self.model is not None:
self.model.kl_coef = self.coef
return self.coef
Expand Down Expand Up @@ -104,7 +104,7 @@ def __init__(
if model is not None:
self.model.kl_coef = self.coef

def update(self, kl_values: Sequence[float]):
def update(self, kl_values: List[float]):
"""Update ``self.coef`` adaptively.
Arguments:
Expand Down Expand Up @@ -256,8 +256,6 @@ def create_rollout_td(self, batch, generated, log_probs, log_ratio):
log_ratio (torch.Tensor): The log ratio of the probabilities of the generated tokens
according to the generative model and the reference model. Can be
obtained by calling the ``generate`` method.
kl_coef (float, optional): Coefficient with which to multiply the KL term before subtracting
from the reward. Defaults to 0.1.
Returns:
A :class:`~tensordict.TensorDict` with the following keys:
Expand Down Expand Up @@ -537,7 +535,7 @@ def generate(self, batch: PromptData, generation_config=None):

def step_scheduler(self):
# recover true kl
self.kl_scheduler.update(self._kl_queue)
self.kl_coef = self.kl_scheduler.update(self._kl_queue)
if isinstance(self._kl_queue, (list, collections.deque)):
# remove all values
while len(self._kl_queue):
Expand Down

0 comments on commit 3d55991

Please sign in to comment.