@@ -203,7 +203,7 @@ def _batch_decode_next_tokens(
203
203
# Take the next token logits for each prompt
204
204
next_token_logits = output [:, pos , :]
205
205
# Argmax (deterministic) TODO: add temperature
206
- next_token = torch .argmax (next_token_logits , dim = - 1 )
206
+ next_token = torch .argmax (next_token_logits , dim = - 1 , keepdim = True )
207
207
# Token ids in int tensor form
208
208
return next_token
209
209
@@ -418,9 +418,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
418
418
419
419
# Decode token id into string and print it
420
420
def decode_in_flight (token ):
421
- # Make a 2D tensor with ids on row dimension
422
- unsqueezed = torch .unsqueeze (token , 1 )
423
- token_str = tokenizer .decode (unsqueezed .tolist ())
421
+ token_str = tokenizer .decode (token .tolist ())
424
422
if tp_rank == 0 :
425
423
logger .info (
426
424
f"{ color .green } responses ====>>>> "
@@ -498,8 +496,10 @@ def decode_in_flight(token):
498
496
499
497
# output formatted response via last pp group and tp rank 0
500
498
if pp_rank == last_pp_rank and tp_rank == 0 :
501
- # `res` is a list of tensors, each being a batch of generated token ids
502
- res = torch .stack (res , dim = 1 )
499
+ # `res` is a list of tensors, each being a batch of generated token ids.
500
+ # We need to concatenate them to get the full sequence of generated
501
+ # token ids. Thus cat'ing along dim 1.
502
+ res = torch .cat (res , dim = 1 )
503
503
res_list = res .tolist ()
504
504
response = tokenizer .decode (res_list )
505
505
for i in range (len (response )):
0 commit comments