diff --git a/vllm/config.py b/vllm/config.py index 5fbfbd769296f..1adf830ffcc12 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -285,10 +285,12 @@ def __init__( pipeline_parallel_size: int, tensor_parallel_size: int, worker_use_ray: bool, + max_parallel_loading_workers: Optional[int] = None, ) -> None: self.pipeline_parallel_size = pipeline_parallel_size self.tensor_parallel_size = tensor_parallel_size self.worker_use_ray = worker_use_ray + self.max_parallel_loading_workers = max_parallel_loading_workers self.world_size = pipeline_parallel_size * tensor_parallel_size if self.world_size > 1: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index cc425a2c079e7..c7e476c704740 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -22,6 +22,7 @@ class EngineArgs: worker_use_ray: bool = False pipeline_parallel_size: int = 1 tensor_parallel_size: int = 1 + max_parallel_loading_workers: Optional[int] = None block_size: int = 16 swap_space: int = 4 # GiB gpu_memory_utilization: float = 0.90 @@ -128,6 +129,12 @@ def add_cli_args( type=int, default=EngineArgs.tensor_parallel_size, help='number of tensor parallel replicas') + parser.add_argument( + '--max-parallel-loading-workers', + type=int, + help='load model sequentially in multiple batches, ' + 'to avoid RAM OOM when using tensor ' + 'parallel and large models') # KV cache arguments parser.add_argument('--block-size', type=int, @@ -195,7 +202,8 @@ def create_engine_configs( getattr(model_config.hf_config, 'sliding_window', None)) parallel_config = ParallelConfig(self.pipeline_parallel_size, self.tensor_parallel_size, - self.worker_use_ray) + self.worker_use_ray, + self.max_parallel_loading_workers) scheduler_config = SchedulerConfig(self.max_num_batched_tokens, self.max_num_seqs, model_config.max_model_len, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index ec7a4587ffe4b..e33d8aa2a2131 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -143,6 +143,12 @@ def _init_workers(self, distributed_init_method: str): "init_model", get_all_outputs=True, ) + self._run_workers( + "load_model", + get_all_outputs=True, + max_concurrent_workers=self.parallel_config. + max_parallel_loading_workers, + ) def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwargs): @@ -182,6 +188,12 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", "init_model", get_all_outputs=True, ) + self._run_workers( + "load_model", + get_all_outputs=True, + max_concurrent_workers=self.parallel_config. + max_parallel_loading_workers, + ) def _verify_args(self) -> None: self.model_config.verify_with_parallel_config(self.parallel_config) @@ -682,16 +694,15 @@ def _check_stop(self, seq: Sequence, seq.status = SequenceStatus.FINISHED_STOPPED return - def _run_workers( + def _run_workers_in_batch( self, + workers, method: str, *args, - get_all_outputs: bool = False, **kwargs, - ) -> Any: - """Runs the given method on all workers.""" + ): all_outputs = [] - for worker in self.workers: + for worker in workers: if self.parallel_config.worker_use_ray: executor = partial(worker.execute_method.remote, method) else: @@ -699,9 +710,31 @@ def _run_workers( output = executor(*args, **kwargs) all_outputs.append(output) - if self.parallel_config.worker_use_ray: all_outputs = ray.get(all_outputs) + return all_outputs + + def _run_workers( + self, + method: str, + *args, + get_all_outputs: bool = False, + max_concurrent_workers: Optional[int] = None, + **kwargs, + ) -> Any: + """Runs the given method on all workers.""" + all_outputs = [] + if max_concurrent_workers: + work_groups = [ + self.workers[i:i + max_concurrent_workers] + for i in range(0, len(self.workers), max_concurrent_workers) + ] + else: + work_groups = [self.workers] + + for workers in work_groups: + all_outputs.extend( + self._run_workers_in_batch(workers, method, *args, **kwargs)) if get_all_outputs: return all_outputs diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 4fcd179ce85f2..702767ebd8d09 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -67,6 +67,8 @@ def init_model(self): # Initialize the model. set_random_seed(self.model_config.seed) + + def load_model(self): self.model = get_model(self.model_config) @torch.inference_mode()