Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix illegal memory access error with chunked prefill, prefix caching, block manager v2 and xformers enabled together #9532

Merged
merged 6 commits into from
Oct 31, 2024

Conversation

sasha0552
Copy link
Contributor

@sasha0552 sasha0552 commented Oct 20, 2024

Deleted

This PR (in its current form) is intended to simply reproduce a crash I've locally observed on a P40 with a particular request sequence, on CI (which have a supported GPU), so don't merge it.

Eventually (regardless of reproductibility on other hardware) I hope to find the cause and fix the crash (although this may take a long time as I'm not familiar with these parts of vLLM, so any help is welcome).

For certain sequences of requests (you can find one in temp.py to reproduce manually, it works using an OpenAI-compatible API), vLLMs with

  1. Chunked prefill
  2. Prefix caching
  3. Block manager V2
  4. XFormers

crashes with CUDA error: an illegal memory access was encountered somewhere. (sometimes, e.g., in the prefix caching kernel).

I believe this is not a Triton for Pascal (or my hardware in general) problem, as I have found similar issues about crashes with prefix caching on other hardware.

The crash looks like this...
(VllmWorkerProcess pid=14578) WARNING 10-17 12:17:31 model_runner_base.py:143] Failed to pickle inputs of failed execution: CUDA error: an illegal memory access was encountered
(VllmWorkerProcess pid=14578) WARNING 10-17 12:17:31 model_runner_base.py:143] Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
(VllmWorkerProcess pid=14578) WARNING 10-17 12:17:31 model_runner_base.py:143]
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231] Exception in worker VllmWorkerProcess while processing method execute_model: Error in model execution: Triton Error [CUDA]: an illegal memory access was encountered, Traceback (most recent call last):
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]   File "/home/tori/.local/share/pipx/venvs/vllm-pascal/lib/python3.11/site-packages/vllm/worker/model_runner_base.py", line 116, in _wrapper
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]     return func(*args, **kwargs)
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]   File "/home/tori/.local/share/pipx/venvs/vllm-pascal/lib/python3.11/site-packages/vllm/worker/model_runner.py", line 1665, in execute_model
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]     hidden_or_intermediate_states = model_executable(
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]                                     ^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]   File "/home/tori/.local/share/pipx/venvs/vllm-pascal/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]     return self._call_impl(*args, **kwargs)
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]   File "/home/tori/.local/share/pipx/venvs/vllm-pascal/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]     return forward_call(*args, **kwargs)
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]   File "/home/tori/.local/share/pipx/venvs/vllm-pascal/lib/python3.11/site-packages/vllm/model_executor/models/llama.py", line 556, in forward
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]     model_output = self.model(input_ids, positions, kv_caches,
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]   File "/home/tori/.local/share/pipx/venvs/vllm-pascal/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]     return self._call_impl(*args, **kwargs)
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]   File "/home/tori/.local/share/pipx/venvs/vllm-pascal/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]     return forward_call(*args, **kwargs)
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]   File "/home/tori/.local/share/pipx/venvs/vllm-pascal/lib/python3.11/site-packages/vllm/model_executor/models/llama.py", line 345, in forward
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]     hidden_states, residual = layer(positions, hidden_states,
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]   File "/home/tori/.local/share/pipx/venvs/vllm-pascal/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]     return self._call_impl(*args, **kwargs)
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]   File "/home/tori/.local/share/pipx/venvs/vllm-pascal/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]     return forward_call(*args, **kwargs)
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]   File "/home/tori/.local/share/pipx/venvs/vllm-pascal/lib/python3.11/site-packages/vllm/model_executor/models/llama.py", line 257, in forward
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]     hidden_states = self.self_attn(positions=positions,
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]   File "/home/tori/.local/share/pipx/venvs/vllm-pascal/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]     return self._call_impl(*args, **kwargs)
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]   File "/home/tori/.local/share/pipx/venvs/vllm-pascal/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]     return forward_call(*args, **kwargs)
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]   File "/home/tori/.local/share/pipx/venvs/vllm-pascal/lib/python3.11/site-packages/vllm/model_executor/models/llama.py", line 187, in forward
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]     attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]   File "/home/tori/.local/share/pipx/venvs/vllm-pascal/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]     return self._call_impl(*args, **kwargs)
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]   File "/home/tori/.local/share/pipx/venvs/vllm-pascal/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]     return forward_call(*args, **kwargs)
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]   File "/home/tori/.local/share/pipx/venvs/vllm-pascal/lib/python3.11/site-packages/vllm/attention/layer.py", line 100, in forward
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]     return self.impl.forward(query,
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]            ^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]   File "/home/tori/.local/share/pipx/venvs/vllm-pascal/lib/python3.11/site-packages/vllm/attention/backends/xformers.py", line 620, in forward
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]     out = PagedAttention.forward_prefix(
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]   File "/home/tori/.local/share/pipx/venvs/vllm-pascal/lib/python3.11/site-packages/vllm/attention/ops/paged_attn.py", line 211, in forward_prefix
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]     context_attention_fwd(
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]   File "/home/tori/.local/share/pipx/venvs/vllm-pascal/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]     return func(*args, **kwargs)
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]   File "/home/tori/.local/share/pipx/venvs/vllm-pascal/lib/python3.11/site-packages/vllm/attention/ops/prefix_prefill.py", line 811, in context_attention_fwd
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]     _fwd_kernel[grid](
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]   File "/home/tori/.local/share/pipx/venvs/vllm-pascal/lib/python3.11/site-packages/triton/runtime/jit.py", line 345, in <lambda>
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]     return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]   File "/home/tori/.local/share/pipx/venvs/vllm-pascal/lib/python3.11/site-packages/triton/runtime/jit.py", line 691, in run
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]     kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]   File "/home/tori/.local/share/pipx/venvs/vllm-pascal/lib/python3.11/site-packages/triton/backends/nvidia/driver.py", line 365, in __call__
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]     self.launch(*args, **kwargs)
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231] RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231] The above exception was the direct cause of the following exception:
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231] Traceback (most recent call last):
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]   File "/home/tori/.local/share/pipx/venvs/vllm-pascal/lib/python3.11/site-packages/vllm/executor/multiproc_worker_utils.py", line 224, in _run_worker_process
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]     output = executor(*args, **kwargs)
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]              ^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]   File "/home/tori/.local/share/pipx/venvs/vllm-pascal/lib/python3.11/site-packages/vllm/worker/worker_base.py", line 327, in execute_model
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]     output = self.model_runner.execute_model(
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]   File "/home/tori/.local/share/pipx/venvs/vllm-pascal/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]     return func(*args, **kwargs)
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]   File "/home/tori/.local/share/pipx/venvs/vllm-pascal/lib/python3.11/site-packages/vllm/worker/model_runner_base.py", line 146, in _wrapper
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]     raise type(err)(f"Error in model execution: "
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231] RuntimeError: Error in model execution: Triton Error [CUDA]: an illegal memory access was encountered
(VllmWorkerProcess pid=14578) ERROR 10-17 12:17:31 multiproc_worker_utils.py:231]
... or this ...
ERROR 10-17 12:43:16 engine.py:160] RuntimeError('CUDA error: an illegal memory access was encountered\nCUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1\nCompile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n')
ERROR 10-17 12:43:16 engine.py:160] Traceback (most recent call last):
ERROR 10-17 12:43:16 engine.py:160]   File "/home/tori/.local/share/pipx/venvs/vllm-pascal/lib/python3.11/site-packages/vllm/engine/multiprocessing/engine.py", line 158, in start
ERROR 10-17 12:43:16 engine.py:160]     self.run_engine_loop()
ERROR 10-17 12:43:16 engine.py:160]   File "/home/tori/.local/share/pipx/venvs/vllm-pascal/lib/python3.11/site-packages/vllm/engine/multiprocessing/engine.py", line 221, in run_engine_loop
ERROR 10-17 12:43:16 engine.py:160]     request_outputs = self.engine_step()
ERROR 10-17 12:43:16 engine.py:160]                       ^^^^^^^^^^^^^^^^^^
ERROR 10-17 12:43:16 engine.py:160]   File "/home/tori/.local/share/pipx/venvs/vllm-pascal/lib/python3.11/site-packages/vllm/engine/multiprocessing/engine.py", line 239, in engine_step
ERROR 10-17 12:43:16 engine.py:160]     raise e
ERROR 10-17 12:43:16 engine.py:160]   File "/home/tori/.local/share/pipx/venvs/vllm-pascal/lib/python3.11/site-packages/vllm/engine/multiprocessing/engine.py", line 230, in engine_step
ERROR 10-17 12:43:16 engine.py:160]     return self.engine.step()
ERROR 10-17 12:43:16 engine.py:160]            ^^^^^^^^^^^^^^^^^^
ERROR 10-17 12:43:16 engine.py:160]   File "/home/tori/.local/share/pipx/venvs/vllm-pascal/lib/python3.11/site-packages/vllm/engine/llm_engine.py", line 1386, in step
ERROR 10-17 12:43:16 engine.py:160]     outputs = self.model_executor.execute_model(
ERROR 10-17 12:43:16 engine.py:160]               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 10-17 12:43:16 engine.py:160]   File "/home/tori/.local/share/pipx/venvs/vllm-pascal/lib/python3.11/site-packages/vllm/executor/gpu_executor.py", line 134, in execute_model
ERROR 10-17 12:43:16 engine.py:160]     output = self.driver_worker.execute_model(execute_model_req)
ERROR 10-17 12:43:16 engine.py:160]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 10-17 12:43:16 engine.py:160]   File "/home/tori/.local/share/pipx/venvs/vllm-pascal/lib/python3.11/site-packages/vllm/worker/worker_base.py", line 303, in execute_model
ERROR 10-17 12:43:16 engine.py:160]     inputs = self.prepare_input(execute_model_req)
ERROR 10-17 12:43:16 engine.py:160]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 10-17 12:43:16 engine.py:160]   File "/home/tori/.local/share/pipx/venvs/vllm-pascal/lib/python3.11/site-packages/vllm/worker/worker_base.py", line 291, in prepare_input
ERROR 10-17 12:43:16 engine.py:160]     return self._get_driver_input_and_broadcast(execute_model_req)
ERROR 10-17 12:43:16 engine.py:160]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 10-17 12:43:16 engine.py:160]   File "/home/tori/.local/share/pipx/venvs/vllm-pascal/lib/python3.11/site-packages/vllm/worker/worker_base.py", line 253, in _get_driver_input_and_broadcast
ERROR 10-17 12:43:16 engine.py:160]     self.model_runner.prepare_model_input(
ERROR 10-17 12:43:16 engine.py:160]   File "/home/tori/.local/share/pipx/venvs/vllm-pascal/lib/python3.11/site-packages/vllm/worker/model_runner.py", line 1593, in prepare_model_input
ERROR 10-17 12:43:16 engine.py:160]     model_input = self._prepare_model_input_tensors(
ERROR 10-17 12:43:16 engine.py:160]                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 10-17 12:43:16 engine.py:160]   File "/home/tori/.local/share/pipx/venvs/vllm-pascal/lib/python3.11/site-packages/vllm/worker/model_runner.py", line 1200, in _prepare_model_input_tensors
ERROR 10-17 12:43:16 engine.py:160]     return builder.build()  # type: ignore
ERROR 10-17 12:43:16 engine.py:160]            ^^^^^^^^^^^^^^^
ERROR 10-17 12:43:16 engine.py:160]   File "/home/tori/.local/share/pipx/venvs/vllm-pascal/lib/python3.11/site-packages/vllm/worker/model_runner.py", line 871, in build
ERROR 10-17 12:43:16 engine.py:160]     attn_metadata = self.attn_metadata_builder.build(
ERROR 10-17 12:43:16 engine.py:160]                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 10-17 12:43:16 engine.py:160]   File "/home/tori/.local/share/pipx/venvs/vllm-pascal/lib/python3.11/site-packages/vllm/attention/backends/utils.py", line 227, in build
ERROR 10-17 12:43:16 engine.py:160]     block_tables = make_tensor_with_pad(
ERROR 10-17 12:43:16 engine.py:160]                    ^^^^^^^^^^^^^^^^^^^^^
ERROR 10-17 12:43:16 engine.py:160]   File "/home/tori/.local/share/pipx/venvs/vllm-pascal/lib/python3.11/site-packages/vllm/utils.py", line 857, in make_tensor_with_pad
ERROR 10-17 12:43:16 engine.py:160]     tensor = torch.from_numpy(padded_x).to(device)
ERROR 10-17 12:43:16 engine.py:160]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 10-17 12:43:16 engine.py:160] RuntimeError: CUDA error: an illegal memory access was encountered
ERROR 10-17 12:43:16 engine.py:160] CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
ERROR 10-17 12:43:16 engine.py:160] For debugging consider passing CUDA_LAUNCH_BLOCKING=1
ERROR 10-17 12:43:16 engine.py:160] Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
ERROR 10-17 12:43:16 engine.py:160]
ERROR:    Exception in ASGI application

