diff --git a/daft/runners/progress_bar.py b/daft/runners/progress_bar.py index ed4c91c9a3..b53df4d3bf 100644 --- a/daft/runners/progress_bar.py +++ b/daft/runners/progress_bar.py @@ -114,27 +114,38 @@ def __init__(self) -> None: self._maxinterval = 5.0 self.tqdm_mod = get_tqdm(False) self.pbars: dict[int, Any] = dict() + self.bar_configs: dict[int, tuple[str, str]] = dict() + self.next_id = 0 def make_new_bar(self, bar_format: str, initial_message: str) -> int: - pbar_id = len(self.pbars) - self.pbars[pbar_id] = self.tqdm_mod( - bar_format=bar_format, - desc=initial_message, - position=pbar_id, - leave=False, - mininterval=1.0, - maxinterval=self._maxinterval, - ) + pbar_id = self.next_id + self.next_id += 1 + self.bar_configs[pbar_id] = (bar_format, initial_message) return pbar_id def update_bar(self, pbar_id: int, message: str) -> None: + if pbar_id not in self.pbars: + if pbar_id not in self.bar_configs: + raise ValueError(f"No bar configuration found for id {pbar_id}") + bar_format, initial_message = self.bar_configs[pbar_id] + self.pbars[pbar_id] = self.tqdm_mod( + bar_format=bar_format, + desc=initial_message, + position=pbar_id, + leave=False, + mininterval=1.0, + maxinterval=self._maxinterval, + ) + del self.bar_configs[pbar_id] self.pbars[pbar_id].set_description_str(message) def close_bar(self, pbar_id: int) -> None: - self.pbars[pbar_id].close() - del self.pbars[pbar_id] + if pbar_id in self.pbars: + self.pbars[pbar_id].close() + del self.pbars[pbar_id] def close(self) -> None: - for p in self.pbars.values(): + for p in list(self.pbars.values()): p.close() - del p + self.pbars.clear() + self.bar_configs.clear()