Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: speculative decoding doesn't work with online mode #6967

Closed
stas00 opened this issue Jul 31, 2024 · 17 comments · Fixed by #7232
Closed

[Bug]: speculative decoding doesn't work with online mode #6967

stas00 opened this issue Jul 31, 2024 · 17 comments · Fixed by #7232
Labels
bug Something isn't working

Comments

@stas00
Copy link
Contributor

stas00 commented Jul 31, 2024

Your current environment

Collecting environment information...
PyTorch version: 2.3.1+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: 10.0.0-4ubuntu1 
CMake version: version 3.30.1
Libc version: glibc-2.31

Python version: 3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:45:18) [GCC 12.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-1014-gcp-tcpx-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 12.1.105
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA H100 80GB HBM3
GPU 1: NVIDIA H100 80GB HBM3
GPU 2: NVIDIA H100 80GB HBM3
GPU 3: NVIDIA H100 80GB HBM3
GPU 4: NVIDIA H100 80GB HBM3
GPU 5: NVIDIA H100 80GB HBM3
GPU 6: NVIDIA H100 80GB HBM3
GPU 7: NVIDIA H100 80GB HBM3

Nvidia driver version: 535.104.05
cuDNN version: Probably one of the following:
/usr/local/cuda-12.0/targets/x86_64-linux/lib/libcudnn.so.8.9.4
/usr/local/cuda-12.0/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.9.4
/usr/local/cuda-12.0/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.9.4
/usr/local/cuda-12.0/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.9.4
/usr/local/cuda-12.0/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.9.4
/usr/local/cuda-12.0/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.9.4
/usr/local/cuda-12.0/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.9.4
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Byte Order:                         Little Endian
Address sizes:                      52 bits physical, 57 bits virtual
CPU(s):                             208
On-line CPU(s) list:                0-207
Thread(s) per core:                 2
Core(s) per socket:                 52
Socket(s):                          2
NUMA node(s):                       2
Vendor ID:                          GenuineIntel
CPU family:                         6
Model:                              143
Model name:                         Intel(R) Xeon(R) Platinum 8481C CPU @ 2.70GHz
Stepping:                           8
CPU MHz:                            2699.998
BogoMIPS:                           5399.99
Hypervisor vendor:                  KVM
Virtualization type:                full
L1d cache:                          4.9 MiB
L1i cache:                          3.3 MiB
L2 cache:                           208 MiB
L3 cache:                           210 MiB
NUMA node0 CPU(s):                  0-51,104-155
NUMA node1 CPU(s):                  52-103,156-207
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Enhanced IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI Syscall hardening, KVM SW loop
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rtm avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx_vnni avx512_bf16 arat avx512vbmi umip avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid cldemote movdiri movdir64b fsrm md_clear serialize amx_bf16 avx512_fp16 amx_tile amx_int8 arch_capabilities

Versions of relevant libraries:
[pip3] flake8==7.1.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.4
[pip3] nvidia-nccl-cu12==2.20.5
[pip3] onnxruntime==1.18.1
[pip3] sentence-transformers==3.0.1
[pip3] torch==2.3.1
[pip3] torchvision==0.18.1
[pip3] transformers==4.43.3
[pip3] triton==2.3.1
[conda] numpy                     1.26.4                   pypi_0    pypi
[conda] nvidia-nccl-cu12          2.20.5                   pypi_0    pypi
[conda] sentence-transformers     3.0.1                    pypi_0    pypi
[conda] torch                     2.3.1                    pypi_0    pypi
[conda] torchvision               0.18.1                   pypi_0    pypi
[conda] transformers              4.43.3                   pypi_0    pypi
[conda] triton                    2.3.1                    pypi_0    pypi
ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.5.3.post1
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0    GPU1    GPU2    GPU3    GPU4    GPU5    GPU6    GPU7    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X      NV18    NV18    NV18    NV18    NV18    NV18    NV18    0-51,104-155    0               N/A
GPU1    NV18     X      NV18    NV18    NV18    NV18    NV18    NV18    0-51,104-155    0               N/A
GPU2    NV18    NV18     X      NV18    NV18    NV18    NV18    NV18    0-51,104-155    0               N/A
GPU3    NV18    NV18    NV18     X      NV18    NV18    NV18    NV18    0-51,104-155    0               N/A
GPU4    NV18    NV18    NV18    NV18     X      NV18    NV18    NV18    52-103,156-207  1               N/A
GPU5    NV18    NV18    NV18    NV18    NV18     X      NV18    NV18    52-103,156-207  1               N/A
GPU6    NV18    NV18    NV18    NV18    NV18    NV18     X      NV18    52-103,156-207  1               N/A
GPU7    NV18    NV18    NV18    NV18    NV18    NV18    NV18     X      52-103,156-207  1               N/A

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

🐛 Describe the bug

I'm able to successfully run your offline speculative example
https://docs.vllm.ai/en/stable/models/spec_decode.html#speculating-with-a-draft-model

I'm trying to make the same work with the online approach and it keeps on crashing.

I mimic the server launch with:

python -m vllm.entrypoints.openai.api_server     --host 0.0.0.0 --port 8000     --model facebook/opt-6.7b     --seed 42     -tp 1 --speculative_model facebook/opt-125m --use-v2-block-manager --num_speculative_tokens 5 --gpu_memory_utilization 0.8

If I use the client example from https://docs.vllm.ai/en/stable/getting_started/examples/openai_completion_client.html the server fails to respond:

INFO:     127.0.0.1:51790 - "GET /v1/models HTTP/1.1" 200 OK
INFO 07-31 02:29:51 logger.py:36] Received request cmpl-cb63769a92d148299b1b20e8d7f27086-0: prompt: 'A robot may not injure a human being', params: SamplingParams(n=2, best_of=2, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=1.0, top_p=1.0, top_k=-1, min_p=0.0, seed=None, use_beam_search=False, length_penalty=1.0, early_stopping=False, stop=[], stop_token_ids=[], include_stop_str_in_output=False, ignore_eos=False, max_tokens=16, min_tokens=0, logprobs=3, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None), prompt_token_ids: [2, 250, 9916, 189, 45, 36841, 10, 1050, 145], lora_request: None, prompt_adapter_request: None.
INFO 07-31 02:29:51 async_llm_engine.py:173] Added request cmpl-cb63769a92d148299b1b20e8d7f27086-0.
ERROR 07-31 02:29:51 async_llm_engine.py:56] Engine background task failed
ERROR 07-31 02:29:51 async_llm_engine.py:56] Traceback (most recent call last):
ERROR 07-31 02:29:51 async_llm_engine.py:56]   File "/env/lib/conda/stas-inference/lib/python3.10/site-packages/vllm/engine/async_llm_engine.py", line 46, in _log_task_completion
ERROR 07-31 02:29:51 async_llm_engine.py:56]     return_value = task.result()
ERROR 07-31 02:29:51 async_llm_engine.py:56]   File "/env/lib/conda/stas-inference/lib/python3.10/site-packages/vllm/engine/async_llm_engine.py", line 637, in run_engine_loop
ERROR 07-31 02:29:51 async_llm_engine.py:56]     result = task.result()
ERROR 07-31 02:29:51 async_llm_engine.py:56]   File "/env/lib/conda/stas-inference/lib/python3.10/site-packages/vllm/engine/async_llm_engine.py", line 580, in engine_step
ERROR 07-31 02:29:51 async_llm_engine.py:56]     request_outputs = await self.engine.step_async(virtual_engine)
ERROR 07-31 02:29:51 async_llm_engine.py:56]   File "/env/lib/conda/stas-inference/lib/python3.10/site-packages/vllm/engine/async_llm_engine.py", line 253, in step_async
ERROR 07-31 02:29:51 async_llm_engine.py:56]     output = await self.model_executor.execute_model_async(
ERROR 07-31 02:29:51 async_llm_engine.py:56]   File "/env/lib/conda/stas-inference/lib/python3.10/site-packages/vllm/executor/gpu_executor.py", line 159, in execute_model_async
ERROR 07-31 02:29:51 async_llm_engine.py:56]     output = await make_async(self.driver_worker.execute_model
ERROR 07-31 02:29:51 async_llm_engine.py:56]   File "/env/lib/conda/stas-inference/lib/python3.10/concurrent/futures/thread.py", line 58, in run
ERROR 07-31 02:29:51 async_llm_engine.py:56]     result = self.fn(*self.args, **self.kwargs)
ERROR 07-31 02:29:51 async_llm_engine.py:56]   File "/env/lib/conda/stas-inference/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
ERROR 07-31 02:29:51 async_llm_engine.py:56]     return func(*args, **kwargs)
ERROR 07-31 02:29:51 async_llm_engine.py:56]   File "/env/lib/conda/stas-inference/lib/python3.10/site-packages/vllm/spec_decode/spec_decode_worker.py", line 373, in execute_model
ERROR 07-31 02:29:51 async_llm_engine.py:56]     return self._run_no_spec(execute_model_req,
ERROR 07-31 02:29:51 async_llm_engine.py:56]   File "/env/lib/conda/stas-inference/lib/python3.10/contextlib.py", line 79, in inner
ERROR 07-31 02:29:51 async_llm_engine.py:56]     return func(*args, **kwds)
ERROR 07-31 02:29:51 async_llm_engine.py:56]   File "/env/lib/conda/stas-inference/lib/python3.10/site-packages/vllm/spec_decode/spec_decode_worker.py", line 454, in _run_no_spec
ERROR 07-31 02:29:51 async_llm_engine.py:56]     self.proposer_worker.execute_model(execute_model_req)
ERROR 07-31 02:29:51 async_llm_engine.py:56]   File "/env/lib/conda/stas-inference/lib/python3.10/site-packages/vllm/worker/worker_base.py", line 272, in execute_model
ERROR 07-31 02:29:51 async_llm_engine.py:56]     output = self.model_runner.execute_model(
ERROR 07-31 02:29:51 async_llm_engine.py:56]   File "/env/lib/conda/stas-inference/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
ERROR 07-31 02:29:51 async_llm_engine.py:56]     return func(*args, **kwargs)
ERROR 07-31 02:29:51 async_llm_engine.py:56]   File "/env/lib/conda/stas-inference/lib/python3.10/site-packages/vllm/spec_decode/draft_model_runner.py", line 335, in execute_model
ERROR 07-31 02:29:51 async_llm_engine.py:56]     self.model.sample(
ERROR 07-31 02:29:51 async_llm_engine.py:56]   File "/env/lib/conda/stas-inference/lib/python3.10/site-packages/vllm/model_executor/models/opt.py", line 324, in sample
ERROR 07-31 02:29:51 async_llm_engine.py:56]     next_tokens = self.sampler(logits, sampling_metadata)
ERROR 07-31 02:29:51 async_llm_engine.py:56]   File "/env/lib/conda/stas-inference/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
ERROR 07-31 02:29:51 async_llm_engine.py:56]     return self._call_impl(*args, **kwargs)
ERROR 07-31 02:29:51 async_llm_engine.py:56]   File "/env/lib/conda/stas-inference/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
ERROR 07-31 02:29:51 async_llm_engine.py:56]     return forward_call(*args, **kwargs)
ERROR 07-31 02:29:51 async_llm_engine.py:56]   File "/env/lib/conda/stas-inference/lib/python3.10/site-packages/vllm/model_executor/layers/sampler.py", line 133, in forward
ERROR 07-31 02:29:51 async_llm_engine.py:56]     sample_results, maybe_sampled_tokens_tensor = _sample(
ERROR 07-31 02:29:51 async_llm_engine.py:56]   File "/env/lib/conda/stas-inference/lib/python3.10/site-packages/vllm/model_executor/layers/sampler.py", line 706, in _sample
ERROR 07-31 02:29:51 async_llm_engine.py:56]     return _sample_with_torch(
ERROR 07-31 02:29:51 async_llm_engine.py:56]   File "/env/lib/conda/stas-inference/lib/python3.10/site-packages/vllm/model_executor/layers/sampler.py", line 571, in _sample_with_torch
ERROR 07-31 02:29:51 async_llm_engine.py:56]     sampled_token_ids_tensor[
ERROR 07-31 02:29:51 async_llm_engine.py:56] RuntimeError: shape mismatch: value tensor of shape [2] cannot be broadcast to indexing result of shape [1, 1]
Exception in callback functools.partial(<function _log_task_completion at 0x7f6e1d8e0d30>, error_callback=<bound method AsyncLLMEngine._error_callback of <vllm.engine.async_llm_engine.AsyncLLMEngine object at 0x7f6e04f6b430>>)
handle: <Handle functools.partial(<function _log_task_completion at 0x7f6e1d8e0d30>, error_callback=<bound method AsyncLLMEngine._error_callback of <vllm.engine.async_llm_engine.AsyncLLMEngine object at 0x7f6e04f6b430>>)>

