Skip to content

Commit 08b3f09

Browse files
committed
[Distributed] Fix new token's shape
1 parent edaa15c commit 08b3f09

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

dist_run.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def _batch_decode_next_tokens(
203203
# Take the next token logits for each prompt
204204
next_token_logits = output[:, pos, :]
205205
# Argmax (deterministic) TODO: add temperature
206-
next_token = torch.argmax(next_token_logits, dim=-1)
206+
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
207207
# Token ids in int tensor form
208208
return next_token
209209

@@ -418,9 +418,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
418418

419419
# Decode token id into string and print it
420420
def decode_in_flight(token):
421-
# Make a 2D tensor with ids on row dimension
422-
unsqueezed = torch.unsqueeze(token, 1)
423-
token_str = tokenizer.decode(unsqueezed.tolist())
421+
token_str = tokenizer.decode(token.tolist())
424422
if tp_rank == 0:
425423
logger.info(
426424
f"{color.green} responses ====>>>> "
@@ -498,8 +496,10 @@ def decode_in_flight(token):
498496

499497
# output formatted response via last pp group and tp rank 0
500498
if pp_rank == last_pp_rank and tp_rank == 0:
501-
# `res` is a list of tensors, each being a batch of generated token ids
502-
res = torch.stack(res, dim=1)
499+
# `res` is a list of tensors, each being a batch of generated token ids.
500+
# We need to concatenate them to get the full sequence of generated
501+
# token ids. Thus cat'ing along dim 1.
502+
res = torch.cat(res, dim=1)
503503
res_list = res.tolist()
504504
response = tokenizer.decode(res_list)
505505
for i in range(len(response)):

0 commit comments

Comments
 (0)