From 62ca7b846cc69f2207c9f1e7143e90f78e7f6e0e Mon Sep 17 00:00:00 2001 From: 4332001876 Date: Sat, 8 Feb 2025 22:25:58 +0800 Subject: [PATCH 1/4] implement REINFORCE++ algorithm --- verl/trainer/ppo/core_algos.py | 34 +++++++++++++++++++++++++++++++++ verl/trainer/ppo/ray_trainer.py | 13 +++++++++++++ 2 files changed, 47 insertions(+) diff --git a/verl/trainer/ppo/core_algos.py b/verl/trainer/ppo/core_algos.py index 4869793d..c3c327ff 100644 --- a/verl/trainer/ppo/core_algos.py +++ b/verl/trainer/ppo/core_algos.py @@ -154,6 +154,40 @@ def compute_grpo_outcome_advantage(token_level_rewards: torch.Tensor, return scores, scores +def compute_reinforce_plus_plus_outcome_advantage(token_level_rewards: torch.Tensor, + eos_mask: torch.Tensor, + gamma: torch.Tensor): + """ + Compute advantage for REINFORCE++. + Args: + token_level_rewards: `(torch.Tensor)` + shape: (bs, response_length) + eos_mask: `(torch.Tensor)` + shape: (bs, response_length) + + Returns: + advantages: `(torch.Tensor)` + shape: (bs, response_length) + Returns: `(torch.Tensor)` + shape: (bs, response_length) + """ + + with torch.no_grad(): + returns = torch.zeros_like(token_level_rewards) + running_return = 0 + + for t in reversed(range(token_level_rewards.shape[1])): + running_return = token_level_rewards[:, t] + gamma * running_return + returns[:, t] = running_return + # Reset after EOS + running_return = running_return * eos_mask[:, t] + + advantages = verl_F.masked_whiten(returns, eos_mask) + advantages = advantages * eos_mask + + return advantages, returns + + def compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio): kl = old_log_prob - ref_log_prob return token_level_scores - kl * kl_ratio diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 5145845e..8f9bd029 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -144,6 +144,17 @@ def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_re index=index) data.batch['advantages'] = advantages data.batch['returns'] = returns + elif adv_estimator == 'reinforce_plus_plus': + token_level_rewards = data.batch['token_level_rewards'] + responses = data.batch['responses'] + response_length = responses.size(-1) + attention_mask = data.batch['attention_mask'] + response_mask = attention_mask[:, -response_length:] + advantages, returns = core_algos.compute_reinforce_plus_plus_outcome_advantage(token_level_rewards=token_level_rewards, + eos_mask=response_mask, + gamma=gamma) + data.batch['advantages'] = advantages + data.batch['returns'] = returns else: raise NotImplementedError return data @@ -343,6 +354,8 @@ def __init__(self, self.use_critic = True elif self.config.algorithm.adv_estimator == 'grpo': self.use_critic = False + elif self.config.algorithm.adv_estimator == 'reinforce_plus_plus': + self.use_critic = False else: raise NotImplementedError From 20fbf5c825802b5c8fcb577fb6c4f2809f85c165 Mon Sep 17 00:00:00 2001 From: 4332001876 Date: Sun, 9 Feb 2025 15:47:29 +0800 Subject: [PATCH 2/4] add citation for R++; formatting --- verl/trainer/ppo/core_algos.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/verl/trainer/ppo/core_algos.py b/verl/trainer/ppo/core_algos.py index c3c327ff..85c847d2 100644 --- a/verl/trainer/ppo/core_algos.py +++ b/verl/trainer/ppo/core_algos.py @@ -158,7 +158,8 @@ def compute_reinforce_plus_plus_outcome_advantage(token_level_rewards: torch.Ten eos_mask: torch.Tensor, gamma: torch.Tensor): """ - Compute advantage for REINFORCE++. + Compute advantage for REINFORCE++. + This implementation is based on the paper: https://arxiv.org/abs/2501.03262 Args: token_level_rewards: `(torch.Tensor)` shape: (bs, response_length) From 83b0cd4612f399055e333f1d8fe073c1d89cf10e Mon Sep 17 00:00:00 2001 From: 4332001876 Date: Sun, 9 Feb 2025 16:32:26 +0800 Subject: [PATCH 3/4] formatting --- verl/trainer/ppo/core_algos.py | 9 ++++----- verl/trainer/ppo/ray_trainer.py | 5 ++--- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/verl/trainer/ppo/core_algos.py b/verl/trainer/ppo/core_algos.py index 85c847d2..6f5fee01 100644 --- a/verl/trainer/ppo/core_algos.py +++ b/verl/trainer/ppo/core_algos.py @@ -154,9 +154,8 @@ def compute_grpo_outcome_advantage(token_level_rewards: torch.Tensor, return scores, scores -def compute_reinforce_plus_plus_outcome_advantage(token_level_rewards: torch.Tensor, - eos_mask: torch.Tensor, - gamma: torch.Tensor): +def compute_reinforce_plus_plus_outcome_advantage(token_level_rewards: torch.Tensor, eos_mask: torch.Tensor, + gamma: torch.Tensor): """ Compute advantage for REINFORCE++. This implementation is based on the paper: https://arxiv.org/abs/2501.03262 @@ -176,7 +175,7 @@ def compute_reinforce_plus_plus_outcome_advantage(token_level_rewards: torch.Ten with torch.no_grad(): returns = torch.zeros_like(token_level_rewards) running_return = 0 - + for t in reversed(range(token_level_rewards.shape[1])): running_return = token_level_rewards[:, t] + gamma * running_return returns[:, t] = running_return @@ -185,7 +184,7 @@ def compute_reinforce_plus_plus_outcome_advantage(token_level_rewards: torch.Ten advantages = verl_F.masked_whiten(returns, eos_mask) advantages = advantages * eos_mask - + return advantages, returns diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 8f9bd029..e0092c08 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -150,9 +150,8 @@ def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_re response_length = responses.size(-1) attention_mask = data.batch['attention_mask'] response_mask = attention_mask[:, -response_length:] - advantages, returns = core_algos.compute_reinforce_plus_plus_outcome_advantage(token_level_rewards=token_level_rewards, - eos_mask=response_mask, - gamma=gamma) + advantages, returns = core_algos.compute_reinforce_plus_plus_outcome_advantage( + token_level_rewards=token_level_rewards, eos_mask=response_mask, gamma=gamma) data.batch['advantages'] = advantages data.batch['returns'] = returns else: From 099aee0266cec1f919249703ae81e9ad56cb7461 Mon Sep 17 00:00:00 2001 From: 4332001876 Date: Sun, 9 Feb 2025 17:28:27 +0800 Subject: [PATCH 4/4] update related document --- docs/examples/config.rst | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/docs/examples/config.rst b/docs/examples/config.rst index cc503580..947d4252 100644 --- a/docs/examples/config.rst +++ b/docs/examples/config.rst @@ -324,9 +324,8 @@ Algorithm - ``gemma``: discount factor - ``lam``: Trade-off between bias and variance in the GAE estimator -- ``adv_estimator``: gae. Currently only supports gae, will support GRPO - in the future -- ``kl_penalty``\ :Support ``kl``, ``abs``, ``mse`` and ``full``.How to +- ``adv_estimator``: Support ``gae``, ``grpo``, ``reinforce_plus_plus``. +- ``kl_penalty``: Support ``kl``, ``abs``, ``mse`` and ``full``. How to calculate the kl divergence between actor and reference policy. For specific options, refer to `core_algos.py `_ .