Skip to content

Commit

Permalink
move test & benchmarks to bindings/python
Browse files Browse the repository at this point in the history
  • Loading branch information
ErikKaum committed Aug 20, 2024
1 parent dbe28ba commit 8dd9e91
Show file tree
Hide file tree
Showing 15 changed files with 2,188 additions and 0 deletions.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
91 changes: 91 additions & 0 deletions bindings/python/tests/fsm/test_fsm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import pytest
from outlines_core.fsm.fsm import RegexFSM, StopAtEosFSM


def assert_expected_tensor_ids(tensor, ids):
assert len(tensor) == len(ids)
norm_tensor = sorted(map(int, tensor))
norm_ids = sorted(map(int, tensor))
assert norm_tensor == norm_ids, (norm_tensor, norm_ids)


def test_stop_at_eos():
class MockTokenizer:
vocabulary = {"a": 1, "eos": 2}
eos_token_id = 2

with pytest.warns(UserWarning):
fsm = StopAtEosFSM(MockTokenizer())

assert fsm.allowed_token_ids(fsm.start_state) is None
assert fsm.allowed_token_ids(fsm.final_state) == [2]
assert fsm.next_state(fsm.start_state, 2) == fsm.final_state
assert fsm.next_state(fsm.start_state, 1) == fsm.start_state
assert fsm.is_final_state(fsm.start_state) is False
assert fsm.is_final_state(fsm.final_state) is True


def test_regex_vocabulary_error():
class MockTokenizer:
vocabulary = {"a": 1}
special_tokens = {"eos"}
eos_token_id = 3

def convert_token_to_string(self, token):
return token

regex_str = "[1-9]"

with pytest.raises(ValueError, match="The vocabulary"):
RegexFSM(regex_str, MockTokenizer())


def test_regex():
class MockTokenizer:
vocabulary = {"1": 1, "a": 2, "eos": 3}
special_tokens = {"eos"}
eos_token_id = 3

def convert_token_to_string(self, token):
return token

regex_str = "[1-9]"
tokenizer = MockTokenizer()

with pytest.warns(UserWarning):
fsm = RegexFSM(regex_str, tokenizer)

assert fsm.states_to_token_maps == {0: {1: 1}}
assert_expected_tensor_ids(fsm.allowed_token_ids(state=0), [1])
assert fsm.next_state(state=0, token_id=1) == 1
assert fsm.next_state(state=0, token_id=tokenizer.eos_token_id) == -1

assert fsm.is_final_state(0) is False

for state in fsm.final_states:
assert fsm.is_final_state(state) is True


def test_regex_final_state():
"""Make sure that the FSM stays in the final state as we keep generating"""

class MockTokenizer:
vocabulary = {"`": 101, ".": 102, "\n": 103, "eos": 104}
special_tokens = {"eos"}
eos_token_id = 104

def convert_token_to_string(self, token):
return token

regex_str = r"`\n(\.\n)?`\n"
tokenizer = MockTokenizer()

with pytest.warns(UserWarning):
fsm = RegexFSM(regex_str, tokenizer)

state = fsm.next_state(state=4, token_id=103)
assert state == 5
assert fsm.is_final_state(state)

state = fsm.next_state(state=5, token_id=103)
assert fsm.is_final_state(state)
189 changes: 189 additions & 0 deletions bindings/python/tests/fsm/test_guide.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
import pytest
from outlines_core.fsm.guide import Generate, RegexGuide, StopAtEOSGuide, Write


def assert_expected_tensor_ids(tensor, ids):
assert len(tensor) == len(ids)
norm_tensor = sorted(map(int, tensor))
norm_ids = sorted(map(int, tensor))
assert norm_tensor == norm_ids, (norm_tensor, norm_ids)


def test_stop_at_eos():
class MockTokenizer:
vocabulary = {"a": 1, "eos": 2}
eos_token_id = 2

fsm = StopAtEOSGuide(MockTokenizer())

instruction = fsm.get_next_instruction(fsm.start_state)
assert isinstance(instruction, Generate)
assert instruction.tokens is None

instruction = fsm.get_next_instruction(fsm.final_state)
assert isinstance(instruction, Write)
assert instruction.tokens == [2]

assert fsm.get_next_state(fsm.start_state, 2) == fsm.final_state
assert fsm.get_next_state(fsm.start_state, 1) == fsm.start_state
assert fsm.is_final_state(fsm.start_state) is False
assert fsm.is_final_state(fsm.final_state) is True


