Skip to content

Commit

Permalink
fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-haibin-lin committed Dec 3, 2024
1 parent 73f1984 commit 369e2ea
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 17 deletions.
16 changes: 11 additions & 5 deletions verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,11 +293,17 @@ def _validate(self):

test_gen_batch = test_batch.pop(['input_ids', 'attention_mask', 'position_ids'])
test_gen_batch.meta_info = {
'eos_token_id': self.tokenizer.eos_token_id,
'pad_token_id': self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id,
'recompute_log_prob': False,
'do_sample': False,
'validate': True,
'eos_token_id':
self.tokenizer.eos_token_id,
'pad_token_id':
self.tokenizer.pad_token_id
if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id,
'recompute_log_prob':
False,
'do_sample':
False,
'validate':
True,
}

test_output_gen_batch = self.actor_rollout_wg.generate_sequences(test_gen_batch)
Expand Down
18 changes: 12 additions & 6 deletions verl/trainer/ppo/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,12 @@ def _build_model_optimizer(self,
actor_model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code)

override_config_kwargs = {
'bos_token_id': self.tokenizer.bos_token_id,
'eos_token_id': self.tokenizer.eos_token_id,
'pad_token_id': self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id,
'bos_token_id':
self.tokenizer.bos_token_id,
'eos_token_id':
self.tokenizer.eos_token_id,
'pad_token_id':
self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id,
}
override_config_kwargs.update(override_model_config)
update_model_config(actor_model_config, override_config_kwargs=override_config_kwargs)
Expand Down Expand Up @@ -471,9 +474,12 @@ def _build_critic_model_optimizer(self, config):
from omegaconf import OmegaConf
override_config = OmegaConf.to_container(self.config.model.get('override_config', OmegaConf.create()))
override_config_kwargs = {
'bos_token_id': self.tokenizer.bos_token_id,
'eos_token_id': self.tokenizer.eos_token_id,
'pad_token_id': self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id,
'bos_token_id':
self.tokenizer.bos_token_id,
'eos_token_id':
self.tokenizer.eos_token_id,
'pad_token_id':
self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id,
}
override_config_kwargs.update(override_config)
if self.rank == 0:
Expand Down
18 changes: 12 additions & 6 deletions verl/trainer/ppo/workers/megatron_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,9 +465,12 @@ def _build_critic_model_optimizer(self,
critic_model_config = AutoConfig.from_pretrained(local_path)

override_config_kwargs = {
'bos_token_id': self.tokenizer.bos_token_id,
'eos_token_id': self.tokenizer.eos_token_id,
'pad_token_id': self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id,
'bos_token_id':
self.tokenizer.bos_token_id,
'eos_token_id':
self.tokenizer.eos_token_id,
'pad_token_id':
self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id,
}
override_config_kwargs.update(override_model_config)
update_model_config(critic_model_config, override_config_kwargs=override_config_kwargs)
Expand Down Expand Up @@ -628,9 +631,12 @@ def _build_rm_model(self, model_path, megatron_config: ModelParallelConfig, over
rm_model_config = AutoConfig.from_pretrained(local_path)

override_config_kwargs = {
'bos_token_id': self.tokenizer.bos_token_id,
'eos_token_id': self.tokenizer.eos_token_id,
'pad_token_id': self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id,
'bos_token_id':
self.tokenizer.bos_token_id,
'eos_token_id':
self.tokenizer.eos_token_id,
'pad_token_id':
self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id,
}
override_config_kwargs.update(override_model_config)
update_model_config(rm_model_config, override_config_kwargs=override_config_kwargs)
Expand Down

0 comments on commit 369e2ea

Please sign in to comment.