diff --git a/open_instruct/vllm_utils3.py b/open_instruct/vllm_utils3.py index bb427f391..279b81345 100644 --- a/open_instruct/vllm_utils3.py +++ b/open_instruct/vllm_utils3.py @@ -368,6 +368,22 @@ def __init__( self.eval_results_queue = eval_results_queue self.actor_manager = actor_manager + # For caching should_stop status. + self._last_should_stop_update = float("-inf") + self._should_stop_value = False + self._should_stop_timeout_s = 5 + + def _should_stop(self) -> bool: + if (time.perf_counter() - self._last_should_stop_update) > self._should_stop_timeout_s: + should_stop_ref = self.actor_manager.should_stop.remote() + ready_refs, _ = ray.wait([should_stop_ref], timeout=0.1) + if ready_refs: + self._should_stop_value = ray.get(ready_refs[0]) + self._last_should_stop_update = time.perf_counter() + else: + ray.cancel(should_stop_ref) + return self._should_stop_value + def process_from_queue(self, timeout: float = 60.0): """Run generation loop using LLMEngine directly, with optional tool support. @@ -375,10 +391,7 @@ def process_from_queue(self, timeout: float = 60.0): int: Number of requests processed (0 or 1) """ while True: - # Non-blocking check for should_stop using ray.wait - should_stop_ref = self.actor_manager.should_stop.remote() - ready_refs, _ = ray.wait([should_stop_ref], timeout=0.1) - if ready_refs and ray.get(ready_refs[0]): + if self._should_stop(): return 0 try: