From d87f39e9a9dd149f5dd7a58b4d98b21f713827b6 Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Wed, 24 Apr 2024 00:28:35 +0800 Subject: [PATCH] [Bugfix] Add init_cached_hf_modules to RayWorkerWrapper (#4286) --- vllm/executor/ray_gpu_executor.py | 2 ++ vllm/worker/worker_base.py | 7 ++++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index d0b5e682bb6f7..e69f104e7d5a4 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -100,6 +100,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", )(RayWorkerWrapper).remote( worker_module_name="vllm.worker.worker", worker_class_name="Worker", + trust_remote_code=self.model_config.trust_remote_code, ) worker_ip = ray.get(worker.get_node_ip.remote()) @@ -110,6 +111,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", self.driver_worker = RayWorkerWrapper( worker_module_name="vllm.worker.worker", worker_class_name="Worker", + trust_remote_code=self.model_config.trust_remote_code, ) else: # Else, added to the list of workers. diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index bcd04e0f98db6..b5dade0a770a0 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -103,10 +103,15 @@ class WorkerWrapperBase: def __init__(self, worker_module_name=None, - worker_class_name=None) -> None: + worker_class_name=None, + trust_remote_code: bool = False) -> None: self.worker_module_name = worker_module_name self.worker_class_name = worker_class_name self.worker = None + if trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() @staticmethod def update_environment_variables(envs: Dict[str, str]) -> None: