diff --git a/examples/offline_inference.py b/examples/offline_inference.py index 9b758fa2479f6..23cc6e8539431 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -19,4 +19,4 @@ for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") \ No newline at end of file diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 7072a8bbc5b3e..39da131fcae5e 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -130,7 +130,7 @@ def _init_workers(self): assert self.parallel_config.world_size == 1, ( "Ray is required if parallel_config.world_size > 1.") - + self._init_single_gpu_config() self.workers: List[Worker] = [] distributed_init_method = f"tcp://{get_ip()}:{get_open_port()}" self.driver_worker = Worker( @@ -917,3 +917,45 @@ def _run_workers( ray_worker_outputs = ray.get(ray_worker_outputs) return [driver_worker_output] + ray_worker_outputs + + def _init_single_gpu_config(self) -> None: + """Using monkey patching to avoid initializing distributed group for a single GPU + + Details + - Step 1: As shown in the following code, use monkey patching to modify + `get_tensor_model_parallel_rank`、`get_tensor_model_parallel_world_size` + and get_tensor_model_parallel_group. + - Step 2: Due to Python's import mechanism, we must reload certain + modules (those to be reloaded are stored in `_NEED_RELOAD_MODULES`) so that + the monkey patching in Step 1 can take effect. + - Step 3: Use monkey patching to modify the `_init_distributed_environment` of + module `vllm.worker.worker` + + + """ + _NEED_RELOAD_MODULES = [ + "vllm.model_executor.parallel_utils.communication_op", + "vllm.model_executor.layers.linear", + "vllm.model_executor.layers.activation", + "vllm.model_executor.layers.sampler", + "vllm.model_executor.layers.vocab_parallel_embedding", + ] + import sys + import importlib + import vllm.model_executor.parallel_utils.parallel_state as ps_module + assert self.parallel_config.world_size == 1, ( + "it is required that the world_size must be 1.") + #Step 1 + ps_module.get_tensor_model_parallel_rank = lambda *args, **kargs: 0 + ps_module.get_tensor_model_parallel_world_size = lambda *args, **kargs: 1 + ps_module.get_tensor_model_parallel_group = lambda *args, **kargs: 1 + #Step 2 + for module_name in _NEED_RELOAD_MODULES: + if module_name in sys.modules: + module_before = sys.modules.get(module_name, None) + _ = importlib.reload(module_before) # retrurn reloaded module + #Step 3 + module_worker_name = "vllm.worker.worker" + module_worker = sys.modules.get(module_worker_name, None) + assert module_worker + module_worker._init_distributed_environment = lambda *args, **kargs: None diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index 37543d8c9838e..40f441fc9cb50 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -71,4 +71,4 @@ def get_model(model_config: ModelConfig) -> nn.Module: # Load the weights from the cached or downloaded files. model.load_weights(model_config.model, model_config.download_dir, model_config.load_format, model_config.revision) - return model.eval() + return model.eval() \ No newline at end of file diff --git a/vllm/model_executor/parallel_utils/communication_op.py b/vllm/model_executor/parallel_utils/communication_op.py index 8bf04f3d1f056..6b95fd91a4e87 100644 --- a/vllm/model_executor/parallel_utils/communication_op.py +++ b/vllm/model_executor/parallel_utils/communication_op.py @@ -82,7 +82,7 @@ def tensor_model_parallel_gather(input_, dst=0, dim=-1): def broadcast(input_, src=0): """Broadcast the input tensor.""" - world_size = torch.distributed.get_world_size() + world_size = get_tensor_model_parallel_world_size() assert 0 <= src < world_size, f"Invalid src rank ({src})" # Bypass the function if we are using only 1 GPU. @@ -95,7 +95,7 @@ def broadcast(input_, src=0): def broadcast_object_list(obj_list, src=0): """Broadcast the input object list.""" - world_size = torch.distributed.get_world_size() + world_size = get_tensor_model_parallel_world_size() assert 0 <= src < world_size, f"Invalid src rank ({src})" # Bypass the function if we are using only 1 GPU. diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index c2a2ac148085b..db41eac9f2631 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -241,4 +241,4 @@ def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): f"of at least 8.0. Your {gpu_name} GPU has compute capability " f"{compute_capability[0]}.{compute_capability[1]}. " "You can use float16 instead by explicitly setting the" - "`dtype` flag in CLI, for example: --dtype=half.") + "`dtype` flag in CLI, for example: --dtype=half.") \ No newline at end of file