Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIRE sampling added. #58

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions verl/trainer/config/ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ actor_rollout_ref:
temperature: 1.0
top_k: -1 # 0 for hf rollout, -1 for vllm rollout
top_p: 1
use_fire_sampling: False # https://arxiv.org/abs/2410.21236
prompt_length: ${data.max_prompt_length} # not use for opensource
response_length: ${data.max_response_length}
# for vllm rollout
Expand Down
63 changes: 53 additions & 10 deletions verl/workers/rollout/vllm_rollout/vllm_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,17 @@ def __init__(self, actor_module: nn.Module, config: DictConfig, tokenizer, model
kwargs[k] = config.get(k)

print(f"kwargs: {kwargs}")

self.use_fire_sampling = config.get('use_fire_sampling', False)
if self.use_fire_sampling:
kwargs_0 = kwargs.copy()
kwargs_0['temperature'] = 30
kwargs_0['max_tokens'] = 1
if 'top_k' not in kwargs_0 or kwargs_0['top_k'] <= 0:
kwargs_0['top_k'] = 16
kwargs['max_tokens'] -= 1
self.sampling_params_0 = SamplingParams(**kwargs_0)

self.sampling_params = SamplingParams(**kwargs)

self.pad_token_id = tokenizer.pad_token_id
Expand All @@ -132,11 +143,24 @@ def update_sampling_params(self, **kwargs):
old_value = getattr(self.sampling_params, key)
old_sampling_params_args[key] = old_value
setattr(self.sampling_params, key, value)

if self.use_fire_sampling:
old_sampling_params_args_0 = {}
if kwargs:
for key, value in kwargs.items():
if hasattr(self.sampling_params_0, key):
old_value = getattr(self.sampling_params_0, key)
old_sampling_params_args_0[key] = old_value
setattr(self.sampling_params_0, key, value)
yield
# roll back to previous sampling params
# if len(old_sampling_params_args):
for key, value in old_sampling_params_args.items():
setattr(self.sampling_params, key, value)

if self.use_fire_sampling:
for key, value in old_sampling_params_args_0.items():
setattr(self.sampling_params_0, key, value)

@torch.no_grad()
def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
Expand Down Expand Up @@ -169,16 +193,35 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
'temperature': 0,
}

# users can customize different sampling_params at different run
with self.update_sampling_params(**kwargs):
output = self.inference_engine.generate(
prompts=None, # because we have already convert it to prompt token id
sampling_params=self.sampling_params,
prompt_token_ids=idx_list,
use_tqdm=False)

response = output[0].to(idx.device) # (bs, response_length)
log_probs = output[1].to(idx.device) # (bs, response_length)
if not self.use_fire_sampling:
# users can customize different sampling_params at different run
with self.update_sampling_params(**kwargs):
output = self.inference_engine.generate(
prompts=None, # because we have already convert it to prompt token id
sampling_params=self.sampling_params,
prompt_token_ids=idx_list,
use_tqdm=False)

response = output[0].to(idx.device) # (bs, response_length)
log_probs = output[1].to(idx.device) # (bs, response_length)
else:
with self.update_sampling_params(**kwargs):
output_0 = self.inference_engine.generate(
prompts=None, # because we have already convert it to prompt token id
sampling_params=self.sampling_params_0,
prompt_token_ids=idx_list,
use_tqdm=False)
new_idx_list = []
for i in range(batch_size):
new_idx_list.append(idx_list[i] + output_0[0][i].tolist())
output = self.inference_engine.generate(
prompts=None, # because we have already convert it to prompt token id
sampling_params=self.sampling_params,
prompt_token_ids=new_idx_list,
use_tqdm=False)

response = torch.cat([output_0[0], output[0]], dim=1).to(idx.device) # (bs, response_length)
log_probs = torch.cat([output_0[1], output[1]], dim=1).to(idx.device) # (bs, response_length)

if response.shape[1] < self.config.response_length:
response = pad_sequence_to_length(response, self.config.response_length, self.pad_token_id)
Expand Down
Loading