-
Notifications
You must be signed in to change notification settings - Fork 3.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bugfix] Make spec. decode respect per-request seed. #6034
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Thomas Parnell <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice :)
uniform_rand[i,:] = torch.rand(1, k, dtype=self.probs_dtype, device=target_probs.device, generator=generator) | ||
else: | ||
|
||
uniform_rand = torch.rand(batch_size, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
preexisting code but this could be simplified as torch.rand_like(selected_target_probs)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
if any(generator is not None for generator in generators): | ||
uniform_rand = torch.empty(batch_size, k, dtype=self.probs_dtype, device=target_probs.device) | ||
for i, generator in enumerate(generators): | ||
uniform_rand[i,:] = torch.rand(1, k, dtype=self.probs_dtype, device=target_probs.device, generator=generator) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it might be preferable here in the mixed case to do a single rand for the non-seed requests, since it could be likely that most requests don't have a seed. This can be done via something like
uniform_rand[non_seed_indices] = torch.rand(len(non_seed_indices), k, ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. Sort-of-related question: do you know whether we can safely assume that the length of each SequenceGroup is exactly 1 here? In the general sampler code there is some logic to handle the more general case. Maybe we need that here too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will double check, I think it's only >1 for beam search which I think is explicitly unsupported with spec decode right now anyhow. Though it might also be the case for parallel decoding (n
>1 i.e. multiple output seqs per single input seq).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually just tested that case: combining n>1
with spec. decode and it seems to fail on main:
| File "/home/zrltpa/vllm/vllm/model_executor/layers/sampler.py", line 96, in forward
| sample_results, maybe_sampled_tokens_tensor = _sample(
| ^^^^^^^^
| File "/home/zrltpa/vllm/vllm/model_executor/layers/sampler.py", line 658, in _sample
| return _sample_with_torch(
| ^^^^^^^^^^^^^^^^^^^
| File "/home/zrltpa/vllm/vllm/model_executor/layers/sampler.py", line 528, in _sample_with_torch
| sampled_token_ids_tensor[
| RuntimeError: shape mismatch: value tensor of shape [2] cannot be broadcast to indexing result of shape [1, 1]
If that case isn't supported by spec decode currently we should at least give a reasonable message back.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
made the suggested change regading uniform_rand
- will create separate PR to address error msg for n>1
if any(generator is not None for generator in generators): | ||
for i, generator in enumerate(generators): | ||
q[i].exponential_(1.0, generator=generator) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar comment to above about doing single exponential_
for all non-seeded rows
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
device=target_probs.device) | ||
|
||
if any(generator is not None for generator in generators): | ||
uniform_rand = torch.empty(batch_size, k, dtype=self.probs_dtype, device=target_probs.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could simplify as torch.empty_like(selected_target_probs)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
vllm/spec_decode/batch_expansion.py
Outdated
generator = torch.Generator( | ||
device=seq_group_metadata.state.generator.device | ||
) | ||
generator.set_state( | ||
seq_group_metadata.state.generator.get_state() | ||
) | ||
state = SequenceGroupState( | ||
generator=generator, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could the same Generator just be used here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My first attempt just shallow copied the generator but it led to incorrect results, so I think we need a "deep" copy here.
Can we add some e2e tests? |
Ideally unit tests in rejection sampler too! |
q = torch.empty_like(probs).exponential_(1.0) | ||
|
||
q = torch.empty_like(probs) | ||
if any(generator is not None for generator in generators): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we see the latency impact when there are multiple generators? Want to make sure rejection sampling is still reasonable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Attached some analysis in comments below
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
@cadedaniel I've been digging into the effect on latency at the level of the rejection sampler. Firstly, I wanted to check that if we never pass any generators, that the performance is similar to main. I call the forward method of rejection sampler 1000 times under different conditions using k: 1, vocab_size: 30000, batch_size: 1, frac_seeded: 0.00, t_elap (main): 0.34 ms, t_elap (new): 0.34 ms, diff: 0.00%
k: 1, vocab_size: 30000, batch_size: 4, frac_seeded: 0.00, t_elap (main): 0.35 ms, t_elap (new): 0.36 ms, diff: 2.86%
k: 1, vocab_size: 30000, batch_size: 8, frac_seeded: 0.00, t_elap (main): 0.35 ms, t_elap (new): 0.36 ms, diff: 2.86%
k: 1, vocab_size: 30000, batch_size: 32, frac_seeded: 0.00, t_elap (main): 0.36 ms, t_elap (new): 0.36 ms, diff: 0.00%
k: 1, vocab_size: 30000, batch_size: 128, frac_seeded: 0.00, t_elap (main): 0.36 ms, t_elap (new): 0.38 ms, diff: 5.56%
k: 1, vocab_size: 50000, batch_size: 1, frac_seeded: 0.00, t_elap (main): 0.34 ms, t_elap (new): 0.34 ms, diff: 0.00%
k: 1, vocab_size: 50000, batch_size: 4, frac_seeded: 0.00, t_elap (main): 0.36 ms, t_elap (new): 0.36 ms, diff: 0.00%
k: 1, vocab_size: 50000, batch_size: 8, frac_seeded: 0.00, t_elap (main): 0.35 ms, t_elap (new): 0.36 ms, diff: 2.86%
k: 1, vocab_size: 50000, batch_size: 32, frac_seeded: 0.00, t_elap (main): 0.35 ms, t_elap (new): 0.37 ms, diff: 5.71%
k: 1, vocab_size: 50000, batch_size: 128, frac_seeded: 0.00, t_elap (main): 0.36 ms, t_elap (new): 0.38 ms, diff: 5.56%
k: 3, vocab_size: 30000, batch_size: 1, frac_seeded: 0.00, t_elap (main): 0.35 ms, t_elap (new): 0.35 ms, diff: 0.00%
k: 3, vocab_size: 30000, batch_size: 4, frac_seeded: 0.00, t_elap (main): 0.37 ms, t_elap (new): 0.37 ms, diff: 0.00%
k: 3, vocab_size: 30000, batch_size: 8, frac_seeded: 0.00, t_elap (main): 0.37 ms, t_elap (new): 0.37 ms, diff: 0.00%
k: 3, vocab_size: 30000, batch_size: 32, frac_seeded: 0.00, t_elap (main): 0.37 ms, t_elap (new): 0.38 ms, diff: 2.70%
k: 3, vocab_size: 30000, batch_size: 128, frac_seeded: 0.00, t_elap (main): 0.37 ms, t_elap (new): 0.40 ms, diff: 8.11%
k: 3, vocab_size: 50000, batch_size: 1, frac_seeded: 0.00, t_elap (main): 0.35 ms, t_elap (new): 0.35 ms, diff: 0.00%
k: 3, vocab_size: 50000, batch_size: 4, frac_seeded: 0.00, t_elap (main): 0.36 ms, t_elap (new): 0.37 ms, diff: 2.78%
k: 3, vocab_size: 50000, batch_size: 8, frac_seeded: 0.00, t_elap (main): 0.36 ms, t_elap (new): 0.37 ms, diff: 2.78%
k: 3, vocab_size: 50000, batch_size: 32, frac_seeded: 0.00, t_elap (main): 0.37 ms, t_elap (new): 0.38 ms, diff: 2.70%
k: 3, vocab_size: 50000, batch_size: 128, frac_seeded: 0.00, t_elap (main): 0.37 ms, t_elap (new): 0.40 ms, diff: 8.11%
k: 5, vocab_size: 30000, batch_size: 1, frac_seeded: 0.00, t_elap (main): 0.35 ms, t_elap (new): 0.35 ms, diff: 0.00%
k: 5, vocab_size: 30000, batch_size: 4, frac_seeded: 0.00, t_elap (main): 0.36 ms, t_elap (new): 0.37 ms, diff: 2.78%
k: 5, vocab_size: 30000, batch_size: 8, frac_seeded: 0.00, t_elap (main): 0.37 ms, t_elap (new): 0.37 ms, diff: 0.00%
k: 5, vocab_size: 30000, batch_size: 32, frac_seeded: 0.00, t_elap (main): 0.37 ms, t_elap (new): 0.38 ms, diff: 2.70%
k: 5, vocab_size: 30000, batch_size: 128, frac_seeded: 0.00, t_elap (main): 0.37 ms, t_elap (new): 0.40 ms, diff: 8.11%
k: 5, vocab_size: 50000, batch_size: 1, frac_seeded: 0.00, t_elap (main): 0.35 ms, t_elap (new): 0.35 ms, diff: 0.00%
k: 5, vocab_size: 50000, batch_size: 4, frac_seeded: 0.00, t_elap (main): 0.37 ms, t_elap (new): 0.37 ms, diff: 0.00%
k: 5, vocab_size: 50000, batch_size: 8, frac_seeded: 0.00, t_elap (main): 0.37 ms, t_elap (new): 0.37 ms, diff: 0.00%
k: 5, vocab_size: 50000, batch_size: 32, frac_seeded: 0.00, t_elap (main): 0.37 ms, t_elap (new): 0.38 ms, diff: 2.70%
k: 5, vocab_size: 50000, batch_size: 128, frac_seeded: 0.00, t_elap (main): 0.38 ms, t_elap (new): 0.40 ms, diff: 5.26% Next, I wanted to investigate how increasing the fraction of requests in the batch that are seeded affects the performance ( k: 1, vocab_size: 30000, batch_size: 1, frac_seeded: 0.00, t_elap: 0.34 ms
k: 1, vocab_size: 30000, batch_size: 1, frac_seeded: 0.10, t_elap: 0.34 ms, diff: -0.22%
k: 1, vocab_size: 30000, batch_size: 1, frac_seeded: 0.20, t_elap: 0.34 ms, diff: -0.89%
k: 1, vocab_size: 30000, batch_size: 1, frac_seeded: 0.50, t_elap: 0.38 ms, diff: +10.64%
k: 1, vocab_size: 30000, batch_size: 1, frac_seeded: 1.00, t_elap: 0.38 ms, diff: +11.24%
k: 1, vocab_size: 30000, batch_size: 4, frac_seeded: 0.00, t_elap: 0.36 ms
k: 1, vocab_size: 30000, batch_size: 4, frac_seeded: 0.10, t_elap: 0.36 ms, diff: +0.32%
k: 1, vocab_size: 30000, batch_size: 4, frac_seeded: 0.20, t_elap: 0.36 ms, diff: +0.23%
k: 1, vocab_size: 30000, batch_size: 4, frac_seeded: 0.50, t_elap: 0.43 ms, diff: +20.67%
k: 1, vocab_size: 30000, batch_size: 4, frac_seeded: 1.00, t_elap: 0.47 ms, diff: +31.69%
k: 1, vocab_size: 30000, batch_size: 8, frac_seeded: 0.00, t_elap: 0.36 ms
k: 1, vocab_size: 30000, batch_size: 8, frac_seeded: 0.10, t_elap: 0.43 ms, diff: +19.74%
k: 1, vocab_size: 30000, batch_size: 8, frac_seeded: 0.20, t_elap: 0.47 ms, diff: +31.40%
k: 1, vocab_size: 30000, batch_size: 8, frac_seeded: 0.50, t_elap: 0.53 ms, diff: +46.26%
k: 1, vocab_size: 30000, batch_size: 8, frac_seeded: 1.00, t_elap: 0.54 ms, diff: +50.61%
k: 1, vocab_size: 30000, batch_size: 32, frac_seeded: 0.00, t_elap: 0.37 ms
k: 1, vocab_size: 30000, batch_size: 32, frac_seeded: 0.10, t_elap: 0.46 ms, diff: +27.00%
k: 1, vocab_size: 30000, batch_size: 32, frac_seeded: 0.20, t_elap: 0.61 ms, diff: +68.38%
k: 1, vocab_size: 30000, batch_size: 32, frac_seeded: 0.50, t_elap: 0.71 ms, diff: +93.97%
k: 1, vocab_size: 30000, batch_size: 32, frac_seeded: 1.00, t_elap: 0.99 ms, diff: +171.91%
k: 1, vocab_size: 30000, batch_size: 128, frac_seeded: 0.00, t_elap: 0.38 ms
k: 1, vocab_size: 30000, batch_size: 128, frac_seeded: 0.10, t_elap: 0.76 ms, diff: +101.38%
k: 1, vocab_size: 30000, batch_size: 128, frac_seeded: 0.20, t_elap: 0.95 ms, diff: +150.41%
k: 1, vocab_size: 30000, batch_size: 128, frac_seeded: 0.50, t_elap: 1.66 ms, diff: +338.59%
k: 1, vocab_size: 30000, batch_size: 128, frac_seeded: 1.00, t_elap: 3.09 ms, diff: +718.98%
k: 1, vocab_size: 50000, batch_size: 1, frac_seeded: 0.00, t_elap: 0.38 ms
k: 1, vocab_size: 50000, batch_size: 1, frac_seeded: 0.10, t_elap: 0.38 ms, diff: -0.13%
k: 1, vocab_size: 50000, batch_size: 1, frac_seeded: 0.20, t_elap: 0.38 ms, diff: -0.16%
k: 1, vocab_size: 50000, batch_size: 1, frac_seeded: 0.50, t_elap: 0.42 ms, diff: +11.83%
k: 1, vocab_size: 50000, batch_size: 1, frac_seeded: 1.00, t_elap: 0.42 ms, diff: +11.69%
k: 1, vocab_size: 50000, batch_size: 4, frac_seeded: 0.00, t_elap: 0.40 ms
k: 1, vocab_size: 50000, batch_size: 4, frac_seeded: 0.10, t_elap: 0.40 ms, diff: -0.17%
k: 1, vocab_size: 50000, batch_size: 4, frac_seeded: 0.20, t_elap: 0.47 ms, diff: +17.71%
k: 1, vocab_size: 50000, batch_size: 4, frac_seeded: 0.50, t_elap: 0.51 ms, diff: +29.08%
k: 1, vocab_size: 50000, batch_size: 4, frac_seeded: 1.00, t_elap: 0.51 ms, diff: +28.90%
k: 1, vocab_size: 50000, batch_size: 8, frac_seeded: 0.00, t_elap: 0.39 ms
k: 1, vocab_size: 50000, batch_size: 8, frac_seeded: 0.10, t_elap: 0.47 ms, diff: +19.12%
k: 1, vocab_size: 50000, batch_size: 8, frac_seeded: 0.20, t_elap: 0.40 ms, diff: +0.67%
k: 1, vocab_size: 50000, batch_size: 8, frac_seeded: 0.50, t_elap: 0.54 ms, diff: +35.82%
k: 1, vocab_size: 50000, batch_size: 8, frac_seeded: 1.00, t_elap: 0.60 ms, diff: +51.98%
k: 1, vocab_size: 50000, batch_size: 32, frac_seeded: 0.00, t_elap: 0.40 ms
k: 1, vocab_size: 50000, batch_size: 32, frac_seeded: 0.10, t_elap: 0.51 ms, diff: +27.04%
k: 1, vocab_size: 50000, batch_size: 32, frac_seeded: 0.20, t_elap: 0.60 ms, diff: +47.76%
k: 1, vocab_size: 50000, batch_size: 32, frac_seeded: 0.50, t_elap: 0.83 ms, diff: +104.82%
k: 1, vocab_size: 50000, batch_size: 32, frac_seeded: 1.00, t_elap: 1.11 ms, diff: +175.25%
k: 1, vocab_size: 50000, batch_size: 128, frac_seeded: 0.00, t_elap: 0.42 ms
k: 1, vocab_size: 50000, batch_size: 128, frac_seeded: 0.10, t_elap: 0.83 ms, diff: +97.62%
k: 1, vocab_size: 50000, batch_size: 128, frac_seeded: 0.20, t_elap: 0.97 ms, diff: +130.45%
k: 1, vocab_size: 50000, batch_size: 128, frac_seeded: 0.50, t_elap: 1.88 ms, diff: +347.14%
k: 1, vocab_size: 50000, batch_size: 128, frac_seeded: 1.00, t_elap: 3.09 ms, diff: +635.32%
k: 3, vocab_size: 30000, batch_size: 1, frac_seeded: 0.00, t_elap: 0.39 ms
k: 3, vocab_size: 30000, batch_size: 1, frac_seeded: 0.10, t_elap: 0.39 ms, diff: +0.17%
k: 3, vocab_size: 30000, batch_size: 1, frac_seeded: 0.20, t_elap: 0.39 ms, diff: -0.01%
k: 3, vocab_size: 30000, batch_size: 1, frac_seeded: 0.50, t_elap: 0.45 ms, diff: +13.64%
k: 3, vocab_size: 30000, batch_size: 1, frac_seeded: 1.00, t_elap: 0.45 ms, diff: +13.29%
k: 3, vocab_size: 30000, batch_size: 4, frac_seeded: 0.00, t_elap: 0.41 ms
k: 3, vocab_size: 30000, batch_size: 4, frac_seeded: 0.10, t_elap: 0.41 ms, diff: +0.03%
k: 3, vocab_size: 30000, batch_size: 4, frac_seeded: 0.20, t_elap: 0.49 ms, diff: +20.87%
k: 3, vocab_size: 30000, batch_size: 4, frac_seeded: 0.50, t_elap: 0.52 ms, diff: +28.37%
k: 3, vocab_size: 30000, batch_size: 4, frac_seeded: 1.00, t_elap: 0.56 ms, diff: +37.89%
k: 3, vocab_size: 30000, batch_size: 8, frac_seeded: 0.00, t_elap: 0.41 ms
k: 3, vocab_size: 30000, batch_size: 8, frac_seeded: 0.10, t_elap: 0.41 ms, diff: +0.15%
k: 3, vocab_size: 30000, batch_size: 8, frac_seeded: 0.20, t_elap: 0.53 ms, diff: +28.73%
k: 3, vocab_size: 30000, batch_size: 8, frac_seeded: 0.50, t_elap: 0.58 ms, diff: +42.66%
k: 3, vocab_size: 30000, batch_size: 8, frac_seeded: 1.00, t_elap: 0.68 ms, diff: +66.24%
k: 3, vocab_size: 30000, batch_size: 32, frac_seeded: 0.00, t_elap: 0.41 ms
k: 3, vocab_size: 30000, batch_size: 32, frac_seeded: 0.10, t_elap: 0.55 ms, diff: +32.40%
k: 3, vocab_size: 30000, batch_size: 32, frac_seeded: 0.20, t_elap: 0.67 ms, diff: +61.56%
k: 3, vocab_size: 30000, batch_size: 32, frac_seeded: 0.50, t_elap: 0.99 ms, diff: +138.39%
k: 3, vocab_size: 30000, batch_size: 32, frac_seeded: 1.00, t_elap: 1.40 ms, diff: +237.93%
k: 3, vocab_size: 30000, batch_size: 128, frac_seeded: 0.00, t_elap: 0.44 ms
k: 3, vocab_size: 30000, batch_size: 128, frac_seeded: 0.10, t_elap: 0.79 ms, diff: +82.31%
k: 3, vocab_size: 30000, batch_size: 128, frac_seeded: 0.20, t_elap: 1.21 ms, diff: +177.38%
k: 3, vocab_size: 30000, batch_size: 128, frac_seeded: 0.50, t_elap: 2.66 ms, diff: +510.67%
k: 3, vocab_size: 30000, batch_size: 128, frac_seeded: 1.00, t_elap: 4.21 ms, diff: +866.02%
k: 3, vocab_size: 50000, batch_size: 1, frac_seeded: 0.00, t_elap: 0.39 ms
k: 3, vocab_size: 50000, batch_size: 1, frac_seeded: 0.10, t_elap: 0.39 ms, diff: +0.07%
k: 3, vocab_size: 50000, batch_size: 1, frac_seeded: 0.20, t_elap: 0.39 ms, diff: -0.03%
k: 3, vocab_size: 50000, batch_size: 1, frac_seeded: 0.50, t_elap: 0.45 ms, diff: +14.26%
k: 3, vocab_size: 50000, batch_size: 1, frac_seeded: 1.00, t_elap: 0.45 ms, diff: +14.16%
k: 3, vocab_size: 50000, batch_size: 4, frac_seeded: 0.00, t_elap: 0.41 ms
k: 3, vocab_size: 50000, batch_size: 4, frac_seeded: 0.10, t_elap: 0.41 ms, diff: -0.04%
k: 3, vocab_size: 50000, batch_size: 4, frac_seeded: 0.20, t_elap: 0.41 ms, diff: -0.04%
k: 3, vocab_size: 50000, batch_size: 4, frac_seeded: 0.50, t_elap: 0.52 ms, diff: +27.43%
k: 3, vocab_size: 50000, batch_size: 4, frac_seeded: 1.00, t_elap: 0.56 ms, diff: +37.24%
k: 3, vocab_size: 50000, batch_size: 8, frac_seeded: 0.00, t_elap: 0.41 ms
k: 3, vocab_size: 50000, batch_size: 8, frac_seeded: 0.10, t_elap: 0.41 ms, diff: -0.15%
k: 3, vocab_size: 50000, batch_size: 8, frac_seeded: 0.20, t_elap: 0.52 ms, diff: +28.21%
k: 3, vocab_size: 50000, batch_size: 8, frac_seeded: 0.50, t_elap: 0.58 ms, diff: +42.51%
k: 3, vocab_size: 50000, batch_size: 8, frac_seeded: 1.00, t_elap: 0.68 ms, diff: +66.48%
k: 3, vocab_size: 50000, batch_size: 32, frac_seeded: 0.00, t_elap: 0.42 ms
k: 3, vocab_size: 50000, batch_size: 32, frac_seeded: 0.10, t_elap: 0.63 ms, diff: +52.15%
k: 3, vocab_size: 50000, batch_size: 32, frac_seeded: 0.20, t_elap: 0.60 ms, diff: +44.58%
k: 3, vocab_size: 50000, batch_size: 32, frac_seeded: 0.50, t_elap: 0.95 ms, diff: +128.86%
k: 3, vocab_size: 50000, batch_size: 32, frac_seeded: 1.00, t_elap: 1.39 ms, diff: +235.73%
k: 3, vocab_size: 50000, batch_size: 128, frac_seeded: 0.00, t_elap: 0.43 ms
k: 3, vocab_size: 50000, batch_size: 128, frac_seeded: 0.10, t_elap: 0.85 ms, diff: +97.39%
k: 3, vocab_size: 50000, batch_size: 128, frac_seeded: 0.20, t_elap: 1.26 ms, diff: +190.45%
k: 3, vocab_size: 50000, batch_size: 128, frac_seeded: 0.50, t_elap: 2.33 ms, diff: +437.52%
k: 3, vocab_size: 50000, batch_size: 128, frac_seeded: 1.00, t_elap: 4.21 ms, diff: +871.77%
k: 5, vocab_size: 30000, batch_size: 1, frac_seeded: 0.00, t_elap: 0.39 ms
k: 5, vocab_size: 30000, batch_size: 1, frac_seeded: 0.10, t_elap: 0.39 ms, diff: +0.16%
k: 5, vocab_size: 30000, batch_size: 1, frac_seeded: 0.20, t_elap: 0.46 ms, diff: +15.72%
k: 5, vocab_size: 30000, batch_size: 1, frac_seeded: 0.50, t_elap: 0.46 ms, diff: +15.84%
k: 5, vocab_size: 30000, batch_size: 1, frac_seeded: 1.00, t_elap: 0.46 ms, diff: +15.70%
k: 5, vocab_size: 30000, batch_size: 4, frac_seeded: 0.00, t_elap: 0.41 ms
k: 5, vocab_size: 30000, batch_size: 4, frac_seeded: 0.10, t_elap: 0.41 ms, diff: +0.08%
k: 5, vocab_size: 30000, batch_size: 4, frac_seeded: 0.20, t_elap: 0.41 ms, diff: +0.10%
k: 5, vocab_size: 30000, batch_size: 4, frac_seeded: 0.50, t_elap: 0.59 ms, diff: +43.95%
k: 5, vocab_size: 30000, batch_size: 4, frac_seeded: 1.00, t_elap: 0.61 ms, diff: +48.61%
k: 5, vocab_size: 30000, batch_size: 8, frac_seeded: 0.00, t_elap: 0.41 ms
k: 5, vocab_size: 30000, batch_size: 8, frac_seeded: 0.10, t_elap: 0.41 ms, diff: +0.08%
k: 5, vocab_size: 30000, batch_size: 8, frac_seeded: 0.20, t_elap: 0.51 ms, diff: +24.68%
k: 5, vocab_size: 30000, batch_size: 8, frac_seeded: 0.50, t_elap: 0.67 ms, diff: +62.47%
k: 5, vocab_size: 30000, batch_size: 8, frac_seeded: 1.00, t_elap: 0.76 ms, diff: +84.95%
k: 5, vocab_size: 30000, batch_size: 32, frac_seeded: 0.00, t_elap: 0.42 ms
k: 5, vocab_size: 30000, batch_size: 32, frac_seeded: 0.10, t_elap: 0.72 ms, diff: +72.41%
k: 5, vocab_size: 30000, batch_size: 32, frac_seeded: 0.20, t_elap: 0.68 ms, diff: +62.23%
k: 5, vocab_size: 30000, batch_size: 32, frac_seeded: 0.50, t_elap: 1.18 ms, diff: +180.81%
k: 5, vocab_size: 30000, batch_size: 32, frac_seeded: 1.00, t_elap: 1.68 ms, diff: +300.72%
k: 5, vocab_size: 30000, batch_size: 128, frac_seeded: 0.00, t_elap: 0.44 ms
k: 5, vocab_size: 30000, batch_size: 128, frac_seeded: 0.10, t_elap: 1.15 ms, diff: +161.95%
k: 5, vocab_size: 30000, batch_size: 128, frac_seeded: 0.20, t_elap: 1.46 ms, diff: +230.81%
k: 5, vocab_size: 30000, batch_size: 128, frac_seeded: 0.50, t_elap: 3.44 ms, diff: +680.86%
k: 5, vocab_size: 30000, batch_size: 128, frac_seeded: 1.00, t_elap: 5.30 ms, diff: +1104.57%
k: 5, vocab_size: 50000, batch_size: 1, frac_seeded: 0.00, t_elap: 0.39 ms
k: 5, vocab_size: 50000, batch_size: 1, frac_seeded: 0.10, t_elap: 0.39 ms, diff: -0.10%
k: 5, vocab_size: 50000, batch_size: 1, frac_seeded: 0.20, t_elap: 0.39 ms, diff: +0.07%
k: 5, vocab_size: 50000, batch_size: 1, frac_seeded: 0.50, t_elap: 0.39 ms, diff: -0.18%
k: 5, vocab_size: 50000, batch_size: 1, frac_seeded: 1.00, t_elap: 0.46 ms, diff: +15.94%
k: 5, vocab_size: 50000, batch_size: 4, frac_seeded: 0.00, t_elap: 0.41 ms
k: 5, vocab_size: 50000, batch_size: 4, frac_seeded: 0.10, t_elap: 0.54 ms, diff: +32.45%
k: 5, vocab_size: 50000, batch_size: 4, frac_seeded: 0.20, t_elap: 0.41 ms, diff: -0.24%
k: 5, vocab_size: 50000, batch_size: 4, frac_seeded: 0.50, t_elap: 0.58 ms, diff: +42.51%
k: 5, vocab_size: 50000, batch_size: 4, frac_seeded: 1.00, t_elap: 0.60 ms, diff: +46.50%
k: 5, vocab_size: 50000, batch_size: 8, frac_seeded: 0.00, t_elap: 0.41 ms
k: 5, vocab_size: 50000, batch_size: 8, frac_seeded: 0.10, t_elap: 0.41 ms, diff: +0.41%
k: 5, vocab_size: 50000, batch_size: 8, frac_seeded: 0.20, t_elap: 0.41 ms, diff: +0.15%
k: 5, vocab_size: 50000, batch_size: 8, frac_seeded: 0.50, t_elap: 0.62 ms, diff: +52.06%
k: 5, vocab_size: 50000, batch_size: 8, frac_seeded: 1.00, t_elap: 0.75 ms, diff: +84.19%
k: 5, vocab_size: 50000, batch_size: 32, frac_seeded: 0.00, t_elap: 0.42 ms
k: 5, vocab_size: 50000, batch_size: 32, frac_seeded: 0.10, t_elap: 0.53 ms, diff: +25.95%
k: 5, vocab_size: 50000, batch_size: 32, frac_seeded: 0.20, t_elap: 0.72 ms, diff: +71.41%
k: 5, vocab_size: 50000, batch_size: 32, frac_seeded: 0.50, t_elap: 1.02 ms, diff: +142.86%
k: 5, vocab_size: 50000, batch_size: 32, frac_seeded: 1.00, t_elap: 1.68 ms, diff: +300.73%
k: 5, vocab_size: 50000, batch_size: 128, frac_seeded: 0.00, t_elap: 0.45 ms
k: 5, vocab_size: 50000, batch_size: 128, frac_seeded: 0.10, t_elap: 1.12 ms, diff: +151.11%
k: 5, vocab_size: 50000, batch_size: 128, frac_seeded: 0.20, t_elap: 1.64 ms, diff: +265.30%
k: 5, vocab_size: 50000, batch_size: 128, frac_seeded: 0.50, t_elap: 3.03 ms, diff: +575.36%
k: 5, vocab_size: 50000, batch_size: 128, frac_seeded: 1.00, t_elap: 5.32 ms, diff: +1087.33% In particular, when the batch size is large, there can be up to a few ms overhead added. I guess this may not be so significant in E2E measurements (I will do some now). I did some investigation into where the overhead is coming from and it is really due to the random number generation part rather than any auxiliary stuff. It seems like if we want to make it reproducible there will be some ms-level performance penalty for large batch sizes. |
BTW I added both E2E tests, as well as unit tests for rejection sampler |
It is probably also worth mentioning though that we would probably see similar when increasing percentage of requests that are random-seeded in a large batch even in the non-spec decoding case, since there's also a loop for that. |
@njhill yes, that's a good point |
Fixes #6038