Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ci][distributed] add distributed test gptq_marlin with tp = 2 #6010

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

llmpros
Copy link
Contributor

@llmpros llmpros commented Jul 1, 2024

follow-up pr of #6007

@llmpros llmpros force-pushed the add_test branch 3 times, most recently from 6a2004c to 8e128d2 Compare July 1, 2024 03:56
@youkaichao
Copy link
Member

Thanks for the PR! You need to move the test from models to distributed:

https://github.com/vllm-project/vllm/blob/main/.buildkite/test-pipeline.yaml

In addition, because of some limitations, you might only test the tp=2 case. It is not safe to test two vLLM instances together.

@llmpros llmpros force-pushed the add_test branch 3 times, most recently from 63b9545 to 49141fb Compare July 1, 2024 05:30
@llmpros llmpros changed the title add tp>1 test coverage for gptq_marlin [ci][distributed] move test gptq_marlin to distributed with tp = 2 Jul 1, 2024
@DarkLight1337
Copy link
Collaborator

DarkLight1337 commented Jul 1, 2024

Imo we should keep the original tp=1 test and add a new file in distributed tests for the tp=2 case.

@llmpros
Copy link
Contributor Author

llmpros commented Jul 1, 2024

Imo we should keep the original tp=1 test and add a new file in distributed tests for the tp=2 case.

make sense - so is it better to abstract the common following test codes into a new code block (e.g. test_gptq_marlin_common) to avoid duplicate and let tp=1 (under tests/models) and tp=2 (under tests/distributed) to call test_gptq_marlin_common respectively, or just copy the original tests/models/test_gptq_marlin.py to tests/distributed/ but only change tp=2 (with small amount of duplicate)?

    # test_gptq_marlin_common()
    # Run marlin.
    with vllm_runner(model_name=model_name,
                     revision=revision,
                     dtype=dtype,
                     quantization="marlin",
                     max_model_len=MAX_MODEL_LEN,
                     tensor_parallel_size=2,
                     distributed_executor_backend=distributed_executor_backend
                     ) as gptq_marlin_model:

        gptq_marlin_outputs = gptq_marlin_model.generate_greedy_logprobs(
            example_prompts[:-1], max_tokens, num_logprobs)
    _ROPE_DICT.clear()  # clear rope cache to avoid rope dtype error

    # Run gptq.
    # The naive gptq kernel doesn't support bf16 yet.
    # Here we always compare fp16/bf16 gpt marlin kernel
    # to fp16 gptq kernel.
    with vllm_runner(model_name=model_name,
                     revision=revision,
                     dtype="half",
                     quantization="gptq",
                     max_model_len=MAX_MODEL_LEN,
                     tensor_parallel_size=2,
                     distributed_executor_backend=distributed_executor_backend
                     ) as gptq_model:
        gptq_outputs = gptq_model.generate_greedy_logprobs(
            example_prompts[:-1], max_tokens, num_logprobs)
   return [gptq_marlin_outputs, gptq_outputs]

@DarkLight1337
Copy link
Collaborator

Let's abstract out the code (similar to what I did for the multimodal distributed tests)

@llmpros llmpros force-pushed the add_test branch 2 times, most recently from 64c0686 to f12288d Compare July 1, 2024 18:29
@llmpros llmpros changed the title [ci][distributed] move test gptq_marlin to distributed with tp = 2 [ci][distributed] add distributed test gptq_marlin with tp = 2 Jul 1, 2024
@@ -17,8 +18,6 @@

from .utils import check_logprobs_close

os.environ["TOKENIZERS_PARALLELISM"] = "true"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please keep this line as it avoids unnecessary warnings from HuggingFace

@llmpros
Copy link
Contributor Author

llmpros commented Jul 2, 2024

@DarkLight1337 it looks like the new unit test (test_distributed_gptq_marlin with tp=2) failed with following info. I may grab a box with 2 GPU and install the current main to test in real env.


