Skip to content

Commit

Permalink
[Distributed] Fix new token's shape (#1254)
Browse files Browse the repository at this point in the history
  • Loading branch information
kwen2501 authored Oct 2, 2024
1 parent 8fcb3ba commit 5952bd1
Showing 1 changed file with 14 additions and 10 deletions.
24 changes: 14 additions & 10 deletions dist_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def _batch_decode_next_tokens(
batch_size, seq_len, vocab_size = output.shape

if step != -1:
# `pos` is not provided, so we can use the first token
next_token_logits = output[:, 0, :]
else:
# get the logits for each prompt at the specified positions
Expand All @@ -228,9 +229,9 @@ def _batch_decode_next_tokens(
).squeeze(-1)
else:
# Argmax (deterministic)
next_tokens = torch.argmax(next_token_logits, dim=-1)
next_tokens = torch.argmax(next_token_logits, dim=-1, keepdim=True)

logger.info(f"{color.yellow}Next tokens: {color.blue}{next_tokens}{color.reset}")
# Token ids in int tensor form
return next_tokens


Expand All @@ -247,6 +248,11 @@ def _update_padded_sequence(
# Decode token id into string and print it
def _decode_in_flight(token, tokenizer, tp_rank):
"""decode token ids for all prompts in the batch and log them"""
# `token` is a tensor of shape (batch_size, 1).
# For TiktokenTokenizer, we need to squeeze it to 1D.
# For SentencePieceProcessor, we don't.
if isinstance(tokenizer, TiktokenTokenizer):
token = torch.squeeze(token, dim=1)
token_str = tokenizer.decode(token.tolist())
# print the token string on tp rank 0
if tp_rank == 0:
Expand Down Expand Up @@ -530,14 +536,12 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:

# output formatted response via last pp group and tp rank 0
if pp_rank == last_pp_rank and tp_rank == 0:
# `res` is a list of tensors, each being a batch of generated token ids

res_stacked = torch.stack(res, dim=1)
res_list = res_stacked.tolist()

# Decode the output as comprehension instead of loop
responses = [tokenizer.decode(sequence) for sequence in res_list]

# `res` is a list of tensors, each being a batch of generated token ids.
# We need to concatenate them to get the full sequence of generated
# token ids. Thus cat'ing along dim 1.
res = torch.cat(res, dim=1)
res_list = res.tolist()
responses = tokenizer.decode(res_list)
# Show prompts and responses
for prompt_text, response_text in zip(prompt, responses):
logger.info(f"Prompt: {color.green}{prompt_text} {color.reset}")
Expand Down

0 comments on commit 5952bd1

Please sign in to comment.