Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Distributed] Fix new token's shape #1254

Merged
merged 1 commit into from
Oct 2, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions dist_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def _batch_decode_next_tokens(
batch_size, seq_len, vocab_size = output.shape

if step != -1:
# `pos` is not provided, so we can use the first token
next_token_logits = output[:, 0, :]
else:
# get the logits for each prompt at the specified positions
Expand All @@ -228,9 +229,9 @@ def _batch_decode_next_tokens(
).squeeze(-1)
else:
# Argmax (deterministic)
next_tokens = torch.argmax(next_token_logits, dim=-1)
next_tokens = torch.argmax(next_token_logits, dim=-1, keepdim=True)

logger.info(f"{color.yellow}Next tokens: {color.blue}{next_tokens}{color.reset}")
# Token ids in int tensor form
return next_tokens


Expand All @@ -247,6 +248,11 @@ def _update_padded_sequence(
# Decode token id into string and print it
def _decode_in_flight(token, tokenizer, tp_rank):
"""decode token ids for all prompts in the batch and log them"""
# `token` is a tensor of shape (batch_size, 1).
# For TiktokenTokenizer, we need to squeeze it to 1D.
# For SentencePieceProcessor, we don't.
if isinstance(tokenizer, TiktokenTokenizer):
token = torch.squeeze(token, dim=1)
token_str = tokenizer.decode(token.tolist())
# print the token string on tp rank 0
if tp_rank == 0:
Expand Down Expand Up @@ -530,14 +536,12 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:

# output formatted response via last pp group and tp rank 0
if pp_rank == last_pp_rank and tp_rank == 0:
# `res` is a list of tensors, each being a batch of generated token ids

res_stacked = torch.stack(res, dim=1)
res_list = res_stacked.tolist()

# Decode the output as comprehension instead of loop
responses = [tokenizer.decode(sequence) for sequence in res_list]

# `res` is a list of tensors, each being a batch of generated token ids.
# We need to concatenate them to get the full sequence of generated
# token ids. Thus cat'ing along dim 1.
res = torch.cat(res, dim=1)
res_list = res.tolist()
responses = tokenizer.decode(res_list)
# Show prompts and responses
for prompt_text, response_text in zip(prompt, responses):
logger.info(f"Prompt: {color.green}{prompt_text} {color.reset}")
Expand Down
Loading