-
Notifications
You must be signed in to change notification settings - Fork 250
[Distributed] Fix new token's shape #1254
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchchat/1254
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 1f8ff93 with merge base 8fcb3ba ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
dist_run.py
Outdated
# Make a 2D tensor with ids on row dimension | ||
unsqueezed = torch.unsqueeze(token, 1) | ||
token_str = tokenizer.decode(unsqueezed.tolist()) | ||
token_str = tokenizer.decode(token.tolist()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I find that this does not work as-is.
However, adding in a one liner before the tokenizer.decode line (421):
token = token.squeeze(1)
And now it works on both llama2 and llama3.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, makes sense. Tokenizer difference :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems I cannot unconditionally squeeze the tensor.
If I do, some tokenizers will output:
responses ====>>>> is Christmasiving
istead of
responses ====>>>> ['is', 'Christmas', 'iving']
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am:
using tokenizer = sentencepiece.SentencePieceProcessor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, I am adding an if
there:
# `token` is a tensor of shape (batch_size, 1).
# For TiktokenTokenizer, we need to squeeze it to 1D.
# For SentencePieceProcessor, we don't.
if isinstance(tokenizer, TiktokenTokenizer):
token = torch.squeeze(token, dim=1)
token_str = tokenizer.decode(token.tolist())
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for adding!
I find that there is a missing line for this PR to work but with that line, verified working on llama2 and llama3.
Stamping to land, please add in the squeeze line.
08b3f09
to
4cdf355
Compare
4cdf355
to
1f8ff93
Compare
Issue
TP-only case is broken due to the following error:
It suggests that in the decoding phase, our
input_ids
(i.e.new_tokens
) is flattened rather than being 2D (batch_size, 1).The flattening happens here:
Fix
The fix is simple, we just add a
keepdim=True
flag to torch.argmax.With that, the
unsqueeze
op indecode_in_flight
can be also saved.