diff --git a/CHANGELOG.md b/CHANGELOG.md index 1f24a4ea7..04cbe9ae4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Added `TTY_INTERACTIVE` environment variable to force interactive mode off or on https://github.com/Textualize/rich/pull/3777 +- Added Context Manager support for TaskID objects returned by Progress.add_task. Allowing for `with progress.add_task(...) as taskid: ...` which automatically removes the progress bar for that task upon exiting the current context. ## [14.0.0] - 2025-03-30 diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 4b04786b9..3b5cfb101 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -94,3 +94,4 @@ The following people have contributed to the development of Rich: - [Jonathan Helmus](https://github.com/jjhelmus) - [Brandon Capener](https://github.com/bcapener) - [Alex Zheng](https://github.com/alexzheng111) +- [Jacob Ogden](https://github.com/AetherBreaker) diff --git a/rich/progress.py b/rich/progress.py index ef6ad60f0..d0a3aef58 100644 --- a/rich/progress.py +++ b/rich/progress.py @@ -27,7 +27,6 @@ List, Literal, NamedTuple, - NewType, Optional, TextIO, Tuple, @@ -51,8 +50,6 @@ from .table import Column, Table from .text import Text, TextType -TaskID = NewType("TaskID", int) - ProgressType = TypeVar("ProgressType") GetTimeCallable = Callable[[], float] @@ -61,6 +58,25 @@ _I = typing.TypeVar("_I", TextIO, BinaryIO) +class TaskID(int): + def __new__(cls, task_id: int, prog_instance: Progress) -> Self: + return super().__new__(cls, task_id) + + def __init__(self, task_id: int, prog_instance: Progress) -> None: + self.prog = prog_instance + + def __enter__(self) -> Self: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + self.prog.remove_task(self) + + class _TrackThread(Thread): """A thread to periodically update progress.""" @@ -1096,7 +1112,7 @@ def __init__( self.disable = disable self.expand = expand self._tasks: Dict[TaskID, Task] = {} - self._task_index: TaskID = TaskID(0) + self._task_index: TaskID = TaskID(0, self) self.live = Live( console=console or get_console(), auto_refresh=auto_refresh, @@ -1635,7 +1651,7 @@ def add_task( if start: self.start_task(self._task_index) new_task_index = self._task_index - self._task_index = TaskID(int(self._task_index) + 1) + self._task_index = TaskID(int(self._task_index) + 1, self) self.refresh() return new_task_index