Could you please share an example of an online speculative decoding that works?

Thank you!

@stas00 stas00 added the bug Something isn't working label Jul 31, 2024
@stas00
Copy link
Contributor Author

stas00 commented Aug 6, 2024

Hi @cadedaniel, would you by chance know how to overcome this or who should I tag about it? Thank you very much!

@cadedaniel
Copy link
Collaborator

Hi @stas00. Can you share what attention backend your online serving setup is using? The way you use this is expected to work.

@stas00
Copy link
Contributor Author

stas00 commented Aug 6, 2024

Thank you for the follow up, @cadedaniel

I launch the server with just the defaults:

python -m vllm.entrypoints.openai.api_server     --host 0.0.0.0 --port 8000 \
--model facebook/opt-6.7b     --seed 42     -tp 1 --speculative_model facebook/opt-125m \
--use-v2-block-manager --num_speculative_tokens 5 --gpu_memory_utilization 0.8

do I need to specify a specific attention? the offline script didn't specify any https://docs.vllm.ai/en/stable/models/spec_decode.html#speculating-with-a-draft-model

@stas00
Copy link
Contributor Author

stas00 commented Aug 6, 2024

The client was just:

$ cat client-speculate.py
from openai import OpenAI

# Modify OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"

client = OpenAI(
    api_key=openai_api_key,
    base_url=openai_api_base,
)

