Skip to content

Commit

Permalink
[tokenizer] fix: fix pad token if pad token is None
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-haibin-lin committed Dec 3, 2024
1 parent 292b60b commit 73f1984
Show file tree
Hide file tree
Showing 11 changed files with 26 additions and 19 deletions.
2 changes: 1 addition & 1 deletion verl/third_party/vllm/vllm_v_0_5_4/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def get_lora_tokenizer(self, lora_request: Optional[LoRARequest]) -> "PreTrained
# FIXME(sgm): for simplicity, we assign the special token here
@property
def pad_token_id(self):
return self.tokenizer.pad_token_id
return self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id

@property
def eos_token_id(self):
Expand Down
2 changes: 1 addition & 1 deletion verl/trainer/main_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def main(config):
skip_special_tokens=False)

# remove the padding
pad_token = tokenizer.pad_token
pad_token = tokenizer.pad_token if tokenizer.pad_token is not None else tokenizer.eos_token
output_text_unpad = []
for text in output_text:
output_text_unpad.append(text.replace(pad_token, ''))
Expand Down
1 change: 1 addition & 0 deletions verl/trainer/ppo/hybrid_engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from verl.utils.import_utils import is_vllm_available, is_megatron_core_available
from .base import BaseShardingManager

AllGatherPPModel = None