... or as in the issues.

Based on this (and the test I wrote), I can assume that something in block manager V2 is corrupting the CUDA memory.

block manager chunked prefill prefix caching works
v1 - - +
v2 - - +
v1 + - +
v2 + - +
v1 - + +
v2 - + +
v1 + + +
v2 + + -
Test output on P40
tori@torilinux ~ % HF_HOME=/mnt/hf HF_HUB_OFFLINE=1 pytest -s repro_fin.py
================================================================================ test session starts =================================================================================
platform linux -- Python 3.11.8, pytest-8.3.3, pluggy-1.5.0
rootdir: /home/tori
plugins: anyio-4.6.2.post1
collected 8 items

repro_fin.py INFO 10-20 04:09:12 config.py:1670] Downcasting torch.float32 to torch.float16.
WARNING 10-20 04:09:16 config.py:380] To see benefits of async output processing, enable CUDA graph. Since, enforce-eager is enabled, async output processor cannot be used
INFO 10-20 04:09:16 llm_engine.py:237] Initializing an LLM engine (v999.999.999) with config: model='mistralai/Ministral-8B-Instruct-2410', speculative_config=None, tokenizer='mistralai/Ministral-8B-Instruct-2410', skip_tokenizer_init=False, tokenizer_mode=mistral, revision=None, override_neuron_config=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=4096, download_dir=None, load_format=LoadFormat.MISTRAL, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=mistralai/Ministral-8B-Instruct-2410, use_v2_block_manager=False, num_scheduler_steps=1, chunked_prefill_enabled=False multi_step_stream_outputs=True, enable_prefix_caching=False, use_async_output_proc=False, use_cached_outputs=False, mm_processor_kwargs=None)
INFO 10-20 04:09:17 selector.py:224] Cannot use FlashAttention-2 backend for Volta and Turing GPUs.
INFO 10-20 04:09:17 selector.py:115] Using XFormers backend.
INFO 10-20 04:09:17 model_runner.py:1060] Starting to load model mistralai/Ministral-8B-Instruct-2410...
INFO 10-20 04:09:18 selector.py:224] Cannot use FlashAttention-2 backend for Volta and Turing GPUs.
INFO 10-20 04:09:18 selector.py:115] Using XFormers backend.
INFO 10-20 04:09:18 weight_utils.py:243] Using model weights format ['consolidated*.safetensors', '*.pt']
INFO 10-20 04:09:18 weight_utils.py:288] No consolidated.safetensors.index.json found in remote.
Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [02:14<00:00, 134.35s/it]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [02:14<00:00, 134.35s/it]

INFO 10-20 04:11:33 model_runner.py:1071] Loading model weights took 14.9693 GB
INFO 10-20 04:11:43 gpu_executor.py:122] # GPU blocks: 2042, # CPU blocks: 0
INFO 10-20 04:11:43 gpu_executor.py:126] Maximum concurrency for 4096 tokens per request: 7.98x
Processed prompts: 100%|██████████████████████████████████████████████████████████████████████████| 1/1 [00:08<00:00,  8.54s/it, est. speed input: 228.38 toks/s, output: 0.12 toks/s]
Processed prompts: 100%|██████████████████████████████████████████████████████████████████████████| 1/1 [00:09<00:00,  9.18s/it, est. speed input: 214.98 toks/s, output: 0.11 toks/s]
Processed prompts: 100%|██████████████████████████████████████████████████████████████████████████| 1/1 [00:09<00:00,  9.42s/it, est. speed input: 217.21 toks/s, output: 0.11 toks/s]
Processed prompts: 100%|██████████████████████████████████████████████████████████████████████████| 1/1 [00:10<00:00, 10.09s/it, est. speed input: 207.83 toks/s, output: 0.10 toks/s]
Processed prompts: 100%|██████████████████████████████████████████████████████████████████████████| 1/1 [00:10<00:00, 10.18s/it, est. speed input: 208.94 toks/s, output: 0.10 toks/s]
.INFO 10-20 04:12:31 config.py:1670] Downcasting torch.float32 to torch.float16.
WARNING 10-20 04:12:31 config.py:380] To see benefits of async output processing, enable CUDA graph. Since, enforce-eager is enabled, async output processor cannot be used
INFO 10-20 04:12:31 llm_engine.py:237] Initializing an LLM engine (v999.999.999) with config: model='mistralai/Ministral-8B-Instruct-2410', speculative_config=None, tokenizer='mistralai/Ministral-8B-Instruct-2410', skip_tokenizer_init=False, tokenizer_mode=mistral, revision=None, override_neuron_config=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=4096, download_dir=None, load_format=LoadFormat.MISTRAL, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=mistralai/Ministral-8B-Instruct-2410, use_v2_block_manager=True, num_scheduler_steps=1, chunked_prefill_enabled=False multi_step_stream_outputs=True, enable_prefix_caching=False, use_async_output_proc=False, use_cached_outputs=False, mm_processor_kwargs=None)
INFO 10-20 04:12:32 model_runner.py:1060] Starting to load model mistralai/Ministral-8B-Instruct-2410...
INFO 10-20 04:12:33 weight_utils.py:243] Using model weights format ['consolidated*.safetensors', '*.pt']
INFO 10-20 04:12:33 weight_utils.py:288] No consolidated.safetensors.index.json found in remote.
Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [02:14<00:00, 134.66s/it]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [02:14<00:00, 134.66s/it]

INFO 10-20 04:14:47 model_runner.py:1071] Loading model weights took 14.9381 GB
INFO 10-20 04:14:57 gpu_executor.py:122] # GPU blocks: 2105, # CPU blocks: 0
INFO 10-20 04:14:57 gpu_executor.py:126] Maximum concurrency for 4096 tokens per request: 8.22x
Processed prompts: 100%|███████████████████████████████████████████████████████████████████████████| 1/1 [00:25<00:00, 25.69s/it, est. speed input: 75.95 toks/s, output: 0.04 toks/s]
Processed prompts: 100%|███████████████████████████████████████████████████████████████████████████| 1/1 [00:26<00:00, 26.14s/it, est. speed input: 75.46 toks/s, output: 0.04 toks/s]
Processed prompts: 100%|███████████████████████████████████████████████████████████████████████████| 1/1 [00:29<00:00, 29.42s/it, est. speed input: 69.51 toks/s, output: 0.03 toks/s]
Processed prompts: 100%|███████████████████████████████████████████████████████████████████████████| 1/1 [00:28<00:00, 28.05s/it, est. speed input: 74.75 toks/s, output: 0.04 toks/s]
Processed prompts: 100%|███████████████████████████████████████████████████████████████████████████| 1/1 [00:30<00:00, 30.72s/it, est. speed input: 69.23 toks/s, output: 0.03 toks/s]
.INFO 10-20 04:17:17 config.py:1670] Downcasting torch.float32 to torch.float16.
INFO 10-20 04:17:17 config.py:1005] Chunked prefill is enabled with max_num_batched_tokens=512.
WARNING 10-20 04:17:17 config.py:380] To see benefits of async output processing, enable CUDA graph. Since, enforce-eager is enabled, async output processor cannot be used
INFO 10-20 04:17:17 llm_engine.py:237] Initializing an LLM engine (v999.999.999) with config: model='mistralai/Ministral-8B-Instruct-2410', speculative_config=None, tokenizer='mistralai/Ministral-8B-Instruct-2410', skip_tokenizer_init=False, tokenizer_mode=mistral, revision=None, override_neuron_config=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=4096, download_dir=None, load_format=LoadFormat.MISTRAL, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=mistralai/Ministral-8B-Instruct-2410, use_v2_block_manager=False, num_scheduler_steps=1, chunked_prefill_enabled=True multi_step_stream_outputs=True, enable_prefix_caching=False, use_async_output_proc=False, use_cached_outputs=False, mm_processor_kwargs=None)
INFO 10-20 04:17:18 model_runner.py:1060] Starting to load model mistralai/Ministral-8B-Instruct-2410...
INFO 10-20 04:17:18 weight_utils.py:243] Using model weights format ['consolidated*.safetensors', '*.pt']
INFO 10-20 04:17:18 weight_utils.py:288] No consolidated.safetensors.index.json found in remote.
Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [02:14<00:00, 134.61s/it]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [02:14<00:00, 134.62s/it]

INFO 10-20 04:19:33 model_runner.py:1071] Loading model weights took 14.9381 GB
INFO 10-20 04:19:36 gpu_executor.py:122] # GPU blocks: 2454, # CPU blocks: 0
INFO 10-20 04:19:36 gpu_executor.py:126] Maximum concurrency for 4096 tokens per request: 9.59x
Processed prompts: 100%|███████████████████████████████████████████████████████████████████████████| 1/1 [00:23<00:00, 23.92s/it, est. speed input: 81.57 toks/s, output: 0.04 toks/s]
Processed prompts: 100%|██████████████████████████████████████████████████████████████████████████| 1/1 [00:18<00:00, 18.79s/it, est. speed input: 104.99 toks/s, output: 0.05 toks/s]
Processed prompts: 100%|███████████████████████████████████████████████████████████████████████████| 1/1 [00:27<00:00, 27.82s/it, est. speed input: 73.50 toks/s, output: 0.04 toks/s]
Processed prompts: 100%|██████████████████████████████████████████████████████████████████████████| 1/1 [00:20<00:00, 20.58s/it, est. speed input: 101.89 toks/s, output: 0.05 toks/s]
Processed prompts: 100%|███████████████████████████████████████████████████████████████████████████| 1/1 [00:22<00:00, 22.07s/it, est. speed input: 96.38 toks/s, output: 0.05 toks/s]
.INFO 10-20 04:21:30 config.py:1670] Downcasting torch.float32 to torch.float16.
INFO 10-20 04:21:30 config.py:1005] Chunked prefill is enabled with max_num_batched_tokens=512.
WARNING 10-20 04:21:30 config.py:380] To see benefits of async output processing, enable CUDA graph. Since, enforce-eager is enabled, async output processor cannot be used
INFO 10-20 04:21:30 llm_engine.py:237] Initializing an LLM engine (v999.999.999) with config: model='mistralai/Ministral-8B-Instruct-2410', speculative_config=None, tokenizer='mistralai/Ministral-8B-Instruct-2410', skip_tokenizer_init=False, tokenizer_mode=mistral, revision=None, override_neuron_config=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=4096, download_dir=None, load_format=LoadFormat.MISTRAL, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=mistralai/Ministral-8B-Instruct-2410, use_v2_block_manager=True, num_scheduler_steps=1, chunked_prefill_enabled=True multi_step_stream_outputs=True, enable_prefix_caching=False, use_async_output_proc=False, use_cached_outputs=False, mm_processor_kwargs=None)
INFO 10-20 04:21:31 model_runner.py:1060] Starting to load model mistralai/Ministral-8B-Instruct-2410...
INFO 10-20 04:21:31 weight_utils.py:243] Using model weights format ['consolidated*.safetensors', '*.pt']
INFO 10-20 04:21:31 weight_utils.py:288] No consolidated.safetensors.index.json found in remote.
Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [02:14<00:00, 134.09s/it]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [02:14<00:00, 134.10s/it]

INFO 10-20 04:23:45 model_runner.py:1071] Loading model weights took 14.9381 GB
INFO 10-20 04:23:51 gpu_executor.py:122] # GPU blocks: 2454, # CPU blocks: 0
INFO 10-20 04:23:51 gpu_executor.py:126] Maximum concurrency for 4096 tokens per request: 9.59x
Processed prompts: 100%|███████████████████████████████████████████████████████████████████████████| 1/1 [00:22<00:00, 22.02s/it, est. speed input: 88.62 toks/s, output: 0.05 toks/s]
Processed prompts: 100%|███████████████████████████████████████████████████████████████████████████| 1/1 [00:21<00:00, 21.44s/it, est. speed input: 92.01 toks/s, output: 0.05 toks/s]
Processed prompts: 100%|███████████████████████████████████████████████████████████████████████████| 1/1 [00:22<00:00, 22.91s/it, est. speed input: 89.25 toks/s, output: 0.04 toks/s]
Processed prompts: 100%|██████████████████████████████████████████████████████████████████████████| 1/1 [00:20<00:00, 20.53s/it, est. speed input: 102.14 toks/s, output: 0.05 toks/s]
Processed prompts: 100%|███████████████████████████████████████████████████████████████████████████| 1/1 [00:22<00:00, 22.13s/it, est. speed input: 96.11 toks/s, output: 0.05 toks/s]
.INFO 10-20 04:25:40 config.py:1670] Downcasting torch.float32 to torch.float16.
WARNING 10-20 04:25:40 config.py:380] To see benefits of async output processing, enable CUDA graph. Since, enforce-eager is enabled, async output processor cannot be used
INFO 10-20 04:25:40 llm_engine.py:237] Initializing an LLM engine (v999.999.999) with config: model='mistralai/Ministral-8B-Instruct-2410', speculative_config=None, tokenizer='mistralai/Ministral-8B-Instruct-2410', skip_tokenizer_init=False, tokenizer_mode=mistral, revision=None, override_neuron_config=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=4096, download_dir=None, load_format=LoadFormat.MISTRAL, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=mistralai/Ministral-8B-Instruct-2410, use_v2_block_manager=False, num_scheduler_steps=1, chunked_prefill_enabled=False multi_step_stream_outputs=True, enable_prefix_caching=True, use_async_output_proc=False, use_cached_outputs=False, mm_processor_kwargs=None)
INFO 10-20 04:25:41 model_runner.py:1060] Starting to load model mistralai/Ministral-8B-Instruct-2410...
INFO 10-20 04:25:41 weight_utils.py:243] Using model weights format ['consolidated*.safetensors', '*.pt']
INFO 10-20 04:25:41 weight_utils.py:288] No consolidated.safetensors.index.json found in remote.
Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [02:16<00:00, 136.90s/it]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [02:16<00:00, 136.92s/it]