models = client.models.list()
model = models.data[0].id

# Completion API
stream = False
completion = client.completions.create(
    model=model,
    prompt="A robot may not injure a human being",
    echo=False,
    n=1,
    stream=stream,
    logprobs=3)

print("Completion results:")
if stream:
    for c in completion:
        print(c)
else:
    print(completion)

@cadedaniel
Copy link
Collaborator

Yes, although I suspect the issue is in which attention backend is used. Can you share the full log when you run the server? It should say the attention backend that is automatically selected

@stas00
Copy link
Contributor Author

stas00 commented Aug 6, 2024

The log mentions no attention. Attached the log - I have just updated to 0.5.4 (same problem)

vllm-log.txt

@cadedaniel
Copy link
Collaborator

Interesting, that log has a different failure

ERROR 08-06 20:02:22 async_llm_engine.py:57] Engine background task failed
ERROR 08-06 20:02:22 async_llm_engine.py:57] Traceback (most recent call last):
ERROR 08-06 20:02:22 async_llm_engine.py:57]   File "/env/lib/conda/stas-inference/lib/python3.10/site-packages/vllm/engine/async_llm_engine.py", line 47, in _log_task_completion
ERROR 08-06 20:02:22 async_llm_engine.py:57]     return_value = task.result()
ERROR 08-06 20:02:22 async_llm_engine.py:57]   File "/env/lib/conda/stas-inference/lib/python3.10/site-packages/vllm/engine/async_llm_engine.py", line 642, in run_engine_loop
ERROR 08-06 20:02:22 async_llm_engine.py:57]     result = task.result()
ERROR 08-06 20:02:22 async_llm_engine.py:57]   File "/env/lib/conda/stas-inference/lib/python3.10/site-packages/vllm/engine/async_llm_engine.py", line 585, in engine_step
ERROR 08-06 20:02:22 async_llm_engine.py:57]     request_outputs = await self.engine.step_async(virtual_engine)
ERROR 08-06 20:02:22 async_llm_engine.py:57]   File "/env/lib/conda/stas-inference/lib/python3.10/site-packages/vllm/engine/async_llm_engine.py", line 259, in step_async
ERROR 08-06 20:02:22 async_llm_engine.py:57]     request_outputs = self._process_model_outputs(
ERROR 08-06 20:02:22 async_llm_engine.py:57]   File "/env/lib/conda/stas-inference/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 831, in _process_model_outputs
ERROR 08-06 20:02:22 async_llm_engine.py:57]     self.output_processor.process_outputs(seq_group, outputs)
ERROR 08-06 20:02:22 async_llm_engine.py:57]   File "/env/lib/conda/stas-inference/lib/python3.10/site-packages/vllm/engine/output_processor/multi_step.py", line 90, in process_outputs
ERROR 08-06 20:02:22 async_llm_engine.py:57]     self._process_seq_outputs(seq, valid_samples,
ERROR 08-06 20:02:22 async_llm_engine.py:57]   File "/env/lib/conda/stas-inference/lib/python3.10/site-packages/vllm/engine/output_processor/multi_step.py", line 131, in _process_seq_outputs
ERROR 08-06 20:02:22 async_llm_engine.py:57]     new_char_count = self.detokenizer.decode_sequence_inplace(
ERROR 08-06 20:02:22 async_llm_engine.py:57]   File "/env/lib/conda/stas-inference/lib/python3.10/site-packages/vllm/transformers_utils/detokenizer.py", line 150, in decode_sequence_inplace
ERROR 08-06 20:02:22 async_llm_engine.py:57]     (_, new_text, _, _) = detokenize_incrementally(
ERROR 08-06 20:02:22 async_llm_engine.py:57]   File "/env/lib/conda/stas-inference/lib/python3.10/site-packages/vllm/transformers_utils/detokenizer.py", line 287, in detokenize_incrementally
ERROR 08-06 20:02:22 async_llm_engine.py:57]     if new_token_id >= len(tokenizer):
ERROR 08-06 20:02:22 async_llm_engine.py:57] TypeError: '>=' not supported between instances of 'NoneType' and 'int'

