Skip to content

Commit

Permalink
[Distributed] Fix new token's shape
Browse files Browse the repository at this point in the history
  • Loading branch information
kwen2501 committed Oct 2, 2024
1 parent edaa15c commit 08b3f09
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions dist_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def _batch_decode_next_tokens(
# Take the next token logits for each prompt
next_token_logits = output[:, pos, :]
# Argmax (deterministic) TODO: add temperature
next_token = torch.argmax(next_token_logits, dim=-1)
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
# Token ids in int tensor form
return next_token

Expand Down Expand Up @@ -418,9 +418,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:

# Decode token id into string and print it
def decode_in_flight(token):
# Make a 2D tensor with ids on row dimension
unsqueezed = torch.unsqueeze(token, 1)
token_str = tokenizer.decode(unsqueezed.tolist())
token_str = tokenizer.decode(token.tolist())
if tp_rank == 0:
logger.info(
f"{color.green} responses ====>>>> "
Expand Down Expand Up @@ -498,8 +496,10 @@ def decode_in_flight(token):

# output formatted response via last pp group and tp rank 0
if pp_rank == last_pp_rank and tp_rank == 0:
# `res` is a list of tensors, each being a batch of generated token ids
res = torch.stack(res, dim=1)
# `res` is a list of tensors, each being a batch of generated token ids.
# We need to concatenate them to get the full sequence of generated
# token ids. Thus cat'ing along dim 1.
res = torch.cat(res, dim=1)
res_list = res.tolist()
response = tokenizer.decode(res_list)
for i in range(len(response)):
Expand Down

0 comments on commit 08b3f09

Please sign in to comment.