INFO 10-20 04:27:58 model_runner.py:1071] Loading model weights took 14.9381 GB
INFO 10-20 04:28:08 gpu_executor.py:122] # GPU blocks: 2105, # CPU blocks: 0
INFO 10-20 04:28:08 gpu_executor.py:126] Maximum concurrency for 4096 tokens per request: 8.22x
INFO 10-20 04:28:08 block_manager_v1.py:263] Automatic prefix caching is enabled.
Processed prompts: 100%|███████████████████████████████████████████████████████████████████████████| 1/1 [00:24<00:00, 24.40s/it, est. speed input: 79.96 toks/s, output: 0.04 toks/s]
Processed prompts: 100%|█████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.74s/it, est. speed input: 1135.02 toks/s, output: 0.58 toks/s]
Processed prompts: 100%|██████████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00,  3.27s/it, est. speed input: 626.31 toks/s, output: 0.31 toks/s]
Processed prompts: 100%|██████████████████████████████████████████████████████████████████████████| 1/1 [00:05<00:00,  5.10s/it, est. speed input: 410.85 toks/s, output: 0.20 toks/s]
Processed prompts: 100%|███████████████████████████████████████████████████████████████████████████| 1/1 [00:29<00:00, 29.94s/it, est. speed input: 71.05 toks/s, output: 0.03 toks/s]
.INFO 10-20 04:29:13 config.py:1670] Downcasting torch.float32 to torch.float16.
WARNING 10-20 04:29:13 config.py:380] To see benefits of async output processing, enable CUDA graph. Since, enforce-eager is enabled, async output processor cannot be used
INFO 10-20 04:29:13 llm_engine.py:237] Initializing an LLM engine (v999.999.999) with config: model='mistralai/Ministral-8B-Instruct-2410', speculative_config=None, tokenizer='mistralai/Ministral-8B-Instruct-2410', skip_tokenizer_init=False, tokenizer_mode=mistral, revision=None, override_neuron_config=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=4096, download_dir=None, load_format=LoadFormat.MISTRAL, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=mistralai/Ministral-8B-Instruct-2410, use_v2_block_manager=True, num_scheduler_steps=1, chunked_prefill_enabled=False multi_step_stream_outputs=True, enable_prefix_caching=True, use_async_output_proc=False, use_cached_outputs=False, mm_processor_kwargs=None)
INFO 10-20 04:29:14 model_runner.py:1060] Starting to load model mistralai/Ministral-8B-Instruct-2410...
INFO 10-20 04:29:14 weight_utils.py:243] Using model weights format ['consolidated*.safetensors', '*.pt']
INFO 10-20 04:29:14 weight_utils.py:288] No consolidated.safetensors.index.json found in remote.
Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [02:13<00:00, 133.83s/it]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [02:13<00:00, 133.84s/it]

INFO 10-20 04:31:28 model_runner.py:1071] Loading model weights took 14.9381 GB
INFO 10-20 04:31:38 gpu_executor.py:122] # GPU blocks: 2105, # CPU blocks: 0
INFO 10-20 04:31:38 gpu_executor.py:126] Maximum concurrency for 4096 tokens per request: 8.22x
Processed prompts: 100%|███████████████████████████████████████████████████████████████████████████| 1/1 [00:26<00:00, 26.08s/it, est. speed input: 74.81 toks/s, output: 0.04 toks/s]
Processed prompts: 100%|█████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.73s/it, est. speed input: 1138.99 toks/s, output: 0.58 toks/s]
Processed prompts: 100%|██████████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00,  3.30s/it, est. speed input: 619.15 toks/s, output: 0.30 toks/s]
Processed prompts: 100%|██████████████████████████████████████████████████████████████████████████| 1/1 [00:05<00:00,  5.36s/it, est. speed input: 390.89 toks/s, output: 0.19 toks/s]
Processed prompts: 100%|███████████████████████████████████████████████████████████████████████████| 1/1 [00:31<00:00, 31.64s/it, est. speed input: 67.23 toks/s, output: 0.03 toks/s]
.INFO 10-20 04:32:46 config.py:1670] Downcasting torch.float32 to torch.float16.
INFO 10-20 04:32:46 config.py:1005] Chunked prefill is enabled with max_num_batched_tokens=512.
WARNING 10-20 04:32:46 config.py:380] To see benefits of async output processing, enable CUDA graph. Since, enforce-eager is enabled, async output processor cannot be used
INFO 10-20 04:32:46 llm_engine.py:237] Initializing an LLM engine (v999.999.999) with config: model='mistralai/Ministral-8B-Instruct-2410', speculative_config=None, tokenizer='mistralai/Ministral-8B-Instruct-2410', skip_tokenizer_init=False, tokenizer_mode=mistral, revision=None, override_neuron_config=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=4096, download_dir=None, load_format=LoadFormat.MISTRAL, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=mistralai/Ministral-8B-Instruct-2410, use_v2_block_manager=False, num_scheduler_steps=1, chunked_prefill_enabled=True multi_step_stream_outputs=True, enable_prefix_caching=True, use_async_output_proc=False, use_cached_outputs=False, mm_processor_kwargs=None)
INFO 10-20 04:32:47 model_runner.py:1060] Starting to load model mistralai/Ministral-8B-Instruct-2410...
INFO 10-20 04:32:47 weight_utils.py:243] Using model weights format ['consolidated*.safetensors', '*.pt']
INFO 10-20 04:32:47 weight_utils.py:288] No consolidated.safetensors.index.json found in remote.
Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [02:14<00:00, 134.58s/it]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [02:14<00:00, 134.58s/it]

INFO 10-20 04:35:02 model_runner.py:1071] Loading model weights took 14.9381 GB
INFO 10-20 04:35:08 gpu_executor.py:122] # GPU blocks: 2454, # CPU blocks: 0
INFO 10-20 04:35:08 gpu_executor.py:126] Maximum concurrency for 4096 tokens per request: 9.59x
INFO 10-20 04:35:08 block_manager_v1.py:263] Automatic prefix caching is enabled.
Processed prompts: 100%|███████████████████████████████████████████████████████████████████████████| 1/1 [00:21<00:00, 21.49s/it, est. speed input: 90.78 toks/s, output: 0.05 toks/s]
Processed prompts: 100%|██████████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00,  3.32s/it, est. speed input: 593.78 toks/s, output: 0.30 toks/s]
Processed prompts: 100%|██████████████████████████████████████████████████████████████████████████| 1/1 [00:04<00:00,  4.90s/it, est. speed input: 417.39 toks/s, output: 0.20 toks/s]
Processed prompts: 100%|██████████████████████████████████████████████████████████████████████████| 1/1 [00:05<00:00,  5.31s/it, est. speed input: 394.64 toks/s, output: 0.19 toks/s]
Processed prompts: 100%|██████████████████████████████████████████████████████████████████████████| 1/1 [00:20<00:00, 20.02s/it, est. speed input: 106.24 toks/s, output: 0.05 toks/s]
.INFO 10-20 04:36:03 config.py:1670] Downcasting torch.float32 to torch.float16.
INFO 10-20 04:36:03 config.py:1005] Chunked prefill is enabled with max_num_batched_tokens=512.
WARNING 10-20 04:36:03 config.py:380] To see benefits of async output processing, enable CUDA graph. Since, enforce-eager is enabled, async output processor cannot be used
INFO 10-20 04:36:03 llm_engine.py:237] Initializing an LLM engine (v999.999.999) with config: model='mistralai/Ministral-8B-Instruct-2410', speculative_config=None, tokenizer='mistralai/Ministral-8B-Instruct-2410', skip_tokenizer_init=False, tokenizer_mode=mistral, revision=None, override_neuron_config=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=4096, download_dir=None, load_format=LoadFormat.MISTRAL, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=mistralai/Ministral-8B-Instruct-2410, use_v2_block_manager=True, num_scheduler_steps=1, chunked_prefill_enabled=True multi_step_stream_outputs=True, enable_prefix_caching=True, use_async_output_proc=False, use_cached_outputs=False, mm_processor_kwargs=None)
INFO 10-20 04:36:04 model_runner.py:1060] Starting to load model mistralai/Ministral-8B-Instruct-2410...
INFO 10-20 04:36:04 weight_utils.py:243] Using model weights format ['consolidated*.safetensors', '*.pt']
INFO 10-20 04:36:04 weight_utils.py:288] No consolidated.safetensors.index.json found in remote.
Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [02:15<00:00, 135.10s/it]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [02:15<00:00, 135.11s/it]

INFO 10-20 04:38:19 model_runner.py:1071] Loading model weights took 14.9381 GB
INFO 10-20 04:38:22 gpu_executor.py:122] # GPU blocks: 2454, # CPU blocks: 0
INFO 10-20 04:38:22 gpu_executor.py:126] Maximum concurrency for 4096 tokens per request: 9.59x
Processed prompts: 100%|███████████████████████████████████████████████████████████████████████████| 1/1 [00:22<00:00, 22.35s/it, est. speed input: 87.30 toks/s, output: 0.04 toks/s]
Processed prompts: 100%|██████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.72s/it, est. speed input: 725.20 toks/s, output: 0.37 toks/s]
Processed prompts: 100%|██████████████████████████████████████████████████████████████████████████| 1/1 [00:04<00:00,  4.82s/it, est. speed input: 424.19 toks/s, output: 0.21 toks/s]
Processed prompts: 100%|██████████████████████████████████████████████████████████████████████████| 1/1 [00:05<00:00,  5.47s/it, est. speed input: 383.18 toks/s, output: 0.18 toks/s]
Processed prompts:   0%|                                                                                    | 0/1 [00:04<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]
.

================================================================================== warnings summary ==================================================================================
repro_fin.py::test_vllm[False-False-False-False]
  /home/tori/.local/share/pipx/venvs/vllm-pascal/lib/python3.11/site-packages/xformers/ops/fmha/flash.py:211: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.
    @torch.library.impl_abstract("xformers_flash::flash_fwd")

repro_fin.py::test_vllm[False-False-False-False]
  /home/tori/.local/share/pipx/venvs/vllm-pascal/lib/python3.11/site-packages/xformers/ops/fmha/flash.py:344: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.
    @torch.library.impl_abstract("xformers_flash::flash_bwd")

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
===================================================================== 8 passed, 2 warnings in 1797.88s (0:29:57) =====================================================================
HF_HOME=/mnt/hf HF_HUB_OFFLINE=1 pytest -s repro_fin.py  969.67s user 135.33s system 61% cpu 30:03.72 total

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@sasha0552 sasha0552 changed the title [Bugfix] Fix illegal memory access error with chunked prefill, prefix caching and block manager v2 enabled together [WIP] [Bugfix] Fix illegal memory access error with chunked prefill, prefix caching and block manager v2 enabled together Oct 20, 2024
@sasha0552 sasha0552 marked this pull request as ready for review October 20, 2024 04:46
@sasha0552
Copy link
Contributor Author

(Sorry for the inconvenience, I unchecked PR as a draft to be able to run fastcheck).

@sasha0552 sasha0552 marked this pull request as draft October 20, 2024 06:22
@sasha0552
Copy link
Contributor Author

It seems that either L4 (or perhaps non-Pascal in general) is unaffected, or the request sequence is different for each GPU architecture (as this is reproduced on three P40s tested independently). Anyway, I'll try to fix that now.

@zyfyyzyf
Copy link