Expand Down
2 changes: 1 addition & 1 deletion verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def _validate(self):
test_gen_batch = test_batch.pop(['input_ids', 'attention_mask', 'position_ids'])
test_gen_batch.meta_info = {
'eos_token_id': self.tokenizer.eos_token_id,
'pad_token_id': self.tokenizer.pad_token_id,
'pad_token_id': self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id,
'recompute_log_prob': False,
'do_sample': False,
'validate': True,
Expand Down
3 changes: 2 additions & 1 deletion verl/trainer/ppo/reward_model/megatron/reward_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ def re_encode_by_rm_tokenizer(self, data: DataProto) -> DataProto:
rm_attention_mask = rm_attention_mask[:ori_seqlen]
else:
# right padding
rm_input_ids = pad_sequence_to_length(rm_input_ids, ori_seqlen, self.rm_tokenizer.pad_token_id)
pad_token_id = self.rm_tokenizer.pad_token_id if self.rm_tokenizer.pad_token_id is not None else self.rm_tokenizer.eos_token_id
rm_input_ids = pad_sequence_to_length(rm_input_ids, ori_seqlen, pad_token_id)
rm_attention_mask = pad_sequence_to_length(rm_attention_mask, ori_seqlen, 0)
rm_position_ids = torch.arange(0, ori_seqlen, device=input_ids.device)
input_ids_for_rm.append(torch.unsqueeze(rm_input_ids, dim=0))
Expand Down
5 changes: 3 additions & 2 deletions verl/trainer/ppo/rollout/vllm_rollout/vllm_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
# NOTE(sgm): add for verl. We can optimize it by making the dataloader yield List[int] without padding.
def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor) -> List[int]:
# remove the left padding in the prompt token_id
# pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None else self.llm_engine.tokenizer.eos_token_id
non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][0]
token_ids = prompt_token_ids[non_pad_index:].tolist()
return token_ids
Expand Down Expand Up @@ -120,7 +119,9 @@ def __init__(self, actor_module: nn.Module, config: DictConfig, tokenizer, model
print(f"kwargs: {kwargs}")
self.sampling_params = SamplingParams(**kwargs)

self.pad_token_id = tokenizer.pad_token_id
# Manually set pad token for llama.
# See discussions: https://discuss.huggingface.co/t/how-to-set-the-pad-token-for-meta-llama-llama-3-models/103418/7
self.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id

@contextmanager
def update_sampling_params(self, **kwargs):
Expand Down
10 changes: 6 additions & 4 deletions verl/trainer/ppo/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def _build_model_optimizer(self,
override_config_kwargs = {
'bos_token_id': self.tokenizer.bos_token_id,
'eos_token_id': self.tokenizer.eos_token_id,
'pad_token_id': self.tokenizer.pad_token_id,
'pad_token_id': self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id,
}
override_config_kwargs.update(override_model_config)
update_model_config(actor_model_config, override_config_kwargs=override_config_kwargs)
Expand Down Expand Up @@ -349,7 +349,8 @@ def generate_sequences(self, prompts: DataProto):
load_grad=self._is_offload_grad)

prompts.batch = prompts.batch.cuda()
meta_info = {'eos_token_id': self.tokenizer.eos_token_id, 'pad_token_id': self.tokenizer.pad_token_id}
pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
meta_info = {'eos_token_id': self.tokenizer.eos_token_id, 'pad_token_id': pad_token_id}
prompts.meta_info.update(meta_info)
with self.sharding_manager:
log_gpu_memory_usage('After entering sharding manager', logger=logger)
Expand Down Expand Up @@ -472,7 +473,7 @@ def _build_critic_model_optimizer(self, config):
override_config_kwargs = {
'bos_token_id': self.tokenizer.bos_token_id,
'eos_token_id': self.tokenizer.eos_token_id,
'pad_token_id': self.tokenizer.pad_token_id,
'pad_token_id': self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id,
}
override_config_kwargs.update(override_config)
if self.rank == 0:
Expand Down Expand Up @@ -770,11 +771,12 @@ def _switch_chat_template(self, data: DataProto):
max_length = self.config.get('max_length', src_max_length)
if max_length is None:
max_length = src_max_length
target_pad_token_id = target_tokenizer.pad_token_id if target_tokenizer.pad_token_id is not None else target_tokenizer.eos_token_id
input_ids, attention_mask = verl_F.tokenize_and_postprocess_data(
prompt=prompt_with_chat_template,
tokenizer=target_tokenizer,
max_length=max_length,
pad_token_id=target_tokenizer.pad_token_id,
pad_token_id=target_pad_token_id,
left_pad=False, # right padding
truncation=self.config.get('truncation', 'right')) # truncate from the right

Expand Down
11 changes: 6 additions & 5 deletions verl/trainer/ppo/workers/megatron_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,11 @@ def _build_model_optimizer(self,

# Step 2: get the actor_model_config
actor_model_config = AutoConfig.from_pretrained(local_path)

pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
override_config_kwargs = {
'bos_token_id': self.tokenizer.bos_token_id,
'eos_token_id': self.tokenizer.eos_token_id,
'pad_token_id': self.tokenizer.pad_token_id,
'pad_token_id': pad_token_id,
}
override_config_kwargs.update(override_model_config)
update_model_config(actor_model_config, override_config_kwargs=override_config_kwargs)
Expand Down Expand Up @@ -346,7 +346,8 @@ def generate_sequences(self, prompts: DataProto):
assert self._is_rollout

prompts.batch = prompts.batch.cuda()
meta_info = {'eos_token_id': self.tokenizer.eos_token_id, 'pad_token_id': self.tokenizer.pad_token_id}
pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
meta_info = {'eos_token_id': self.tokenizer.eos_token_id, 'pad_token_id': pad_token_id}
prompts.meta_info.update(meta_info)
with self.sharding_manager:
log_gpu_memory_usage('After entering sharding manager', logger=logger)
Expand Down Expand Up @@ -466,7 +467,7 @@ def _build_critic_model_optimizer(self,
override_config_kwargs = {
'bos_token_id': self.tokenizer.bos_token_id,
'eos_token_id': self.tokenizer.eos_token_id,
'pad_token_id': self.tokenizer.pad_token_id,
'pad_token_id': self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id,
}
override_config_kwargs.update(override_model_config)
update_model_config(critic_model_config, override_config_kwargs=override_config_kwargs)
Expand Down Expand Up @@ -629,7 +630,7 @@ def _build_rm_model(self, model_path, megatron_config: ModelParallelConfig, over
override_config_kwargs = {
'bos_token_id': self.tokenizer.bos_token_id,
'eos_token_id': self.tokenizer.eos_token_id,
'pad_token_id': self.tokenizer.pad_token_id,
'pad_token_id': self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id,
}
override_config_kwargs.update(override_model_config)
update_model_config(rm_model_config, override_config_kwargs=override_config_kwargs)
Expand Down
4 changes: 2 additions & 2 deletions verl/utils/dataset/rl_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,11 @@ def __getitem__(self, item):
chat = row_dict.pop(self.prompt_key)

prompt_with_chat_template = self.tokenizer.apply_chat_template(chat, add_generation_prompt=True, tokenize=False)

pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
input_ids, attention_mask = verl_F.tokenize_and_postprocess_data(prompt=prompt_with_chat_template,
tokenizer=self.tokenizer,
max_length=self.max_prompt_length,
pad_token_id=self.tokenizer.pad_token_id,
pad_token_id=pad_token_id,
left_pad=True,
truncation=self.truncation)

Expand Down
3 changes: 2 additions & 1 deletion verl/utils/dataset/sft_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,9 @@ def __getitem__(self, item):
# padding to max length
sequence_length = input_ids.shape[0]
if sequence_length < self.max_length:
pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
padded_input_ids = torch.ones(size=(self.max_length - sequence_length,),
dtype=input_ids.dtype) * self.tokenizer.pad_token_id
dtype=input_ids.dtype) * pad_token_id
padded_attention_mask = torch.zeros(size=(self.max_length - sequence_length,), dtype=attention_mask.dtype)

input_ids = torch.cat((input_ids, padded_input_ids))
Expand Down
2 changes: 1 addition & 1 deletion verl/utils/logger/aggregate_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,4 @@ def flush(self):

def log(self, data, step):
if self.print_to_console:
print(concat_dict_to_str(data, step=step))
print(concat_dict_to_str(data, step=step), flush=True)

0 comments on commit 73f1984

Please sign in to comment.