Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Make spec. decode respect per-request seed. #6034

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

tdoublep
Copy link
Contributor

@tdoublep tdoublep commented Jul 1, 2024

Fixes #6038

Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice :)

uniform_rand[i,:] = torch.rand(1, k, dtype=self.probs_dtype, device=target_probs.device, generator=generator)
else:

uniform_rand = torch.rand(batch_size,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

preexisting code but this could be simplified as torch.rand_like(selected_target_probs)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines 173 to 176
if any(generator is not None for generator in generators):
uniform_rand = torch.empty(batch_size, k, dtype=self.probs_dtype, device=target_probs.device)
for i, generator in enumerate(generators):
uniform_rand[i,:] = torch.rand(1, k, dtype=self.probs_dtype, device=target_probs.device, generator=generator)
Copy link
Member

@njhill njhill Jul 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might be preferable here in the mixed case to do a single rand for the non-seed requests, since it could be likely that most requests don't have a seed. This can be done via something like

uniform_rand[non_seed_indices] = torch.rand(len(non_seed_indices), k, ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. Sort-of-related question: do you know whether we can safely assume that the length of each SequenceGroup is exactly 1 here? In the general sampler code there is some logic to handle the more general case. Maybe we need that here too?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will double check, I think it's only >1 for beam search which I think is explicitly unsupported with spec decode right now anyhow. Though it might also be the case for parallel decoding (n>1 i.e. multiple output seqs per single input seq).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually just tested that case: combining n>1 with spec. decode and it seems to fail on main:

    |   File "/home/zrltpa/vllm/vllm/model_executor/layers/sampler.py", line 96, in forward
    |     sample_results, maybe_sampled_tokens_tensor = _sample(
    |                                                   ^^^^^^^^
    |   File "/home/zrltpa/vllm/vllm/model_executor/layers/sampler.py", line 658, in _sample
    |     return _sample_with_torch(
    |            ^^^^^^^^^^^^^^^^^^^
    |   File "/home/zrltpa/vllm/vllm/model_executor/layers/sampler.py", line 528, in _sample_with_torch
    |     sampled_token_ids_tensor[
    | RuntimeError: shape mismatch: value tensor of shape [2] cannot be broadcast to indexing result of shape [1, 1]

If that case isn't supported by spec decode currently we should at least give a reasonable message back.

Copy link
Contributor Author

@tdoublep tdoublep Jul 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

made the suggested change regading uniform_rand - will create separate PR to address error msg for n>1

Comment on lines 277 to 279
if any(generator is not None for generator in generators):
for i, generator in enumerate(generators):
q[i].exponential_(1.0, generator=generator)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar comment to above about doing single exponential_ for all non-seeded rows

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

device=target_probs.device)

if any(generator is not None for generator in generators):
uniform_rand = torch.empty(batch_size, k, dtype=self.probs_dtype, device=target_probs.device)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could simplify as torch.empty_like(selected_target_probs)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines 297 to 305
generator = torch.Generator(
device=seq_group_metadata.state.generator.device
)
generator.set_state(
seq_group_metadata.state.generator.get_state()
)
state = SequenceGroupState(
generator=generator,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could the same Generator just be used here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My first attempt just shallow copied the generator but it led to incorrect results, so I think we need a "deep" copy here.

@tdoublep tdoublep changed the title [WIP] Make spec. decode respect per-request seed. [Bugfix] Make spec. decode respect per-request seed. Jul 1, 2024
@tdoublep tdoublep marked this pull request as ready for review July 1, 2024 17:12
@cadedaniel
Copy link
Collaborator

Can we add some e2e tests?

@cadedaniel
Copy link
Collaborator

Ideally unit tests in rejection sampler too!

q = torch.empty_like(probs).exponential_(1.0)

q = torch.empty_like(probs)
if any(generator is not None for generator in generators):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we see the latency impact when there are multiple generators? Want to make sure rejection sampling is still reasonable

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Attached some analysis in comments below

@tdoublep
Copy link
Contributor Author

tdoublep commented Jul 3, 2024

@cadedaniel I've been digging into the effect on latency at the level of the rejection sampler.

Firstly, I wanted to check that if we never pass any generators, that the performance is similar to main. I call the forward method of rejection sampler 1000 times under different conditions using main branch and spec-decode-seed branch and compare the diference (final column). It looks ok imo.

k: 1, vocab_size: 30000, batch_size:   1, frac_seeded: 0.00, t_elap (main): 0.34 ms, t_elap (new): 0.34 ms, diff: 0.00%
k: 1, vocab_size: 30000, batch_size:   4, frac_seeded: 0.00, t_elap (main): 0.35 ms, t_elap (new): 0.36 ms, diff: 2.86%
k: 1, vocab_size: 30000, batch_size:   8, frac_seeded: 0.00, t_elap (main): 0.35 ms, t_elap (new): 0.36 ms, diff: 2.86%
k: 1, vocab_size: 30000, batch_size:  32, frac_seeded: 0.00, t_elap (main): 0.36 ms, t_elap (new): 0.36 ms, diff: 0.00%
k: 1, vocab_size: 30000, batch_size: 128, frac_seeded: 0.00, t_elap (main): 0.36 ms, t_elap (new): 0.38 ms, diff: 5.56%
k: 1, vocab_size: 50000, batch_size:   1, frac_seeded: 0.00, t_elap (main): 0.34 ms, t_elap (new): 0.34 ms, diff: 0.00%
k: 1, vocab_size: 50000, batch_size:   4, frac_seeded: 0.00, t_elap (main): 0.36 ms, t_elap (new): 0.36 ms, diff: 0.00%
k: 1, vocab_size: 50000, batch_size:   8, frac_seeded: 0.00, t_elap (main): 0.35 ms, t_elap (new): 0.36 ms, diff: 2.86%
k: 1, vocab_size: 50000, batch_size:  32, frac_seeded: 0.00, t_elap (main): 0.35 ms, t_elap (new): 0.37 ms, diff: 5.71%
k: 1, vocab_size: 50000, batch_size: 128, frac_seeded: 0.00, t_elap (main): 0.36 ms, t_elap (new): 0.38 ms, diff: 5.56%
k: 3, vocab_size: 30000, batch_size:   1, frac_seeded: 0.00, t_elap (main): 0.35 ms, t_elap (new): 0.35 ms, diff: 0.00%
k: 3, vocab_size: 30000, batch_size:   4, frac_seeded: 0.00, t_elap (main): 0.37 ms, t_elap (new): 0.37 ms, diff: 0.00%
k: 3, vocab_size: 30000, batch_size:   8, frac_seeded: 0.00, t_elap (main): 0.37 ms, t_elap (new): 0.37 ms, diff: 0.00%
k: 3, vocab_size: 30000, batch_size:  32, frac_seeded: 0.00, t_elap (main): 0.37 ms, t_elap (new): 0.38 ms, diff: 2.70%
k: 3, vocab_size: 30000, batch_size: 128, frac_seeded: 0.00, t_elap (main): 0.37 ms, t_elap (new): 0.40 ms, diff: 8.11%
k: 3, vocab_size: 50000, batch_size:   1, frac_seeded: 0.00, t_elap (main): 0.35 ms, t_elap (new): 0.35 ms, diff: 0.00%
k: 3, vocab_size: 50000, batch_size:   4, frac_seeded: 0.00, t_elap (main): 0.36 ms, t_elap (new): 0.37 ms, diff: 2.78%
k: 3, vocab_size: 50000, batch_size:   8, frac_seeded: 0.00, t_elap (main): 0.36 ms, t_elap (new): 0.37 ms, diff: 2.78%
k: 3, vocab_size: 50000, batch_size:  32, frac_seeded: 0.00, t_elap (main): 0.37 ms, t_elap (new): 0.38 ms, diff: 2.70%
k: 3, vocab_size: 50000, batch_size: 128, frac_seeded: 0.00, t_elap (main): 0.37 ms, t_elap (new): 0.40 ms, diff: 8.11%
k: 5, vocab_size: 30000, batch_size:   1, frac_seeded: 0.00, t_elap (main): 0.35 ms, t_elap (new): 0.35 ms, diff: 0.00%
k: 5, vocab_size: 30000, batch_size:   4, frac_seeded: 0.00, t_elap (main): 0.36 ms, t_elap (new): 0.37 ms, diff: 2.78%
k: 5, vocab_size: 30000, batch_size:   8, frac_seeded: 0.00, t_elap (main): 0.37 ms, t_elap (new): 0.37 ms, diff: 0.00%
k: 5, vocab_size: 30000, batch_size:  32, frac_seeded: 0.00, t_elap (main): 0.37 ms, t_elap (new): 0.38 ms, diff: 2.70%
k: 5, vocab_size: 30000, batch_size: 128, frac_seeded: 0.00, t_elap (main): 0.37 ms, t_elap (new): 0.40 ms, diff: 8.11%
k: 5, vocab_size: 50000, batch_size:   1, frac_seeded: 0.00, t_elap (main): 0.35 ms, t_elap (new): 0.35 ms, diff: 0.00%
k: 5, vocab_size: 50000, batch_size:   4, frac_seeded: 0.00, t_elap (main): 0.37 ms, t_elap (new): 0.37 ms, diff: 0.00%
k: 5, vocab_size: 50000, batch_size:   8, frac_seeded: 0.00, t_elap (main): 0.37 ms, t_elap (new): 0.37 ms, diff: 0.00%
k: 5, vocab_size: 50000, batch_size:  32, frac_seeded: 0.00, t_elap (main): 0.37 ms, t_elap (new): 0.38 ms, diff: 2.70%
k: 5, vocab_size: 50000, batch_size: 128, frac_seeded: 0.00, t_elap (main): 0.38 ms, t_elap (new): 0.40 ms, diff: 5.26%

Next, I wanted to investigate how increasing the fraction of requests in the batch that are seeded affects the performance (frac_seeded = fraction of generators that are not None). These results are a bit more disappointing:

k: 1, vocab_size: 30000, batch_size:   1, frac_seeded: 0.00, t_elap:   0.34 ms
k: 1, vocab_size: 30000, batch_size:   1, frac_seeded: 0.10, t_elap:   0.34 ms, diff:    -0.22%
k: 1, vocab_size: 30000, batch_size:   1, frac_seeded: 0.20, t_elap:   0.34 ms, diff:    -0.89%
k: 1, vocab_size: 30000, batch_size:   1, frac_seeded: 0.50, t_elap:   0.38 ms, diff:   +10.64%
k: 1, vocab_size: 30000, batch_size:   1, frac_seeded: 1.00, t_elap:   0.38 ms, diff:   +11.24%
k: 1, vocab_size: 30000, batch_size:   4, frac_seeded: 0.00, t_elap:   0.36 ms
k: 1, vocab_size: 30000, batch_size:   4, frac_seeded: 0.10, t_elap:   0.36 ms, diff:    +0.32%
k: 1, vocab_size: 30000, batch_size:   4, frac_seeded: 0.20, t_elap:   0.36 ms, diff:    +0.23%
k: 1, vocab_size: 30000, batch_size:   4, frac_seeded: 0.50, t_elap:   0.43 ms, diff:   +20.67%
k: 1, vocab_size: 30000, batch_size:   4, frac_seeded: 1.00, t_elap:   0.47 ms, diff:   +31.69%
k: 1, vocab_size: 30000, batch_size:   8, frac_seeded: 0.00, t_elap:   0.36 ms
k: 1, vocab_size: 30000, batch_size:   8, frac_seeded: 0.10, t_elap:   0.43 ms, diff:   +19.74%
k: 1, vocab_size: 30000, batch_size:   8, frac_seeded: 0.20, t_elap:   0.47 ms, diff:   +31.40%
k: 1, vocab_size: 30000, batch_size:   8, frac_seeded: 0.50, t_elap:   0.53 ms, diff:   +46.26%
k: 1, vocab_size: 30000, batch_size:   8, frac_seeded: 1.00, t_elap:   0.54 ms, diff:   +50.61%
k: 1, vocab_size: 30000, batch_size:  32, frac_seeded: 0.00, t_elap:   0.37 ms
k: 1, vocab_size: 30000, batch_size:  32, frac_seeded: 0.10, t_elap:   0.46 ms, diff:   +27.00%
k: 1, vocab_size: 30000, batch_size:  32, frac_seeded: 0.20, t_elap:   0.61 ms, diff:   +68.38%
k: 1, vocab_size: 30000, batch_size:  32, frac_seeded: 0.50, t_elap:   0.71 ms, diff:   +93.97%
k: 1, vocab_size: 30000, batch_size:  32, frac_seeded: 1.00, t_elap:   0.99 ms, diff:  +171.91%
k: 1, vocab_size: 30000, batch_size: 128, frac_seeded: 0.00, t_elap:   0.38 ms
k: 1, vocab_size: 30000, batch_size: 128, frac_seeded: 0.10, t_elap:   0.76 ms, diff:  +101.38%
k: 1, vocab_size: 30000, batch_size: 128, frac_seeded: 0.20, t_elap:   0.95 ms, diff:  +150.41%
k: 1, vocab_size: 30000, batch_size: 128, frac_seeded: 0.50, t_elap:   1.66 ms, diff:  +338.59%
k: 1, vocab_size: 30000, batch_size: 128, frac_seeded: 1.00, t_elap:   3.09 ms, diff:  +718.98%
k: 1, vocab_size: 50000, batch_size:   1, frac_seeded: 0.00, t_elap:   0.38 ms
k: 1, vocab_size: 50000, batch_size:   1, frac_seeded: 0.10, t_elap:   0.38 ms, diff:    -0.13%
k: 1, vocab_size: 50000, batch_size:   1, frac_seeded: 0.20, t_elap:   0.38 ms, diff:    -0.16%
k: 1, vocab_size: 50000, batch_size:   1, frac_seeded: 0.50, t_elap:   0.42 ms, diff:   +11.83%
k: 1, vocab_size: 50000, batch_size:   1, frac_seeded: 1.00, t_elap:   0.42 ms, diff:   +11.69%
k: 1, vocab_size: 50000, batch_size:   4, frac_seeded: 0.00, t_elap:   0.40 ms
k: 1, vocab_size: 50000, batch_size:   4, frac_seeded: 0.10, t_elap:   0.40 ms, diff:    -0.17%
k: 1, vocab_size: 50000, batch_size:   4, frac_seeded: 0.20, t_elap:   0.47 ms, diff:   +17.71%
k: 1, vocab_size: 50000, batch_size:   4, frac_seeded: 0.50, t_elap:   0.51 ms, diff:   +29.08%
k: 1, vocab_size: 50000, batch_size:   4, frac_seeded: 1.00, t_elap:   0.51 ms, diff:   +28.90%
k: 1, vocab_size: 50000, batch_size:   8, frac_seeded: 0.00, t_elap:   0.39 ms
k: 1, vocab_size: 50000, batch_size:   8, frac_seeded: 0.10, t_elap:   0.47 ms, diff:   +19.12%
k: 1, vocab_size: 50000, batch_size:   8, frac_seeded: 0.20, t_elap:   0.40 ms, diff:    +0.67%
k: 1, vocab_size: 50000, batch_size:   8, frac_seeded: 0.50, t_elap:   0.54 ms, diff:   +35.82%
k: 1, vocab_size: 50000, batch_size:   8, frac_seeded: 1.00, t_elap:   0.60 ms, diff:   +51.98%
k: 1, vocab_size: 50000, batch_size:  32, frac_seeded: 0.00, t_elap:   0.40 ms
k: 1, vocab_size: 50000, batch_size:  32, frac_seeded: 0.10, t_elap:   0.51 ms, diff:   +27.04%
k: 1, vocab_size: 50000, batch_size:  32, frac_seeded: 0.20, t_elap:   0.60 ms, diff:   +47.76%
k: 1, vocab_size: 50000, batch_size:  32, frac_seeded: 0.50, t_elap:   0.83 ms, diff:  +104.82%
k: 1, vocab_size: 50000, batch_size:  32, frac_seeded: 1.00, t_elap:   1.11 ms, diff:  +175.25%
k: 1, vocab_size: 50000, batch_size: 128, frac_seeded: 0.00, t_elap:   0.42 ms
k: 1, vocab_size: 50000, batch_size: 128, frac_seeded: 0.10, t_elap:   0.83 ms, diff:   +97.62%
k: 1, vocab_size: 50000, batch_size: 128, frac_seeded: 0.20, t_elap:   0.97 ms, diff:  +130.45%
k: 1, vocab_size: 50000, batch_size: 128, frac_seeded: 0.50, t_elap:   1.88 ms, diff:  +347.14%
k: 1, vocab_size: 50000, batch_size: 128, frac_seeded: 1.00, t_elap:   3.09 ms, diff:  +635.32%
k: 3, vocab_size: 30000, batch_size:   1, frac_seeded: 0.00, t_elap:   0.39 ms
k: 3, vocab_size: 30000, batch_size:   1, frac_seeded: 0.10, t_elap:   0.39 ms, diff:    +0.17%
k: 3, vocab_size: 30000, batch_size:   1, frac_seeded: 0.20, t_elap:   0.39 ms, diff:    -0.01%
k: 3, vocab_size: 30000, batch_size:   1, frac_seeded: 0.50, t_elap:   0.45 ms, diff:   +13.64%
k: 3, vocab_size: 30000, batch_size:   1, frac_seeded: 1.00, t_elap:   0.45 ms, diff:   +13.29%
k: 3, vocab_size: 30000, batch_size:   4, frac_seeded: 0.00, t_elap:   0.41 ms
k: 3, vocab_size: 30000, batch_size:   4, frac_seeded: 0.10, t_elap:   0.41 ms, diff:    +0.03%
k: 3, vocab_size: 30000, batch_size:   4, frac_seeded: 0.20, t_elap:   0.49 ms, diff:   +20.87%
k: 3, vocab_size: 30000, batch_size:   4, frac_seeded: 0.50, t_elap:   0.52 ms, diff:   +28.37%
k: 3, vocab_size: 30000, batch_size:   4, frac_seeded: 1.00, t_elap:   0.56 ms, diff:   +37.89%
k: 3, vocab_size: 30000, batch_size:   8, frac_seeded: 0.00, t_elap:   0.41 ms
k: 3, vocab_size: 30000, batch_size:   8, frac_seeded: 0.10, t_elap:   0.41 ms, diff:    +0.15%
k: 3, vocab_size: 30000, batch_size:   8, frac_seeded: 0.20, t_elap:   0.53 ms, diff:   +28.73%
k: 3, vocab_size: 30000, batch_size:   8, frac_seeded: 0.50, t_elap:   0.58 ms, diff:   +42.66%
k: 3, vocab_size: 30000, batch_size:   8, frac_seeded: 1.00, t_elap:   0.68 ms, diff:   +66.24%
k: 3, vocab_size: 30000, batch_size:  32, frac_seeded: 0.00, t_elap:   0.41 ms
k: 3, vocab_size: 30000, batch_size:  32, frac_seeded: 0.10, t_elap:   0.55 ms, diff:   +32.40%
k: 3, vocab_size: 30000, batch_size:  32, frac_seeded: 0.20, t_elap:   0.67 ms, diff:   +61.56%
k: 3, vocab_size: 30000, batch_size:  32, frac_seeded: 0.50, t_elap:   0.99 ms, diff:  +138.39%
k: 3, vocab_size: 30000, batch_size:  32, frac_seeded: 1.00, t_elap:   1.40 ms, diff:  +237.93%
k: 3, vocab_size: 30000, batch_size: 128, frac_seeded: 0.00, t_elap:   0.44 ms
k: 3, vocab_size: 30000, batch_size: 128, frac_seeded: 0.10, t_elap:   0.79 ms, diff:   +82.31%
k: 3, vocab_size: 30000, batch_size: 128, frac_seeded: 0.20, t_elap:   1.21 ms, diff:  +177.38%
k: 3, vocab_size: 30000, batch_size: 128, frac_seeded: 0.50, t_elap:   2.66 ms, diff:  +510.67%
k: 3, vocab_size: 30000, batch_size: 128, frac_seeded: 1.00, t_elap:   4.21 ms, diff:  +866.02%
k: 3, vocab_size: 50000, batch_size:   1, frac_seeded: 0.00, t_elap:   0.39 ms
k: 3, vocab_size: 50000, batch_size:   1, frac_seeded: 0.10, t_elap:   0.39 ms, diff:    +0.07%
k: 3, vocab_size: 50000, batch_size:   1, frac_seeded: 0.20, t_elap:   0.39 ms, diff:    -0.03%
k: 3, vocab_size: 50000, batch_size:   1, frac_seeded: 0.50, t_elap:   0.45 ms, diff:   +14.26%
k: 3, vocab_size: 50000, batch_size:   1, frac_seeded: 1.00, t_elap:   0.45 ms, diff:   +14.16%
k: 3, vocab_size: 50000, batch_size:   4, frac_seeded: 0.00, t_elap:   0.41 ms
k: 3, vocab_size: 50000, batch_size:   4, frac_seeded: 0.10, t_elap:   0.41 ms, diff:    -0.04%
k: 3, vocab_size: 50000, batch_size:   4, frac_seeded: 0.20, t_elap:   0.41 ms, diff:    -0.04%
k: 3, vocab_size: 50000, batch_size:   4, frac_seeded: 0.50, t_elap:   0.52 ms, diff:   +27.43%
k: 3, vocab_size: 50000, batch_size:   4, frac_seeded: 1.00, t_elap:   0.56 ms, diff:   +37.24%
k: 3, vocab_size: 50000, batch_size:   8, frac_seeded: 0.00, t_elap:   0.41 ms
k: 3, vocab_size: 50000, batch_size:   8, frac_seeded: 0.10, t_elap:   0.41 ms, diff:    -0.15%
k: 3, vocab_size: 50000, batch_size:   8, frac_seeded: 0.20, t_elap:   0.52 ms, diff:   +28.21%
k: 3, vocab_size: 50000, batch_size:   8, frac_seeded: 0.50, t_elap:   0.58 ms, diff:   +42.51%
k: 3, vocab_size: 50000, batch_size:   8, frac_seeded: 1.00, t_elap:   0.68 ms, diff:   +66.48%
k: 3, vocab_size: 50000, batch_size:  32, frac_seeded: 0.00, t_elap:   0.42 ms
k: 3, vocab_size: 50000, batch_size:  32, frac_seeded: 0.10, t_elap:   0.63 ms, diff:   +52.15%
k: 3, vocab_size: 50000, batch_size:  32, frac_seeded: 0.20, t_elap:   0.60 ms, diff:   +44.58%
k: 3, vocab_size: 50000, batch_size:  32, frac_seeded: 0.50, t_elap:   0.95 ms, diff:  +128.86%
k: 3, vocab_size: 50000, batch_size:  32, frac_seeded: 1.00, t_elap:   1.39 ms, diff:  +235.73%
k: 3, vocab_size: 50000, batch_size: 128, frac_seeded: 0.00, t_elap:   0.43 ms
k: 3, vocab_size: 50000, batch_size: 128, frac_seeded: 0.10, t_elap:   0.85 ms, diff:   +97.39%
k: 3, vocab_size: 50000, batch_size: 128, frac_seeded: 0.20, t_elap:   1.26 ms, diff:  +190.45%
k: 3, vocab_size: 50000, batch_size: 128, frac_seeded: 0.50, t_elap:   2.33 ms, diff:  +437.52%
k: 3, vocab_size: 50000, batch_size: 128, frac_seeded: 1.00, t_elap:   4.21 ms, diff:  +871.77%
k: 5, vocab_size: 30000, batch_size:   1, frac_seeded: 0.00, t_elap:   0.39 ms
k: 5, vocab_size: 30000, batch_size:   1, frac_seeded: 0.10, t_elap:   0.39 ms, diff:    +0.16%
k: 5, vocab_size: 30000, batch_size:   1, frac_seeded: 0.20, t_elap:   0.46 ms, diff:   +15.72%
k: 5, vocab_size: 30000, batch_size:   1, frac_seeded: 0.50, t_elap:   0.46 ms, diff:   +15.84%
k: 5, vocab_size: 30000, batch_size:   1, frac_seeded: 1.00, t_elap:   0.46 ms, diff:   +15.70%
k: 5, vocab_size: 30000, batch_size:   4, frac_seeded: 0.00, t_elap:   0.41 ms
k: 5, vocab_size: 30000, batch_size:   4, frac_seeded: 0.10, t_elap:   0.41 ms, diff:    +0.08%
k: 5, vocab_size: 30000, batch_size:   4, frac_seeded: 0.20, t_elap:   0.41 ms, diff:    +0.10%
k: 5, vocab_size: 30000, batch_size:   4, frac_seeded: 0.50, t_elap:   0.59 ms, diff:   +43.95%
k: 5, vocab_size: 30000, batch_size:   4, frac_seeded: 1.00, t_elap:   0.61 ms, diff:   +48.61%
k: 5, vocab_size: 30000, batch_size:   8, frac_seeded: 0.00, t_elap:   0.41 ms
k: 5, vocab_size: 30000, batch_size:   8, frac_seeded: 0.10, t_elap:   0.41 ms, diff:    +0.08%
k: 5, vocab_size: 30000, batch_size:   8, frac_seeded: 0.20, t_elap:   0.51 ms, diff:   +24.68%
k: 5, vocab_size: 30000, batch_size:   8, frac_seeded: 0.50, t_elap:   0.67 ms, diff:   +62.47%
k: 5, vocab_size: 30000, batch_size:   8, frac_seeded: 1.00, t_elap:   0.76 ms, diff:   +84.95%
k: 5, vocab_size: 30000, batch_size:  32, frac_seeded: 0.00, t_elap:   0.42 ms
k: 5, vocab_size: 30000, batch_size:  32, frac_seeded: 0.10, t_elap:   0.72 ms, diff:   +72.41%
k: 5, vocab_size: 30000, batch_size:  32, frac_seeded: 0.20, t_elap:   0.68 ms, diff:   +62.23%
k: 5, vocab_size: 30000, batch_size:  32, frac_seeded: 0.50, t_elap:   1.18 ms, diff:  +180.81%
k: 5, vocab_size: 30000, batch_size:  32, frac_seeded: 1.00, t_elap:   1.68 ms, diff:  +300.72%
k: 5, vocab_size: 30000, batch_size: 128, frac_seeded: 0.00, t_elap:   0.44 ms
k: 5, vocab_size: 30000, batch_size: 128, frac_seeded: 0.10, t_elap:   1.15 ms, diff:  +161.95%
k: 5, vocab_size: 30000, batch_size: 128, frac_seeded: 0.20, t_elap:   1.46 ms, diff:  +230.81%
k: 5, vocab_size: 30000, batch_size: 128, frac_seeded: 0.50, t_elap:   3.44 ms, diff:  +680.86%
k: 5, vocab_size: 30000, batch_size: 128, frac_seeded: 1.00, t_elap:   5.30 ms, diff: +1104.57%
k: 5, vocab_size: 50000, batch_size:   1, frac_seeded: 0.00, t_elap:   0.39 ms
k: 5, vocab_size: 50000, batch_size:   1, frac_seeded: 0.10, t_elap:   0.39 ms, diff:    -0.10%
k: 5, vocab_size: 50000, batch_size:   1, frac_seeded: 0.20, t_elap:   0.39 ms, diff:    +0.07%
k: 5, vocab_size: 50000, batch_size:   1, frac_seeded: 0.50, t_elap:   0.39 ms, diff:    -0.18%
k: 5, vocab_size: 50000, batch_size:   1, frac_seeded: 1.00, t_elap:   0.46 ms, diff:   +15.94%
k: 5, vocab_size: 50000, batch_size:   4, frac_seeded: 0.00, t_elap:   0.41 ms
k: 5, vocab_size: 50000, batch_size:   4, frac_seeded: 0.10, t_elap:   0.54 ms, diff:   +32.45%
k: 5, vocab_size: 50000, batch_size:   4, frac_seeded: 0.20, t_elap:   0.41 ms, diff:    -0.24%
k: 5, vocab_size: 50000, batch_size:   4, frac_seeded: 0.50, t_elap:   0.58 ms, diff:   +42.51%
k: 5, vocab_size: 50000, batch_size:   4, frac_seeded: 1.00, t_elap:   0.60 ms, diff:   +46.50%
k: 5, vocab_size: 50000, batch_size:   8, frac_seeded: 0.00, t_elap:   0.41 ms
k: 5, vocab_size: 50000, batch_size:   8, frac_seeded: 0.10, t_elap:   0.41 ms, diff:    +0.41%
k: 5, vocab_size: 50000, batch_size:   8, frac_seeded: 0.20, t_elap:   0.41 ms, diff:    +0.15%
k: 5, vocab_size: 50000, batch_size:   8, frac_seeded: 0.50, t_elap:   0.62 ms, diff:   +52.06%
k: 5, vocab_size: 50000, batch_size:   8, frac_seeded: 1.00, t_elap:   0.75 ms, diff:   +84.19%
k: 5, vocab_size: 50000, batch_size:  32, frac_seeded: 0.00, t_elap:   0.42 ms
k: 5, vocab_size: 50000, batch_size:  32, frac_seeded: 0.10, t_elap:   0.53 ms, diff:   +25.95%
k: 5, vocab_size: 50000, batch_size:  32, frac_seeded: 0.20, t_elap:   0.72 ms, diff:   +71.41%
k: 5, vocab_size: 50000, batch_size:  32, frac_seeded: 0.50, t_elap:   1.02 ms, diff:  +142.86%
k: 5, vocab_size: 50000, batch_size:  32, frac_seeded: 1.00, t_elap:   1.68 ms, diff:  +300.73%
k: 5, vocab_size: 50000, batch_size: 128, frac_seeded: 0.00, t_elap:   0.45 ms
k: 5, vocab_size: 50000, batch_size: 128, frac_seeded: 0.10, t_elap:   1.12 ms, diff:  +151.11%
k: 5, vocab_size: 50000, batch_size: 128, frac_seeded: 0.20, t_elap:   1.64 ms, diff:  +265.30%
k: 5, vocab_size: 50000, batch_size: 128, frac_seeded: 0.50, t_elap:   3.03 ms, diff:  +575.36%
k: 5, vocab_size: 50000, batch_size: 128, frac_seeded: 1.00, t_elap:   5.32 ms, diff: +1087.33%

In particular, when the batch size is large, there can be up to a few ms overhead added. I guess this may not be so significant in E2E measurements (I will do some now). I did some investigation into where the overhead is coming from and it is really due to the random number generation part rather than any auxiliary stuff. It seems like if we want to make it reproducible there will be some ms-level performance penalty for large batch sizes.

@tdoublep
Copy link
Contributor Author

tdoublep commented Jul 3, 2024

BTW I added both E2E tests, as well as unit tests for rejection sampler

@njhill
Copy link
Member

njhill commented Jul 3, 2024

It is probably also worth mentioning though that we would probably see similar when increasing percentage of requests that are random-seeded in a large batch even in the non-spec decoding case, since there's also a loop for that.

@tdoublep
Copy link
Contributor Author

tdoublep commented Jul 3, 2024

@njhill yes, that's a good point

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Bug]: Speculative decoding does not respect per-request seed
3 participants