@stas00
Copy link
Contributor Author

stas00 commented Aug 6, 2024

Yes, I noticed that too - as I said probably something changed in 0.5.4 - I can re-run with 0.5.3.post1 if it helps, but I don't think it mentions the attn backend either. Should I activate some debug flag to get it printed?

@njhill
Copy link
Member

njhill commented Aug 6, 2024

@stas00 I noticed that you are passing logprobs=3, could you try with no logprobs?

@tjohnson31415 discovered a regression related to this and is working on a fix right now.

@stas00
Copy link
Contributor Author

stas00 commented Aug 6, 2024

@njhill, your suggestion did the trick! Thank you!

it works if I remove logprobs=3 - I have just copied the example you provided.

Should it not be set for speculative decoding? Could you perhaps add a check if it must not be passed - but I suppose the client doesn't know of the server configuration so it'd be tricky to know when not to allow it. Perhaps it could query its capabilities/setup? some sort of debug mode?

@njhill
Copy link
Member

njhill commented Aug 6, 2024

Should it not be set for speculative decoding? Could you perhaps add a check if it must not be passed - but I suppose the client doesn't know of the server configuration so it'd be tricky to know when not to allow it. Perhaps it could query its capabilities/setup? some sort of debug mode?

It's a regression (@tjohnson31415 thinks from #6485). I think logprobs should work with spec decoding but just don't in the current version.

