diff --git a/adaptive/runner.py b/adaptive/runner.py index 8ba3e0e7f..84fcecd9f 100644 --- a/adaptive/runner.py +++ b/adaptive/runner.py @@ -776,6 +776,55 @@ def live_info(self, *, update_interval: float = 0.1) -> None: """ return live_info(self, update_interval=update_interval) + def live_info_terminal( + self, *, update_interval: float = 0.5, overwrite_previous: bool = True + ) -> asyncio.Task: + """ + Display live information about the runner in the terminal. + + This function provides a live update of the runner's status in the terminal. + The update can either overwrite the previous status or be printed on a new line. + + Parameters + ---------- + update_interval : float, optional + The time interval (in seconds) at which the runner's status is updated in the terminal. + Default is 0.5 seconds. + overwrite_previous : bool, optional + If True, each update will overwrite the previous status in the terminal. + If False, each update will be printed on a new line. + Default is True. + + Returns + ------- + asyncio.Task + The asynchronous task responsible for updating the runner's status in the terminal. + + Examples + -------- + >>> runner = AsyncRunner(...) + >>> runner.live_info_terminal(update_interval=1.0, overwrite_previous=False) + + Notes + ----- + This function uses ANSI escape sequences to control the terminal's cursor position. + It might not work as expected on all terminal emulators. + """ + + async def _update(runner: AsyncRunner) -> None: + try: + while not runner.task.done(): + if overwrite_previous: + # Clear the terminal + print("\033[H\033[J", end="") + print(_info_text(runner, separator="\t")) + await asyncio.sleep(update_interval) + + except asyncio.CancelledError: + print("Live info display cancelled.") + + return self.ioloop.create_task(_update(self)) + async def _run(self) -> None: first_completed = asyncio.FIRST_COMPLETED @@ -855,6 +904,43 @@ async def _saver(): return self.saving_task +def _info_text(runner, separator: str = "\n"): + status = runner.status() + + color_map = { + "cancelled": "\033[33m", # Yellow + "failed": "\033[31m", # Red + "running": "\033[34m", # Blue + "finished": "\033[32m", # Green + } + + overhead = runner.overhead() + if overhead < 50: + overhead_color = "\033[32m" # Green + else: + overhead_color = "\033[31m" # Red + + info = [ + ("time", str(datetime.now())), + ("status", f"{color_map[status]}{status}\033[0m"), + ("elapsed time", str(timedelta(seconds=runner.elapsed_time()))), + ("overhead", f"{overhead_color}{overhead:.2f}%\033[0m"), + ] + + with suppress(Exception): + info.append(("# of points", runner.learner.npoints)) + + with suppress(Exception): + info.append(("# of samples", runner.learner.nsamples)) + + with suppress(Exception): + info.append(("latest loss", f'{runner.learner._cache["loss"]:.3f}')) + + width = 30 + formatted_info = [f"{k}: {v}".ljust(width) for i, (k, v) in enumerate(info)] + return separator.join(formatted_info) + + # Default runner Runner = AsyncRunner