diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index 22835cca..f859ed89 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -51,6 +51,7 @@ actor_rollout_ref: temperature: 1.0 top_k: -1 # 0 for hf rollout, -1 for vllm rollout top_p: 1 + use_fire_sampling: False # https://arxiv.org/abs/2410.21236 prompt_length: ${data.max_prompt_length} # not use for opensource response_length: ${data.max_response_length} # for vllm rollout diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout.py b/verl/workers/rollout/vllm_rollout/vllm_rollout.py index 0dc2de08..e2e958f2 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_rollout.py +++ b/verl/workers/rollout/vllm_rollout/vllm_rollout.py @@ -118,6 +118,17 @@ def __init__(self, actor_module: nn.Module, config: DictConfig, tokenizer, model kwargs[k] = config.get(k) print(f"kwargs: {kwargs}") + + self.use_fire_sampling = config.get('use_fire_sampling', False) + if self.use_fire_sampling: + kwargs_0 = kwargs.copy() + kwargs_0['temperature'] = 30 + kwargs_0['max_tokens'] = 1 + if 'top_k' not in kwargs_0 or kwargs_0['top_k'] <= 0: + kwargs_0['top_k'] = 16 + kwargs['max_tokens'] -= 1 + self.sampling_params_0 = SamplingParams(**kwargs_0) + self.sampling_params = SamplingParams(**kwargs) self.pad_token_id = tokenizer.pad_token_id @@ -132,11 +143,24 @@ def update_sampling_params(self, **kwargs): old_value = getattr(self.sampling_params, key) old_sampling_params_args[key] = old_value setattr(self.sampling_params, key, value) + + if self.use_fire_sampling: + old_sampling_params_args_0 = {} + if kwargs: + for key, value in kwargs.items(): + if hasattr(self.sampling_params_0, key): + old_value = getattr(self.sampling_params_0, key) + old_sampling_params_args_0[key] = old_value + setattr(self.sampling_params_0, key, value) yield # roll back to previous sampling params # if len(old_sampling_params_args): for key, value in old_sampling_params_args.items(): setattr(self.sampling_params, key, value) + + if self.use_fire_sampling: + for key, value in old_sampling_params_args_0.items(): + setattr(self.sampling_params_0, key, value) @torch.no_grad() def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: @@ -169,16 +193,35 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: 'temperature': 0, } - # users can customize different sampling_params at different run - with self.update_sampling_params(**kwargs): - output = self.inference_engine.generate( - prompts=None, # because we have already convert it to prompt token id - sampling_params=self.sampling_params, - prompt_token_ids=idx_list, - use_tqdm=False) - - response = output[0].to(idx.device) # (bs, response_length) - log_probs = output[1].to(idx.device) # (bs, response_length) + if not self.use_fire_sampling: + # users can customize different sampling_params at different run + with self.update_sampling_params(**kwargs): + output = self.inference_engine.generate( + prompts=None, # because we have already convert it to prompt token id + sampling_params=self.sampling_params, + prompt_token_ids=idx_list, + use_tqdm=False) + + response = output[0].to(idx.device) # (bs, response_length) + log_probs = output[1].to(idx.device) # (bs, response_length) + else: + with self.update_sampling_params(**kwargs): + output_0 = self.inference_engine.generate( + prompts=None, # because we have already convert it to prompt token id + sampling_params=self.sampling_params_0, + prompt_token_ids=idx_list, + use_tqdm=False) + new_idx_list = [] + for i in range(batch_size): + new_idx_list.append(idx_list[i] + output_0[0][i].tolist()) + output = self.inference_engine.generate( + prompts=None, # because we have already convert it to prompt token id + sampling_params=self.sampling_params, + prompt_token_ids=new_idx_list, + use_tqdm=False) + + response = torch.cat([output_0[0], output[0]], dim=1).to(idx.device) # (bs, response_length) + log_probs = torch.cat([output_0[1], output[1]], dim=1).to(idx.device) # (bs, response_length) if response.shape[1] < self.config.response_length: response = pad_sequence_to_length(response, self.config.response_length, self.pad_token_id)