I think prompt logprobs might intentionally not be supported currently with spec decoding, and we should return an appropriate error message if that's the case (not sure if that's currently done or not). It would be nice to support these too however, maybe that can also be fixed.

@cadedaniel
Copy link
Collaborator

Oh good catch @njhill ! likely that is the top-level issue.

@tjohnson31415
Copy link
Contributor

tjohnson31415 commented Aug 6, 2024

Hi all.
I am still doing some investigation and looking to add/update some unit tests, but I just pushed up a draft PR with the changes that I made to resolve the TypeError: '>=' not supported between instances of 'NoneType' and 'int':
https://github.com/vllm-project/vllm/pull/7232/files

Even with my fix, I can't get logprobs to be returned for generated tokens, so there may be an additional change needed for that.

@cadedaniel
Copy link
Collaborator

From a latency perspective, calculating logprobs costs a lot. We may just create an issue with the regression and guide users to run the model without speculative decoding for logprob calculation until it is fixed.

@stas00
Copy link
Contributor Author

stas00 commented Aug 7, 2024

FWIW, in my particular case I didn't care for logprobs - it was just there in the AI client example of yours. So you already unblocked me by telling me to remove it.

Should this doc be updated to show the online version of it as well?
https://docs.vllm.ai/en/stable/models/spec_decode.html#speculating-with-a-draft-model

@cadedaniel
Copy link
Collaborator

Yeah, if you are able to update the doc to have an online version @stas00 that would be amazing.

@stas00
Copy link
Contributor Author

stas00 commented Aug 7, 2024

Added here: #7243

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
4 participants