diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 1372a3211..5cc655847 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -2303,28 +2303,6 @@ def weight_sync_thread( logger.info("[Weight Sync Thread] 🛑 Stopping weight sync thread") -def generate_thread(args, vllm_engines, resume_training_step, stop_event, generate_metrics_Q): - """Thread function that repeatedly calls process_from_queue on vllm engines.""" - logger.info("[Generate Thread] 🚀 Starting generation thread") - while not stop_event.is_set(): - with Timer("🔥 Generation time") as timer: - processed_results, _ = ray_get_with_progress( - [engine.process_from_queue.remote(timeout=20) for engine in vllm_engines], - desc="[Generate Thread] Waiting for vLLM engines to process", - enable=args.verbose, - ) - num_processed = sum(int(result) for result in processed_results) - # Suppress timing output if nothing was processed - if num_processed == 0: - timer.noop = True - if num_processed > 0: - try: - generate_metrics_Q.put_nowait({"time/generation": timer.duration}) - except Full: - logger.warning("[Generate Thread] generate metrics queue full, skipping metric") - logger.info("[Generate Thread] 🛑 Stopping generation thread") - - def one_training_step( args: Args, policy_group: ModelGroup, @@ -2673,7 +2651,6 @@ def cleanup_training_resources( actor_manager: ActorManager, ) -> None: """Clean up all training resources including threads and Ray queues.""" - # Signal generate_thread to stop stop_event.set() logger.info("Signaling all actors to stop...") @@ -2782,14 +2759,13 @@ def run_training( model_dims, ) - logger.info("======== ✅ generation thread starts =========") - generation_future = executor.submit( - generate_thread, args, vllm_engines, resume_training_step, stop_event, generate_metrics_Q - ) - - # setup health check function to check that everything is still alive def health_check_fn(): - [f.result() for f in [packing_future, generation_future, weight_sync_thread_future] if f.done()] + [f.result() for f in [packing_future, weight_sync_thread_future] if f.done()] + ray_get_with_progress( + [engine.check_background_threads.remote() for engine in vllm_engines], + desc="Checking vLLM engine health", + enable=False, + ) # Send initial data to ensure we have a N-step offset. for _ in range(args.async_steps): @@ -2826,7 +2802,9 @@ def health_check_fn(): ) # Check if any of the threads have raised an exception. + health_check_start = time.perf_counter() health_check_fn() + health_check_time = time.perf_counter() - health_check_start logger.debug(f"[Main Thread] Triggered weight sync for step {training_step}") weight_sync_trigger_event.set() @@ -2850,7 +2828,6 @@ def health_check_fn(): is_eval=True, ) - # The generate_thread is now handling vLLM processing asynchronously collated_data, data_thread_metrics, num_total_tokens, num_step_tokens, prompt_lengths, response_lengths = ( load_data_from_packing_thread(packed_sequences_Q, num_total_tokens, stop_event, health_check_fn) ) @@ -2863,6 +2840,8 @@ def health_check_fn(): except Empty: logger.info("[Main Thread] didn't get train generation metrics") + data_thread_metrics["time/health_check"] = health_check_time + one_training_step( args, policy_group, diff --git a/open_instruct/vllm_utils3.py b/open_instruct/vllm_utils3.py index cf1a05b99..1b44974bd 100644 --- a/open_instruct/vllm_utils3.py +++ b/open_instruct/vllm_utils3.py @@ -18,6 +18,7 @@ import dataclasses import os import queue +import threading import time from collections import defaultdict from concurrent import futures @@ -395,10 +396,15 @@ def __init__( self._should_stop_value = False self._should_stop_timeout_s = 5 - self._executor = futures.ThreadPoolExecutor(max_workers=1) - self._prefetch_future = self._executor.submit(self._prefetch_worker) + # Initialize instance variables before starting threads self.tracking = _init_tool_tracking() self.request_outputs = {} + self._threads_started = threading.Event() + + # Start background threads + self._executor = futures.ThreadPoolExecutor(max_workers=2) + self._prefetch_future = self._executor.submit(self._prefetch_worker) + self._process_future = self._executor.submit(self._process_from_queue) def get_model_dims_dict(self): """Get only the model dimensions as a simple dict without loading weights.""" @@ -431,8 +437,9 @@ def _should_stop(self) -> bool: def _prefetch_worker(self, sleep_length_s: int = 1): """Background worker that prefetches requests until we have enough buffered.""" + self._threads_started.set() while True: - if self._should_stop(): + if not self.inflight_updates and self._should_stop(): time.sleep(sleep_length_s) continue current_unfinished = self.llm_engine.get_num_unfinished_requests() @@ -456,58 +463,18 @@ def _insert_result_to_queue(self, result, is_eval: bool): results_queue = self.eval_results_queue if is_eval else self.results_queue results_queue.put(result) - def _should_exit(self) -> bool: - """Determine if the processing loop should exit. - - Returns: - bool: True if the loop should exit, False otherwise. - """ - # Check stop condition first (cheapest check) - stop_requested = self._should_stop() - - # Case 1: inflight_updates enabled and stop requested - exit immediately - if self.inflight_updates and stop_requested: - return True - - # Now check for pending work (only if needed) - if stop_requested: - # Need to check if we have pending work - pending_tools = len(self.tracking["pending_tool_futures"]) - unfinished = self.llm_engine.get_num_unfinished_requests() - - # Case 2: stop requested and no pending work - exit - if pending_tools == 0 and unfinished == 0: - return True - # Otherwise, we have pending work and should continue - return False - - # No stop requested - check if there's any work to do - pending_tools = len(self.tracking["pending_tool_futures"]) - unfinished = self.llm_engine.get_num_unfinished_requests() - - # Case 3: no work left at all - exit - if pending_tools == 0 and unfinished == 0: - return True - - # Otherwise, continue processing - return False - - def process_from_queue(self, timeout: float = 60.0): + def _process_from_queue(self, timeout: float = 60.0): """Run generation loop using LLMEngine directly, with optional tool support. - Runs continuously until should_stop is set, periodically adding new requests - and yielding control to allow weight synchronization. + Runs continuously in a background thread, processing requests from the engine. Returns: int: Number of requests processed """ - - # Use persistent instance variables for tracking and outputs - # This ensures state is maintained across multiple calls total_processed = 0 iteration_count = 0 - while not self._should_exit(): + while True: iteration_count += 1 # Health check: ensure prefetch worker is alive. This will raise if it has crashed. @@ -558,17 +525,7 @@ def process_from_queue(self, timeout: float = 60.0): total_processed += self._finalize_sub_request( output.request_id, output, complete_output, current_time ) - - if self.verbose and iteration_count % 100 == 0: - final_unfinished = self.llm_engine.get_num_unfinished_requests() - pending_tools = len(self.tracking["pending_tool_futures"]) - self.logger.info( - f"process_from_queue iteration {iteration_count}: unfinished={final_unfinished}, pending_tools={pending_tools}" - ) - - # If we have only pending tools but no unfinished requests, sleep briefly - # to let pending tools complete before the next iteration - if self.llm_engine.get_num_unfinished_requests() == 0 and len(self.tracking["pending_tool_futures"]) > 0: + if self.llm_engine.get_num_unfinished_requests() == 0: time.sleep(1) return total_processed @@ -870,10 +827,22 @@ def init_process_group( args=(master_address, master_port, rank_offset, world_size, group_name, backend, use_ray, timeout_minutes), ) + def _maybe_drain_requests(self, sleep_s: float = 0.1): + while not self.inflight_updates: + pending_tools = len(self.tracking["pending_tool_futures"]) + unfinished = self.llm_engine.get_num_unfinished_requests() + + if pending_tools == 0 and unfinished == 0: + break + + time.sleep(sleep_s) + def update_weight(self, name, dtype, shape, empty_cache=False): + self._maybe_drain_requests() return self.llm_engine.collective_rpc("update_weight", args=(name, dtype, shape, empty_cache)) def update_weight_cuda_ipc(self, name, dtype, shape, ipc_handles, empty_cache=False): + self._maybe_drain_requests() return self.llm_engine.collective_rpc( "update_weight_cuda_ipc", args=(name, dtype, shape, ipc_handles, empty_cache) ) @@ -888,8 +857,15 @@ def wake_up(self, tags: Optional[list[str]] = None): self.llm_engine.wake_up(tags) def ready(self): + self._threads_started.wait(timeout=30) return True + def check_background_threads(self): + if self._prefetch_future.done(): + self._prefetch_future.result() + if self._process_future.done(): + self._process_future.result() + def get_kv_cache_info(self): """Get KV cache max concurrency from the vLLM engine.""" kv_cache_specs = self.llm_engine.model_executor.get_kv_cache_specs()