From 5952bd1e3d2666976e93e7d695deb398f0167c6a Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Thu, 3 Oct 2024 05:45:57 +0800 Subject: [PATCH] [Distributed] Fix new token's shape (#1254) --- dist_run.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/dist_run.py b/dist_run.py index f09261da4..ceb18bf37 100644 --- a/dist_run.py +++ b/dist_run.py @@ -209,6 +209,7 @@ def _batch_decode_next_tokens( batch_size, seq_len, vocab_size = output.shape if step != -1: + # `pos` is not provided, so we can use the first token next_token_logits = output[:, 0, :] else: # get the logits for each prompt at the specified positions @@ -228,9 +229,9 @@ def _batch_decode_next_tokens( ).squeeze(-1) else: # Argmax (deterministic) - next_tokens = torch.argmax(next_token_logits, dim=-1) + next_tokens = torch.argmax(next_token_logits, dim=-1, keepdim=True) - logger.info(f"{color.yellow}Next tokens: {color.blue}{next_tokens}{color.reset}") + # Token ids in int tensor form return next_tokens @@ -247,6 +248,11 @@ def _update_padded_sequence( # Decode token id into string and print it def _decode_in_flight(token, tokenizer, tp_rank): """decode token ids for all prompts in the batch and log them""" + # `token` is a tensor of shape (batch_size, 1). + # For TiktokenTokenizer, we need to squeeze it to 1D. + # For SentencePieceProcessor, we don't. + if isinstance(tokenizer, TiktokenTokenizer): + token = torch.squeeze(token, dim=1) token_str = tokenizer.decode(token.tolist()) # print the token string on tp rank 0 if tp_rank == 0: @@ -530,14 +536,12 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: # output formatted response via last pp group and tp rank 0 if pp_rank == last_pp_rank and tp_rank == 0: - # `res` is a list of tensors, each being a batch of generated token ids - - res_stacked = torch.stack(res, dim=1) - res_list = res_stacked.tolist() - - # Decode the output as comprehension instead of loop - responses = [tokenizer.decode(sequence) for sequence in res_list] - + # `res` is a list of tensors, each being a batch of generated token ids. + # We need to concatenate them to get the full sequence of generated + # token ids. Thus cat'ing along dim 1. + res = torch.cat(res, dim=1) + res_list = res.tolist() + responses = tokenizer.decode(res_list) # Show prompts and responses for prompt_text, response_text in zip(prompt, responses): logger.info(f"Prompt: {color.green}{prompt_text} {color.reset}")