diff --git a/tests/fsm/test_statistical.py b/tests/fsm/test_statistical.py index 20ef28c..942796a 100644 --- a/tests/fsm/test_statistical.py +++ b/tests/fsm/test_statistical.py @@ -1,27 +1,12 @@ from typing import Callable, List, Optional import numpy as np -from outlines_core.fsm.guide import RegexGuide +from outlines_core.fsm import Guide, Index, Vocabulary from pytest import approx from scipy.stats import ks_2samp def test_generate_length(): - class MockTokenizer: - vocabulary = {"0": 1, "1": 2, "eos": 3} - inverse_vocabulary = {1: "0", 2: "1", 3: ""} - special_tokens = {"eos"} - eos_token_id = 3 - - def length(self): - return len(self.vocabulary) - - def convert_token_to_string(self, token): - return token - - def decode(self, token): - return self.inverse_vocabulary[token] - class NextToken: def __init__( self, @@ -43,17 +28,18 @@ def __call__( next_t = [self.rng.choice(self.states, p=prob / np.sum(prob))] return tokens + next_t if tokens is not None else next_t - def generate(model, tokenizer, regex_str) -> Optional[List[int]]: - n_tokens = tokenizer.length() + def generate(model, regex_str) -> Optional[List[int]]: + vocabulary = Vocabulary(3, {"0": [1], "1": [2], "2": [4]}) + index = Index(regex_str, vocabulary) + guide = Guide(index) - fsm = RegexGuide.from_regex(regex_str, tokenizer) - state: int = fsm.initial_state + n_tokens = len(vocabulary) tokens = None - while state != -1: - allowed = fsm.get_next_instruction(state).tokens + allowed = guide.get_start_tokens() + while not guide.is_finished(): mask: List[int] = [1 if s in allowed else 0 for s in range(1, n_tokens + 1)] tokens = model(tokens, mask=mask) - state = fsm.get_next_state(state, tokens[-1]) + allowed = guide.read_next_token(tokens[-1]) return tokens def prob_non_markov(tokens: List[int]) -> np.array: @@ -75,16 +61,15 @@ def prob_markov(token: List[int]) -> np.array: n_samples: int = 250 regex_str: str = r"11[01]+|0[01]*" - tokenizer = MockTokenizer() model1 = NextToken(prob_markov, p0, states, 30127) model2 = NextToken(prob_non_markov, p0, states, 24601) lengths1: np.array = np.zeros((n_samples,)) lengths2: np.array = np.zeros((n_samples,)) for i in range(n_samples): - out1: List[int] = generate(model1, tokenizer, regex_str) + out1: List[int] = generate(model1, regex_str) lengths1[i] = len(out1) - 1 # take off the eos token - out2: List[int] = generate(model2, tokenizer, regex_str) + out2: List[int] = generate(model2, regex_str) lengths2[i] = len(out2) - 1 # take off the eos token # 2 sample KS test to check that lengths has the same distribution as