-
-
Notifications
You must be signed in to change notification settings - Fork 5.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bugfix] Make spec. decode respect per-request seed. #6034
Conversation
Signed-off-by: Thomas Parnell <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice :)
if any(generator is not None for generator in generators): | ||
uniform_rand = torch.empty(batch_size, k, dtype=self.probs_dtype, device=target_probs.device) | ||
for i, generator in enumerate(generators): | ||
uniform_rand[i,:] = torch.rand(1, k, dtype=self.probs_dtype, device=target_probs.device, generator=generator) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it might be preferable here in the mixed case to do a single rand for the non-seed requests, since it could be likely that most requests don't have a seed. This can be done via something like
uniform_rand[non_seed_indices] = torch.rand(len(non_seed_indices), k, ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. Sort-of-related question: do you know whether we can safely assume that the length of each SequenceGroup is exactly 1 here? In the general sampler code there is some logic to handle the more general case. Maybe we need that here too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will double check, I think it's only >1 for beam search which I think is explicitly unsupported with spec decode right now anyhow. Though it might also be the case for parallel decoding (n
>1 i.e. multiple output seqs per single input seq).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually just tested that case: combining n>1
with spec. decode and it seems to fail on main:
| File "/home/zrltpa/vllm/vllm/model_executor/layers/sampler.py", line 96, in forward
| sample_results, maybe_sampled_tokens_tensor = _sample(
| ^^^^^^^^
| File "/home/zrltpa/vllm/vllm/model_executor/layers/sampler.py", line 658, in _sample
| return _sample_with_torch(
| ^^^^^^^^^^^^^^^^^^^
| File "/home/zrltpa/vllm/vllm/model_executor/layers/sampler.py", line 528, in _sample_with_torch
| sampled_token_ids_tensor[
| RuntimeError: shape mismatch: value tensor of shape [2] cannot be broadcast to indexing result of shape [1, 1]
If that case isn't supported by spec decode currently we should at least give a reasonable message back.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
made the suggested change regading uniform_rand
- will create separate PR to address error msg for n>1
vllm/spec_decode/batch_expansion.py
Outdated
generator = torch.Generator( | ||
device=seq_group_metadata.state.generator.device | ||
) | ||
generator.set_state( | ||
seq_group_metadata.state.generator.get_state() | ||
) | ||
state = SequenceGroupState( | ||
generator=generator, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could the same Generator just be used here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My first attempt just shallow copied the generator but it led to incorrect results, so I think we need a "deep" copy here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's still not obvious to me why this doesn't work, we could look into it as a follow-on because it's not ideal to create a new generator every iteration.
Can we add some e2e tests? |
Ideally unit tests in rejection sampler too! |
q = torch.empty_like(probs).exponential_(1.0) | ||
|
||
q = torch.empty_like(probs) | ||
if any(generator is not None for generator in generators): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we see the latency impact when there are multiple generators? Want to make sure rejection sampling is still reasonable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Attached some analysis in comments below
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
@cadedaniel I've been digging into the effect on latency at the level of the rejection sampler. Firstly, I wanted to check that if we never pass any generators, that the performance is similar to main. I call the forward method of rejection sampler 1000 times under different conditions using k: 1, vocab_size: 30000, batch_size: 1, frac_seeded: 0.00, t_elap (main): 0.34 ms, t_elap (new): 0.34 ms, diff: 0.00%
k: 1, vocab_size: 30000, batch_size: 4, frac_seeded: 0.00, t_elap (main): 0.35 ms, t_elap (new): 0.36 ms, diff: 2.86%
k: 1, vocab_size: 30000, batch_size: 8, frac_seeded: 0.00, t_elap (main): 0.35 ms, t_elap (new): 0.36 ms, diff: 2.86%
k: 1, vocab_size: 30000, batch_size: 32, frac_seeded: 0.00, t_elap (main): 0.36 ms, t_elap (new): 0.36 ms, diff: 0.00%
k: 1, vocab_size: 30000, batch_size: 128, frac_seeded: 0.00, t_elap (main): 0.36 ms, t_elap (new): 0.38 ms, diff: 5.56%
k: 1, vocab_size: 50000, batch_size: 1, frac_seeded: 0.00, t_elap (main): 0.34 ms, t_elap (new): 0.34 ms, diff: 0.00%
k: 1, vocab_size: 50000, batch_size: 4, frac_seeded: 0.00, t_elap (main): 0.36 ms, t_elap (new): 0.36 ms, diff: 0.00%
k: 1, vocab_size: 50000, batch_size: 8, frac_seeded: 0.00, t_elap (main): 0.35 ms, t_elap (new): 0.36 ms, diff: 2.86%
k: 1, vocab_size: 50000, batch_size: 32, frac_seeded: 0.00, t_elap (main): 0.35 ms, t_elap (new): 0.37 ms, diff: 5.71%
k: 1, vocab_size: 50000, batch_size: 128, frac_seeded: 0.00, t_elap (main): 0.36 ms, t_elap (new): 0.38 ms, diff: 5.56%
k: 3, vocab_size: 30000, batch_size: 1, frac_seeded: 0.00, t_elap (main): 0.35 ms, t_elap (new): 0.35 ms, diff: 0.00%
k: 3, vocab_size: 30000, batch_size: 4, frac_seeded: 0.00, t_elap (main): 0.37 ms, t_elap (new): 0.37 ms, diff: 0.00%
k: 3, vocab_size: 30000, batch_size: 8, frac_seeded: 0.00, t_elap (main): 0.37 ms, t_elap (new): 0.37 ms, diff: 0.00%
k: 3, vocab_size: 30000, batch_size: 32, frac_seeded: 0.00, t_elap (main): 0.37 ms, t_elap (new): 0.38 ms, diff: 2.70%
k: 3, vocab_size: 30000, batch_size: 128, frac_seeded: 0.00, t_elap (main): 0.37 ms, t_elap (new): 0.40 ms, diff: 8.11%
k: 3, vocab_size: 50000, batch_size: 1, frac_seeded: 0.00, t_elap (main): 0.35 ms, t_elap (new): 0.35 ms, diff: 0.00%
k: 3, vocab_size: 50000, batch_size: 4, frac_seeded: 0.00, t_elap (main): 0.36 ms, t_elap (new): 0.37 ms, diff: 2.78%
k: 3, vocab_size: 50000, batch_size: 8, frac_seeded: 0.00, t_elap (main): 0.36 ms, t_elap (new): 0.37 ms, diff: 2.78%
k: 3, vocab_size: 50000, batch_size: 32, frac_seeded: 0.00, t_elap (main): 0.37 ms, t_elap (new): 0.38 ms, diff: 2.70%
k: 3, vocab_size: 50000, batch_size: 128, frac_seeded: 0.00, t_elap (main): 0.37 ms, t_elap (new): 0.40 ms, diff: 8.11%
k: 5, vocab_size: 30000, batch_size: 1, frac_seeded: 0.00, t_elap (main): 0.35 ms, t_elap (new): 0.35 ms, diff: 0.00%
k: 5, vocab_size: 30000, batch_size: 4, frac_seeded: 0.00, t_elap (main): 0.36 ms, t_elap (new): 0.37 ms, diff: 2.78%
k: 5, vocab_size: 30000, batch_size: 8, frac_seeded: 0.00, t_elap (main): 0.37 ms, t_elap (new): 0.37 ms, diff: 0.00%
k: 5, vocab_size: 30000, batch_size: 32, frac_seeded: 0.00, t_elap (main): 0.37 ms, t_elap (new): 0.38 ms, diff: 2.70%
k: 5, vocab_size: 30000, batch_size: 128, frac_seeded: 0.00, t_elap (main): 0.37 ms, t_elap (new): 0.40 ms, diff: 8.11%
k: 5, vocab_size: 50000, batch_size: 1, frac_seeded: 0.00, t_elap (main): 0.35 ms, t_elap (new): 0.35 ms, diff: 0.00%
k: 5, vocab_size: 50000, batch_size: 4, frac_seeded: 0.00, t_elap (main): 0.37 ms, t_elap (new): 0.37 ms, diff: 0.00%
k: 5, vocab_size: 50000, batch_size: 8, frac_seeded: 0.00, t_elap (main): 0.37 ms, t_elap (new): 0.37 ms, diff: 0.00%
k: 5, vocab_size: 50000, batch_size: 32, frac_seeded: 0.00, t_elap (main): 0.37 ms, t_elap (new): 0.38 ms, diff: 2.70%
k: 5, vocab_size: 50000, batch_size: 128, frac_seeded: 0.00, t_elap (main): 0.38 ms, t_elap (new): 0.40 ms, diff: 5.26% Next, I wanted to investigate how increasing the fraction of requests in the batch that are seeded affects the performance ( k: 1, vocab_size: 30000, batch_size: 1, frac_seeded: 0.00, t_elap: 0.34 ms
k: 1, vocab_size: 30000, batch_size: 1, frac_seeded: 0.10, t_elap: 0.34 ms, diff: -0.22%
k: 1, vocab_size: 30000, batch_size: 1, frac_seeded: 0.20, t_elap: 0.34 ms, diff: -0.89%
k: 1, vocab_size: 30000, batch_size: 1, frac_seeded: 0.50, t_elap: 0.38 ms, diff: +10.64%
k: 1, vocab_size: 30000, batch_size: 1, frac_seeded: 1.00, t_elap: 0.38 ms, diff: +11.24%
k: 1, vocab_size: 30000, batch_size: 4, frac_seeded: 0.00, t_elap: 0.36 ms
k: 1, vocab_size: 30000, batch_size: 4, frac_seeded: 0.10, t_elap: 0.36 ms, diff: +0.32%
k: 1, vocab_size: 30000, batch_size: 4, frac_seeded: 0.20, t_elap: 0.36 ms, diff: +0.23%
k: 1, vocab_size: 30000, batch_size: 4, frac_seeded: 0.50, t_elap: 0.43 ms, diff: +20.67%
k: 1, vocab_size: 30000, batch_size: 4, frac_seeded: 1.00, t_elap: 0.47 ms, diff: +31.69%
k: 1, vocab_size: 30000, batch_size: 8, frac_seeded: 0.00, t_elap: 0.36 ms
k: 1, vocab_size: 30000, batch_size: 8, frac_seeded: 0.10, t_elap: 0.43 ms, diff: +19.74%
k: 1, vocab_size: 30000, batch_size: 8, frac_seeded: 0.20, t_elap: 0.47 ms, diff: +31.40%
k: 1, vocab_size: 30000, batch_size: 8, frac_seeded: 0.50, t_elap: 0.53 ms, diff: +46.26%
k: 1, vocab_size: 30000, batch_size: 8, frac_seeded: 1.00, t_elap: 0.54 ms, diff: +50.61%
k: 1, vocab_size: 30000, batch_size: 32, frac_seeded: 0.00, t_elap: 0.37 ms
k: 1, vocab_size: 30000, batch_size: 32, frac_seeded: 0.10, t_elap: 0.46 ms, diff: +27.00%
k: 1, vocab_size: 30000, batch_size: 32, frac_seeded: 0.20, t_elap: 0.61 ms, diff: +68.38%
k: 1, vocab_size: 30000, batch_size: 32, frac_seeded: 0.50, t_elap: 0.71 ms, diff: +93.97%
k: 1, vocab_size: 30000, batch_size: 32, frac_seeded: 1.00, t_elap: 0.99 ms, diff: +171.91%
k: 1, vocab_size: 30000, batch_size: 128, frac_seeded: 0.00, t_elap: 0.38 ms
k: 1, vocab_size: 30000, batch_size: 128, frac_seeded: 0.10, t_elap: 0.76 ms, diff: +101.38%
k: 1, vocab_size: 30000, batch_size: 128, frac_seeded: 0.20, t_elap: 0.95 ms, diff: +150.41%
k: 1, vocab_size: 30000, batch_size: 128, frac_seeded: 0.50, t_elap: 1.66 ms, diff: +338.59%
k: 1, vocab_size: 30000, batch_size: 128, frac_seeded: 1.00, t_elap: 3.09 ms, diff: +718.98%
k: 1, vocab_size: 50000, batch_size: 1, frac_seeded: 0.00, t_elap: 0.38 ms
k: 1, vocab_size: 50000, batch_size: 1, frac_seeded: 0.10, t_elap: 0.38 ms, diff: -0.13%
k: 1, vocab_size: 50000, batch_size: 1, frac_seeded: 0.20, t_elap: 0.38 ms, diff: -0.16%
k: 1, vocab_size: 50000, batch_size: 1, frac_seeded: 0.50, t_elap: 0.42 ms, diff: +11.83%
k: 1, vocab_size: 50000, batch_size: 1, frac_seeded: 1.00, t_elap: 0.42 ms, diff: +11.69%
k: 1, vocab_size: 50000, batch_size: 4, frac_seeded: 0.00, t_elap: 0.40 ms
k: 1, vocab_size: 50000, batch_size: 4, frac_seeded: 0.10, t_elap: 0.40 ms, diff: -0.17%
k: 1, vocab_size: 50000, batch_size: 4, frac_seeded: 0.20, t_elap: 0.47 ms, diff: +17.71%
k: 1, vocab_size: 50000, batch_size: 4, frac_seeded: 0.50, t_elap: 0.51 ms, diff: +29.08%
k: 1, vocab_size: 50000, batch_size: 4, frac_seeded: 1.00, t_elap: 0.51 ms, diff: +28.90%
k: 1, vocab_size: 50000, batch_size: 8, frac_seeded: 0.00, t_elap: 0.39 ms
k: 1, vocab_size: 50000, batch_size: 8, frac_seeded: 0.10, t_elap: 0.47 ms, diff: +19.12%
k: 1, vocab_size: 50000, batch_size: 8, frac_seeded: 0.20, t_elap: 0.40 ms, diff: +0.67%
k: 1, vocab_size: 50000, batch_size: 8, frac_seeded: 0.50, t_elap: 0.54 ms, diff: +35.82%
k: 1, vocab_size: 50000, batch_size: 8, frac_seeded: 1.00, t_elap: 0.60 ms, diff: +51.98%
k: 1, vocab_size: 50000, batch_size: 32, frac_seeded: 0.00, t_elap: 0.40 ms
k: 1, vocab_size: 50000, batch_size: 32, frac_seeded: 0.10, t_elap: 0.51 ms, diff: +27.04%
k: 1, vocab_size: 50000, batch_size: 32, frac_seeded: 0.20, t_elap: 0.60 ms, diff: +47.76%
k: 1, vocab_size: 50000, batch_size: 32, frac_seeded: 0.50, t_elap: 0.83 ms, diff: +104.82%
k: 1, vocab_size: 50000, batch_size: 32, frac_seeded: 1.00, t_elap: 1.11 ms, diff: +175.25%
k: 1, vocab_size: 50000, batch_size: 128, frac_seeded: 0.00, t_elap: 0.42 ms
k: 1, vocab_size: 50000, batch_size: 128, frac_seeded: 0.10, t_elap: 0.83 ms, diff: +97.62%
k: 1, vocab_size: 50000, batch_size: 128, frac_seeded: 0.20, t_elap: 0.97 ms, diff: +130.45%
k: 1, vocab_size: 50000, batch_size: 128, frac_seeded: 0.50, t_elap: 1.88 ms, diff: +347.14%
k: 1, vocab_size: 50000, batch_size: 128, frac_seeded: 1.00, t_elap: 3.09 ms, diff: +635.32%
k: 3, vocab_size: 30000, batch_size: 1, frac_seeded: 0.00, t_elap: 0.39 ms
k: 3, vocab_size: 30000, batch_size: 1, frac_seeded: 0.10, t_elap: 0.39 ms, diff: +0.17%
k: 3, vocab_size: 30000, batch_size: 1, frac_seeded: 0.20, t_elap: 0.39 ms, diff: -0.01%
k: 3, vocab_size: 30000, batch_size: 1, frac_seeded: 0.50, t_elap: 0.45 ms, diff: +13.64%
k: 3, vocab_size: 30000, batch_size: 1, frac_seeded: 1.00, t_elap: 0.45 ms, diff: +13.29%
k: 3, vocab_size: 30000, batch_size: 4, frac_seeded: 0.00, t_elap: 0.41 ms
k: 3, vocab_size: 30000, batch_size: 4, frac_seeded: 0.10, t_elap: 0.41 ms, diff: +0.03%
k: 3, vocab_size: 30000, batch_size: 4, frac_seeded: 0.20, t_elap: 0.49 ms, diff: +20.87%
k: 3, vocab_size: 30000, batch_size: 4, frac_seeded: 0.50, t_elap: 0.52 ms, diff: +28.37%
k: 3, vocab_size: 30000, batch_size: 4, frac_seeded: 1.00, t_elap: 0.56 ms, diff: +37.89%
k: 3, vocab_size: 30000, batch_size: 8, frac_seeded: 0.00, t_elap: 0.41 ms
k: 3, vocab_size: 30000, batch_size: 8, frac_seeded: 0.10, t_elap: 0.41 ms, diff: +0.15%
k: 3, vocab_size: 30000, batch_size: 8, frac_seeded: 0.20, t_elap: 0.53 ms, diff: +28.73%
k: 3, vocab_size: 30000, batch_size: 8, frac_seeded: 0.50, t_elap: 0.58 ms, diff: +42.66%
k: 3, vocab_size: 30000, batch_size: 8, frac_seeded: 1.00, t_elap: 0.68 ms, diff: +66.24%
k: 3, vocab_size: 30000, batch_size: 32, frac_seeded: 0.00, t_elap: 0.41 ms
k: 3, vocab_size: 30000, batch_size: 32, frac_seeded: 0.10, t_elap: 0.55 ms, diff: +32.40%
k: 3, vocab_size: 30000, batch_size: 32, frac_seeded: 0.20, t_elap: 0.67 ms, diff: +61.56%
k: 3, vocab_size: 30000, batch_size: 32, frac_seeded: 0.50, t_elap: 0.99 ms, diff: +138.39%
k: 3, vocab_size: 30000, batch_size: 32, frac_seeded: 1.00, t_elap: 1.40 ms, diff: +237.93%
k: 3, vocab_size: 30000, batch_size: 128, frac_seeded: 0.00, t_elap: 0.44 ms
k: 3, vocab_size: 30000, batch_size: 128, frac_seeded: 0.10, t_elap: 0.79 ms, diff: +82.31%
k: 3, vocab_size: 30000, batch_size: 128, frac_seeded: 0.20, t_elap: 1.21 ms, diff: +177.38%
k: 3, vocab_size: 30000, batch_size: 128, frac_seeded: 0.50, t_elap: 2.66 ms, diff: +510.67%
k: 3, vocab_size: 30000, batch_size: 128, frac_seeded: 1.00, t_elap: 4.21 ms, diff: +866.02%
k: 3, vocab_size: 50000, batch_size: 1, frac_seeded: 0.00, t_elap: 0.39 ms
k: 3, vocab_size: 50000, batch_size: 1, frac_seeded: 0.10, t_elap: 0.39 ms, diff: +0.07%
k: 3, vocab_size: 50000, batch_size: 1, frac_seeded: 0.20, t_elap: 0.39 ms, diff: -0.03%
k: 3, vocab_size: 50000, batch_size: 1, frac_seeded: 0.50, t_elap: 0.45 ms, diff: +14.26%
k: 3, vocab_size: 50000, batch_size: 1, frac_seeded: 1.00, t_elap: 0.45 ms, diff: +14.16%
k: 3, vocab_size: 50000, batch_size: 4, frac_seeded: 0.00, t_elap: 0.41 ms
k: 3, vocab_size: 50000, batch_size: 4, frac_seeded: 0.10, t_elap: 0.41 ms, diff: -0.04%
k: 3, vocab_size: 50000, batch_size: 4, frac_seeded: 0.20, t_elap: 0.41 ms, diff: -0.04%
k: 3, vocab_size: 50000, batch_size: 4, frac_seeded: 0.50, t_elap: 0.52 ms, diff: +27.43%
k: 3, vocab_size: 50000, batch_size: 4, frac_seeded: 1.00, t_elap: 0.56 ms, diff: +37.24%
k: 3, vocab_size: 50000, batch_size: 8, frac_seeded: 0.00, t_elap: 0.41 ms
k: 3, vocab_size: 50000, batch_size: 8, frac_seeded: 0.10, t_elap: 0.41 ms, diff: -0.15%
k: 3, vocab_size: 50000, batch_size: 8, frac_seeded: 0.20, t_elap: 0.52 ms, diff: +28.21%
k: 3, vocab_size: 50000, batch_size: 8, frac_seeded: 0.50, t_elap: 0.58 ms, diff: +42.51%
k: 3, vocab_size: 50000, batch_size: 8, frac_seeded: 1.00, t_elap: 0.68 ms, diff: +66.48%
k: 3, vocab_size: 50000, batch_size: 32, frac_seeded: 0.00, t_elap: 0.42 ms
k: 3, vocab_size: 50000, batch_size: 32, frac_seeded: 0.10, t_elap: 0.63 ms, diff: +52.15%
k: 3, vocab_size: 50000, batch_size: 32, frac_seeded: 0.20, t_elap: 0.60 ms, diff: +44.58%
k: 3, vocab_size: 50000, batch_size: 32, frac_seeded: 0.50, t_elap: 0.95 ms, diff: +128.86%
k: 3, vocab_size: 50000, batch_size: 32, frac_seeded: 1.00, t_elap: 1.39 ms, diff: +235.73%
k: 3, vocab_size: 50000, batch_size: 128, frac_seeded: 0.00, t_elap: 0.43 ms
k: 3, vocab_size: 50000, batch_size: 128, frac_seeded: 0.10, t_elap: 0.85 ms, diff: +97.39%
k: 3, vocab_size: 50000, batch_size: 128, frac_seeded: 0.20, t_elap: 1.26 ms, diff: +190.45%
k: 3, vocab_size: 50000, batch_size: 128, frac_seeded: 0.50, t_elap: 2.33 ms, diff: +437.52%
k: 3, vocab_size: 50000, batch_size: 128, frac_seeded: 1.00, t_elap: 4.21 ms, diff: +871.77%
k: 5, vocab_size: 30000, batch_size: 1, frac_seeded: 0.00, t_elap: 0.39 ms
k: 5, vocab_size: 30000, batch_size: 1, frac_seeded: 0.10, t_elap: 0.39 ms, diff: +0.16%
k: 5, vocab_size: 30000, batch_size: 1, frac_seeded: 0.20, t_elap: 0.46 ms, diff: +15.72%
k: 5, vocab_size: 30000, batch_size: 1, frac_seeded: 0.50, t_elap: 0.46 ms, diff: +15.84%
k: 5, vocab_size: 30000, batch_size: 1, frac_seeded: 1.00, t_elap: 0.46 ms, diff: +15.70%
k: 5, vocab_size: 30000, batch_size: 4, frac_seeded: 0.00, t_elap: 0.41 ms
k: 5, vocab_size: 30000, batch_size: 4, frac_seeded: 0.10, t_elap: 0.41 ms, diff: +0.08%
k: 5, vocab_size: 30000, batch_size: 4, frac_seeded: 0.20, t_elap: 0.41 ms, diff: +0.10%
k: 5, vocab_size: 30000, batch_size: 4, frac_seeded: 0.50, t_elap: 0.59 ms, diff: +43.95%
k: 5, vocab_size: 30000, batch_size: 4, frac_seeded: 1.00, t_elap: 0.61 ms, diff: +48.61%
k: 5, vocab_size: 30000, batch_size: 8, frac_seeded: 0.00, t_elap: 0.41 ms
k: 5, vocab_size: 30000, batch_size: 8, frac_seeded: 0.10, t_elap: 0.41 ms, diff: +0.08%
k: 5, vocab_size: 30000, batch_size: 8, frac_seeded: 0.20, t_elap: 0.51 ms, diff: +24.68%
k: 5, vocab_size: 30000, batch_size: 8, frac_seeded: 0.50, t_elap: 0.67 ms, diff: +62.47%
k: 5, vocab_size: 30000, batch_size: 8, frac_seeded: 1.00, t_elap: 0.76 ms, diff: +84.95%
k: 5, vocab_size: 30000, batch_size: 32, frac_seeded: 0.00, t_elap: 0.42 ms
k: 5, vocab_size: 30000, batch_size: 32, frac_seeded: 0.10, t_elap: 0.72 ms, diff: +72.41%
k: 5, vocab_size: 30000, batch_size: 32, frac_seeded: 0.20, t_elap: 0.68 ms, diff: +62.23%
k: 5, vocab_size: 30000, batch_size: 32, frac_seeded: 0.50, t_elap: 1.18 ms, diff: +180.81%
k: 5, vocab_size: 30000, batch_size: 32, frac_seeded: 1.00, t_elap: 1.68 ms, diff: +300.72%
k: 5, vocab_size: 30000, batch_size: 128, frac_seeded: 0.00, t_elap: 0.44 ms
k: 5, vocab_size: 30000, batch_size: 128, frac_seeded: 0.10, t_elap: 1.15 ms, diff: +161.95%
k: 5, vocab_size: 30000, batch_size: 128, frac_seeded: 0.20, t_elap: 1.46 ms, diff: +230.81%
k: 5, vocab_size: 30000, batch_size: 128, frac_seeded: 0.50, t_elap: 3.44 ms, diff: +680.86%
k: 5, vocab_size: 30000, batch_size: 128, frac_seeded: 1.00, t_elap: 5.30 ms, diff: +1104.57%
k: 5, vocab_size: 50000, batch_size: 1, frac_seeded: 0.00, t_elap: 0.39 ms
k: 5, vocab_size: 50000, batch_size: 1, frac_seeded: 0.10, t_elap: 0.39 ms, diff: -0.10%
k: 5, vocab_size: 50000, batch_size: 1, frac_seeded: 0.20, t_elap: 0.39 ms, diff: +0.07%
k: 5, vocab_size: 50000, batch_size: 1, frac_seeded: 0.50, t_elap: 0.39 ms, diff: -0.18%
k: 5, vocab_size: 50000, batch_size: 1, frac_seeded: 1.00, t_elap: 0.46 ms, diff: +15.94%
k: 5, vocab_size: 50000, batch_size: 4, frac_seeded: 0.00, t_elap: 0.41 ms
k: 5, vocab_size: 50000, batch_size: 4, frac_seeded: 0.10, t_elap: 0.54 ms, diff: +32.45%
k: 5, vocab_size: 50000, batch_size: 4, frac_seeded: 0.20, t_elap: 0.41 ms, diff: -0.24%
k: 5, vocab_size: 50000, batch_size: 4, frac_seeded: 0.50, t_elap: 0.58 ms, diff: +42.51%
k: 5, vocab_size: 50000, batch_size: 4, frac_seeded: 1.00, t_elap: 0.60 ms, diff: +46.50%
k: 5, vocab_size: 50000, batch_size: 8, frac_seeded: 0.00, t_elap: 0.41 ms
k: 5, vocab_size: 50000, batch_size: 8, frac_seeded: 0.10, t_elap: 0.41 ms, diff: +0.41%
k: 5, vocab_size: 50000, batch_size: 8, frac_seeded: 0.20, t_elap: 0.41 ms, diff: +0.15%
k: 5, vocab_size: 50000, batch_size: 8, frac_seeded: 0.50, t_elap: 0.62 ms, diff: +52.06%
k: 5, vocab_size: 50000, batch_size: 8, frac_seeded: 1.00, t_elap: 0.75 ms, diff: +84.19%
k: 5, vocab_size: 50000, batch_size: 32, frac_seeded: 0.00, t_elap: 0.42 ms
k: 5, vocab_size: 50000, batch_size: 32, frac_seeded: 0.10, t_elap: 0.53 ms, diff: +25.95%
k: 5, vocab_size: 50000, batch_size: 32, frac_seeded: 0.20, t_elap: 0.72 ms, diff: +71.41%
k: 5, vocab_size: 50000, batch_size: 32, frac_seeded: 0.50, t_elap: 1.02 ms, diff: +142.86%
k: 5, vocab_size: 50000, batch_size: 32, frac_seeded: 1.00, t_elap: 1.68 ms, diff: +300.73%
k: 5, vocab_size: 50000, batch_size: 128, frac_seeded: 0.00, t_elap: 0.45 ms
k: 5, vocab_size: 50000, batch_size: 128, frac_seeded: 0.10, t_elap: 1.12 ms, diff: +151.11%
k: 5, vocab_size: 50000, batch_size: 128, frac_seeded: 0.20, t_elap: 1.64 ms, diff: +265.30%
k: 5, vocab_size: 50000, batch_size: 128, frac_seeded: 0.50, t_elap: 3.03 ms, diff: +575.36%
k: 5, vocab_size: 50000, batch_size: 128, frac_seeded: 1.00, t_elap: 5.32 ms, diff: +1087.33% In particular, when the batch size is large, there can be up to a few ms overhead added. I guess this may not be so significant in E2E measurements (I will do some now). I did some investigation into where the overhead is coming from and it is really due to the random number generation part rather than any auxiliary stuff. It seems like if we want to make it reproducible there will be some ms-level performance penalty for large batch sizes. |
BTW I added both E2E tests, as well as unit tests for rejection sampler |
It is probably also worth mentioning though that we would probably see similar when increasing percentage of requests that are random-seeded in a large batch even in the non-spec decoding case, since there's also a loop for that. |
@njhill yes, that's a good point |
@njhill @cadedaniel anything else you would like to see addressed here? |
Can we find a more optimized way to do this? Basically this defeats the purpose of performance benchmarking w/seeds unless the overhead is small. |
The overheads come from having to iterate through the I agree that if we want optimal performance numbers, we may want to not set seeds. However, it is pretty important on our side to be able to get reproducible output when running our integration tests etc. @njhill can probably comment more on that. |
Antoni was able to get pretty low overhead per-seed sampling. any thoughts @Yard1 ? |
and I think this is admissible if it's for correctness testing. but if it's for users or for performance testing then we need to understand tradeoff of optimizing it vs. simply not allowing it, so quality bar is high |
@cadedaniel it will require a custom sampling kernel, and invoking through triton jit had a massive overhead. That being said this PR has something interesting: 391d761, maybe it will work (though I see the author removed it later) |
as an alternative, we could simplify this problem to batch generation of uniform random tensors with different seeds per row. This is definitely easy to write in triton (but again, overheads) - but perhaps writing it in pure CUDA would be just as simple. |
I don't see why we would not allow seed even if its use incurs some performance overhead. It's no different to how seed is already implemented for non spec decode case, there won't be any "extra" overhead from a seeded request than there is now. Each generated token for each seeded request requires using the generator to perform a separate exponential sample whether or not that was a speculated token. In practice I don't think most users care about or use the seed, so in most contexts the overall overhead would not be significant. But it means that those who want more repeatable/stable outputs can still make use of it. Service providers are still free to block this parameter if they want (or we could add an env var for that - so that we optionally fail requests with seed set). But this would be something done irrespective of use of spec decoding. Adding this to the spec decoding path is just filling a current functional gap/inconsistency imo, and whether/how to use it in performance benchmarking is a separate question. |
Alright this doesn't meet my quality bar but I won't block y'all since it's off by default. LGTM. |
@cadedaniel do you mean quality in terms of the specific code changes in this PR, or just the fact that overhead exists from requests that include a seed? The latter may be a valid concern but is a preexisting one and I think orthogonal to spec/non-spec decoding? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@cadedaniel do you mean quality in terms of the specific code changes in this PR, or just the fact that overhead exists from requests that include a seed? The latter may be a valid concern but is a preexisting one and I think orthogonal to spec/non-spec decoding?
I mean that users may expect vLLM to give low latency for the features spec decode supports, but per-user seed as is in this PR will not give them that. I am OK to merge if we add a comment explaining the state, or even create an issue asking for someone to optimize it, e.g. as Antoni suggests.
That said thanks for adding what's here! Will be a good start to optimizing later.
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @tdoublep!
Looks like Triton 3.0.0 will reduce kernel launch overhead, meaning that we can move towards triton kernels for batched seeded random generation! triton-lang/triton#3503 We also already have this kernel in vLLM code (https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/ops/rand.py), so it would be just a matter of using it to generate the random uniform tensor to replace the |
Signed-off-by: Thomas Parnell <[email protected]> Co-authored-by: Nick Hill <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]> Co-authored-by: Nick Hill <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]> Co-authored-by: Nick Hill <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]> Co-authored-by: Nick Hill <[email protected]> Signed-off-by: Alvant <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]> Co-authored-by: Nick Hill <[email protected]>
Fixes #6038