Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement REINFORCE++ algorithm #228

Merged
merged 4 commits into from
Feb 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions docs/examples/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -324,9 +324,8 @@ Algorithm
- ``gemma``: discount factor
- ``lam``: Trade-off between bias and variance in the GAE estimator
- ``adv_estimator``: gae. Currently only supports gae, will support GRPO
in the future
- ``kl_penalty``\ :Support ``kl``, ``abs``, ``mse`` and ``full``.How to
- ``adv_estimator``: Support ``gae``, ``grpo``, ``reinforce_plus_plus``.
- ``kl_penalty``: Support ``kl``, ``abs``, ``mse`` and ``full``. How to
calculate the kl divergence between actor and reference policy. For
specific options, refer to `core_algos.py <https://github.com/volcengine/verl/blob/main/verl/trainer/ppo/core_algos.py#L192>`_ .

Expand Down
34 changes: 34 additions & 0 deletions verl/trainer/ppo/core_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,40 @@ def compute_grpo_outcome_advantage(token_level_rewards: torch.Tensor,
return scores, scores


def compute_reinforce_plus_plus_outcome_advantage(token_level_rewards: torch.Tensor, eos_mask: torch.Tensor,
gamma: torch.Tensor):
"""
Compute advantage for REINFORCE++.
This implementation is based on the paper: https://arxiv.org/abs/2501.03262
Args:
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
eos_mask: `(torch.Tensor)`
shape: (bs, response_length)

Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""

with torch.no_grad():
returns = torch.zeros_like(token_level_rewards)
running_return = 0

for t in reversed(range(token_level_rewards.shape[1])):
running_return = token_level_rewards[:, t] + gamma * running_return
returns[:, t] = running_return
# Reset after EOS
running_return = running_return * eos_mask[:, t]

advantages = verl_F.masked_whiten(returns, eos_mask)
advantages = advantages * eos_mask

return advantages, returns


def compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio):
kl = old_log_prob - ref_log_prob
return token_level_scores - kl * kl_ratio
Expand Down
12 changes: 12 additions & 0 deletions verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,16 @@ def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_re
index=index)
data.batch['advantages'] = advantages
data.batch['returns'] = returns
elif adv_estimator == 'reinforce_plus_plus':
eric-haibin-lin marked this conversation as resolved.
Show resolved Hide resolved
token_level_rewards = data.batch['token_level_rewards']
responses = data.batch['responses']
response_length = responses.size(-1)
attention_mask = data.batch['attention_mask']
response_mask = attention_mask[:, -response_length:]
advantages, returns = core_algos.compute_reinforce_plus_plus_outcome_advantage(
token_level_rewards=token_level_rewards, eos_mask=response_mask, gamma=gamma)
data.batch['advantages'] = advantages
data.batch['returns'] = returns
else:
raise NotImplementedError
return data
Expand Down Expand Up @@ -343,6 +353,8 @@ def __init__(self,
self.use_critic = True
elif self.config.algorithm.adv_estimator == 'grpo':
self.use_critic = False
elif self.config.algorithm.adv_estimator == 'reinforce_plus_plus':
self.use_critic = False
else:
raise NotImplementedError

Expand Down
Loading