Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
7c61ac1
Updated LLMRayActor to run continuously
finbarrtimbers Oct 3, 2025
d9dd059
Set v1 = true
finbarrtimbers Oct 3, 2025
3379473
Added pause loop in update weights
finbarrtimbers Oct 3, 2025
c704b19
SEt inflight updates false
finbarrtimbers Oct 3, 2025
9364f25
Fixed condition
finbarrtimbers Oct 3, 2025
63f6989
Update code
finbarrtimbers Oct 3, 2025
9943575
another attempt at fixing deadlocks
finbarrtimbers Oct 3, 2025
86bdb40
Added logs
finbarrtimbers Oct 3, 2025
91a25b1
More logging to diagnose hang
finbarrtimbers Oct 3, 2025
75d2866
Fixed busy-waiting.
finbarrtimbers Oct 3, 2025
5824515
Updated the code
finbarrtimbers Oct 3, 2025
bc912e0
Added logging
finbarrtimbers Oct 3, 2025
3743097
fixed prefetch
finbarrtimbers Oct 3, 2025
adc93da
Updated code
finbarrtimbers Oct 3, 2025
0280728
Cleaned up PR.
finbarrtimbers Oct 5, 2025
ab46044
Cleaned up code.
finbarrtimbers Oct 5, 2025
3d6f829
Fixed race condition
finbarrtimbers Oct 5, 2025
3c34957
Removed logging from update_weight
finbarrtimbers Oct 5, 2025
981dde2
Less logging
finbarrtimbers Oct 5, 2025
00f072e
Cleaned up PR.
finbarrtimbers Oct 6, 2025
8890d6b
Merge branch 'main' into remove-generate
finbarrtimbers Oct 6, 2025
020d008
Cleaned up PR.
finbarrtimbers Oct 6, 2025
10508dc
Cleaned up PR.
finbarrtimbers Oct 6, 2025
b02f9e0
Cleaned up PR.
finbarrtimbers Oct 6, 2025
78995a1
Rearranged timeout
finbarrtimbers Oct 6, 2025
2291ad6
Removes broken code.
finbarrtimbers Oct 6, 2025
7bb9481
Undid changes to inflight
finbarrtimbers Oct 6, 2025
fa430d6
Updated code with a health check on the actors.
finbarrtimbers Oct 6, 2025
d096136
Merge branch 'main' into remove-generate
finbarrtimbers Oct 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 10 additions & 31 deletions open_instruct/grpo_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -2303,28 +2303,6 @@ def weight_sync_thread(
logger.info("[Weight Sync Thread] 🛑 Stopping weight sync thread")


def generate_thread(args, vllm_engines, resume_training_step, stop_event, generate_metrics_Q):
"""Thread function that repeatedly calls process_from_queue on vllm engines."""
logger.info("[Generate Thread] 🚀 Starting generation thread")
while not stop_event.is_set():
with Timer("🔥 Generation time") as timer:
processed_results, _ = ray_get_with_progress(
[engine.process_from_queue.remote(timeout=20) for engine in vllm_engines],
desc="[Generate Thread] Waiting for vLLM engines to process",
enable=args.verbose,
)
num_processed = sum(int(result) for result in processed_results)
# Suppress timing output if nothing was processed
if num_processed == 0:
timer.noop = True
if num_processed > 0:
try:
generate_metrics_Q.put_nowait({"time/generation": timer.duration})
except Full:
logger.warning("[Generate Thread] generate metrics queue full, skipping metric")
logger.info("[Generate Thread] 🛑 Stopping generation thread")


def one_training_step(
args: Args,
policy_group: ModelGroup,
Expand Down Expand Up @@ -2673,7 +2651,6 @@ def cleanup_training_resources(
actor_manager: ActorManager,
) -> None:
"""Clean up all training resources including threads and Ray queues."""
# Signal generate_thread to stop
stop_event.set()

logger.info("Signaling all actors to stop...")
Expand Down Expand Up @@ -2782,14 +2759,13 @@ def run_training(
model_dims,
)

logger.info("======== ✅ generation thread starts =========")
generation_future = executor.submit(
generate_thread, args, vllm_engines, resume_training_step, stop_event, generate_metrics_Q
)

# setup health check function to check that everything is still alive
def health_check_fn():
[f.result() for f in [packing_future, generation_future, weight_sync_thread_future] if f.done()]
[f.result() for f in [packing_future, weight_sync_thread_future] if f.done()]
ray_get_with_progress(
[engine.check_background_threads.remote() for engine in vllm_engines],
desc="Checking vLLM engine health",
enable=False,
)

# Send initial data to ensure we have a N-step offset.
for _ in range(args.async_steps):
Expand Down Expand Up @@ -2826,7 +2802,9 @@ def health_check_fn():
)

# Check if any of the threads have raised an exception.
health_check_start = time.perf_counter()
health_check_fn()
health_check_time = time.perf_counter() - health_check_start

logger.debug(f"[Main Thread] Triggered weight sync for step {training_step}")
weight_sync_trigger_event.set()
Expand All @@ -2850,7 +2828,6 @@ def health_check_fn():
is_eval=True,
)

