From c0152989605e9880eb2f10719b2479cb340d58b1 Mon Sep 17 00:00:00 2001 From: Jacob Ogden Date: Wed, 23 Jul 2025 10:29:12 -0400 Subject: [PATCH 1/4] Added support for context managers to TaskID to automatically remove a Task when context is exited. --- rich/progress.py | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/rich/progress.py b/rich/progress.py index ef6ad60f0..259c65940 100644 --- a/rich/progress.py +++ b/rich/progress.py @@ -7,6 +7,7 @@ from collections import deque from dataclasses import dataclass, field from datetime import timedelta +from functools import partial from io import RawIOBase, UnsupportedOperation from math import ceil from mmap import mmap @@ -51,8 +52,6 @@ from .table import Column, Table from .text import Text, TextType -TaskID = NewType("TaskID", int) - ProgressType = TypeVar("ProgressType") GetTimeCallable = Callable[[], float] @@ -61,6 +60,25 @@ _I = typing.TypeVar("_I", TextIO, BinaryIO) +class TaskID(int): + def __new__(cls, task_id: int, prog_instance: Progress | type[Progress]): + return super().__new__(cls, task_id) + + def __init__(self, task_id: int, prog_instance: Progress | type[Progress]): + self.remove = partial(prog_instance.remove_task, self) + + def __enter__(self) -> Self: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ): + self.remove() + + class _TrackThread(Thread): """A thread to periodically update progress.""" @@ -1096,7 +1114,7 @@ def __init__( self.disable = disable self.expand = expand self._tasks: Dict[TaskID, Task] = {} - self._task_index: TaskID = TaskID(0) + self._task_index: TaskID = TaskID(0, self) self.live = Live( console=console or get_console(), auto_refresh=auto_refresh, @@ -1635,7 +1653,7 @@ def add_task( if start: self.start_task(self._task_index) new_task_index = self._task_index - self._task_index = TaskID(int(self._task_index) + 1) + self._task_index = TaskID(int(self._task_index) + 1, self) self.refresh() return new_task_index From ab786c840cc15d65237597f2c4327d159637c494 Mon Sep 17 00:00:00 2001 From: Jacob Ogden Date: Wed, 23 Jul 2025 10:29:56 -0400 Subject: [PATCH 2/4] Forgot to remove NewType import after last change. --- rich/progress.py | 1 - 1 file changed, 1 deletion(-) diff --git a/rich/progress.py b/rich/progress.py index 259c65940..d2e9e6138 100644 --- a/rich/progress.py +++ b/rich/progress.py @@ -28,7 +28,6 @@ List, Literal, NamedTuple, - NewType, Optional, TextIO, Tuple, From 7cf2fb8feadbdaea63f25b0e4551a6c9693383e7 Mon Sep 17 00:00:00 2001 From: Jacob Ogden Date: Thu, 24 Jul 2025 12:00:20 -0400 Subject: [PATCH 3/4] Adjusted TaskID to not need the use of functools.partial --- CHANGELOG.md | 1 + CONTRIBUTORS.md | 1 + rich/progress.py | 5 ++--- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1f24a4ea7..04cbe9ae4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Added `TTY_INTERACTIVE` environment variable to force interactive mode off or on https://github.com/Textualize/rich/pull/3777 +- Added Context Manager support for TaskID objects returned by Progress.add_task. Allowing for `with progress.add_task(...) as taskid: ...` which automatically removes the progress bar for that task upon exiting the current context. ## [14.0.0] - 2025-03-30 diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 4b04786b9..3b5cfb101 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -94,3 +94,4 @@ The following people have contributed to the development of Rich: - [Jonathan Helmus](https://github.com/jjhelmus) - [Brandon Capener](https://github.com/bcapener) - [Alex Zheng](https://github.com/alexzheng111) +- [Jacob Ogden](https://github.com/AetherBreaker) diff --git a/rich/progress.py b/rich/progress.py index d2e9e6138..e30d6f931 100644 --- a/rich/progress.py +++ b/rich/progress.py @@ -7,7 +7,6 @@ from collections import deque from dataclasses import dataclass, field from datetime import timedelta -from functools import partial from io import RawIOBase, UnsupportedOperation from math import ceil from mmap import mmap @@ -64,7 +63,7 @@ def __new__(cls, task_id: int, prog_instance: Progress | type[Progress]): return super().__new__(cls, task_id) def __init__(self, task_id: int, prog_instance: Progress | type[Progress]): - self.remove = partial(prog_instance.remove_task, self) + self.prog = prog_instance def __enter__(self) -> Self: return self @@ -75,7 +74,7 @@ def __exit__( exc_value: BaseException | None, traceback: TracebackType | None, ): - self.remove() + self.prog.remove_task(self) class _TrackThread(Thread): From 4ef0cef8b0e3d96204527632ee77c254019458ea Mon Sep 17 00:00:00 2001 From: Jacob Ogden Date: Thu, 24 Jul 2025 12:21:54 -0400 Subject: [PATCH 4/4] Fixed typechecking errors found by mypy --- rich/progress.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/rich/progress.py b/rich/progress.py index e30d6f931..d0a3aef58 100644 --- a/rich/progress.py +++ b/rich/progress.py @@ -59,10 +59,10 @@ class TaskID(int): - def __new__(cls, task_id: int, prog_instance: Progress | type[Progress]): + def __new__(cls, task_id: int, prog_instance: Progress) -> Self: return super().__new__(cls, task_id) - def __init__(self, task_id: int, prog_instance: Progress | type[Progress]): + def __init__(self, task_id: int, prog_instance: Progress) -> None: self.prog = prog_instance def __enter__(self) -> Self: @@ -73,7 +73,7 @@ def __exit__( exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, - ): + ) -> None: self.prog.remove_task(self)