Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: speculative decoding dies: IndexError: index 0 is out of bounds for dimension 0 with size 0 #7047

Closed
pseudotensor opened this issue Aug 1, 2024 · 6 comments
Labels
bug Something isn't working

Comments

@pseudotensor
Copy link

pseudotensor commented Aug 1, 2024

Your current environment

docker pull vllm/vllm-openai:latest
docker run -d --restart=always \
    --runtime=nvidia \
    --gpus '"device=1"' \
    --shm-size=10.24gb \
    -p 5001:5001 \
        -e NCCL_IGNORE_DISABLED_P2P=1 \
        -e VLLM_NCCL_SO_PATH=/usr/local/lib/python3.10/dist-packages/nvidia/nccl/lib/libnccl.so.2 \
    -e HUGGING_FACE_HUB_TOKEN=$HUGGING_FACE_HUB_TOKEN \
    -v /etc/passwd:/etc/passwd:ro \
    -v /etc/group:/etc/group:ro \
    -u `id -u`:`id -g` \
    -v "${HOME}"/.cache:$HOME/.cache/ -v "${HOME}"/.config:$HOME/.config/   -v "${HOME}"/.triton:$HOME/.triton/  \
    --network host \
    --name phi3mini \
    vllm/vllm-openai:latest \
        --port=5001 \
        --host=0.0.0.0 \
        --model=microsoft/Phi-3-mini-128k-instruct \
        --seed 1234 \
        --trust-remote-code \
        --tensor-parallel-size=1 \
        --max-num-batched-tokens=131072 --max-log-len=100 \
        --max-model-len=131072 \
        --max-num-seqs=17 \
        --use-v2-block-manager \
        --num-speculative-tokens=5 \
        --ngram-prompt-lookup-max=4 \
        --speculative-model="[ngram]" \
        --download-dir=$HOME/.cache/huggingface/hub &>> logs.vllm_server.phi3.txt

🐛 Describe the bug

ERROR 08-01 21:27:03 async_llm_engine.py:56]   File "/usr/local/lib/python3.10/dist-packages/vllm/spec_decode/spec_decode_worker.py", line 375, in execute_model
ERROR 08-01 21:27:03 async_llm_engine.py:56]     return self._run_speculative_decoding_step(execute_model_req,
ERROR 08-01 21:27:03 async_llm_engine.py:56]   File "/usr/lib/python3.10/contextlib.py", line 79, in inner
ERROR 08-01 21:27:03 async_llm_engine.py:56]     return func(*args, **kwds)
ERROR 08-01 21:27:03 async_llm_engine.py:56]   File "/usr/local/lib/python3.10/dist-packages/vllm/spec_decode/spec_decode_worker.py", line 538, in _run_speculative_decoding_step
ERROR 08-01 21:27:03 async_llm_engine.py:56]     accepted_token_ids, target_logprobs = self._verify_tokens(
ERROR 08-01 21:27:03 async_llm_engine.py:56]   File "/usr/lib/python3.10/contextlib.py", line 79, in inner
ERROR 08-01 21:27:03 async_llm_engine.py:56]     return func(*args, **kwds)
ERROR 08-01 21:27:03 async_llm_engine.py:56]   File "/usr/local/lib/python3.10/dist-packages/vllm/spec_decode/spec_decode_worker.py", line 609, in _verify_tokens
ERROR 08-01 21:27:03 async_llm_engine.py:56]     accepted_token_ids = self.spec_decode_sampler(
ERROR 08-01 21:27:03 async_llm_engine.py:56]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
ERROR 08-01 21:27:03 async_llm_engine.py:56]     return self._call_impl(*args, **kwargs)
ERROR 08-01 21:27:03 async_llm_engine.py:56]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
ERROR 08-01 21:27:03 async_llm_engine.py:56]     return forward_call(*args, **kwargs)
ERROR 08-01 21:27:03 async_llm_engine.py:56]   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/layers/rejection_sampler.py", line 82, in forward
ERROR 08-01 21:27:03 async_llm_engine.py:56]     self._batch_modified_rejection_sampling(
ERROR 08-01 21:27:03 async_llm_engine.py:56]   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/layers/rejection_sampler.py", line 119, in _batch_modified_rejection_sampling
ERROR 08-01 21:27:03 async_llm_engine.py:56]     accepted = self._get_accepted(target_probs, draft_probs,
ERROR 08-01 21:27:03 async_llm_engine.py:56]   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/layers/rejection_sampler.py", line 190, in _get_accepted
ERROR 08-01 21:27:03 async_llm_engine.py:56]     uniform_rand[idx, :] = torch.rand(1,
ERROR 08-01 21:27:03 async_llm_engine.py:56] IndexError: index 0 is out of bounds for dimension 0 with size 0

What very first message to the model of "Who are you?" I got "I" and then died.

@pseudotensor pseudotensor added the bug Something isn't working label Aug 1, 2024
@ShangmingCai
Copy link
Contributor

Maybe you can change your speculative model or set the spec_decoding_acceptance_method to typical_acceptance_sampler. When using '[ngram]', a bug exists in the RejectionSampler source code. It can not handle draft_probs with the shape (0, k).

@ShangmingCai
Copy link
Contributor

Is anyone fixing this bug?
cc @cadedaniel

@pseudotensor
Copy link
Author

I'm happy to try other options. It was working well for someone else, but not for me on the phi-3-mini-128k model. Failed instantly. I'll probably wait until this bug is fixed before trying again.

The hope is that for structured output, others are getting quite good speed-up. i.e. for guided_json and JSON output, about 5x improvement for a 7b model. Sounds great, but just crashes for me.

@ShangmingCai
Copy link
Contributor

I'm happy to try other options. It was working well for someone else, but not for me on the phi-3-mini-128k model. Failed instantly. I'll probably wait until this bug is fixed before trying again.

The hope is that for structured output, others are getting quite good speed-up. i.e. for guided_json and JSON output, about 5x improvement for a 7b model. Sounds great, but just crashes for me.

Did you try adding --spec-decoding-acceptance-method='typical_acceptance_sampler' \? It works for me to avoid the crash.

@ShangmingCai
Copy link
Contributor

I'm happy to try other options. It was working well for someone else, but not for me on the phi-3-mini-128k model. Failed instantly. I'll probably wait until this bug is fixed before trying again.

The hope is that for structured output, others are getting quite good speed-up. i.e. for guided_json and JSON output, about 5x improvement for a 7b model. Sounds great, but just crashes for me.

FYI, you can build from the source code of the main branch. I guess the container you are using is built with vllm version v0.5.3 or v0.5.3.post1. (#6698) has fixed this bug. Alternatively, you can wait for the release of v0.5.4, which should not cause the crash again.

@pseudotensor
Copy link
Author

0.5.4 seems to fix the issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants