From 86b7d860cfb055a54131cc687bdf2c4c1cd13034 Mon Sep 17 00:00:00 2001 From: "Victoria Terenina (torymur)" Date: Thu, 31 Oct 2024 16:53:43 +0000 Subject: [PATCH] get_index_dict -> get_transitions --- python/outlines_core/fsm/guide.py | 2 +- src/index.rs | 4 ++-- src/python_bindings/mod.rs | 4 ++-- tests/fsm/test_regex.py | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/python/outlines_core/fsm/guide.py b/python/outlines_core/fsm/guide.py index c7aa0012..18b4523c 100644 --- a/python/outlines_core/fsm/guide.py +++ b/python/outlines_core/fsm/guide.py @@ -306,4 +306,4 @@ def copy(self): def get_index_dict(self): """Returns the Index as a Python Dict object.""" - return self.states_to_token_maps.get_index_dict() + return self.states_to_token_maps.get_transitions() diff --git a/src/index.rs b/src/index.rs index 727061c1..587cd76a 100644 --- a/src/index.rs +++ b/src/index.rs @@ -35,7 +35,7 @@ impl FSMInfo { #[derive(Debug)] pub struct Index { - pub(crate) initial: u32, + initial: u32, finals: HashSet, states_to_token_subsets: HashMap>, eos_token_id: u32, @@ -126,7 +126,7 @@ impl Index { self.finals.contains(&state) } - pub(crate) fn index(&self) -> &HashMap> { + pub(crate) fn transitions(&self) -> &HashMap> { &self.states_to_token_subsets } } diff --git a/src/python_bindings/mod.rs b/src/python_bindings/mod.rs index 368e8ced..046d7ce9 100644 --- a/src/python_bindings/mod.rs +++ b/src/python_bindings/mod.rs @@ -101,8 +101,8 @@ impl PyIndex { self.0.is_final(state) } - fn get_index_dict(&self) -> HashMap> { - self.0.index().clone() + fn get_transitions(&self) -> HashMap> { + self.0.transitions().clone() } fn get_initial_state(&self) -> u32 { diff --git a/tests/fsm/test_regex.py b/tests/fsm/test_regex.py index cdac64d4..40406f01 100644 --- a/tests/fsm/test_regex.py +++ b/tests/fsm/test_regex.py @@ -372,7 +372,7 @@ def test_create_fsm_index_tokenizer(hf_tokenizer_uri, revision): ) assert not empty_token_ids - assert len(states_to_token_subsets.get_index_dict()) / num_fsm_states > 0.94 + assert len(states_to_token_subsets.get_transitions()) / num_fsm_states > 0.94 @pytest.mark.parametrize(