def test_regex_vocabulary_error():
class MockTokenizer:
vocabulary = {"a": 1}
special_tokens = {"eos"}
eos_token_id = 3

def convert_token_to_string(self, token):
return token

regex_str = "[1-9]"

with pytest.raises(ValueError, match="The vocabulary"):
RegexGuide(regex_str, MockTokenizer())


def test_regex():
class MockTokenizer:
vocabulary = {"1": 1, "a": 2, "eos": 3}
special_tokens = {"eos"}
eos_token_id = 3

def convert_token_to_string(self, token):
return token

regex_str = "[1-9]"
tokenizer = MockTokenizer()
fsm = RegexGuide(regex_str, tokenizer)

assert fsm.states_to_token_maps == {0: {1: 1}}

instruction = fsm.get_next_instruction(0)
assert isinstance(instruction, Generate)
assert_expected_tensor_ids(instruction.tokens, [1])

assert fsm.get_next_state(state=0, token_id=1) == 1
assert fsm.get_next_state(state=0, token_id=tokenizer.eos_token_id) == -1

assert fsm.is_final_state(0) is False

for state in fsm.final_states:
assert fsm.is_final_state(state) is True


def test_regex_multi_byte_llama_like():
class MockTokenizer:
vocabulary = {
"1": 1,
"a": 2,
"eos": 3,
"😍": 4,
"<0xF0>": 5,
"<0x9F>": 6,
"<0x98>": 7,
"<0x88>": 8, # 😈
"\ufffd": 9,
"\ufffd\ufffd": 10,
}
special_tokens = {"eos"}
eos_token_id = 3

def convert_token_to_string(self, token):
if token[0] == "<":
return "\ufffd"
return token

regex_str = "[😁-😎]"
tokenizer = MockTokenizer()
fsm = RegexGuide(regex_str, tokenizer)

assert fsm.states_to_token_maps == {
0: {5: 1, 4: 2},
1: {6: 3},
3: {7: 4},
4: {8: 2},
}

instruction = fsm.get_next_instruction(0)
assert isinstance(instruction, Generate)
assert_expected_tensor_ids(instruction.tokens, [5, 4])

assert fsm.get_next_state(state=0, token_id=5) == 1
assert fsm.get_next_state(state=0, token_id=tokenizer.eos_token_id) == -1

assert fsm.is_final_state(0) is False

for state in fsm.final_states:
assert fsm.is_final_state(state) is True


def test_regex_multi_byte_gpt2_like():
class MockTokenizer:
vocabulary = {
"1": 1,
"a": 2,
"eos": 3,
"😍": 4,
" ": 5,
"\ufffd": 6,
"\ufffd\ufffd": 7,
"ðŁĺ": 8,
"Ī": 9, # '😈'
"Ġð": 10,
"ŁĺĪ": 11, # ' 😈'
}
special_tokens = {"eos"}
eos_token_id = 3

def convert_token_to_string(self, token):
if self.vocabulary[token] >= 8:
return "\ufffd"
return token

regex_str = " [😁-😎]"
tokenizer = MockTokenizer()
fsm = RegexGuide(regex_str, tokenizer)

assert fsm.states_to_token_maps == {
0: {5: 1, 10: 2},
1: {8: 5, 4: 3},
2: {11: 3},
5: {9: 3},
}

instruction = fsm.get_next_instruction(0)
assert isinstance(instruction, Generate)
assert_expected_tensor_ids(instruction.tokens, [5, 10])

assert fsm.get_next_state(state=0, token_id=5) == 1
assert fsm.get_next_state(state=0, token_id=tokenizer.eos_token_id) == -1

assert fsm.is_final_state(0) is False

for state in fsm.final_states:
assert fsm.is_final_state(state) is True


def test_regex_final_state():
"""Make sure that the FSM stays in the final state as we keep generating"""

class MockTokenizer:
vocabulary = {"`": 101, ".": 102, "\n": 103, "eos": 104}
special_tokens = {"eos"}
eos_token_id = 104

def convert_token_to_string(self, token):
return token

regex_str = r"`\n(\.\n)?`\n"
tokenizer = MockTokenizer()
fsm = RegexGuide(regex_str, tokenizer)

state = fsm.get_next_state(state=4, token_id=103)
assert state == 5
assert fsm.is_final_state(state)

state = fsm.get_next_state(state=5, token_id=103)
assert fsm.is_final_state(state)
Loading

0 comments on commit 8dd9e91

Please sign in to comment.