From ffab2ac4fa273b1b63702d63d3e42035be023017 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Thu, 23 May 2024 01:43:46 -0500 Subject: [PATCH] Use TQDM to track index compilation progress --- .pre-commit-config.yaml | 1 + outlines/fsm/regex.py | 13 ++++++++++++- pyproject.toml | 1 + 3 files changed, 14 insertions(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8fe83f89a..b528f0e8e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -30,3 +30,4 @@ repos: - id: mypy args: [--allow-redefinition] exclude: ^examples/ + additional_dependencies: [types-tqdm] diff --git a/outlines/fsm/regex.py b/outlines/fsm/regex.py index 0941bbb9f..b68e31897 100644 --- a/outlines/fsm/regex.py +++ b/outlines/fsm/regex.py @@ -26,6 +26,7 @@ anything_else, ) from numba.typed.typedobjectutils import _nonoptional +from tqdm import tqdm if TYPE_CHECKING: from outlines.models.tokenizer import Tokenizer @@ -692,6 +693,12 @@ def create_fsm_index_end_to_end( seen: Set[int] = set() next_states = {fsm_info.initial} + pbar = tqdm( + total=len(set(fsm_info.transitions.values())) + + 1, # all transitions plus initial + desc="Compiling FSM index for all state transitions", + ) + while next_states: start_state = next_states.pop() @@ -713,7 +720,11 @@ def create_fsm_index_end_to_end( if end_state not in seen: next_states.add(end_state) - seen.add(start_state) + if start_state not in seen: + pbar.update(1) + seen.add(start_state) + + pbar.close() return states_to_token_subsets diff --git a/pyproject.toml b/pyproject.toml index 41c306b14..b18036ffc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ dependencies = [ "referencing", "jsonschema", "requests", + "tqdm" ] dynamic = ["version"]