Skip to content

Commit

Permalink
Review comments
Browse files Browse the repository at this point in the history
Signed-off-by: Tyler Michael Smith <[email protected]>
  • Loading branch information
tlrmchlsmth committed Nov 26, 2024
1 parent d4ea706 commit 2174a5b
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 53 deletions.
8 changes: 6 additions & 2 deletions vllm/distributed/device_communicators/shm_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,13 @@
# We prefer to use os.sched_yield as it results in tighter polling loops,
# measured to be around 3e-7 seconds. However on earlier versions of Python
# os.sched_yield() does not release the GIL, so we fall back to time.sleep(0)
USE_SCHED_YIELD = ((sys.version_info[:3] >= (3, 11, 1))
or (sys.version_info[:2] == (3, 10)
and sys.version_info[2] >= 8))


def sched_yield():
if ((sys.version_info[:3] >= (3, 11, 1)) or
(sys.version_info[:2] == (3, 10) and sys.version_info[2] >= 8)):
if USE_SCHED_YIELD:
os.sched_yield()
else:
time.sleep(0)
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/executor/multiproc_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def start_workers(self):
self.model_output_mq.wait_until_ready()
self.workers_in_busy_loop = True

def run_on_workers(self, fn: str, *args):
def run_on_workers(self, fn: str, *args) -> List:
with ThreadPoolExecutor() as executor:
futures = [
executor.submit(getattr(type(w), fn), w, *args)
Expand Down
73 changes: 23 additions & 50 deletions vllm/v1/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,8 @@ class WorkerInitOutputType:
Request types defined as hex byte strings, so it can be sent over sockets
without separate encoding step.
"""
NUM_BLOCKS = b'\x00'
MODEL_OUTPUT_MSG_QUEUE = b'\x01'
READY = b'\x00'
NUM_BLOCKS = b'\x01'


@dataclass
Expand All @@ -268,11 +268,9 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
request_type = type_frame.buffer
request_data = data_frame.buffer

if request_type == WorkerInitOutputType.NUM_BLOCKS:
num_blocks = pickle.loads(request_data)
return num_blocks
else:
if request_type != WorkerInitOutputType.NUM_BLOCKS:
raise ValueError(f"Unknown RequestType: {request_type}")
return pickle.loads(request_data)

def initialize_cache(self, num_gpu_blocks: int) -> int:
with make_zmq_socket(self.initialization_output_path,
Expand All @@ -291,8 +289,6 @@ def start_busy_loop(self) -> None:
class WorkerProc:
"""Wrapper that runs one Worker in a separate process."""

READY_STR = "READY"

def __init__(
self,
vllm_config: VllmConfig,
Expand All @@ -302,7 +298,6 @@ def __init__(
input_shm_handle: Handle,
initialization_input_path: str,
initialization_output_path: str,
ready_path: str,
):
self.rank = rank
self.worker = Worker(vllm_config, local_rank, rank,
Expand All @@ -312,22 +307,20 @@ def __init__(
self.scheduler_output_receiver = MessageQueue.create_from_handle(
input_shm_handle, self.worker.rank)

# Send Readiness signal to EngineCore process.
with make_zmq_socket(ready_path, zmq.constants.PUSH) as ready_socket:
ready_socket.send_string(WorkerProc.READY_STR)

# Worker 0 initializes a message queue for sending the model output
if self.rank == 0:
self.model_output_mq = MessageQueue(1, 1)
output_mq_handle = self.model_output_mq.export_handle()
with make_zmq_socket(initialization_output_path,
zmq.constants.PUSH) as socket:
msg = pickle.dumps(output_mq_handle,
protocol=pickle.HIGHEST_PROTOCOL)
socket.send_multipart(
(WorkerInitOutputType.MODEL_OUTPUT_MSG_QUEUE, msg))
else:
self.model_output_mq = None
output_mq_handle = None

# Send Readiness signal to EngineCore process.
with make_zmq_socket(initialization_output_path,
zmq.constants.PUSH) as ready_socket:
payload = pickle.dumps(output_mq_handle,
protocol=pickle.HIGHEST_PROTOCOL)
ready_socket.send_multipart((WorkerInitOutputType.READY, payload))

self.worker.initialize()
self.worker.load_model()
Expand All @@ -352,15 +345,13 @@ def make_worker_process(
# Used for initialization.
initialization_input_path = get_open_zmq_ipc_path()
initialization_output_path = get_open_zmq_ipc_path()
ready_path = get_open_zmq_ipc_path()

process_kwargs = {
"vllm_config": vllm_config,
"local_rank": local_rank,
"rank": rank,
"distributed_init_method": distributed_init_method,
"input_shm_handle": input_shm_handle,
"ready_path": ready_path,
"initialization_input_path": initialization_output_path,
"initialization_output_path": initialization_input_path,
}
Expand All @@ -371,14 +362,8 @@ def make_worker_process(
proc.start()

# Wait for startup
WorkerProc.wait_for_startup(proc, ready_path)

# Read Shm MessageQueue from rank 0
if rank == 0:
model_output_mq_handle = WorkerProc.read_model_output_mq_handle(
initialization_input_path)
else:
model_output_mq_handle = None
model_output_mq_handle = WorkerProc.wait_for_startup(
proc, initialization_input_path)

return WorkerProcHandle(proc, initialization_input_path,
initialization_output_path,
Expand Down Expand Up @@ -406,34 +391,22 @@ def run_worker(*args, **kwargs):
@staticmethod
def wait_for_startup(
proc: BaseProcess,
ready_path: str,
) -> None:
path: str,
) -> Optional[Handle]:
"""Wait until the Worker is ready."""
with make_zmq_socket(ready_path, zmq.constants.PULL) as socket:
with make_zmq_socket(path, zmq.constants.PULL) as socket:

# Wait for Worker to send Worker.READY_STR.
# Wait for Worker to send READY.
while socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
logger.debug("Waiting for WorkerProc to startup.")

if not proc.is_alive():
raise RuntimeError("WorkerProc failed to start.")

message = socket.recv_string()
assert message == WorkerProc.READY_STR

@staticmethod
def read_model_output_mq_handle(init_input_path: str, ) -> Handle:
with make_zmq_socket(init_input_path,
zmq.constants.PULL) as recv_socket:
type_frame, data_frame = recv_socket.recv_multipart(copy=False)
request_type = type_frame.buffer
request_data = data_frame.buffer

if (request_type == WorkerInitOutputType.MODEL_OUTPUT_MSG_QUEUE):
handle = pickle.loads(request_data)
return handle
else:
raise ValueError(f"Unknown RequestType: {request_type}")
type_frame, data_frame = socket.recv_multipart(copy=False)
assert type_frame.buffer == WorkerInitOutputType.READY
handle = pickle.loads(data_frame.buffer)
return handle

# Busy loop used for initializing Multiprocessing Workers
def model_initialization_loop(self, init_input_path, init_output_path):
Expand Down Expand Up @@ -497,7 +470,7 @@ def execute_model_busy_loop(self):

if msg.message_type == ExecutorMsgType.TERMINATE:
return
elif msg.message_type == ExecutorMsgType.WORK:
if msg.message_type == ExecutorMsgType.WORK:
output = self.worker.execute_model(msg.payload)
if self.worker.rank == 0:
self.model_output_mq.enqueue(output)
Expand Down

0 comments on commit 2174a5b

Please sign in to comment.