Skip to content

[Distributed] Fix new token's shape #1254

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 2, 2024
Merged

[Distributed] Fix new token's shape #1254

merged 1 commit into from
Oct 2, 2024

Conversation

kwen2501
Copy link
Contributor

@kwen2501 kwen2501 commented Oct 2, 2024

Issue

TP-only case is broken due to the following error:

[rank1]:   File "/home/kw2501/local/torchchat/torchchat/model.py", line 815, in forward
[rank1]:     bsz, seqlen, _ = x.shape
[rank1]: ValueError: not enough values to unpack (expected 3, got 2)
[rank3]:   File "/home/kw2501/local/torchchat/dist_run.py", line 477, in main
[rank3]:     output = decorder.step(new_token, **kwargs)
[rank3]:   File "/home/kw2501/local/pytorch/torch/distributed/pipelining/schedules.py", line 610, in step
[rank3]:     self._step_microbatches(args_split, kwargs_split, targets_split, losses)
[rank3]:   File "/home/kw2501/local/pytorch/torch/distributed/pipelining/schedules.py", line 710, in _step_microbatches
[rank3]:     output = self._stage.forward_one_chunk(i, arg_mbs[i], kwarg_mbs[i])  # type: ignore[index]
[rank3]:   File "/home/kw2501/local/pytorch/torch/distributed/pipelining/stage.py", line 595, in forward_one_chunk
[rank3]:     raise RuntimeError(exc_msg) from e
[rank3]: RuntimeError: 
[rank3]:             [Stage 0] failed to run forward:
[rank3]:             args: ('Tensor(torch.Size([4]), grad=False, dtype=torch.int64)',)
[rank3]:             kwargs: {'input_pos': 'Tensor(torch.Size([1]), grad=False, dtype=torch.int64)', 'cache_lane': '0'}

It suggests that in the decoding phase, our input_ids (i.e. new_tokens) is flattened rather than being 2D (batch_size, 1).

The flattening happens here:

    # Argmax (deterministic) TODO: add temperature
    next_token = torch.argmax(next_token_logits, dim=-1)

Fix

The fix is simple, we just add a keepdim=True flag to torch.argmax.
With that, the unsqueeze op in decode_in_flight can be also saved.

Copy link

pytorch-bot bot commented Oct 2, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchchat/1254

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 1f8ff93 with merge base 8fcb3ba (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 2, 2024
dist_run.py Outdated
# Make a 2D tensor with ids on row dimension
unsqueezed = torch.unsqueeze(token, 1)
token_str = tokenizer.decode(unsqueezed.tolist())
token_str = tokenizer.decode(token.tolist())
Copy link
Contributor

@lessw2020 lessw2020 Oct 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find that this does not work as-is.
However, adding in a one liner before the tokenizer.decode line (421):

token = token.squeeze(1)

And now it works on both llama2 and llama3.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, makes sense. Tokenizer difference :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems I cannot unconditionally squeeze the tensor.
If I do, some tokenizers will output:
responses ====>>>> is Christmasiving
istead of
responses ====>>>> ['is', 'Christmas', 'iving']

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am:
using tokenizer = sentencepiece.SentencePieceProcessor

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I am adding an if there:

    # `token` is a tensor of shape (batch_size, 1).
    # For TiktokenTokenizer, we need to squeeze it to 1D.
    # For SentencePieceProcessor, we don't.
    if isinstance(tokenizer, TiktokenTokenizer):
        token = torch.squeeze(token, dim=1)
    token_str = tokenizer.decode(token.tolist())

Copy link
Contributor

@lessw2020 lessw2020 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for adding!
I find that there is a missing line for this PR to work but with that line, verified working on llama2 and llama3.
Stamping to land, please add in the squeeze line.

@kwen2501 kwen2501 changed the base branch from main to meta_init October 2, 2024 20:24
@kwen2501 kwen2501 changed the base branch from meta_init to main October 2, 2024 20:25
@kwen2501 kwen2501 merged commit 5952bd1 into main Oct 2, 2024
52 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants