diff --git a/src/python_bindings/mod.rs b/src/python_bindings/mod.rs index 7c3a1d4d..cdb27ef3 100644 --- a/src/python_bindings/mod.rs +++ b/src/python_bindings/mod.rs @@ -486,6 +486,10 @@ pub struct InteregularFSMInfo { states: HashSet, #[pyo3(get)] map: HashMap>, + #[pyo3(get)] + symbol_mapping: HashMap, + #[pyo3(get)] + by_transition: HashMap>, } use crate::interegular::fsm::Alphabet; @@ -494,7 +498,7 @@ use crate::interegular::patterns::Flag; #[pyfunction(name = "parse_pattern_to_fsm")] #[pyo3(text_signature = "(pattern: &str)")] -pub fn parse_pattern_to_fsm_internal(py: Python, pattern: &str) -> PyResult { +pub fn parse_pattern_to_fsm_internal(pattern: &str) -> PyResult { let regex_element = parse_pattern(pattern).map_err(|_| PyValueError::new_err("Invalid pattern"))?; @@ -512,14 +516,15 @@ pub fn parse_pattern_to_fsm_internal(py: Python, pattern: &str) -> PyResult> = fsm_info @@ -535,11 +540,25 @@ pub fn parse_pattern_to_fsm_internal(py: Python, pattern: &str) -> PyResult = alphabet + .symbol_mapping + .iter() + .map(|(k, v)| (*k, (*v).into())) + .collect(); + + let python_by_transition: HashMap> = alphabet + .by_transition + .iter() + .map(|(k, v)| (usize::from(*k), v.iter().map(|&c| c).collect())) + .collect(); + Ok(InteregularFSMInfo { initial: fsm_info.initial.into(), finals: fsm_info.finals.iter().map(|f| (*f).into()).collect(), states: fsm_info.states.iter().map(|s| (*s).into()).collect(), map, + symbol_mapping: python_symbol_mapping, + by_transition: python_by_transition, }) } diff --git a/tests/interegular/test_parse_pattern_to_fsm.py b/tests/interegular/test_parse_pattern_to_fsm.py index afc24978..63d00893 100644 --- a/tests/interegular/test_parse_pattern_to_fsm.py +++ b/tests/interegular/test_parse_pattern_to_fsm.py @@ -1,9 +1,66 @@ # TODO: THIS IS A WORK IN PROGRESS AND WILL BE COMPLETELY REFACTORED BEFORE MERGING +from interegular.fsm import anything_else from outlines_core.fsm.regex import parse_pattern_to_fsm import interegular +class InteregularFSMInfo: + def __init__(self, initial, finals, states, map, symbol_mapping, by_transition): + self.initial = initial + self.finals = finals + self.states = states + self.map = map + self.symbol_mapping = symbol_mapping + self.by_transition = by_transition + + +def map_states_with_symbols(state_map, symbol_mapping): + inv_symbol_mapping = {v: k for k, v in symbol_mapping.items()} + + mapped_states = {} + for state, transitions in state_map.items(): + mapped_transitions = {} + for symbol, next_state in transitions.items(): + mapped_symbol = inv_symbol_mapping.get(symbol, symbol) + mapped_transitions[mapped_symbol] = next_state + mapped_states[state] = mapped_transitions + + return mapped_states + + +def make_fsm_comparable(fsm): + # Create a new symbol mapping + new_symbol_mapping = {} + for symbol, value in fsm.symbol_mapping.items(): + if symbol == "\x00": + new_symbol_mapping[anything_else] = value + else: + new_symbol_mapping[symbol] = value + + # Create a new map + new_map = {} + for state, transitions in fsm.map.items(): + new_transitions = {} + for symbol, next_state in transitions.items(): + if symbol == b"\x00": + new_transitions[anything_else] = next_state + else: + new_transitions[symbol] = next_state + new_map[state] = new_transitions + + new_fsm = InteregularFSMInfo( + states=fsm.states, + initial=fsm.initial, + finals=fsm.finals, + map=new_map, + symbol_mapping=new_symbol_mapping, + by_transition=fsm.by_transition, + ) + + return new_fsm + + def compare_sets(set1, set2): # ensure that the sets are equal return frozenset(set1) == frozenset(set2) @@ -18,6 +75,7 @@ def sort_map(map): def test_parse_pattern_to_fsm(pattern): fsm = parse_pattern_to_fsm(pattern) + fsm = make_fsm_comparable(fsm) ref_pattern = interegular.parse_pattern(pattern) @@ -47,19 +105,28 @@ def test_parse_pattern_to_fsm(pattern): # assert fsm.finals == ref_fsm.finals # assert fsm.map == ref_fsm.map + # make maps deterministic (sort by key) + fsm_map = sort_map(fsm.map) + ref_map = sort_map(ref_fsm.map) + equal_states = frozenset(fsm.states) == frozenset(ref_fsm.states) equal_initial = fsm.initial == ref_fsm.initial equal_finals = frozenset(fsm.finals) == frozenset(ref_fsm.finals) - # equal_map = fsm.map == ref_fsm.map + equal_map = map_states_with_symbols( + fsm.map, fsm.symbol_mapping + ) == map_states_with_symbols(ref_fsm.map, ref_fsm.alphabet._symbol_mapping) print() - if equal_states and equal_initial and equal_finals: # and equal_map: + if equal_states and equal_initial and equal_finals and equal_map: print(f"✅ Test passed for pattern: {pattern}") else: print(f"❌ Test failed for pattern: {pattern}") - print("_symbol_mapping\n", ref_fsm.alphabet._symbol_mapping) - print("by_transition\n", ref_fsm.alphabet.by_transition) + print("fsm: symbol_mapping\n", fsm.symbol_mapping) + print("fsm: by_transition\n", fsm.by_transition) + + print("ref: symbol_mapping\n", ref_fsm.alphabet._symbol_mapping) + print("ref: by_transition\n", ref_fsm.alphabet.by_transition) print("States") print(f" fsm: {frozenset(fsm.states)}") @@ -75,13 +142,18 @@ def test_parse_pattern_to_fsm(pattern): print("Map") - # make maps deterministic (sort by key) - fsm_map = sort_map(fsm.map) - ref_map = sort_map(ref_fsm.map) - print(f" fsm: {fsm_map}") print(f" ref: {ref_map}") + print("Map with symbols") + fsm_map_with_symbols = map_states_with_symbols(fsm_map, fsm.symbol_mapping) + print(f" fsm: {sort_map(fsm_map_with_symbols)}") + + ref_map_with_symbols = map_states_with_symbols( + ref_map, ref_fsm.alphabet._symbol_mapping + ) + print(f" ref: {sort_map(ref_map_with_symbols)}") + return True @@ -89,10 +161,10 @@ def test_parse_pattern_to_fsm(pattern): # tests copied so they can be run as a standalone script if __name__ == "__main__": test_cases = [ - "a", + # "a", # "ab", # "a|b", - # "[ab]", + "[ab]", # TODO: long simple patterns (should work) # "aaaaa", # "davidholtz",