Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Make spec. decode respect per-request seed. #6034

Merged
merged 16 commits into from
Jul 19, 2024

Conversation

tdoublep
Copy link
Member

@tdoublep tdoublep commented Jul 1, 2024

Fixes #6038

Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice :)

vllm/model_executor/layers/rejection_sampler.py Outdated Show resolved Hide resolved
Comment on lines 173 to 176
if any(generator is not None for generator in generators):
uniform_rand = torch.empty(batch_size, k, dtype=self.probs_dtype, device=target_probs.device)
for i, generator in enumerate(generators):
uniform_rand[i,:] = torch.rand(1, k, dtype=self.probs_dtype, device=target_probs.device, generator=generator)
Copy link
Member

@njhill njhill Jul 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might be preferable here in the mixed case to do a single rand for the non-seed requests, since it could be likely that most requests don't have a seed. This can be done via something like

uniform_rand[non_seed_indices] = torch.rand(len(non_seed_indices), k, ...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. Sort-of-related question: do you know whether we can safely assume that the length of each SequenceGroup is exactly 1 here? In the general sampler code there is some logic to handle the more general case. Maybe we need that here too?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will double check, I think it's only >1 for beam search which I think is explicitly unsupported with spec decode right now anyhow. Though it might also be the case for parallel decoding (n>1 i.e. multiple output seqs per single input seq).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually just tested that case: combining n>1 with spec. decode and it seems to fail on main:

    |   File "/home/zrltpa/vllm/vllm/model_executor/layers/sampler.py", line 96, in forward
    |     sample_results, maybe_sampled_tokens_tensor = _sample(
    |                                                   ^^^^^^^^
    |   File "/home/zrltpa/vllm/vllm/model_executor/layers/sampler.py", line 658, in _sample
    |     return _sample_with_torch(
    |            ^^^^^^^^^^^^^^^^^^^
    |   File "/home/zrltpa/vllm/vllm/model_executor/layers/sampler.py", line 528, in _sample_with_torch
    |     sampled_token_ids_tensor[
    | RuntimeError: shape mismatch: value tensor of shape [2] cannot be broadcast to indexing result of shape [1, 1]

If that case isn't supported by spec decode currently we should at least give a reasonable message back.

Copy link
Member Author

@tdoublep tdoublep Jul 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

made the suggested change regading uniform_rand - will create separate PR to address error msg for n>1

vllm/model_executor/layers/rejection_sampler.py Outdated Show resolved Hide resolved
vllm/model_executor/layers/rejection_sampler.py Outdated Show resolved Hide resolved
Comment on lines 297 to 305
generator = torch.Generator(
device=seq_group_metadata.state.generator.device
)
generator.set_state(
seq_group_metadata.state.generator.get_state()
)
state = SequenceGroupState(
generator=generator,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could the same Generator just be used here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My first attempt just shallow copied the generator but it led to incorrect results, so I think we need a "deep" copy here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's still not obvious to me why this doesn't work, we could look into it as a follow-on because it's not ideal to create a new generator every iteration.

@tdoublep tdoublep changed the title [WIP] Make spec. decode respect per-request seed. [Bugfix] Make spec. decode respect per-request seed. Jul 1, 2024
@tdoublep tdoublep marked this pull request as ready for review July 1, 2024 17:12
@cadedaniel
Copy link
Collaborator

Can we add some e2e tests?

@cadedaniel
Copy link
Collaborator

Ideally unit tests in rejection sampler too!

q = torch.empty_like(probs).exponential_(1.0)

q = torch.empty_like(probs)
if any(generator is not None for generator in generators):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we see the latency impact when there are multiple generators? Want to make sure rejection sampling is still reasonable

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Attached some analysis in comments below

@tdoublep
Copy link
Member Author

tdoublep commented Jul 3, 2024

@cadedaniel I've been digging into the effect on latency at the level of the rejection sampler.

Firstly, I wanted to check that if we never pass any generators, that the performance is similar to main. I call the forward method of rejection sampler 1000 times under different conditions using main branch and spec-decode-seed branch and compare the diference (final column). It looks ok imo.

k: 1, vocab_size: 30000, batch_size:   1, frac_seeded: 0.00, t_elap (main): 0.34 ms, t_elap (new): 0.34 ms, diff: 0.00%
k: 1, vocab_size: 30000, batch_size:   4, frac_seeded: 0.00, t_elap (main): 0.35 ms, t_elap (new): 0.36 ms, diff: 2.86%
k: 1, vocab_size: 30000, batch_size:   8, frac_seeded: 0.00, t_elap (main): 0.35 ms, t_elap (new): 0.36 ms, diff: 2.86%
k: 1, vocab_size: 30000, batch_size:  32, frac_seeded: 0.00, t_elap (main): 0.36 ms, t_elap (new): 0.36 ms, diff: 0.00%
k: 1, vocab_size: 30000, batch_size: 128, frac_seeded: 0.00, t_elap (main): 0.36 ms, t_elap (new): 0.38 ms, diff: 5.56%
k: 1, vocab_size: 50000, batch_size:   1, frac_seeded: 0.00, t_elap (main): 0.34 ms, t_elap (new): 0.34 ms, diff: 0.00%
k: 1, vocab_size: 50000, batch_size:   4, frac_seeded: 0.00, t_elap (main): 0.36 ms, t_elap (new): 0.36 ms, diff: 0.00%
k: 1, vocab_size: 50000, batch_size:   8, frac_seeded: 0.00, t_elap (main): 0.35 ms, t_elap (new): 0.36 ms, diff: 2.86%
k: 1, vocab_size: 50000, batch_size:  32, frac_seeded: 0.00, t_elap (main): 0.35 ms, t_elap (new): 0.37 ms, diff: 5.71%
k: 1, vocab_size: 50000, batch_size: 128, frac_seeded: 0.00, t_elap (main): 0.36 ms, t_elap (new): 0.38 ms, diff: 5.56%
k: 3, vocab_size: 30000, batch_size:   1, frac_seeded: 0.00, t_elap (main): 0.35 ms, t_elap (new): 0.35 ms, diff: 0.00%
k: 3, vocab_size: 30000, batch_size:   4, frac_seeded: 0.00, t_elap (main): 0.37 ms, t_elap (new): 0.37 ms, diff: 0.00%
k: 3, vocab_size: 30000, batch_size:   8, frac_seeded: 0.00, t_elap (main): 0.37 ms, t_elap (new): 0.37 ms, diff: 0.00%
k: 3, vocab_size: 30000, batch_size:  32, frac_seeded: 0.00, t_elap (main): 0.37 ms, t_elap (new): 0.38 ms, diff: 2.70%
k: 3, vocab_size: 30000, batch_size: 128, frac_seeded: 0.00, t_elap (main): 0.37 ms, t_elap (new): 0.40 ms, diff: 8.11%
k: 3, vocab_size: 50000, batch_size:   1, frac_seeded: 0.00, t_elap (main): 0.35 ms, t_elap (new): 0.35 ms, diff: 0.00%
k: 3, vocab_size: 50000, batch_size:   4, frac_seeded: 0.00, t_elap (main): 0.36 ms, t_elap (new): 0.37 ms, diff: 2.78%
k: 3, vocab_size: 50000, batch_size:   8, frac_seeded: 0.00, t_elap (main): 0.36 ms, t_elap (new): 0.37 ms, diff: 2.78%
k: 3, vocab_size: 50000, batch_size:  32, frac_seeded: 0.00, t_elap (main): 0.37 ms, t_elap (new): 0.38 ms, diff: 2.70%
k: 3, vocab_size: 50000, batch_size: 128, frac_seeded: 0.00, t_elap (main): 0.37 ms, t_elap (new): 0.40 ms, diff: 8.11%
k: 5, vocab_size: 30000, batch_size:   1, frac_seeded: 0.00, t_elap (main): 0.35 ms, t_elap (new): 0.35 ms, diff: 0.00%
k: 5, vocab_size: 30000, batch_size:   4, frac_seeded: 0.00, t_elap (main): 0.36 ms, t_elap (new): 0.37 ms, diff: 2.78%
k: 5, vocab_size: 30000, batch_size:   8, frac_seeded: 0.00, t_elap (main): 0.37 ms, t_elap (new): 0.37 ms, diff: 0.00%
k: 5, vocab_size: 30000, batch_size:  32, frac_seeded: 0.00, t_elap (main): 0.37 ms, t_elap (new): 0.38 ms, diff: 2.70%
k: 5, vocab_size: 30000, batch_size: 128, frac_seeded: 0.00, t_elap (main): 0.37 ms, t_elap (new): 0.40 ms, diff: 8.11%
k: 5, vocab_size: 50000, batch_size:   1, frac_seeded: 0.00, t_elap (main): 0.35 ms, t_elap (new): 0.35 ms, diff: 0.00%
k: 5, vocab_size: 50000, batch_size:   4, frac_seeded: 0.00, t_elap (main): 0.37 ms, t_elap (new): 0.37 ms, diff: 0.00%
k: 5, vocab_size: 50000, batch_size:   8, frac_seeded: 0.00, t_elap (main): 0.37 ms, t_elap (new): 0.37 ms, diff: 0.00%
k: 5, vocab_size: 50000, batch_size:  32, frac_seeded: 0.00, t_elap (main): 0.37 ms, t_elap (new): 0.38 ms, diff: 2.70%
k: 5, vocab_size: 50000, batch_size: 128, frac_seeded: 0.00, t_elap (main): 0.38 ms, t_elap (new): 0.40 ms, diff: 5.26%

Next, I wanted to investigate how increasing the fraction of requests in the batch that are seeded affects the performance (frac_seeded = fraction of generators that are not None). These results are a bit more disappointing:

k: 1, vocab_size: 30000, batch_size:   1, frac_seeded: 0.00, t_elap:   0.34 ms
k: 1, vocab_size: 30000, batch_size:   1, frac_seeded: 0.10, t_elap:   0.34 ms, diff:    -0.22%
k: 1, vocab_size: 30000, batch_size:   1, frac_seeded: 0.20, t_elap:   0.34 ms, diff:    -0.89%
k: 1, vocab_size: 30000, batch_size:   1, frac_seeded: 0.50, t_elap:   0.38 ms, diff:   +10.64%
k: 1, vocab_size: 30000, batch_size:   1, frac_seeded: 1.00, t_elap:   0.38 ms, diff:   +11.24%
k: 1, vocab_size: 30000, batch_size:   4, frac_seeded: 0.00, t_elap:   0.36 ms
k: 1, vocab_size: 30000, batch_size:   4, frac_seeded: 0.10, t_elap:   0.36 ms, diff:    +0.32%
k: 1, vocab_size: 30000, batch_size:   4, frac_seeded: 0.20, t_elap:   0.36 ms, diff:    +0.23%
k: 1, vocab_size: 30000, batch_size:   4, frac_seeded: 0.50, t_elap:   0.43 ms, diff:   +20.67%
k: 1, vocab_size: 30000, batch_size:   4, frac_seeded: 1.00, t_elap:   0.47 ms, diff:   +31.69%
k: 1, vocab_size: 30000, batch_size:   8, frac_seeded: 0.00, t_elap:   0.36 ms
k: 1, vocab_size: 30000, batch_size:   8, frac_seeded: 0.10, t_elap:   0.43 ms, diff:   +19.74%
k: 1, vocab_size: 30000, batch_size:   8, frac_seeded: 0.20, t_elap:   0.47 ms, diff:   +31.40%
k: 1, vocab_size: 30000, batch_size:   8, frac_seeded: 0.50, t_elap:   0.53 ms, diff:   +46.26%
k: 1, vocab_size: 30000, batch_size:   8, frac_seeded: 1.00, t_elap:   0.54 ms, diff:   +50.61%
k: 1, vocab_size: 30000, batch_size:  32, frac_seeded: 0.00, t_elap:   0.37 ms
k: 1, vocab_size: 30000, batch_size:  32, frac_seeded: 0.10, t_elap:   0.46 ms, diff:   +27.00%
k: 1, vocab_size: 30000, batch_size:  32, frac_seeded: 0.20, t_elap:   0.61 ms, diff:   +68.38%
k: 1, vocab_size: 30000, batch_size:  32, frac_seeded: 0.50, t_elap:   0.71 ms, diff:   +93.97%
k: 1, vocab_size: 30000, batch_size:  32, frac_seeded: 1.00, t_elap:   0.99 ms, diff:  +171.91%
k: 1, vocab_size: 30000, batch_size: 128, frac_seeded: 0.00, t_elap:   0.38 ms
k: 1, vocab_size: 30000, batch_size: 128, frac_seeded: 0.10, t_elap:   0.76 ms, diff:  +101.38%
k: 1, vocab_size: 30000, batch_size: 128, frac_seeded: 0.20, t_elap:   0.95 ms, diff:  +150.41%
k: 1, vocab_size: 30000, batch_size: 128, frac_seeded: 0.50, t_elap:   1.66 ms, diff:  +338.59%
k: 1, vocab_size: 30000, batch_size: 128, frac_seeded: 1.00, t_elap:   3.09 ms, diff:  +718.98%
k: 1, vocab_size: 50000, batch_size:   1, frac_seeded: 0.00, t_elap:   0.38 ms
k: 1, vocab_size: 50000, batch_size:   1, frac_seeded: 0.10, t_elap:   0.38 ms, diff:    -0.13%
k: 1, vocab_size: 50000, batch_size:   1, frac_seeded: 0.20, t_elap:   0.38 ms, diff:    -0.16%
k: 1, vocab_size: 50000, batch_size:   1, frac_seeded: 0.50, t_elap:   0.42 ms, diff:   +11.83%
k: 1, vocab_size: 50000, batch_size:   1, frac_seeded: 1.00, t_elap:   0.42 ms, diff:   +11.69%
k: 1, vocab_size: 50000, batch_size:   4, frac_seeded: 0.00, t_elap:   0.40 ms
k: 1, vocab_size: 50000, batch_size:   4, frac_seeded: 0.10, t_elap:   0.40 ms, diff:    -0.17%
k: 1, vocab_size: 50000, batch_size:   4, frac_seeded: 0.20, t_elap:   0.47 ms, diff:   +17.71%
k: 1, vocab_size: 50000, batch_size:   4, frac_seeded: 0.50, t_elap:   0.51 ms, diff:   +29.08%
k: 1, vocab_size: 50000, batch_size:   4, frac_seeded: 1.00, t_elap:   0.51 ms, diff:   +28.90%
k: 1, vocab_size: 50000, batch_size:   8, frac_seeded: 0.00, t_elap:   0.39 ms
k: 1, vocab_size: 50000, batch_size:   8, frac_seeded: 0.10, t_elap:   0.47 ms, diff:   +19.12%
k: 1, vocab_size: 50000, batch_size:   8, frac_seeded: 0.20, t_elap:   0.40 ms, diff:    +0.67%
k: 1, vocab_size: 50000, batch_size:   8, frac_seeded: 0.50, t_elap:   0.54 ms, diff:   +35.82%
k: 1, vocab_size: 50000, batch_size:   8, frac_seeded: 1.00, t_elap:   0.60 ms, diff:   +51.98%
k: 1, vocab_size: 50000, batch_size:  32, frac_seeded: 0.00, t_elap:   0.40 ms
k: 1, vocab_size: 50000, batch_size:  32, frac_seeded: 0.10, t_elap:   0.51 ms, diff:   +27.04%
k: 1, vocab_size: 50000, batch_size:  32, frac_seeded: 0.20, t_elap:   0.60 ms, diff:   +47.76%
k: 1, vocab_size: 50000, batch_size:  32, frac_seeded: 0.50, t_elap:   0.83 ms, diff:  +104.82%
k: 1, vocab_size: 50000, batch_size:  32, frac_seeded: 1.00, t_elap:   1.11 ms, diff:  +175.25%
k: 1, vocab_size: 50000, batch_size: 128, frac_seeded: 0.00, t_elap:   0.42 ms
k: 1, vocab_size: 50000, batch_size: 128, frac_seeded: 0.10, t_elap:   0.83 ms, diff:   +97.62%
k: 1, vocab_size: 50000, batch_size: 128, frac_seeded: 0.20, t_elap:   0.97 ms, diff:  +130.45%
k: 1, vocab_size: 50000, batch_size: 128, frac_seeded: 0.50, t_elap:   1.88 ms, diff:  +347.14%
k: 1, vocab_size: 50000, batch_size: 128, frac_seeded: 1.00, t_elap:   3.09 ms, diff:  +635.32%
k: 3, vocab_size: 30000, batch_size:   1, frac_seeded: 0.00, t_elap:   0.39 ms
k: 3, vocab_size: 30000, batch_size:   1, frac_seeded: 0.10, t_elap:   0.39 ms, diff:    +0.17%
k: 3, vocab_size: 30000, batch_size:   1, frac_seeded: 0.20, t_elap:   0.39 ms, diff:    -0.01%
k: 3, vocab_size: 30000, batch_size:   1, frac_seeded: 0.50, t_elap:   0.45 ms, diff:   +13.64%
k: 3, vocab_size: 30000, batch_size:   1, frac_seeded: 1.00, t_elap:   0.45 ms, diff:   +13.29%
k: 3, vocab_size: 30000, batch_size:   4, frac_seeded: 0.00, t_elap:   0.41 ms
k: 3, vocab_size: 30000, batch_size:   4, frac_seeded: 0.10, t_elap:   0.41 ms, diff:    +0.03%
k: 3, vocab_size: 30000, batch_size:   4, frac_seeded: 0.20, t_elap:   0.49 ms, diff:   +20.87%
k: 3, vocab_size: 30000, batch_size:   4, frac_seeded: 0.50, t_elap:   0.52 ms, diff:   +28.37%
k: 3, vocab_size: 30000, batch_size:   4, frac_seeded: 1.00, t_elap:   0.56 ms, diff:   +37.89%
k: 3, vocab_size: 30000, batch_size:   8, frac_seeded: 0.00, t_elap:   0.41 ms
k: 3, vocab_size: 30000, batch_size:   8, frac_seeded: 0.10, t_elap:   0.41 ms, diff:    +0.15%
k: 3, vocab_size: 30000, batch_size:   8, frac_seeded: 0.20, t_elap:   0.53 ms, diff:   +28.73%
k: 3, vocab_size: 30000, batch_size:   8, frac_seeded: 0.50, t_elap:   0.58 ms, diff:   +42.66%
k: 3, vocab_size: 30000, batch_size:   8, frac_seeded: 1.00, t_elap:   0.68 ms, diff:   +66.24%
k: 3, vocab_size: 30000, batch_size:  32, frac_seeded: 0.00, t_elap:   0.41 ms
k: 3, vocab_size: 30000, batch_size:  32, frac_seeded: 0.10, t_elap:   0.55 ms, diff:   +32.40%
k: 3, vocab_size: 30000, batch_size:  32, frac_seeded: 0.20, t_elap:   0.67 ms, diff:   +61.56%
k: 3, vocab_size: 30000, batch_size:  32, frac_seeded: 0.50, t_elap:   0.99 ms, diff:  +138.39%
k: 3, vocab_size: 30000, batch_size:  32, frac_seeded: 1.00, t_elap:   1.40 ms, diff:  +237.93%
k: 3, vocab_size: 30000, batch_size: 128, frac_seeded: 0.00, t_elap:   0.44 ms
k: 3, vocab_size: 30000, batch_size: 128, frac_seeded: 0.10, t_elap:   0.79 ms, diff:   +82.31%
k: 3, vocab_size: 30000, batch_size: 128, frac_seeded: 0.20, t_elap:   1.21 ms, diff:  +177.38%
k: 3, vocab_size: 30000, batch_size: 128, frac_seeded: 0.50, t_elap:   2.66 ms, diff:  +510.67%
k: 3, vocab_size: 30000, batch_size: 128, frac_seeded: 1.00, t_elap:   4.21 ms, diff:  +866.02%
k: 3, vocab_size: 50000, batch_size:   1, frac_seeded: 0.00, t_elap:   0.39 ms
k: 3, vocab_size: 50000, batch_size:   1, frac_seeded: 0.10, t_elap:   0.39 ms, diff:    +0.07%
k: 3, vocab_size: 50000, batch_size:   1, frac_seeded: 0.20, t_elap:   0.39 ms, diff:    -0.03%
k: 3, vocab_size: 50000, batch_size:   1, frac_seeded: 0.50, t_elap:   0.45 ms, diff:   +14.26%
k: 3, vocab_size: 50000, batch_size:   1, frac_seeded: 1.00, t_elap:   0.45 ms, diff:   +14.16%
k: 3, vocab_size: 50000, batch_size:   4, frac_seeded: 0.00, t_elap:   0.41 ms
k: 3, vocab_size: 50000, batch_size:   4, frac_seeded: 0.10, t_elap:   0.41 ms, diff:    -0.04%
k: 3, vocab_size: 50000, batch_size:   4, frac_seeded: 0.20, t_elap:   0.41 ms, diff:    -0.04%
k: 3, vocab_size: 50000, batch_size:   4, frac_seeded: 0.50, t_elap:   0.52 ms, diff:   +27.43%
k: 3, vocab_size: 50000, batch_size:   4, frac_seeded: 1.00, t_elap:   0.56 ms, diff:   +37.24%
k: 3, vocab_size: 50000, batch_size:   8, frac_seeded: 0.00, t_elap:   0.41 ms
k: 3, vocab_size: 50000, batch_size:   8, frac_seeded: 0.10, t_elap:   0.41 ms, diff:    -0.15%
k: 3, vocab_size: 50000, batch_size:   8, frac_seeded: 0.20, t_elap:   0.52 ms, diff:   +28.21%
k: 3, vocab_size: 50000, batch_size:   8, frac_seeded: 0.50, t_elap:   0.58 ms, diff:   +42.51%
k: 3, vocab_size: 50000, batch_size:   8, frac_seeded: 1.00, t_elap:   0.68 ms, diff:   +66.48%
k: 3, vocab_size: 50000, batch_size:  32, frac_seeded: 0.00, t_elap:   0.42 ms
k: 3, vocab_size: 50000, batch_size:  32, frac_seeded: 0.10, t_elap:   0.63 ms, diff:   +52.15%
k: 3, vocab_size: 50000, batch_size:  32, frac_seeded: 0.20, t_elap:   0.60 ms, diff:   +44.58%
k: 3, vocab_size: 50000, batch_size:  32, frac_seeded: 0.50, t_elap:   0.95 ms, diff:  +128.86%
k: 3, vocab_size: 50000, batch_size:  32, frac_seeded: 1.00, t_elap:   1.39 ms, diff:  +235.73%
k: 3, vocab_size: 50000, batch_size: 128, frac_seeded: 0.00, t_elap:   0.43 ms
k: 3, vocab_size: 50000, batch_size: 128, frac_seeded: 0.10, t_elap:   0.85 ms, diff:   +97.39%
k: 3, vocab_size: 50000, batch_size: 128, frac_seeded: 0.20, t_elap:   1.26 ms, diff:  +190.45%
k: 3, vocab_size: 50000, batch_size: 128, frac_seeded: 0.50, t_elap:   2.33 ms, diff:  +437.52%
k: 3, vocab_size: 50000, batch_size: 128, frac_seeded: 1.00, t_elap:   4.21 ms, diff:  +871.77%
k: 5, vocab_size: 30000, batch_size:   1, frac_seeded: 0.00, t_elap:   0.39 ms
k: 5, vocab_size: 30000, batch_size:   1, frac_seeded: 0.10, t_elap:   0.39 ms, diff:    +0.16%
k: 5, vocab_size: 30000, batch_size:   1, frac_seeded: 0.20, t_elap:   0.46 ms, diff:   +15.72%
k: 5, vocab_size: 30000, batch_size:   1, frac_seeded: 0.50, t_elap:   0.46 ms, diff:   +15.84%
k: 5, vocab_size: 30000, batch_size:   1, frac_seeded: 1.00, t_elap:   0.46 ms, diff:   +15.70%
k: 5, vocab_size: 30000, batch_size:   4, frac_seeded: 0.00, t_elap:   0.41 ms
k: 5, vocab_size: 30000, batch_size:   4, frac_seeded: 0.10, t_elap:   0.41 ms, diff:    +0.08%
k: 5, vocab_size: 30000, batch_size:   4, frac_seeded: 0.20, t_elap:   0.41 ms, diff:    +0.10%
k: 5, vocab_size: 30000, batch_size:   4, frac_seeded: 0.50, t_elap:   0.59 ms, diff:   +43.95%
k: 5, vocab_size: 30000, batch_size:   4, frac_seeded: 1.00, t_elap:   0.61 ms, diff:   +48.61%
k: 5, vocab_size: 30000, batch_size:   8, frac_seeded: 0.00, t_elap:   0.41 ms
k: 5, vocab_size: 30000, batch_size:   8, frac_seeded: 0.10, t_elap:   0.41 ms, diff:    +0.08%
k: 5, vocab_size: 30000, batch_size:   8, frac_seeded: 0.20, t_elap:   0.51 ms, diff:   +24.68%
k: 5, vocab_size: 30000, batch_size:   8, frac_seeded: 0.50, t_elap:   0.67 ms, diff:   +62.47%
k: 5, vocab_size: 30000, batch_size:   8, frac_seeded: 1.00, t_elap:   0.76 ms, diff:   +84.95%
k: 5, vocab_size: 30000, batch_size:  32, frac_seeded: 0.00, t_elap:   0.42 ms
k: 5, vocab_size: 30000, batch_size:  32, frac_seeded: 0.10, t_elap:   0.72 ms, diff:   +72.41%
k: 5, vocab_size: 30000, batch_size:  32, frac_seeded: 0.20, t_elap:   0.68 ms, diff:   +62.23%
k: 5, vocab_size: 30000, batch_size:  32, frac_seeded: 0.50, t_elap:   1.18 ms, diff:  +180.81%
k: 5, vocab_size: 30000, batch_size:  32, frac_seeded: 1.00, t_elap:   1.68 ms, diff:  +300.72%
k: 5, vocab_size: 30000, batch_size: 128, frac_seeded: 0.00, t_elap:   0.44 ms
k: 5, vocab_size: 30000, batch_size: 128, frac_seeded: 0.10, t_elap:   1.15 ms, diff:  +161.95%
k: 5, vocab_size: 30000, batch_size: 128, frac_seeded: 0.20, t_elap:   1.46 ms, diff:  +230.81%
k: 5, vocab_size: 30000, batch_size: 128, frac_seeded: 0.50, t_elap:   3.44 ms, diff:  +680.86%
k: 5, vocab_size: 30000, batch_size: 128, frac_seeded: 1.00, t_elap:   5.30 ms, diff: +1104.57%
k: 5, vocab_size: 50000, batch_size:   1, frac_seeded: 0.00, t_elap:   0.39 ms
k: 5, vocab_size: 50000, batch_size:   1, frac_seeded: 0.10, t_elap:   0.39 ms, diff:    -0.10%
k: 5, vocab_size: 50000, batch_size:   1, frac_seeded: 0.20, t_elap:   0.39 ms, diff:    +0.07%
k: 5, vocab_size: 50000, batch_size:   1, frac_seeded: 0.50, t_elap:   0.39 ms, diff:    -0.18%
k: 5, vocab_size: 50000, batch_size:   1, frac_seeded: 1.00, t_elap:   0.46 ms, diff:   +15.94%
k: 5, vocab_size: 50000, batch_size:   4, frac_seeded: 0.00, t_elap:   0.41 ms
k: 5, vocab_size: 50000, batch_size:   4, frac_seeded: 0.10, t_elap:   0.54 ms, diff:   +32.45%
k: 5, vocab_size: 50000, batch_size:   4, frac_seeded: 0.20, t_elap:   0.41 ms, diff:    -0.24%
k: 5, vocab_size: 50000, batch_size:   4, frac_seeded: 0.50, t_elap:   0.58 ms, diff:   +42.51%
k: 5, vocab_size: 50000, batch_size:   4, frac_seeded: 1.00, t_elap:   0.60 ms, diff:   +46.50%
k: 5, vocab_size: 50000, batch_size:   8, frac_seeded: 0.00, t_elap:   0.41 ms
k: 5, vocab_size: 50000, batch_size:   8, frac_seeded: 0.10, t_elap:   0.41 ms, diff:    +0.41%
k: 5, vocab_size: 50000, batch_size:   8, frac_seeded: 0.20, t_elap:   0.41 ms, diff:    +0.15%
k: 5, vocab_size: 50000, batch_size:   8, frac_seeded: 0.50, t_elap:   0.62 ms, diff:   +52.06%
k: 5, vocab_size: 50000, batch_size:   8, frac_seeded: 1.00, t_elap:   0.75 ms, diff:   +84.19%
k: 5, vocab_size: 50000, batch_size:  32, frac_seeded: 0.00, t_elap:   0.42 ms
k: 5, vocab_size: 50000, batch_size:  32, frac_seeded: 0.10, t_elap:   0.53 ms, diff:   +25.95%
k: 5, vocab_size: 50000, batch_size:  32, frac_seeded: 0.20, t_elap:   0.72 ms, diff:   +71.41%
k: 5, vocab_size: 50000, batch_size:  32, frac_seeded: 0.50, t_elap:   1.02 ms, diff:  +142.86%
k: 5, vocab_size: 50000, batch_size:  32, frac_seeded: 1.00, t_elap:   1.68 ms, diff:  +300.73%
k: 5, vocab_size: 50000, batch_size: 128, frac_seeded: 0.00, t_elap:   0.45 ms
k: 5, vocab_size: 50000, batch_size: 128, frac_seeded: 0.10, t_elap:   1.12 ms, diff:  +151.11%
k: 5, vocab_size: 50000, batch_size: 128, frac_seeded: 0.20, t_elap:   1.64 ms, diff:  +265.30%
k: 5, vocab_size: 50000, batch_size: 128, frac_seeded: 0.50, t_elap:   3.03 ms, diff:  +575.36%
k: 5, vocab_size: 50000, batch_size: 128, frac_seeded: 1.00, t_elap:   5.32 ms, diff: +1087.33%

In particular, when the batch size is large, there can be up to a few ms overhead added. I guess this may not be so significant in E2E measurements (I will do some now). I did some investigation into where the overhead is coming from and it is really due to the random number generation part rather than any auxiliary stuff. It seems like if we want to make it reproducible there will be some ms-level performance penalty for large batch sizes.

@tdoublep
Copy link
Member Author

tdoublep commented Jul 3, 2024

BTW I added both E2E tests, as well as unit tests for rejection sampler

@njhill
Copy link
Member

njhill commented Jul 3, 2024

It is probably also worth mentioning though that we would probably see similar when increasing percentage of requests that are random-seeded in a large batch even in the non-spec decoding case, since there's also a loop for that.

@tdoublep
Copy link
Member Author

tdoublep commented Jul 3, 2024

@njhill yes, that's a good point

@tdoublep
Copy link
Member Author

@njhill @cadedaniel anything else you would like to see addressed here?

@cadedaniel
Copy link
Collaborator

Can we find a more optimized way to do this? Basically this defeats the purpose of performance benchmarking w/seeds unless the overhead is small.

@tdoublep
Copy link
Member Author

The overheads come from having to iterate through the torch.Generators over large batches. As long as all requests in the batch can have different seeds, I don't see how this can be avoided. We have the same overheads when using seeds via the non-speculative path.

I agree that if we want optimal performance numbers, we may want to not set seeds. However, it is pretty important on our side to be able to get reproducible output when running our integration tests etc. @njhill can probably comment more on that.

@cadedaniel
Copy link
Collaborator

Antoni was able to get pretty low overhead per-seed sampling. any thoughts @Yard1 ?

@cadedaniel
Copy link
Collaborator

and I think this is admissible if it's for correctness testing. but if it's for users or for performance testing then we need to understand tradeoff of optimizing it vs. simply not allowing it, so quality bar is high

@Yard1
Copy link
Collaborator

Yard1 commented Jul 15, 2024

@cadedaniel it will require a custom sampling kernel, and invoking through triton jit had a massive overhead. That being said this PR has something interesting: 391d761, maybe it will work

(though I see the author removed it later)

@Yard1
Copy link
Collaborator

Yard1 commented Jul 15, 2024

as an alternative, we could simplify this problem to batch generation of uniform random tensors with different seeds per row. This is definitely easy to write in triton (but again, overheads) - but perhaps writing it in pure CUDA would be just as simple.

@njhill
Copy link
Member

njhill commented Jul 15, 2024

but if it's for users or for performance testing then we need to understand tradeoff of optimizing it vs. simply not allowing it, so quality bar is high

I don't see why we would not allow seed even if its use incurs some performance overhead. It's no different to how seed is already implemented for non spec decode case, there won't be any "extra" overhead from a seeded request than there is now. Each generated token for each seeded request requires using the generator to perform a separate exponential sample whether or not that was a speculated token.

In practice I don't think most users care about or use the seed, so in most contexts the overall overhead would not be significant. But it means that those who want more repeatable/stable outputs can still make use of it. Service providers are still free to block this parameter if they want (or we could add an env var for that - so that we optionally fail requests with seed set). But this would be something done irrespective of use of spec decoding.

Adding this to the spec decoding path is just filling a current functional gap/inconsistency imo, and whether/how to use it in performance benchmarking is a separate question.

@cadedaniel
Copy link
Collaborator

Alright this doesn't meet my quality bar but I won't block y'all since it's off by default. LGTM.

@njhill
Copy link
Member

njhill commented Jul 16, 2024

Alright this doesn't meet my quality bar but I won't block y'all since it's off by default. LGTM.

@cadedaniel do you mean quality in terms of the specific code changes in this PR, or just the fact that overhead exists from requests that include a seed? The latter may be a valid concern but is a preexisting one and I think orthogonal to spec/non-spec decoding?

vllm/model_executor/layers/rejection_sampler.py Outdated Show resolved Hide resolved
vllm/spec_decode/batch_expansion.py Outdated Show resolved Hide resolved
vllm/model_executor/layers/rejection_sampler.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@cadedaniel cadedaniel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cadedaniel do you mean quality in terms of the specific code changes in this PR, or just the fact that overhead exists from requests that include a seed? The latter may be a valid concern but is a preexisting one and I think orthogonal to spec/non-spec decoding?

I mean that users may expect vLLM to give low latency for the features spec decode supports, but per-user seed as is in this PR will not give them that. I am OK to merge if we add a comment explaining the state, or even create an issue asking for someone to optimize it, e.g. as Antoni suggests.

That said thanks for adding what's here! Will be a good start to optimizing later.

@njhill njhill added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 18, 2024
Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @tdoublep!

@njhill njhill merged commit d4201e0 into vllm-project:main Jul 19, 2024
72 checks passed
@Yard1
Copy link
Collaborator

Yard1 commented Jul 22, 2024

Looks like Triton 3.0.0 will reduce kernel launch overhead, meaning that we can move towards triton kernels for batched seeded random generation! triton-lang/triton#3503

We also already have this kernel in vLLM code (https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/ops/rand.py), so it would be just a matter of using it to generate the random uniform tensor to replace the exponential call in multinomial sampling (will need turning to exponential noise but that's a simple math operation that can be put in the kernel itself), or just use the fused random sampling kernel in https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/ops/sample.py directly (though it doesn't really support models with very large vocab sizes that well).

@tdoublep tdoublep deleted the spec-decode-seed branch July 22, 2024 19:48
xjpang pushed a commit to xjpang/vllm that referenced this pull request Jul 24, 2024
gnpinkert pushed a commit to gnpinkert/vllm that referenced this pull request Jul 26, 2024
cduk pushed a commit to cduk/vllm-pascal that referenced this pull request Aug 6, 2024
Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
Signed-off-by: Thomas Parnell <[email protected]>
Co-authored-by: Nick Hill <[email protected]>
Signed-off-by: Alvant <[email protected]>
KuntaiDu pushed a commit to KuntaiDu/vllm that referenced this pull request Nov 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Bug]: Speculative decoding does not respect per-request seed
4 participants