[2024-07-02T04:45:17Z] (VllmWorkerProcess pid=22301) Process VllmWorkerProcess:
--
  | [2024-07-02T04:45:17Z] (VllmWorkerProcess pid=22301) Traceback (most recent call last):
  | [2024-07-02T04:45:17Z] (VllmWorkerProcess pid=22301)   File "/usr/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
  | [2024-07-02T04:45:17Z] (VllmWorkerProcess pid=22301)     self.run()
  | [2024-07-02T04:45:17Z] (VllmWorkerProcess pid=22301)   File "/usr/lib/python3.10/multiprocessing/process.py", line 108, in run
  | [2024-07-02T04:45:17Z] (VllmWorkerProcess pid=22301)     self._target(*self._args, **self._kwargs)
  | [2024-07-02T04:45:17Z] (VllmWorkerProcess pid=22301)   File "/usr/local/lib/python3.10/dist-packages/vllm/executor/multiproc_worker_utils.py", line 210, in _run_worker_process
  | [2024-07-02T04:45:17Z] (VllmWorkerProcess pid=22301)     worker = worker_factory()
  | [2024-07-02T04:45:17Z] (VllmWorkerProcess pid=22301)   File "/usr/local/lib/python3.10/dist-packages/vllm/executor/gpu_executor.py", line 67, in _create_worker
  | [2024-07-02T04:45:17Z] (VllmWorkerProcess pid=22301)     wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank,
  | [2024-07-02T04:45:17Z] (VllmWorkerProcess pid=22301)   File "/usr/local/lib/python3.10/dist-packages/vllm/worker/worker_base.py", line 311, in init_worker
  | [2024-07-02T04:45:17Z] (VllmWorkerProcess pid=22301)     self.worker = worker_class(*args, **kwargs)
  | [2024-07-02T04:45:17Z] (VllmWorkerProcess pid=22301)   File "/usr/local/lib/python3.10/dist-packages/vllm/worker/worker.py", line 87, in __init__
  | [2024-07-02T04:45:17Z] (VllmWorkerProcess pid=22301)     self.model_runner: GPUModelRunnerBase = ModelRunnerClass(
  | [2024-07-02T04:45:17Z] (VllmWorkerProcess pid=22301)   File "/usr/local/lib/python3.10/dist-packages/vllm/worker/model_runner.py", line 196, in __init__
  | [2024-07-02T04:45:17Z] (VllmWorkerProcess pid=22301)     self.attn_backend = get_attn_backend(
  | [2024-07-02T04:45:17Z] (VllmWorkerProcess pid=22301)   File "/usr/local/lib/python3.10/dist-packages/vllm/attention/selector.py", line 45, in get_attn_backend
  | [2024-07-02T04:45:17Z] (VllmWorkerProcess pid=22301)     backend = which_attn_to_use(num_heads, head_size, num_kv_heads,
  | [2024-07-02T04:45:17Z] (VllmWorkerProcess pid=22301)   File "/usr/local/lib/python3.10/dist-packages/vllm/attention/selector.py", line 151, in which_attn_to_use
  | [2024-07-02T04:45:17Z] (VllmWorkerProcess pid=22301)     if torch.cuda.get_device_capability()[0] < 8:
  | [2024-07-02T04:45:17Z] (VllmWorkerProcess pid=22301)   File "/usr/local/lib/python3.10/dist-packages/torch/cuda/__init__.py", line 430, in get_device_capability
  | [2024-07-02T04:45:17Z] (VllmWorkerProcess pid=22301)     prop = get_device_properties(device)
  | [2024-07-02T04:45:17Z] (VllmWorkerProcess pid=22301)   File "/usr/local/lib/python3.10/dist-packages/torch/cuda/__init__.py", line 444, in get_device_properties
  | [2024-07-02T04:45:17Z] (VllmWorkerProcess pid=22301)     _lazy_init()  # will define _get_device_properties
  | [2024-07-02T04:45:17Z] (VllmWorkerProcess pid=22301)   File "/usr/local/lib/python3.10/dist-packages/torch/cuda/__init__.py", line 279, in _lazy_init
  | [2024-07-02T04:45:17Z] (VllmWorkerProcess pid=22301)     raise RuntimeError(
  | [2024-07-02T04:45:17Z] (VllmWorkerProcess pid=22301) RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method


@DarkLight1337
Copy link
Collaborator

DarkLight1337 commented Jul 2, 2024

This happens because you initialized CUDA too early (probably indirectly via imports). Try to avoid importing torch-related stuff in the top level code of your test.

@DarkLight1337
Copy link
Collaborator

If the issue persists, #6056 should help you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants