Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Distributed] Fix new token's shape #1254

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

[Distributed] Fix new token's shape #1254

wants to merge 1 commit into from

Conversation

kwen2501
Copy link
Contributor

@kwen2501 kwen2501 commented Oct 2, 2024

Issue

TP-only case is broken due to the following error:

[rank1]:   File "/home/kw2501/local/torchchat/torchchat/model.py", line 815, in forward
[rank1]:     bsz, seqlen, _ = x.shape
[rank1]: ValueError: not enough values to unpack (expected 3, got 2)
[rank3]:   File "/home/kw2501/local/torchchat/dist_run.py", line 477, in main
[rank3]:     output = decorder.step(new_token, **kwargs)
[rank3]:   File "/home/kw2501/local/pytorch/torch/distributed/pipelining/schedules.py", line 610, in step
[rank3]:     self._step_microbatches(args_split, kwargs_split, targets_split, losses)
[rank3]:   File "/home/kw2501/local/pytorch/torch/distributed/pipelining/schedules.py", line 710, in _step_microbatches
[rank3]:     output = self._stage.forward_one_chunk(i, arg_mbs[i], kwarg_mbs[i])  # type: ignore[index]
[rank3]:   File "/home/kw2501/local/pytorch/torch/distributed/pipelining/stage.py", line 595, in forward_one_chunk
[rank3]:     raise RuntimeError(exc_msg) from e
[rank3]: RuntimeError: 
[rank3]:             [Stage 0] failed to run forward:
[rank3]:             args: ('Tensor(torch.Size([4]), grad=False, dtype=torch.int64)',)
[rank3]:             kwargs: {'input_pos': 'Tensor(torch.Size([1]), grad=False, dtype=torch.int64)', 'cache_lane': '0'}

It suggests that in the decoding phase, our input_ids (i.e. new_tokens) is flattened rather than being 2D (batch_size, 1).

The flattening happens here:

    # Argmax (deterministic) TODO: add temperature
    next_token = torch.argmax(next_token_logits, dim=-1)

Fix

The fix is simple, we just add a keepdim=True flag to torch.argmax.
With that, the unsqueeze op in decode_in_flight can be also saved.

Copy link

pytorch-bot bot commented Oct 2, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchchat/1254

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 08b3f09 with merge base edaa15c (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 2, 2024
# Make a 2D tensor with ids on row dimension
unsqueezed = torch.unsqueeze(token, 1)
token_str = tokenizer.decode(unsqueezed.tolist())
token_str = tokenizer.decode(token.tolist())
Copy link
Contributor

@lessw2020 lessw2020 Oct 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find that this does not work as-is.
However, adding in a one liner before the tokenizer.decode line (421):

token = token.squeeze(1)

And now it works on both llama2 and llama3.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, makes sense. Tokenizer difference :)

Copy link
Contributor

@lessw2020 lessw2020 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for adding!
I find that there is a missing line for this PR to work but with that line, verified working on llama2 and llama3.
Stamping to land, please add in the squeeze line.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants