diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index e15e83b792829..97aeb713e9f52 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -92,7 +92,7 @@ def _initialize_kv_caches(self, num_gpu_blocks = num_gpu_blocks_override num_cpu_blocks = 0 - self.model_executor.initialize_cache(num_gpu_blocks) + self.model_executor.initialize(num_gpu_blocks) return num_gpu_blocks, num_cpu_blocks def add_request(self, request: EngineCoreRequest): diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 45aa95757a2c7..5c5558be8a9e6 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -93,14 +93,6 @@ def __init__(self, vllm_config: VllmConfig) -> None: model_output_mq_handle = self.workers[0].model_output_mq_handle self.model_output_mq = MessageQueue.create_from_handle( model_output_mq_handle, 0) - self.workers_in_busy_loop = False - - def start_workers(self): - for w in self.workers: - w.start_busy_loop() - self.scheduler_output_mq.wait_until_ready() - self.model_output_mq.wait_until_ready() - self.workers_in_busy_loop = True def run_on_workers(self, fn: str, *args) -> List: with ThreadPoolExecutor() as executor: @@ -111,12 +103,21 @@ def run_on_workers(self, fn: str, *args) -> List: result = [f.result() for f in futures] # Wait for all to complete return result - def initialize_cache(self, num_gpu_blocks: int) -> None: - """Initialize the KV caches by invoking the underlying worker.""" - self.run_on_workers('initialize_cache', num_gpu_blocks) + def initialize(self, num_gpu_blocks: int) -> None: + """ + Initialize the KV caches and begin the model execution loop of the + underlying workers. + """ + success_vals = self.run_on_workers('initialize', num_gpu_blocks) + if not all(success_vals): + raise RuntimeError("Worker initialization failed.") + + self.scheduler_output_mq.wait_until_ready() + self.model_output_mq.wait_until_ready() def determine_num_available_blocks(self) -> Tuple[int, int]: - """Determine the number of available KV blocks by invoking the + """ + Determine the number of available KV blocks by invoking the underlying worker. """ # Get the maximum number of blocks that can be allocated on GPU and CPU. @@ -134,9 +135,6 @@ def execute_model( self, scheduler_output, ) -> ModelRunnerOutput: - if not self.workers_in_busy_loop: - self.start_workers() - self.scheduler_output_mq.enqueue( ExecutorMsg(ExecutorMsgType.WORK, scheduler_output)) model_output = self.model_output_mq.dequeue() diff --git a/vllm/v1/executor/uniproc_executor.py b/vllm/v1/executor/uniproc_executor.py index 4d612e7a62761..224620d527c25 100644 --- a/vllm/v1/executor/uniproc_executor.py +++ b/vllm/v1/executor/uniproc_executor.py @@ -54,7 +54,7 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: """ return self.worker.determine_num_available_blocks() - def initialize_cache(self, num_gpu_blocks: int) -> None: + def initialize(self, num_gpu_blocks: int) -> None: """Initialize the KV cache by invoking the underlying worker. """ # NOTE: This is logged in the executor because there can be >1 worker diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 303f5214f805c..bd836c70db8a5 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -234,18 +234,17 @@ class WorkerInitRequestType: without separate encoding step. """ DETERMINE_NUM_BLOCKS = b'\x00' - INIT_CACHE = b'\x01' - BEGIN_MODEL_EXECUTION = b'\x02' - + INITIALIZE = b'\x01' # Initialize cache and begin worker execution @dataclass -class WorkerInitOutputType: +class WorkerInitResponseType: """ Request types defined as hex byte strings, so it can be sent over sockets without separate encoding step. """ READY = b'\x00' NUM_BLOCKS = b'\x01' + INITIALIZE_SUCCESS = b'\x02' @dataclass @@ -261,30 +260,37 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: make_zmq_socket(self.initialization_input_path, zmq.constants.PULL) as recv_socket: + # Send message to determine the number of blocks send_socket.send_multipart( (WorkerInitRequestType.DETERMINE_NUM_BLOCKS, )) - type_frame, data_frame = recv_socket.recv_multipart(copy=False) - - request_type = type_frame.buffer - request_data = data_frame.buffer - - if request_type != WorkerInitOutputType.NUM_BLOCKS: - raise ValueError(f"Unknown RequestType: {request_type}") - return pickle.loads(request_data) - def initialize_cache(self, num_gpu_blocks: int) -> int: + # Receive response + type_frame, data_frame = recv_socket.recv_multipart(copy=False) + response_type = type_frame.buffer + response_data = data_frame.buffer + if response_type != WorkerInitResponseType.NUM_BLOCKS: + raise ValueError(f"Unknown RequestType: {response_type}") + return pickle.loads(response_data) + + def initialize(self, num_gpu_blocks: int) -> bool: + """ Initialize the KV cache and begin worker execution loop """ with make_zmq_socket(self.initialization_output_path, - zmq.constants.PUSH) as socket: + zmq.constants.PUSH) as send_socket, \ + make_zmq_socket(self.initialization_input_path, + zmq.constants.PULL) as recv_socket: + + # Send initialization message msg = pickle.dumps(num_gpu_blocks, protocol=pickle.HIGHEST_PROTOCOL) - socket.send_multipart((WorkerInitRequestType.INIT_CACHE, msg)) - - def start_busy_loop(self) -> None: - with make_zmq_socket(self.initialization_output_path, - zmq.constants.PUSH) as socket: - socket.send_multipart( - (WorkerInitRequestType.BEGIN_MODEL_EXECUTION, )) + send_socket.send_multipart((WorkerInitRequestType.INITIALIZE, msg)) + # Receive success or failure response + type_frame, data_frame = recv_socket.recv_multipart(copy=False) + response_type = type_frame.buffer + response_data = data_frame.buffer + if response_type != WorkerInitResponseType.INITIALIZE_SUCCESS: + raise ValueError(f"Unknown RequestType: {response_type}") + return pickle.loads(response_data) class WorkerProc: """Wrapper that runs one Worker in a separate process.""" @@ -320,7 +326,7 @@ def __init__( zmq.constants.PUSH) as ready_socket: payload = pickle.dumps(output_mq_handle, protocol=pickle.HIGHEST_PROTOCOL) - ready_socket.send_multipart((WorkerInitOutputType.READY, payload)) + ready_socket.send_multipart((WorkerInitResponseType.READY, payload)) self.worker.initialize() self.worker.load_model() @@ -404,7 +410,7 @@ def wait_for_startup( raise RuntimeError("WorkerProc failed to start.") type_frame, data_frame = socket.recv_multipart(copy=False) - assert type_frame.buffer == WorkerInitOutputType.READY + assert type_frame.buffer == WorkerInitResponseType.READY handle = pickle.loads(data_frame.buffer) return handle @@ -420,26 +426,45 @@ def model_initialization_loop(self, init_input_path, init_output_path): request_type = request[0].buffer # Deserialize the request data. - if (request_type == WorkerInitRequestType.DETERMINE_NUM_BLOCKS - ): + if request_type == WorkerInitRequestType.DETERMINE_NUM_BLOCKS: num_blocks = self.worker.determine_num_available_blocks() send_socket.send_multipart( - (WorkerInitOutputType.NUM_BLOCKS, + (WorkerInitResponseType.NUM_BLOCKS, pickle.dumps(num_blocks)), copy=False) - elif request_type == WorkerInitRequestType.INIT_CACHE: - request_data = request[1].buffer - num_gpu_blocks = pickle.loads(request_data) - self.worker.initialize_cache(num_gpu_blocks) - self.worker.compile_or_warm_up_model() - elif (request_type == - WorkerInitRequestType.BEGIN_MODEL_EXECUTION): - # Make sure message queues are ready. - self.scheduler_output_receiver.wait_until_ready() + elif request_type == WorkerInitRequestType.INITIALIZE: + # Initialize cache with the number of requested gpu blocks + try: + request_data = request[1].buffer + num_gpu_blocks = pickle.loads(request_data) + self.worker.initialize_cache(num_gpu_blocks) + self.worker.compile_or_warm_up_model() + except BaseException as e: + logger.exception(e) + + # Send a failure response + send_socket.send_multipart( + (WorkerInitResponseType.INITIALIZE_SUCCESS, + pickle.dumps(False)), + copy=False) + + raise e + + # Send a success response. Order is important: + # The executor will call wait_until_ready() on its + # message queues after receiving this message. + send_socket.send_multipart( + (WorkerInitResponseType.INITIALIZE_SUCCESS, + pickle.dumps(True)), + copy=False) + # Ensure message queues are ready. + # Must happen after sending the INITIALIZE_SUCESS message. + self.scheduler_output_receiver.wait_until_ready() if self.model_output_mq is not None: self.model_output_mq.wait_until_ready() + # Exit initialization loop to begin model execution loop return else: