diff --git a/python/ray/_private/worker.py b/python/ray/_private/worker.py index f7bbd1555ba2..ffe72feed445 100644 --- a/python/ray/_private/worker.py +++ b/python/ray/_private/worker.py @@ -948,11 +948,25 @@ def print_logs(self): # probably will not be able to consume the log messages as rapidly # as they are coming in. # This is meaningful only for GCS subscriber. - last_polling_batch_size = 0 + last_polling_batch_size = -1 + consecutive_zero_batch_cnt = 0 job_id_hex = self.current_job_id.hex() while True: - # Exit if we received a signal that we should stop. - if self.threads_stopped.is_set(): + if last_polling_batch_size == 0: + consecutive_zero_batch_cnt += 1 + # Sleep for a while to avoid consecutive_zero_batch_cnt + # increasing too fast. + time.sleep(0.5) + else: + consecutive_zero_batch_cnt = 0 + # Only returns when main thread is dead. + # But wait until the polling batch size decreases to 0 and remains so + # for a while before returning; this ensures that there are no + # pending logs. + if ( + not threading.main_thread().is_alive() + and consecutive_zero_batch_cnt >= 3 + ): return data = subscriber.poll() @@ -961,11 +975,7 @@ def print_logs(self): last_polling_batch_size = 0 continue - if ( - self._filter_logs_by_job - and data["job"] - and data["job"] != job_id_hex - ): + if self._filter_logs_by_job and data["job"] != job_id_hex: last_polling_batch_size = 0 continue @@ -2529,7 +2539,6 @@ def connect( worker.logger_thread = threading.Thread( target=worker.print_logs, name="ray_print_logs" ) - worker.logger_thread.daemon = True worker.logger_thread.start() # Setup tracing here