# The generate_thread is now handling vLLM processing asynchronously
collated_data, data_thread_metrics, num_total_tokens, num_step_tokens, prompt_lengths, response_lengths = (
load_data_from_packing_thread(packed_sequences_Q, num_total_tokens, stop_event, health_check_fn)
)
Expand All @@ -2863,6 +2840,8 @@ def health_check_fn():
except Empty:
logger.info("[Main Thread] didn't get train generation metrics")

data_thread_metrics["time/health_check"] = health_check_time

one_training_step(
args,
policy_group,
Expand Down
90 changes: 33 additions & 57 deletions open_instruct/vllm_utils3.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import dataclasses
import os
import queue
import threading
import time
from collections import defaultdict
from concurrent import futures
Expand Down Expand Up @@ -395,10 +396,15 @@ def __init__(
self._should_stop_value = False
self._should_stop_timeout_s = 5

self._executor = futures.ThreadPoolExecutor(max_workers=1)
self._prefetch_future = self._executor.submit(self._prefetch_worker)
# Initialize instance variables before starting threads
self.tracking = _init_tool_tracking()
self.request_outputs = {}
self._threads_started = threading.Event()

# Start background threads
self._executor = futures.ThreadPoolExecutor(max_workers=2)
self._prefetch_future = self._executor.submit(self._prefetch_worker)
self._process_future = self._executor.submit(self._process_from_queue)

def get_model_dims_dict(self):
"""Get only the model dimensions as a simple dict without loading weights."""
Expand Down Expand Up @@ -431,8 +437,9 @@ def _should_stop(self) -> bool:

def _prefetch_worker(self, sleep_length_s: int = 1):
"""Background worker that prefetches requests until we have enough buffered."""
self._threads_started.set()
while True:
if self._should_stop():
if not self.inflight_updates and self._should_stop():
time.sleep(sleep_length_s)
continue
current_unfinished = self.llm_engine.get_num_unfinished_requests()
Expand All @@ -456,58 +463,18 @@ def _insert_result_to_queue(self, result, is_eval: bool):
results_queue = self.eval_results_queue if is_eval else self.results_queue
results_queue.put(result)

def _should_exit(self) -> bool:
"""Determine if the processing loop should exit.

Returns:
bool: True if the loop should exit, False otherwise.
"""
# Check stop condition first (cheapest check)
stop_requested = self._should_stop()

# Case 1: inflight_updates enabled and stop requested - exit immediately
if self.inflight_updates and stop_requested:
return True

# Now check for pending work (only if needed)
if stop_requested:
# Need to check if we have pending work
pending_tools = len(self.tracking["pending_tool_futures"])
unfinished = self.llm_engine.get_num_unfinished_requests()

# Case 2: stop requested and no pending work - exit
if pending_tools == 0 and unfinished == 0:
return True
# Otherwise, we have pending work and should continue
return False

# No stop requested - check if there's any work to do
pending_tools = len(self.tracking["pending_tool_futures"])
unfinished = self.llm_engine.get_num_unfinished_requests()

# Case 3: no work left at all - exit
if pending_tools == 0 and unfinished == 0:
return True

# Otherwise, continue processing
return False

def process_from_queue(self, timeout: float = 60.0):
def _process_from_queue(self, timeout: float = 60.0):
"""Run generation loop using LLMEngine directly, with optional tool support.

Runs continuously until should_stop is set, periodically adding new requests
and yielding control to allow weight synchronization.
Runs continuously in a background thread, processing requests from the engine.

Returns:
int: Number of requests processed
"""

# Use persistent instance variables for tracking and outputs
# This ensures state is maintained across multiple calls
total_processed = 0
iteration_count = 0

while not self._should_exit():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar question to above, what happens if some other process dies and the vllm worker should stop?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then it gets killed when Ray shuts down:

while True:
iteration_count += 1

# Health check: ensure prefetch worker is alive. This will raise if it has crashed.
Expand Down Expand Up @@ -558,17 +525,7 @@ def process_from_queue(self, timeout: float = 60.0):
total_processed += self._finalize_sub_request(
output.request_id, output, complete_output, current_time
)

if self.verbose and iteration_count % 100 == 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i never really used this logging so fine with removing but was there any other impetus for doing so?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Truthfully, I asked Claude to remove all the debug logging (I had a bunch of debug logging in an earlier version of this PR) and it also removed this. I thought it was fine so I kept it. Open to changing this if you prefer.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, all good! Was just wondering why the change.

final_unfinished = self.llm_engine.get_num_unfinished_requests()
pending_tools = len(self.tracking["pending_tool_futures"])
self.logger.info(
f"process_from_queue iteration {iteration_count}: unfinished={final_unfinished}, pending_tools={pending_tools}"
)

# If we have only pending tools but no unfinished requests, sleep briefly
# to let pending tools complete before the next iteration
if self.llm_engine.get_num_unfinished_requests() == 0 and len(self.tracking["pending_tool_futures"]) > 0:
if self.llm_engine.get_num_unfinished_requests() == 0:
time.sleep(1)

return total_processed
Expand Down Expand Up @@ -870,10 +827,22 @@ def init_process_group(
args=(master_address, master_port, rank_offset, world_size, group_name, backend, use_ray, timeout_minutes),
)

def _maybe_drain_requests(self, sleep_s: float = 0.1):
while not self.inflight_updates:
pending_tools = len(self.tracking["pending_tool_futures"])
unfinished = self.llm_engine.get_num_unfinished_requests()

if pending_tools == 0 and unfinished == 0:
break

time.sleep(sleep_s)

def update_weight(self, name, dtype, shape, empty_cache=False):
self._maybe_drain_requests()
return self.llm_engine.collective_rpc("update_weight", args=(name, dtype, shape, empty_cache))

def update_weight_cuda_ipc(self, name, dtype, shape, ipc_handles, empty_cache=False):
self._maybe_drain_requests()
return self.llm_engine.collective_rpc(
"update_weight_cuda_ipc", args=(name, dtype, shape, ipc_handles, empty_cache)
)
Expand All @@ -888,8 +857,15 @@ def wake_up(self, tags: Optional[list[str]] = None):
self.llm_engine.wake_up(tags)

def ready(self):
self._threads_started.wait(timeout=30)
return True

def check_background_threads(self):
if self._prefetch_future.done():
self._prefetch_future.result()
if self._process_future.done():
self._process_future.result()

def get_kv_cache_info(self):
"""Get KV cache max concurrency from the vLLM engine."""
kv_cache_specs = self.llm_engine.model_executor.get_kv_cache_specs()
Expand Down