diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 7d40607e81791..af426e31591f2 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -62,6 +62,18 @@ def _get_worker_kwargs( observability_config=self.observability_config, ) + def _get_worker_module_and_class(self) -> Tuple[str, str]: + if self.scheduler_config.is_multi_step: + worker_module_name = "vllm.worker.multi_step_worker" + worker_class_name = "MultiStepWorker" + elif self.speculative_config: + worker_module_name = "vllm.spec_decode.spec_decode_worker" + worker_class_name = "create_spec_worker" + else: + worker_module_name = "vllm.worker.worker" + worker_class_name = "Worker" + return (worker_module_name, worker_class_name) + def _get_create_worker_kwargs( self, local_rank: int = 0, @@ -70,17 +82,10 @@ def _get_create_worker_kwargs( worker_kwargs = self._get_worker_kwargs(local_rank, rank, distributed_init_method) - if self.scheduler_config.is_multi_step: - worker_kwargs.update( - worker_module_name="vllm.worker.multi_step_worker", - worker_class_name="MultiStepWorker") - elif self.speculative_config: - worker_kwargs.update( - worker_module_name="vllm.spec_decode.spec_decode_worker", - worker_class_name="create_spec_worker") - else: - worker_kwargs.update(worker_module_name="vllm.worker.worker", - worker_class_name="Worker") + (worker_module_name, + worker_class_name) = self._get_worker_module_and_class() + worker_kwargs.update(worker_module_name=worker_module_name, + worker_class_name=worker_class_name) return worker_kwargs diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 4c38cd1cbd546..bddb95210dbc9 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -91,15 +91,8 @@ def _configure_ray_workers_use_nsight(self, return ray_remote_kwargs def _get_worker_wrapper_args(self) -> Dict[str, Any]: - if self.speculative_config is not None: - worker_module_name = "vllm.spec_decode.spec_decode_worker" - worker_class_name = "create_spec_worker" - elif self.scheduler_config.is_multi_step: - worker_module_name = "vllm.worker.multi_step_worker" - worker_class_name = "MultiStepWorker" - else: - worker_module_name = "vllm.worker.worker" - worker_class_name = "Worker" + (worker_module_name, + worker_class_name) = self._get_worker_module_and_class() return dict( worker_module_name=worker_module_name, @@ -107,6 +100,10 @@ def _get_worker_wrapper_args(self) -> Dict[str, Any]: trust_remote_code=self.model_config.trust_remote_code, ) + # child class could overwrite this to return actual env vars. + def _get_env_vars_to_be_updated(self): + return self._env_vars_for_all_workers + def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwargs): if (self.parallel_config.tensor_parallel_size == 1 @@ -231,8 +228,12 @@ def sort_by_driver_then_worker_ip(worker): "VLLM_TRACE_FUNCTION": str(envs.VLLM_TRACE_FUNCTION), }, ) for (node_id, _) in worker_node_and_gpu_ids] + + self._env_vars_for_all_workers = ( + all_args_to_update_environment_variables) + self._run_workers("update_environment_variables", - all_args=all_args_to_update_environment_variables) + all_args=self._get_env_vars_to_be_updated()) if len(node_gpus) == 1: # in single node case, we don't need to get the IP address.