attention this, According to the official documentation: https://docs.vllm.ai/en/latest/models/engine_args.html, block manager v1 has been removed and SelfAttnBlockSpaceManager (i.e. block manager v2) is now the default. Setting this flag(--use-v2-block-manager) to True or False has no effect on vLLM behavior.
1729414289507

@StevenTang1998
Copy link

@simon-mo @DarkLight1337 @robertgshaw2-neuralmagic
I wonder whether there is any new progress on this issue.

@sasha0552
Copy link
Contributor Author

sasha0552 commented Oct 28, 2024

@StevenTang1998

Please don't ping the reviewers, this PR is still draft. I haven't found the cause yet. I think the block manager v2 just doesn't allocate (enough/at all) memory in some cases.

Compute sanitizer output
========= COMPUTE-SANITIZER
========= Invalid __global__ read of size 2 bytes
=========     at _fwd_kernel+0x2e98 in /usr/local/lib/python3.12/dist-packages/vllm/attention/ops/prefix_prefill.py:122
=========     by thread (32,0,0) in block (0,11,2)
=========     Address 0x7830b0b12020 is out of bounds
=========     and is 336666657 bytes after the nearest allocation at 0x783094000000 of size 144703488 bytes
=========     Saved host backtrace up to driver entry point at kernel launch time
=========     Host Frame: [0x2dcdbf]
=========                in /usr/lib/x86_64-linux-gnu/libcuda.so.1
=========     Host Frame:launch [0x2db2]
=========                in /root/.triton/cache/b4703c4adfc4af27c3e3b286e52b8e5e8a63075a6a5cb45895f44d0947d862b5/__triton_launcher.so
=========     Host Frame: [0x17893e]
=========                in /usr/bin/python3
=========     ...
=========     Host Frame:_start [0x210b54]
=========                in /usr/bin/python3

At the moment, if you need chunked prefill + prefix caching, you can copy the block manager v1 from the old version of vLLM. It works even without modifications.

You just need to replace

            from vllm.core.block_manager import SelfAttnBlockSpaceManager
            return SelfAttnBlockSpaceManager

with

            from vllm.core.block_manager_v1 import BlockSpaceManagerV1
            return BlockSpaceManagerV1

in vllm/core/interfaces.py.

@StevenTang1998
Copy link

@StevenTang1998

Please don't ping the reviewers, this PR is still draft. I haven't found the cause yet. I think the block manager v2 just doesn't allocate (enough/at all) memory in some cases.

Compute sanitizer output
At the moment, if you need chunked prefill + prefix caching, you can copy the block manager v1 from the old version of vLLM. It works even without modifications.

You just need to replace

            from vllm.core.block_manager import SelfAttnBlockSpaceManager
            return SelfAttnBlockSpaceManager

with

            from vllm.core.block_manager_v1 import BlockSpaceManagerV1
            return BlockSpaceManagerV1

in vllm/core/interfaces.py.

Thanks for your help! Hope this issue can be solved soon.

@StevenTang1998
Copy link

Hi @sasha0552, I have tried to disable chunked_prefill or use block manage v1, but they all failed. Only with enable_prefix_caching=False can work. Maybe it is due to the prefix caching?

@sasha0552
Copy link
Contributor Author

What error do you see? Can you reproduce it consistently by sending the same sequence of requests? What GPU(s) do you have and what model are you using?

If it consistently reproducing, you can anonymize the prompts (if they are confidential) like I did and send them there, it may help identify the underlying problem. If they are not confidential, you can just send them as is.

You can anonymize the prompts by converting them to tokens using the /tokenize endpoint, and then replace repeated sequences of tokens with static tokens. Here's my example that consistently causes illegal memory access on P40:

repro.py
from vllm import LLM, SamplingParams, TokensPrompt

llm = LLM(
  config_format="mistral",
  dtype="float16",
  enable_chunked_prefill=True,
  enable_prefix_caching=True,
  enforce_eager=True,
  load_format="mistral",
  max_model_len=4096,
  model="mistralai/Ministral-8B-Instruct-2410",
  swap_space=0,
  tensor_parallel_size=1,
  tokenizer_mode="mistral",
  use_v2_block_manager=True,
)

llm.generate(TokensPrompt(prompt_token_ids=([0] * 588 ) + ([1] * 1332) + ([2] * 30  ) + ([3] * 1   )), SamplingParams(max_tokens=1, seed=42))
llm.generate(TokensPrompt(prompt_token_ids=([0] * 588 ) + ([1] * 1332) + ([4] * 3   ) + ([5] * 50  )), SamplingParams(max_tokens=1, seed=42))
llm.generate(TokensPrompt(prompt_token_ids=([0] * 588 ) + ([1] * 1332) + ([2] * 30  ) + ([6] * 95  )), SamplingParams(max_tokens=1, seed=42))
llm.generate(TokensPrompt(prompt_token_ids=([0] * 588 ) + ([1] * 1332) + ([4] * 3   ) + ([7] * 174 )), SamplingParams(max_tokens=1, seed=42))
llm.generate(TokensPrompt(prompt_token_ids=([0] * 588 ) + ([8] * 1539)                              ), SamplingParams(max_tokens=1, seed=42))

All five requests started with the same prefix of 588 tokens, the first four requests started with the same prefix of 1332 tokens, and the last request had a different prefix. The first and third have the same prefix, and the second and fourth have the same (but different from the first and third) prefix.

Additionally, you could try running vLLM with compute-sanitizer if it is illegal memory access to see where the failure occurs.

compute-sanitizer --launch-timeout=60 --log-file=compute-sanitizer.log --target-processes=application-only --tool=memcheck vllm serve ...

@StevenTang1998
Copy link

Hi @sasha0552, for the prompt and my code, I have send them to your email.

I can reproduce the error consistently. Whether disable chunked_prefill or use block manage v1 met the following error:

