-
Notifications
You must be signed in to change notification settings - Fork 448
Removes the generate thread #1054
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7c61ac1
d9dd059
3379473
c704b19
9364f25
63f6989
9943575
86bdb40
91a25b1
75d2866
5824515
bc912e0
3743097
adc93da
0280728
ab46044
3d6f829
3c34957
981dde2
00f072e
8890d6b
020d008
10508dc
b02f9e0
78995a1
2291ad6
7bb9481
fa430d6
d096136
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -18,6 +18,7 @@ | |||
import dataclasses | ||||
import os | ||||
import queue | ||||
import threading | ||||
import time | ||||
from collections import defaultdict | ||||
from concurrent import futures | ||||
|
@@ -395,10 +396,15 @@ def __init__( | |||
self._should_stop_value = False | ||||
self._should_stop_timeout_s = 5 | ||||
|
||||
self._executor = futures.ThreadPoolExecutor(max_workers=1) | ||||
self._prefetch_future = self._executor.submit(self._prefetch_worker) | ||||
# Initialize instance variables before starting threads | ||||
self.tracking = _init_tool_tracking() | ||||
self.request_outputs = {} | ||||
self._threads_started = threading.Event() | ||||
|
||||
# Start background threads | ||||
self._executor = futures.ThreadPoolExecutor(max_workers=2) | ||||
self._prefetch_future = self._executor.submit(self._prefetch_worker) | ||||
self._process_future = self._executor.submit(self._process_from_queue) | ||||
|
||||
def get_model_dims_dict(self): | ||||
"""Get only the model dimensions as a simple dict without loading weights.""" | ||||
|
@@ -431,8 +437,9 @@ def _should_stop(self) -> bool: | |||
|
||||
def _prefetch_worker(self, sleep_length_s: int = 1): | ||||
"""Background worker that prefetches requests until we have enough buffered.""" | ||||
self._threads_started.set() | ||||
while True: | ||||
if self._should_stop(): | ||||
if not self.inflight_updates and self._should_stop(): | ||||
time.sleep(sleep_length_s) | ||||
continue | ||||
current_unfinished = self.llm_engine.get_num_unfinished_requests() | ||||
|
@@ -456,58 +463,18 @@ def _insert_result_to_queue(self, result, is_eval: bool): | |||
results_queue = self.eval_results_queue if is_eval else self.results_queue | ||||
results_queue.put(result) | ||||
|
||||
def _should_exit(self) -> bool: | ||||
"""Determine if the processing loop should exit. | ||||
|
||||
Returns: | ||||
bool: True if the loop should exit, False otherwise. | ||||
""" | ||||
# Check stop condition first (cheapest check) | ||||
stop_requested = self._should_stop() | ||||
|
||||
# Case 1: inflight_updates enabled and stop requested - exit immediately | ||||
if self.inflight_updates and stop_requested: | ||||
return True | ||||
|
||||
# Now check for pending work (only if needed) | ||||
if stop_requested: | ||||
# Need to check if we have pending work | ||||
pending_tools = len(self.tracking["pending_tool_futures"]) | ||||
unfinished = self.llm_engine.get_num_unfinished_requests() | ||||
|
||||
# Case 2: stop requested and no pending work - exit | ||||
if pending_tools == 0 and unfinished == 0: | ||||
return True | ||||
# Otherwise, we have pending work and should continue | ||||
return False | ||||
|
||||
# No stop requested - check if there's any work to do | ||||
pending_tools = len(self.tracking["pending_tool_futures"]) | ||||
unfinished = self.llm_engine.get_num_unfinished_requests() | ||||
|
||||
# Case 3: no work left at all - exit | ||||
if pending_tools == 0 and unfinished == 0: | ||||
return True | ||||
|
||||
# Otherwise, continue processing | ||||
return False | ||||
|
||||
def process_from_queue(self, timeout: float = 60.0): | ||||
def _process_from_queue(self, timeout: float = 60.0): | ||||
"""Run generation loop using LLMEngine directly, with optional tool support. | ||||
|
||||
Runs continuously until should_stop is set, periodically adding new requests | ||||
and yielding control to allow weight synchronization. | ||||
Runs continuously in a background thread, processing requests from the engine. | ||||
|
||||
Returns: | ||||
int: Number of requests processed | ||||
""" | ||||
|
||||
# Use persistent instance variables for tracking and outputs | ||||
# This ensures state is maintained across multiple calls | ||||
total_processed = 0 | ||||
iteration_count = 0 | ||||
|
||||
while not self._should_exit(): | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. similar question to above, what happens if some other process dies and the vllm worker should stop? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Then it gets killed when Ray shuts down: open-instruct/open_instruct/grpo_fast.py Line 2708 in c3f79a3
|
||||
while True: | ||||
iteration_count += 1 | ||||
|
||||
# Health check: ensure prefetch worker is alive. This will raise if it has crashed. | ||||
|
@@ -558,17 +525,7 @@ def process_from_queue(self, timeout: float = 60.0): | |||
total_processed += self._finalize_sub_request( | ||||
output.request_id, output, complete_output, current_time | ||||
) | ||||
|
||||
if self.verbose and iteration_count % 100 == 0: | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i never really used this logging so fine with removing but was there any other impetus for doing so? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Truthfully, I asked Claude to remove all the debug logging (I had a bunch of debug logging in an earlier version of this PR) and it also removed this. I thought it was fine so I kept it. Open to changing this if you prefer. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no, all good! Was just wondering why the change. |
||||
final_unfinished = self.llm_engine.get_num_unfinished_requests() | ||||
pending_tools = len(self.tracking["pending_tool_futures"]) | ||||
self.logger.info( | ||||
f"process_from_queue iteration {iteration_count}: unfinished={final_unfinished}, pending_tools={pending_tools}" | ||||
) | ||||
|
||||
# If we have only pending tools but no unfinished requests, sleep briefly | ||||
# to let pending tools complete before the next iteration | ||||
if self.llm_engine.get_num_unfinished_requests() == 0 and len(self.tracking["pending_tool_futures"]) > 0: | ||||
if self.llm_engine.get_num_unfinished_requests() == 0: | ||||
time.sleep(1) | ||||
|
||||
return total_processed | ||||
|
@@ -870,10 +827,22 @@ def init_process_group( | |||
args=(master_address, master_port, rank_offset, world_size, group_name, backend, use_ray, timeout_minutes), | ||||
) | ||||
|
||||
def _maybe_drain_requests(self, sleep_s: float = 0.1): | ||||
while not self.inflight_updates: | ||||
pending_tools = len(self.tracking["pending_tool_futures"]) | ||||
unfinished = self.llm_engine.get_num_unfinished_requests() | ||||
|
||||
if pending_tools == 0 and unfinished == 0: | ||||
break | ||||
|
||||
time.sleep(sleep_s) | ||||
|
||||
def update_weight(self, name, dtype, shape, empty_cache=False): | ||||
self._maybe_drain_requests() | ||||
return self.llm_engine.collective_rpc("update_weight", args=(name, dtype, shape, empty_cache)) | ||||
|
||||
def update_weight_cuda_ipc(self, name, dtype, shape, ipc_handles, empty_cache=False): | ||||
self._maybe_drain_requests() | ||||
return self.llm_engine.collective_rpc( | ||||
"update_weight_cuda_ipc", args=(name, dtype, shape, ipc_handles, empty_cache) | ||||
) | ||||
|
@@ -888,8 +857,15 @@ def wake_up(self, tags: Optional[list[str]] = None): | |||
self.llm_engine.wake_up(tags) | ||||
|
||||
def ready(self): | ||||
self._threads_started.wait(timeout=30) | ||||
return True | ||||
|
||||
def check_background_threads(self): | ||||
if self._prefetch_future.done(): | ||||
self._prefetch_future.result() | ||||
if self._process_future.done(): | ||||
self._process_future.result() | ||||
|
||||
def get_kv_cache_info(self): | ||||
"""Get KV cache max concurrency from the vLLM engine.""" | ||||
kv_cache_specs = self.llm_engine.model_executor.get_kv_cache_specs() | ||||
|
Uh oh!
There was an error while loading. Please reload this page.