[rank4]:[E1029 17:25:15.306246334 ProcessGroupNCCL.cpp:1515] [PG 3 Rank 4] Process group watchdog thread terminated with exception: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
CL INFO Channel 10/0 : 4[4] -> 3[3] via P2P/IPC
sh-l20y-cn-0056:45:492 [4] NCCL INFO Channel 11/0 : 4[4] -> 3[3] via P2P/IPC
sh-l20y-cn-0056:45:492 [4] NCCL INFO Channel 12/0 : 4[4] -> 3[3] via P2P/IPC
sh-l20y-cn-0056:45:492 [4] NCCL INFO Channel 13/0 : 4[4] -> 3[3] via P2P/IPC
sh-l20y-cn-0056:45:492 [4] NCCL INFO Channel 14/0 : 4[4] -> 3[3] via P2P/IPC
sh-l20y-cn-0056:45:492 [4] NCCL INFO Channel 15/0 : 4[4] -> 3[3] via P2P/IPC
sh-l20y-cn-0056:45:492 [4] NCCL INFO Connected all trees
sh-l20y-cn-0056:45:492 [4] NCCL INFO NVLS comm 0x55e6f7308e10 headRank 4 nHeads 8 buffSize 4194304 memSize 2097152 nvlsPerRankSize 201326592 nvlsTotalSize 1610612736
sh-l20y-cn-0056:45:492 [4] NCCL INFO threadThresholds 8/8/64 | 64/8/64 | 512 | 512
sh-l20y-cn-0056:45:492 [4] NCCL INFO 16 coll channels, 0 collnet channels, 16 nvls channels, 16 p2p channels, 16 p2p channels per peer
sh-l20y-cn-0056:45:492 [4] NCCL INFO comm 0x55e6f7308e10 rank 4 nranks 8 cudaDev 4 nvmlDev 4 busId 8b000 commId 0x11c8bbd57e1dc5db - Init COMPLETE
sh-l20y-cn-0056:45:519 [4] NCCL INFO Channel 00/1 : 4[4] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:45:519 [4] NCCL INFO Channel 01/1 : 4[4] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:45:519 [4] NCCL INFO Channel 02/1 : 4[4] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:45:519 [4] NCCL INFO Channel 03/1 : 4[4] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:45:519 [4] NCCL INFO Channel 04/1 : 4[4] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:45:519 [4] NCCL INFO Channel 05/1 : 4[4] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:45:519 [4] NCCL INFO Channel 06/1 : 4[4] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:45:519 [4] NCCL INFO Channel 07/1 : 4[4] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:45:519 [4] NCCL INFO Channel 08/1 : 4[4] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:45:519 [4] NCCL INFO Channel 09/1 : 4[4] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:45:519 [4] NCCL INFO Channel 10/1 : 4[4] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:45:519 [4] NCCL INFO Channel 11/1 : 4[4] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:45:519 [4] NCCL INFO Channel 12/1 : 4[4] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:45:519 [4] NCCL INFO Channel 13/1 : 4[4] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:45:519 [4] NCCL INFO Channel 14/1 : 4[4] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:45:519 [4] NCCL INFO Channel 15/1 : 4[4] -> 0[0] via P2P/IPC
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Exception raised from c10_cuda_check_implementation at ../c10/cuda/CUDAException.cpp:43 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f8112f77f86 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7f8112f26d10 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x118 (0x7f81133cbf08 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #3: c10d::ProcessGroupNCCL::WorkNCCL::finishedGPUExecutionInternal() const + 0x56 (0x7f80c49c33e6 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #4: c10d::ProcessGroupNCCL::WorkNCCL::isCompleted() + 0xa0 (0x7f80c49c8600 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #5: c10d::ProcessGroupNCCL::watchdogHandler() + 0x1da (0x7f80c49cf2ba in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #6: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x10c (0x7f80c49d16fc in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #7: <unknown function> + 0xdc253 (0x7f81126b0253 in /usr/lib/x86_64-linux-gnu/libstdc++.so.6)
frame #8: <unknown function> + 0x94ac3 (0x7f8113e94ac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #9: <unknown function> + 0x126a40 (0x7f8113f26a40 in /usr/lib/x86_64-linux-gnu/libc.so.6)

terminate called after throwing an instance of 'c10::DistBackendError'
  what():  [PG 3 Rank 4] Process group watchdog thread terminated with exception: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Exception raised from c10_cuda_check_implementation at ../c10/cuda/CUDAException.cpp:43 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f8112f77f86 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7f8112f26d10 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x118 (0x7f81133cbf08 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #3: c10d::ProcessGroupNCCL::WorkNCCL::finishedGPUExecutionInternal() const + 0x56 (0x7f80c49c33e6 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #4: c10d::ProcessGroupNCCL::WorkNCCL::isCompleted() + 0xa0 (0x7f80c49c8600 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #5: c10d::ProcessGroupNCCL::watchdogHandler() + 0x1da (0x7f80c49cf2ba in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #6: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x10c (0x7f80c49d16fc in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #7: <unknown function> + 0xdc253 (0x7f81126b0253 in /usr/lib/x86_64-linux-gnu/libstdc++.so.6)
frame #8: <unknown function> + 0x94ac3 (0x7f8113e94ac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #9: <unknown function> + 0x126a40 (0x7f8113f26a40 in /usr/lib/x86_64-linux-gnu/libc.so.6)

Exception raised from ncclCommWatchdog at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1521 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f8112f77f86 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0xe5aa84 (0x7f80c465aa84 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #2: <unknown function> + 0xdc253 (0x7f81126b0253 in /usr/lib/x86_64-linux-gnu/libstdc++.so.6)
frame #3: <unknown function> + 0x94ac3 (0x7f8113e94ac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #4: <unknown function> + 0x126a40 (0x7f8113f26a40 in /usr/lib/x86_64-linux-gnu/libc.so.6)

INFO 10-29 17:25:15 model_runner_base.py:120] Writing input of failed execution to /tmp/err_execute_model_input_20241029-172515.pkl...
FO Channel 12/0 : 0[0] -> 1[1] via P2P/IPC
sh-l20y-cn-0056:30:486 [0] NCCL INFO Channel 13/0 : 0[0] -> 1[1] via P2P/IPC
sh-l20y-cn-0056:30:486 [0] NCCL INFO Channel 14/0 : 0[0] -> 1[1] via P2P/IPC
sh-l20y-cn-0056:30:486 [0] NCCL INFO Channel 15/0 : 0[0] -> 1[1] via P2P/IPC
sh-l20y-cn-0056:30:486 [0] NCCL INFO Connected all rings
sh-l20y-cn-0056:30:486 [0] NCCL INFO Connected all trees
sh-l20y-cn-0056:30:486 [0] NCCL INFO NVLS comm 0x55e6f72f8f20 headRank 0 nHeads 8 buffSize 4194304 memSize 2097152 nvlsPerRankSize 201326592 nvlsTotalSize 1610612736
sh-l20y-cn-0056:30:486 [0] NCCL INFO threadThresholds 8/8/64 | 64/8/64 | 512 | 512
sh-l20y-cn-0056:30:486 [0] NCCL INFO 16 coll channels, 0 collnet channels, 16 nvls channels, 16 p2p channels, 16 p2p channels per peer
sh-l20y-cn-0056:30:486 [0] NCCL INFO comm 0x55e6f72f8f20 rank 0 nranks 8 cudaDev 0 nvmlDev 0 busId 19000 commId 0x11c8bbd57e1dc5db - Init COMPLETE
[rank0]:[E1029 17:25:15.350809944 ProcessGroupNCCL.cpp:1515] [PG 3 Rank 0] Process group watchdog thread terminated with exception: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Exception raised from c10_cuda_check_implementation at ../c10/cuda/CUDAException.cpp:43 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f8112f77f86 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7f8112f26d10 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x118 (0x7f81133cbf08 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #3: c10d::ProcessGroupNCCL::WorkNCCL::finishedGPUExecutionInternal() const + 0x56 (0x7f80c49c33e6 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #4: c10d::ProcessGroupNCCL::WorkNCCL::isCompleted() + 0xa0 (0x7f80c49c8600 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #5: c10d::ProcessGroupNCCL::watchdogHandler() + 0x1da (0x7f80c49cf2ba in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #6: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x10c (0x7f80c49d16fc in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #7: <unknown function> + 0xdc253 (0x7f81126b0253 in /usr/lib/x86_64-linux-gnu/libstdc++.so.6)
frame #8: <unknown function> + 0x94ac3 (0x7f8113e94ac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #9: <unknown function> + 0x126a40 (0x7f8113f26a40 in /usr/lib/x86_64-linux-gnu/libc.so.6)

terminate called after throwing an instance of 'c10::DistBackendError'
WARNING 10-29 17:25:15 model_runner_base.py:143] Failed to pickle inputs of failed execution: CUDA error: an illegal memory access was encountered
WARNING 10-29 17:25:15 model_runner_base.py:143] CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
WARNING 10-29 17:25:15 model_runner_base.py:143] For debugging consider passing CUDA_LAUNCH_BLOCKING=1
WARNING 10-29 17:25:15 model_runner_base.py:143] Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
WARNING 10-29 17:25:15 model_runner_base.py:143] 
  what():  [PG 3 Rank 0] Process group watchdog thread terminated with exception: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Exception raised from c10_cuda_check_implementation at ../c10/cuda/CUDAException.cpp:43 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f8112f77f86 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7f8112f26d10 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x118 (0x7f81133cbf08 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #3: c10d::ProcessGroupNCCL::WorkNCCL::finishedGPUExecutionInternal() const + 0x56 (0x7f80c49c33e6 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #4: c10d::ProcessGroupNCCL::WorkNCCL::isCompleted() + 0xa0 (0x7f80c49c8600 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #5: c10d::ProcessGroupNCCL::watchdogHandler() + 0x1da (0x7f80c49cf2ba in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #6: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x10c (0x7f80c49d16fc in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #7: <unknown function> + 0xdc253 (0x7f81126b0253 in /usr/lib/x86_64-linux-gnu/libstdc++.so.6)
frame #8: <unknown function> + 0x94ac3 (0x7f8113e94ac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #9: <unknown function> + 0x126a40 (0x7f8113f26a40 in /usr/lib/x86_64-linux-gnu/libc.so.6)

Exception raised from ncclCommWatchdog at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1521 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f8112f77f86 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0xe5aa84 (0x7f80c465aa84 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #2: <unknown function> + 0xdc253 (0x7f81126b0253 in /usr/lib/x86_64-linux-gnu/libstdc++.so.6)
frame #3: <unknown function> + 0x94ac3 (0x7f8113e94ac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #4: <unknown function> + 0x126a40 (0x7f8113f26a40 in /usr/lib/x86_64-linux-gnu/libc.so.6)

[rank3]:[E1029 17:25:15.355723142 ProcessGroupNCCL.cpp:1515] [PG 3 Rank 3] Process group watchdog thread terminated with exception: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Exception raised from c10_cuda_check_implementation at ../c10/cuda/CUDAException.cpp:43 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f8112f77f86 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7f8112f26d10 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x118 (0x7f81133cbf08 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #3: c10d::ProcessGroupNCCL::WorkNCCL::finishedGPUExecutionInternal() const + 0x56 (0x7f80c49c33e6 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #4: c10d::ProcessGroupNCCL::WorkNCCL::isCompleted() + 0xa0 (0x7f80c49c8600 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #5: c10d::ProcessGroupNCCL::watchdogHandler() + 0x1da (0x7f80c49cf2ba in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
CL INFO Channel 10/0 : 3[3] -> 2[2] via P2P/IPC
sh-l20y-cn-0056:44:489 [3] NCCL INFO Channel 11/0 : 3[3] -> 2[2] via P2P/IPC
sh-l20y-cn-0056:44:489 [3] NCCL INFO Channel 12/0 : 3[3] -> 2[2] via P2P/IPC
sh-l20y-cn-0056:44:489 [3] NCCL INFO Channel 13/0 : 3[3] -> 2[2] via P2P/IPC
sh-l20y-cn-0056:44:489 [3] NCCL INFO Channel 14/0 : 3[3] -> 2[2] via P2P/IPC
frame #6: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x10c (0x7f80c49d16fc in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
sh-l20y-cn-0056:44:489 [3] NCCL INFO Channel 15/0 : 3[3] -> 2[2] via P2P/IPC
frame #7: <unknown function> + 0xdc253 (0x7f81126b0253 in /usr/lib/x86_64-linux-gnu/libstdc++.so.6)
sh-l20y-cn-0056:44:489 [3] NCCL INFO Connected all trees
frame #8: <unknown function> + 0x94ac3 (0x7f8113e94ac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
sh-l20y-cn-0056:44:489 [3] NCCL INFO NVLS comm 0x55e6f7308d10 headRank 3 nHeads 8 buffSize 4194304 memSize 2097152 nvlsPerRankSize 201326592 nvlsTotalSize 1610612736
frame #9: <unknown function> + 0x126a40 (0x7f8113f26a40 in /usr/lib/x86_64-linux-gnu/libc.so.6)
sh-l20y-cn-0056:44:489 [3] NCCL INFO threadThresholds 8/8/64 | 64/8/64 | 512 | 512

sh-l20y-cn-0056:44:489 [3] NCCL INFO 16 coll channels, 0 collnet channels, 16 nvls channels, 16 p2p channels, 16 p2p channels per peer
sh-l20y-cn-0056:44:489 [3] NCCL INFO comm 0x55e6f7308d10 rank 3 nranks 8 cudaDev 3 nvmlDev 3 busId 5d000 commId 0x11c8bbd57e1dc5db - Init COMPLETE
sh-l20y-cn-0056:44:518 [3] NCCL INFO Channel 00/1 : 3[3] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:44:518 [3] NCCL INFO Channel 01/1 : 3[3] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:44:518 [3] NCCL INFO Channel 02/1 : 3[3] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:44:518 [3] NCCL INFO Channel 03/1 : 3[3] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:44:518 [3] NCCL INFO Channel 04/1 : 3[3] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:44:518 [3] NCCL INFO Channel 05/1 : 3[3] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:44:518 [3] NCCL INFO Channel 06/1 : 3[3] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:44:518 [3] NCCL INFO Channel 07/1 : 3[3] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:44:518 [3] NCCL INFO Channel 08/1 : 3[3] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:44:518 [3] NCCL INFO Channel 09/1 : 3[3] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:44:518 [3] NCCL INFO Channel 10/1 : 3[3] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:44:518 [3] NCCL INFO Channel 11/1 : 3[3] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:44:518 [3] NCCL INFO Channel 12/1 : 3[3] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:44:518 [3] NCCL INFO Channel 13/1 : 3[3] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:44:518 [3] NCCL INFO Channel 14/1 : 3[3] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:44:518 [3] NCCL INFO Channel 15/1 : 3[3] -> 0[0] via P2P/IPC
CL INFO Channel 10/0 : 1[1] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:42:487 [1] NCCL INFO Channel 11/0 : 1[1] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:42:487 [1] NCCL INFO Channel 12/0 : 1[1] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:42:487 [1] NCCL INFO Channel 13/0 : 1[1] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:42:487 [1] NCCL INFO Channel 14/0 : 1[1] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:42:487 [1] NCCL INFO Channel 15/0 : 1[1] -> 0[0] via P2P/IPC
[rank1]:[E1029 17:25:15.355741990 ProcessGroupNCCL.cpp:1515] [PG 3 Rank 1] Process group watchdog thread terminated with exception: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Exception raised from c10_cuda_check_implementation at ../c10/cuda/CUDAException.cpp:43 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f8112f77f86 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7f8112f26d10 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x118 (0x7f81133cbf08 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #3: c10d::ProcessGroupNCCL::WorkNCCL::finishedGPUExecutionInternal() const + 0x56 (0x7f80c49c33e6 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #4: c10d::ProcessGroupNCCL::WorkNCCL::isCompleted() + 0xa0 (0x7f80c49c8600 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #5: c10d::ProcessGroupNCCL::watchdogHandler() + 0x1da (0x7f80c49cf2ba in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #6: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x10c (0x7f80c49d16fc in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #7: <unknown function> + 0xdc253 (0x7f81126b0253 in /usr/lib/x86_64-linux-gnu/libstdc++.so.6)
frame #8: <unknown function> + 0x94ac3 (0x7f8113e94ac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #9: <unknown function> + 0x126a40 (0x7f8113f26a40 in /usr/lib/x86_64-linux-gnu/libc.so.6)

terminate called after throwing an instance of 'c10::DistBackendError'
terminate called after throwing an instance of 'c10::DistBackendError'
sh-l20y-cn-0056:42:487 [1] NCCL INFO Connected all trees
sh-l20y-cn-0056:42:487 [1] NCCL INFO NVLS comm 0x55e6f7308a20 headRank 1 nHeads 8 buffSize 4194304 memSize 2097152 nvlsPerRankSize 201326592 nvlsTotalSize 1610612736
sh-l20y-cn-0056:42:487 [1] NCCL INFO threadThresholds 8/8/64 | 64/8/64 | 512 | 512
sh-l20y-cn-0056:42:487 [1] NCCL INFO 16 coll channels, 0 collnet channels, 16 nvls channels, 16 p2p channels, 16 p2p channels per peer
sh-l20y-cn-0056:42:487 [1] NCCL INFO comm 0x55e6f7308a20 rank 1 nranks 8 cudaDev 1 nvmlDev 1 busId 3b000 commId 0x11c8bbd57e1dc5db - Init COMPLETE
sh-l20y-cn-0056:42:521 [1] NCCL INFO Channel 00/1 : 1[1] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:42:521 [1] NCCL INFO Channel 01/1 : 1[1] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:42:521 [1] NCCL INFO Channel 02/1 : 1[1] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:42:521 [1] NCCL INFO Channel 03/1 : 1[1] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:42:521 [1] NCCL INFO Channel 04/1 : 1[1] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:42:521 [1] NCCL INFO Channel 05/1 : 1[1] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:42:521 [1] NCCL INFO Channel 06/1 : 1[1] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:42:521 [1] NCCL INFO Channel 07/1 : 1[1] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:42:521 [1] NCCL INFO Channel 08/1 : 1[1] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:42:521 [1] NCCL INFO Channel 09/1 : 1[1] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:42:521 [1] NCCL INFO Channel 10/1 : 1[1] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:42:521 [1] NCCL INFO Channel 11/1 : 1[1] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:42:521 [1] NCCL INFO Channel 12/1 : 1[1] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:42:521 [1] NCCL INFO Channel 13/1 : 1[1] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:42:521 [1] NCCL INFO Channel 14/1 : 1[1] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:42:521 [1] NCCL INFO Channel 15/1 : 1[1] -> 0[0] via P2P/IPC
[rank2]:[E1029 17:25:15.355889087 ProcessGroupNCCL.cpp:1515] [PG 3 Rank 2] Process group watchdog thread terminated with exception: CUDA error: an illegal memory access was encountered
CL INFO Channel 10/0 : 2[2] -> 1[1] via P2P/IPC
sh-l20y-cn-0056:43:488 [2] NCCL INFO Channel 11/0 : 2[2] -> 1[1] via P2P/IPC
sh-l20y-cn-0056:43:488 [2] NCCL INFO Channel 12/0 : 2[2] -> 1[1] via P2P/IPC
sh-l20y-cn-0056:43:488 [2] NCCL INFO Channel 13/0 : 2[2] -> 1[1] via P2P/IPC
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
sh-l20y-cn-0056:43:488 [2] NCCL INFO Channel 14/0 : 2[2] -> 1[1] via P2P/IPC
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
sh-l20y-cn-0056:43:488 [2] NCCL INFO Channel 15/0 : 2[2] -> 1[1] via P2P/IPC
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
sh-l20y-cn-0056:43:488 [2] NCCL INFO Connected all trees

sh-l20y-cn-0056:43:488 [2] NCCL INFO NVLS comm 0x55e6f7308d30 headRank 2 nHeads 8 buffSize 4194304 memSize 2097152 nvlsPerRankSize 201326592 nvlsTotalSize 1610612736
Exception raised from c10_cuda_check_implementation at ../c10/cuda/CUDAException.cpp:43 (most recent call first):
sh-l20y-cn-0056:43:488 [2] NCCL INFO threadThresholds 8/8/64 | 64/8/64 | 512 | 512
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f8112f77f86 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
sh-l20y-cn-0056:43:488 [2] NCCL INFO 16 coll channels, 0 collnet channels, 16 nvls channels, 16 p2p channels, 16 p2p channels per peer
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7f8112f26d10 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
sh-l20y-cn-0056:43:488 [2] NCCL INFO comm 0x55e6f7308d30 rank 2 nranks 8 cudaDev 2 nvmlDev 2 busId 4c000 commId 0x11c8bbd57e1dc5db - Init COMPLETE
frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x118 (0x7f81133cbf08 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
sh-l20y-cn-0056:43:523 [2] NCCL INFO Channel 00/1 : 2[2] -> 0[0] via P2P/IPC
frame #3: c10d::ProcessGroupNCCL::WorkNCCL::finishedGPUExecutionInternal() const + 0x56 (0x7f80c49c33e6 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
sh-l20y-cn-0056:43:523 [2] NCCL INFO Channel 01/1 : 2[2] -> 0[0] via P2P/IPC
frame #4: c10d::ProcessGroupNCCL::WorkNCCL::isCompleted() + 0xa0 (0x7f80c49c8600 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #5: c10d::ProcessGroupNCCL::watchdogHandler() + 0x1da (0x7f80c49cf2ba in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #6: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x10c (0x7f80c49d16fc in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #7: <unknown function> + 0xdc253 (0x7f81126b0253 in /usr/lib/x86_64-linux-gnu/libstdc++.so.6)
frame #8: <unknown function> + 0x94ac3 (0x7f8113e94ac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #9: <unknown function> + 0x126a40 (0x7f8113f26a40 in /usr/lib/x86_64-linux-gnu/libc.so.6)

terminate called after throwing an instance of 'c10::DistBackendError'
sh-l20y-cn-0056:43:523 [2] NCCL INFO Channel 02/1 : 2[2] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:43:523 [2] NCCL INFO Channel 03/1 : 2[2] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:43:523 [2] NCCL INFO Channel 04/1 : 2[2] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:43:523 [2] NCCL INFO Channel 05/1 : 2[2] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:43:523 [2] NCCL INFO Channel 06/1 : 2[2] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:43:523 [2] NCCL INFO Channel 07/1 : 2[2] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:43:523 [2] NCCL INFO Channel 08/1 : 2[2] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:43:523 [2] NCCL INFO Channel 09/1 : 2[2] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:43:523 [2] NCCL INFO Channel 10/1 : 2[2] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:43:523 [2] NCCL INFO Channel 11/1 : 2[2] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:43:523 [2] NCCL INFO Channel 12/1 : 2[2] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:43:523 [2] NCCL INFO Channel 13/1 : 2[2] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:43:523 [2] NCCL INFO Channel 14/1 : 2[2] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:43:523 [2] NCCL INFO Channel 15/1 : 2[2] -> 0[0] via P2P/IPC
[rank6]:[E1029 17:25:15.356128519 ProcessGroupNCCL.cpp:1515] [PG 3 Rank 6] Process group watchdog thread terminated with exception: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

CL INFO Channel 10/0 : 6[6] -> 5[5] via P2P/IPC
sh-l20y-cn-0056:47:491 [6] NCCL INFO Channel 11/0 : 6[6] -> 5[5] via P2P/IPC
sh-l20y-cn-0056:47:491 [6] NCCL INFO Channel 12/0 : 6[6] -> 5[5] via P2P/IPC
sh-l20y-cn-0056:47:491 [6] NCCL INFO Channel 13/0 : 6[6] -> 5[5] via P2P/IPC
Exception raised from c10_cuda_check_implementation at ../c10/cuda/CUDAException.cpp:43 (most recent call first):
sh-l20y-cn-0056:47:491 [6] NCCL INFO Channel 14/0 : 6[6] -> 5[5] via P2P/IPC
sh-l20y-cn-0056:47:491 [6] NCCL INFO Channel 15/0 : 6[6] -> 5[5] via P2P/IPC
sh-l20y-cn-0056:47:491 [6] NCCL INFO Connected all trees
sh-l20y-cn-0056:47:491 [6] NCCL INFO NVLS comm 0x55e6f7309580 headRank 6 nHeads 8 buffSize 4194304 memSize 2097152 nvlsPerRankSize 201326592 nvlsTotalSize 1610612736
sh-l20y-cn-0056:47:491 [6] NCCL INFO threadThresholds 8/8/64 | 64/8/64 | 512 | 512
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f8112f77f86 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
sh-l20y-cn-0056:47:491 [6] NCCL INFO 16 coll channels, 0 collnet channels, 16 nvls channels, 16 p2p channels, 16 p2p channels per peer
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7f8112f26d10 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
sh-l20y-cn-0056:47:491 [6] NCCL INFO comm 0x55e6f7309580 rank 6 nranks 8 cudaDev 6 nvmlDev 6 busId dd000 commId 0x11c8bbd57e1dc5db - Init COMPLETE
frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x118 (0x7f81133cbf08 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
sh-l20y-cn-0056:47:520 [6] NCCL INFO Channel 00/1 : 6[6] -> 0[0] via P2P/IPC
frame #3: c10d::ProcessGroupNCCL::WorkNCCL::finishedGPUExecutionInternal() const + 0x56 (0x7f80c49c33e6 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
sh-l20y-cn-0056:47:520 [6] NCCL INFO Channel 01/1 : 6[6] -> 0[0] via P2P/IPC
frame #4: c10d::ProcessGroupNCCL::WorkNCCL::isCompleted() + 0xa0 (0x7f80c49c8600 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
sh-l20y-cn-0056:47:520 [6] NCCL INFO Channel 02/1 : 6[6] -> 0[0] via P2P/IPC
frame #5: c10d::ProcessGroupNCCL::watchdogHandler() + 0x1da (0x7f80c49cf2ba in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
sh-l20y-cn-0056:47:520 [6] NCCL INFO Channel 03/1 : 6[6] -> 0[0] via P2P/IPC
frame #6: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x10c (0x7f80c49d16fc in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
sh-l20y-cn-0056:47:520 [6] NCCL INFO Channel 04/1 : 6[6] -> 0[0] via P2P/IPC
frame #7: <unknown function> + 0xdc253 (0x7f81126b0253 in /usr/lib/x86_64-linux-gnu/libstdc++.so.6)
sh-l20y-cn-0056:47:520 [6] NCCL INFO Channel 05/1 : 6[6] -> 0[0] via P2P/IPC
frame #8: <unknown function> + 0x94ac3 (0x7f8113e94ac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
sh-l20y-cn-0056:47:520 [6] NCCL INFO Channel 06/1 : 6[6] -> 0[0] via P2P/IPC
frame #9: <unknown function> + 0x126a40 (0x7f8113f26a40 in /usr/lib/x86_64-linux-gnu/libc.so.6)
sh-l20y-cn-0056:47:520 [6] NCCL INFO Channel 07/1 : 6[6] -> 0[0] via P2P/IPC

sh-l20y-cn-0056:47:520 [6] NCCL INFO Channel 08/1 : 6[6] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:47:520 [6] NCCL INFO Channel 09/1 : 6[6] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:47:520 [6] NCCL INFO Channel 10/1 : 6[6] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:47:520 [6] NCCL INFO Channel 11/1 : 6[6] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:47:520 [6] NCCL INFO Channel 12/1 : 6[6] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:47:520 [6] NCCL INFO Channel 13/1 : 6[6] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:47:520 [6] NCCL INFO Channel 14/1 : 6[6] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:47:520 [6] NCCL INFO Channel 15/1 : 6[6] -> 0[0] via P2P/IPC
terminate called after throwing an instance of 'c10::DistBackendError'
  what():  [PG 3 Rank 3] Process group watchdog thread terminated with exception: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Exception raised from c10_cuda_check_implementation at ../c10/cuda/CUDAException.cpp:43 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f8112f77f86 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7f8112f26d10 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x118 (0x7f81133cbf08 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #3: c10d::ProcessGroupNCCL::WorkNCCL::finishedGPUExecutionInternal() const + 0x56 (0x7f80c49c33e6 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #4: c10d::ProcessGroupNCCL::WorkNCCL::isCompleted() + 0xa0 (0x7f80c49c8600 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #5: c10d::ProcessGroupNCCL::watchdogHandler() + 0x1da (0x7f80c49cf2ba in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #6: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x10c (0x7f80c49d16fc in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #7: <unknown function> + 0xdc253 (0x7f81126b0253 in /usr/lib/x86_64-linux-gnu/libstdc++.so.6)
frame #8: <unknown function> + 0x94ac3 (0x7f8113e94ac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #9: <unknown function> + 0x126a40 (0x7f8113f26a40 in /usr/lib/x86_64-linux-gnu/libc.so.6)

Exception raised from ncclCommWatchdog at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1521 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f8112f77f86 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0xe5aa84 (0x7f80c465aa84 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #2: <unknown function> + 0xdc253 (0x7f81126b0253 in /usr/lib/x86_64-linux-gnu/libstdc++.so.6)
frame #3: <unknown function> + 0x94ac3 (0x7f8113e94ac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #4: <unknown function> + 0x126a40 (0x7f8113f26a40 in /usr/lib/x86_64-linux-gnu/libc.so.6)

  what():  [PG 3 Rank 1] Process group watchdog thread terminated with exception: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Exception raised from c10_cuda_check_implementation at ../c10/cuda/CUDAException.cpp:43 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f8112f77f86 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7f8112f26d10 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x118 (0x7f81133cbf08 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #3: c10d::ProcessGroupNCCL::WorkNCCL::finishedGPUExecutionInternal() const + 0x56 (0x7f80c49c33e6 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #4: c10d::ProcessGroupNCCL::WorkNCCL::isCompleted() + 0xa0 (0x7f80c49c8600 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #5: c10d::ProcessGroupNCCL::watchdogHandler() + 0x1da (0x7f80c49cf2ba in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #6: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x10c (0x7f80c49d16fc in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #7: <unknown function> + 0xdc253 (0x7f81126b0253 in /usr/lib/x86_64-linux-gnu/libstdc++.so.6)
frame #8: <unknown function> + 0x94ac3 (0x7f8113e94ac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #9: <unknown function> + 0x126a40 (0x7f8113f26a40 in /usr/lib/x86_64-linux-gnu/libc.so.6)

Exception raised from ncclCommWatchdog at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1521 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f8112f77f86 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0xe5aa84 (0x7f80c465aa84 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #2: <unknown function> + 0xdc253 (0x7f81126b0253 in /usr/lib/x86_64-linux-gnu/libstdc++.so.6)
frame #3: <unknown function> + 0x94ac3 (0x7f8113e94ac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #4: <unknown function> + 0x126a40 (0x7f8113f26a40 in /usr/lib/x86_64-linux-gnu/libc.so.6)

  what():  [PG 3 Rank 2] Process group watchdog thread terminated with exception: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Exception raised from c10_cuda_check_implementation at ../c10/cuda/CUDAException.cpp:43 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f8112f77f86 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7f8112f26d10 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x118 (0x7f81133cbf08 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #3: c10d::ProcessGroupNCCL::WorkNCCL::finishedGPUExecutionInternal() const + 0x56 (0x7f80c49c33e6 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #4: c10d::ProcessGroupNCCL::WorkNCCL::isCompleted() + 0xa0 (0x7f80c49c8600 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #5: c10d::ProcessGroupNCCL::watchdogHandler() + 0x1da (0x7f80c49cf2ba in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #6: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x10c (0x7f80c49d16fc in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #7: <unknown function> + 0xdc253 (0x7f81126b0253 in /usr/lib/x86_64-linux-gnu/libstdc++.so.6)
frame #8: <unknown function> + 0x94ac3 (0x7f8113e94ac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #9: <unknown function> + 0x126a40 (0x7f8113f26a40 in /usr/lib/x86_64-linux-gnu/libc.so.6)

Exception raised from ncclCommWatchdog at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1521 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f8112f77f86 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0xe5aa84 (0x7f80c465aa84 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #2: <unknown function> + 0xdc253 (0x7f81126b0253 in /usr/lib/x86_64-linux-gnu/libstdc++.so.6)
frame #3: <unknown function> + 0x94ac3 (0x7f8113e94ac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #4: <unknown function> + 0x126a40 (0x7f8113f26a40 in /usr/lib/x86_64-linux-gnu/libc.so.6)

  what():  [PG 3 Rank 6] Process group watchdog thread terminated with exception: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Exception raised from c10_cuda_check_implementation at ../c10/cuda/CUDAException.cpp:43 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f8112f77f86 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7f8112f26d10 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x118 (0x7f81133cbf08 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #3: c10d::ProcessGroupNCCL::WorkNCCL::finishedGPUExecutionInternal() const + 0x56 (0x7f80c49c33e6 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #4: c10d::ProcessGroupNCCL::WorkNCCL::isCompleted() + 0xa0 (0x7f80c49c8600 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #5: c10d::ProcessGroupNCCL::watchdogHandler() + 0x1da (0x7f80c49cf2ba in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #6: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x10c (0x7f80c49d16fc in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #7: <unknown function> + 0xdc253 (0x7f81126b0253 in /usr/lib/x86_64-linux-gnu/libstdc++.so.6)
frame #8: <unknown function> + 0x94ac3 (0x7f8113e94ac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #9: <unknown function> + 0x126a40 (0x7f8113f26a40 in /usr/lib/x86_64-linux-gnu/libc.so.6)

Exception raised from ncclCommWatchdog at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1521 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f8112f77f86 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0xe5aa84 (0x7f80c465aa84 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #2: <unknown function> + 0xdc253 (0x7f81126b0253 in /usr/lib/x86_64-linux-gnu/libstdc++.so.6)
frame #3: <unknown function> + 0x94ac3 (0x7f8113e94ac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #4: <unknown function> + 0x126a40 (0x7f8113f26a40 in /usr/lib/x86_64-linux-gnu/libc.so.6)

PC
sh-l20y-cn-0056:48:490 [7] NCCL INFO Channel 10/0 : 7[7] -> 6[6] via P2P/IPC
[rank7]:[E1029 17:25:15.363365220 ProcessGroupNCCL.cpp:1515] [PG 3 Rank 7] Process group watchdog thread terminated with exception: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Exception raised from c10_cuda_check_implementation at ../c10/cuda/CUDAException.cpp:43 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f8112f77f86 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
sh-l20y-cn-0056:48:490 [7] NCCL INFO Channel 11/0 : 7[7] -> 6[6] via P2P/IPC
sh-l20y-cn-0056:48:490 [7] NCCL INFO Channel 12/0 : 7[7] -> 6[6] via P2P/IPC
sh-l20y-cn-0056:48:490 [7] NCCL INFO Channel 13/0 : 7[7] -> 6[6] via P2P/IPC
sh-l20y-cn-0056:48:490 [7] NCCL INFO Channel 14/0 : 7[7] -> 6[6] via P2P/IPC
sh-l20y-cn-0056:48:490 [7] NCCL INFO Channel 15/0 : 7[7] -> 6[6] via P2P/IPC
sh-l20y-cn-0056:48:490 [7] NCCL INFO Connected all trees
sh-l20y-cn-0056:48:490 [7] NCCL INFO NVLS comm 0x55e6f72fb090 headRank 7 nHeads 8 buffSize 4194304 memSize 2097152 nvlsPerRankSize 201326592 nvlsTotalSize 1610612736
sh-l20y-cn-0056:48:490 [7] NCCL INFO threadThresholds 8/8/64 | 64/8/64 | 512 | 512
sh-l20y-cn-0056:48:490 [7] NCCL INFO 16 coll channels, 0 collnet channels, 16 nvls channels, 16 p2p channels, 16 p2p channels per peer
sh-l20y-cn-0056:48:490 [7] NCCL INFO comm 0x55e6f72fb090 rank 7 nranks 8 cudaDev 7 nvmlDev 7 busId e4000 commId 0x11c8bbd57e1dc5db - Init COMPLETE
sh-l20y-cn-0056:48:525 [7] NCCL INFO Channel 00/1 : 7[7] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:48:525 [7] NCCL INFO Channel 01/1 : 7[7] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:48:525 [7] NCCL INFO Channel 02/1 : 7[7] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:48:525 [7] NCCL INFO Channel 03/1 : 7[7] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:48:525 [7] NCCL INFO Channel 04/1 : 7[7] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:48:525 [7] NCCL INFO Channel 05/1 : 7[7] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:48:525 [7] NCCL INFO Channel 06/1 : 7[7] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:48:525 [7] NCCL INFO Channel 07/1 : 7[7] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:48:525 [7] NCCL INFO Channel 08/1 : 7[7] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:48:525 [7] NCCL INFO Channel 09/1 : 7[7] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:48:525 [7] NCCL INFO Channel 10/1 : 7[7] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:48:525 [7] NCCL INFO Channel 11/1 : 7[7] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:48:525 [7] NCCL INFO Channel 12/1 : 7[7] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:48:525 [7] NCCL INFO Channel 13/1 : 7[7] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:48:525 [7] NCCL INFO Channel 14/1 : 7[7] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:48:525 [7] NCCL INFO Channel 15/1 : 7[7] -> 0[0] via P2P/IPC
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7f8112f26d10 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x118 (0x7f81133cbf08 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #3: c10d::ProcessGroupNCCL::WorkNCCL::finishedGPUExecutionInternal() const + 0x56 (0x7f80c49c33e6 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #4: c10d::ProcessGroupNCCL::WorkNCCL::isCompleted() + 0xa0 (0x7f80c49c8600 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #5: c10d::ProcessGroupNCCL::watchdogHandler() + 0x1da (0x7f80c49cf2ba in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #6: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x10c (0x7f80c49d16fc in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #7: <unknown function> + 0xdc253 (0x7f81126b0253 in /usr/lib/x86_64-linux-gnu/libstdc++.so.6)
frame #8: <unknown function> + 0x94ac3 (0x7f8113e94ac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #9: <unknown function> + 0x126a40 (0x7f8113f26a40 in /usr/lib/x86_64-linux-gnu/libc.so.6)

terminate called after throwing an instance of 'c10::DistBackendError'
  what():  [PG 3 Rank 7] Process group watchdog thread terminated with exception: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Exception raised from c10_cuda_check_implementation at ../c10/cuda/CUDAException.cpp:43 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f8112f77f86 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7f8112f26d10 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x118 (0x7f81133cbf08 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #3: c10d::ProcessGroupNCCL::WorkNCCL::finishedGPUExecutionInternal() const + 0x56 (0x7f80c49c33e6 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #4: c10d::ProcessGroupNCCL::WorkNCCL::isCompleted() + 0xa0 (0x7f80c49c8600 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #5: c10d::ProcessGroupNCCL::watchdogHandler() + 0x1da (0x7f80c49cf2ba in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #6: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x10c (0x7f80c49d16fc in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #7: <unknown function> + 0xdc253 (0x7f81126b0253 in /usr/lib/x86_64-linux-gnu/libstdc++.so.6)
frame #8: <unknown function> + 0x94ac3 (0x7f8113e94ac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #9: <unknown function> + 0x126a40 (0x7f8113f26a40 in /usr/lib/x86_64-linux-gnu/libc.so.6)

Exception raised from ncclCommWatchdog at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1521 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f8112f77f86 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0xe5aa84 (0x7f80c465aa84 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #2: <unknown function> + 0xdc253 (0x7f81126b0253 in /usr/lib/x86_64-linux-gnu/libstdc++.so.6)
frame #3: <unknown function> + 0x94ac3 (0x7f8113e94ac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #4: <unknown function> + 0x126a40 (0x7f8113f26a40 in /usr/lib/x86_64-linux-gnu/libc.so.6)

[rank5]:[E1029 17:25:15.609097879 ProcessGroupNCCL.cpp:1515] [PG 3 Rank 5] Process group watchdog thread terminated with exception: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Exception raised from c10_cuda_check_implementation at ../c10/cuda/CUDAException.cpp:43 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f8112f77f86 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7f8112f26d10 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x118 (0x7f81133cbf08 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #3: c10d::ProcessGroupNCCL::WorkNCCL::finishedGPUExecutionInternal() const + 0x56 (0x7f80c49c33e6 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #4: c10d::ProcessGroupNCCL::WorkNCCL::isCompleted() + 0xa0 (0x7f80c49c8600 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #5: c10d::ProcessGroupNCCL::watchdogHandler() + 0x1da (0x7f80c49cf2ba in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #6: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x10c (0x7f80c49d16fc in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #7: <unknown function> + 0xdc253 (0x7f81126b0253 in /usr/lib/x86_64-linux-gnu/libstdc++.so.6)
frame #8: <unknown function> + 0x94ac3 (0x7f8113e94ac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #9: <unknown function> + 0x126a40 (0x7f8113f26a40 in /usr/lib/x86_64-linux-gnu/libc.so.6)

terminate called after throwing an instance of 'c10::DistBackendError'
CL INFO Channel 10/0 : 5[5] -> 4[4] via P2P/IPC
sh-l20y-cn-0056:46:493 [5] NCCL INFO Channel 11/0 : 5[5] -> 4[4] via P2P/IPC
sh-l20y-cn-0056:46:493 [5] NCCL INFO Channel 12/0 : 5[5] -> 4[4] via P2P/IPC
sh-l20y-cn-0056:46:493 [5] NCCL INFO Channel 13/0 : 5[5] -> 4[4] via P2P/IPC
sh-l20y-cn-0056:46:493 [5] NCCL INFO Channel 14/0 : 5[5] -> 4[4] via P2P/IPC
sh-l20y-cn-0056:46:493 [5] NCCL INFO Channel 15/0 : 5[5] -> 4[4] via P2P/IPC
sh-l20y-cn-0056:46:493 [5] NCCL INFO Connected all trees
sh-l20y-cn-0056:46:493 [5] NCCL INFO NVLS comm 0x55e6f73098a0 headRank 5 nHeads 8 buffSize 4194304 memSize 2097152 nvlsPerRankSize 201326592 nvlsTotalSize 1610612736
sh-l20y-cn-0056:46:493 [5] NCCL INFO threadThresholds 8/8/64 | 64/8/64 | 512 | 512
sh-l20y-cn-0056:46:493 [5] NCCL INFO 16 coll channels, 0 collnet channels, 16 nvls channels, 16 p2p channels, 16 p2p channels per peer
sh-l20y-cn-0056:46:493 [5] NCCL INFO comm 0x55e6f73098a0 rank 5 nranks 8 cudaDev 5 nvmlDev 5 busId d6000 commId 0x11c8bbd57e1dc5db - Init COMPLETE
sh-l20y-cn-0056:46:522 [5] NCCL INFO Channel 00/1 : 5[5] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:46:522 [5] NCCL INFO Channel 01/1 : 5[5] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:46:522 [5] NCCL INFO Channel 02/1 : 5[5] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:46:522 [5] NCCL INFO Channel 03/1 : 5[5] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:46:522 [5] NCCL INFO Channel 04/1 : 5[5] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:46:522 [5] NCCL INFO Channel 05/1 : 5[5] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:46:522 [5] NCCL INFO Channel 06/1 : 5[5] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:46:522 [5] NCCL INFO Channel 07/1 : 5[5] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:46:522 [5] NCCL INFO Channel 08/1 : 5[5] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:46:522 [5] NCCL INFO Channel 09/1 : 5[5] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:46:522 [5] NCCL INFO Channel 10/1 : 5[5] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:46:522 [5] NCCL INFO Channel 11/1 : 5[5] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:46:522 [5] NCCL INFO Channel 12/1 : 5[5] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:46:522 [5] NCCL INFO Channel 13/1 : 5[5] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:46:522 [5] NCCL INFO Channel 14/1 : 5[5] -> 0[0] via P2P/IPC
sh-l20y-cn-0056:46:522 [5] NCCL INFO Channel 15/1 : 5[5] -> 0[0] via P2P/IPC
  what():  [PG 3 Rank 5] Process group watchdog thread terminated with exception: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Exception raised from c10_cuda_check_implementation at ../c10/cuda/CUDAException.cpp:43 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f8112f77f86 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7f8112f26d10 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x118 (0x7f81133cbf08 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #3: c10d::ProcessGroupNCCL::WorkNCCL::finishedGPUExecutionInternal() const + 0x56 (0x7f80c49c33e6 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #4: c10d::ProcessGroupNCCL::WorkNCCL::isCompleted() + 0xa0 (0x7f80c49c8600 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #5: c10d::ProcessGroupNCCL::watchdogHandler() + 0x1da (0x7f80c49cf2ba in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #6: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x10c (0x7f80c49d16fc in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #7: <unknown function> + 0xdc253 (0x7f81126b0253 in /usr/lib/x86_64-linux-gnu/libstdc++.so.6)
frame #8: <unknown function> + 0x94ac3 (0x7f8113e94ac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #9: <unknown function> + 0x126a40 (0x7f8113f26a40 in /usr/lib/x86_64-linux-gnu/libc.so.6)

Exception raised from ncclCommWatchdog at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1521 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x96 (0x7f8112f77f86 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0xe5aa84 (0x7f80c465aa84 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #2: <unknown function> + 0xdc253 (0x7f81126b0253 in /usr/lib/x86_64-linux-gnu/libstdc++.so.6)
frame #3: <unknown function> + 0x94ac3 (0x7f8113e94ac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #4: <unknown function> + 0x126a40 (0x7f8113f26a40 in /usr/lib/x86_64-linux-gnu/libc.so.6)

/usr/lib/python3.10/multiprocessing/resource_tracker.py:224: UserWarning: resource_tracker: There appear to be 1 leaked shared_memory objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '

@sasha0552
Copy link
Contributor Author

It looks like there is something wrong with your email, it was rejected by Cloudflare due to lack of SPF, DMARC and DKIM. Could you try resending the email using a different address? You could also try sending the email to [email protected].
image

@StevenTang1998
Copy link

Hi @sasha0552 , I have resended them. If you have any confusion, feel free to contact me.

@sasha0552 sasha0552 marked this pull request as ready for review October 30, 2024 20:40
Copy link

mergify bot commented Oct 30, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. @sasha0552 please rebase it. https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@sasha0552 sasha0552 changed the title [WIP] [Bugfix] Fix illegal memory access error with chunked prefill, prefix caching and block manager v2 enabled together [WIP] [Bugfix] Fix illegal memory access error with chunked prefill, prefix caching, block manager v2 and xformers enabled together Oct 30, 2024
@sasha0552
Copy link
Contributor Author

sasha0552 commented Oct 30, 2024

I have reproduced the illegal memory access error that I got locally, on CI. This only happens with the xformers attention backend.

https://buildkite.com/vllm/fastcheck/builds/6857#0192df28-9c8c-4244-b850-f3ccdbf7ca2e

Also, I pushed a fix, it seems that copying the MetadataBuilder code from flash-attn helps. (related: #7018)

repro.py for manual testing
from vllm import LLM, SamplingParams, TokensPrompt

llm = LLM(
  dtype="float16",
  enable_chunked_prefill=True,
  enable_prefix_caching=True,
  enforce_eager=True,
  gpu_memory_utilization=0.20,
  max_model_len=4096,
  model="Qwen/Qwen2.5-0.5B-Instruct",
  swap_space=0,
  tensor_parallel_size=1,
)

llm.generate(TokensPrompt(prompt_token_ids=([0] * 588 ) + ([1] * 1332) + ([2] * 30  ) + ([3] * 1   )), SamplingParams(max_tokens=1, seed=42))
llm.generate(TokensPrompt(prompt_token_ids=([0] * 588 ) + ([1] * 1332) + ([4] * 3   ) + ([5] * 50  )), SamplingParams(max_tokens=1, seed=42))
llm.generate(TokensPrompt(prompt_token_ids=([0] * 588 ) + ([1] * 1332) + ([2] * 30  ) + ([6] * 95  )), SamplingParams(max_tokens=1, seed=42))
llm.generate(TokensPrompt(prompt_token_ids=([0] * 588 ) + ([1] * 1332) + ([4] * 3   ) + ([7] * 174 )), SamplingParams(max_tokens=1, seed=42))
llm.generate(TokensPrompt(prompt_token_ids=([0] * 588 ) + ([8] * 1539)                              ), SamplingParams(max_tokens=1, seed=42))

@StevenTang1998 I tried to reproduce the illegal memory access using your prompts and it does not crash. It looks like your illegal memory access is unrelated to mine, i.e. has a different cause. Mine happens only with the xformers backend, which you are probably not using since you are using H100s. I suggest you reproduce the crash with compute-sanitizer running and submit the compute-sanitizer output as a separate issue for the vLLM team to sort out. Unfortunately, I can't fix yours as I don't have access to such hardware.

@sasha0552 sasha0552 force-pushed the cuda-illegal-memory-access-fix branch from fef6ade to a2d2a8b Compare October 30, 2024 21:29
@mergify mergify bot removed the needs-rebase label Oct 30, 2024
@comaniac
Copy link
Collaborator

Is this ready for review?

@sasha0552 sasha0552 force-pushed the cuda-illegal-memory-access-fix branch from 10a8040 to 9233866 Compare October 30, 2024 21:45
Signed-off-by: sasha0552 <[email protected]>
@sasha0552 sasha0552 force-pushed the cuda-illegal-memory-access-fix branch from 9233866 to 63c87eb Compare October 30, 2024 21:47
@@ -384,9 +392,166 @@ def _get_seq_len_block_table_args(
raise AttributeError(f"Invalid attention type {str(attn_type)}")


class XFormersMetadataBuilder(CommonMetadataBuilder[XFormersMetadata]):
class XFormersMetadataBuilder(AttentionMetadataBuilder[XFormersMetadata]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we need to re-implement the metadata builder? What's the difference between this implementation and the common one?

Copy link
Contributor Author

@sasha0552 sasha0552 Oct 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It has #7018 and

            if prefix_cache_hit:
                # NOTE(woosuk): For xformers, the block table should
                # include the entries for the incoming prefill tokens.
                block_table = block_tables[seq_id]
            elif ((chunked_prefill_enabled or not is_prompt)
                  and block_tables is not None):
                if curr_sliding_window_block == 0:
                    block_table = block_tables[seq_id]
                else:
                    block_table = block_tables[seq_id][
                        -curr_sliding_window_block:]

from flash-attn. However, it is based on the common one, not just copied from flash-attn. Theoretically, I could modify the common one, but that might affect other attention backends.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, it works even without #7018, so I'll remove the changes from #7018.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you try to modify the common one? We should avoid to introduce this divergence as possible.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@StevenTang1998
Copy link

I have reproduced the illegal memory access error that I got locally, on CI. This only happens with the xformers attention backend.

https://buildkite.com/vllm/fastcheck/builds/6857#0192df28-9c8c-4244-b850-f3ccdbf7ca2e

Also, I pushed a fix, it seems that copying the MetadataBuilder code from flash-attn helps. (related: #7018)

repro.py for manual testing
@StevenTang1998 I tried to reproduce the illegal memory access using your prompts and it does not crash. It looks like your illegal memory access is unrelated to mine, i.e. has a different cause. Mine happens only with the xformers backend, which you are probably not using since you are using H100s. I suggest you reproduce the crash with compute-sanitizer running and submit the compute-sanitizer output as a separate issue for the vLLM team to sort out. Unfortunately, I can't fix yours as I don't have access to such hardware.

OK, thanks for your help! BTW, could you tell me how to use compute-sanitizer with my code?

@sasha0552
Copy link
Contributor Author

sasha0552 commented Oct 31, 2024

You can run vLLM as follows

compute-sanitizer --launch-timeout=60 --log-file=compute-sanitizer.log --print-limit=0 --target-processes=application-only --tool=memcheck vllm serve ...

However, it needs to be installed first. On ubuntu/vLLM docker container this can be done as follows

apt-get update && apt-get install -y cuda-sanitizer-12-6 (replace 12-6 with the cuda version, which you can find out with apt list --installed | grep cuda).

In general, in the docker I run vLLM with compute-sanitizer as follows

docker run                                                \
  --cap-add=SYS_ADMIN                                     \
  --cap-add=SYS_PTRACE                                    \
  --entrypoint sh                                         \
  --env VLLM_NO_USAGE_STATS=1                             \
  --gpus all                                              \
  --ipc host                                              \
  --privileged                                            \
  --restart no                                            \
  --rm                                                    \
  --runtime nvidia                                        \
  --security-opt seccomp=unconfined                       \
  --volume ./hf_cache:/root/.cache/huggingface            \
  --volume ./repro.py:/repro.py                           \
  --volume ./logs:/logs                                   \
    vllm-temp2:latest                                     \
      -c "                                                \
        compute-sanitizer                                 \
          --launch-timeout=60                             \
          --log-file=/logs/l1.log                         \
          --padding=32                                    \
          --print-limit=0                                 \
          --save=/logs/l1.xml                             \
          --save-session-details                          \
          --target-processes=application-only             \
          --tool=memcheck                                 \
          --xml                                           \
            python3                                       \
              /repro.py                                   \
      "

I created vllm-temp2:latest by running docker run -it --rm --entrypoint sh vllm/vllm-openai:latest, installing compute-sanitizer inside, and then running docker commit <container id> vllm-temp2:latest in the second terminal.

@StevenTang1998
Copy link

You can run vLLM as follows

compute-sanitizer --launch-timeout=60 --log-file=compute-sanitizer.log --print-limit=0 --target-processes=application-only --tool=memcheck vllm serve ...

However, it needs to be installed first. On ubuntu/vLLM docker container this can be done as follows

apt-get update && apt-get install -y cuda-sanitizer-12-6 (replace 12-6 with the cuda version, which you can find out with apt list --installed | grep cuda).

In general, in the docker I run vLLM with compute-sanitizer as follows

Hi @sasha0552 , thanks for your help! Sorry, I am still a little confused. As I know, vllm serve deals with online query one-by-one, however, my code is offline batching. (I have tested that processing query one-by-one does not lead to error, but generating in batch will.)

@sasha0552
Copy link
Contributor Author

You can simply replace vllm serve with python3 yourscript.py.

@comaniac comaniac added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 31, 2024
Signed-off-by: sasha0552 <[email protected]>
@sasha0552
Copy link
Contributor Author

@comaniac All tests passed, could you please merge this PR?

@comaniac comaniac merged commit 55650c8 into vllm-project:main Oct 31, 2024
59 checks passed
@sasha0552 sasha0552 deleted the cuda-illegal-memory-access-fix branch October 31, 2024 18:48
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Nov 4, 2024
…ix caching, block manager v2 and xformers enabled together (vllm-project#9532)

Signed-off-by: sasha0552 <[email protected]>
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Nov 4, 2024
…ix caching, block manager v2 and xformers enabled together (vllm-project#9532)

Signed-off-by: sasha0552 <[email protected]>
Signed-off-by: Linkun Chen <[email protected]>
bigPYJ1151 pushed a commit to bigPYJ1151/vllm that referenced this pull request Nov 5, 2024
…ix caching, block manager v2 and xformers enabled together (vllm-project#9532)

Signed-off-by: sasha0552 <[email protected]>
hissu-hyvarinen pushed a commit to ROCm/vllm that referenced this pull request Nov 6, 2024
…ix caching, block manager v2 and xformers enabled together (vllm-project#9532)

Signed-off-by: sasha0552 <[email protected]>
JC1DA pushed a commit to JC1DA/vllm that referenced this pull request Nov 11, 2024
…ix caching, block manager v2 and xformers enabled together (vllm-project#9532)

Signed-off-by: sasha0552 <[email protected]>
Signed-off-by: Loc Huynh <[email protected]>
sumitd2 pushed a commit to sumitd2/vllm that referenced this pull request Nov 14, 2024
…ix caching, block manager v2 and xformers enabled together (vllm-project#9532)

Signed-off-by: sasha0552 <[email protected]>
Signed-off-by: Sumit Dubey <[email protected]>
KuntaiDu pushed a commit to KuntaiDu/vllm that referenced this pull request Nov 20, 2024
…ix caching, block manager v2 and xformers enabled together (vllm-project#9532)

Signed-off-by: sasha0552 <[email protected]>
mfournioux pushed a commit to mfournioux/vllm that referenced this pull request Nov 20, 2024
…ix caching, block manager v2 and xformers enabled together (vllm-project#9532)

Signed-off-by: sasha0552 <[email protected]>
Signed-off-by: Maxime Fournioux <[email protected]>
tlrmchlsmth pushed a commit to neuralmagic/vllm that referenced this pull request Nov 23, 2024
…ix caching, block manager v2 and xformers enabled together (vllm-project#9532)

Signed-off-by: sasha0552 <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
sleepwalker2017 pushed a commit to sleepwalker2017/vllm that referenced this pull request Dec 13, 2024
…ix caching, block manager v2 and xformers enabled together (vllm-project#9532)

Signed-off-by: sasha0552 <[email protected]>
sergeykochetkov pushed a commit to sergeykochetkov/vllm_spec_decoding that referenced this pull request Dec 27, 2024
…ix caching, block manager v2 and xformers enabled together (vllm-project#9532)

Signed-off-by: sasha0552 <[email protected]>
Signed-off-by: s.kochetkov <[email protected]>
whyiug pushed a commit to whyiug/vllm that referenced this pull request Jan 4, 2025
…ix caching, block manager v2 and xformers enabled together (vllm-project#9532)

Signed-off-by: sasha0552 <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants