Skip to content

Commit

Permalink
Test llamacpp when successive regex-guided generations
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Mar 7, 2024
1 parent c1851df commit 05c1d56
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions tests/generate/test_integration_llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,15 +328,28 @@ class Spam(BaseModel):

def test_llamacpp_json_function(model):
model.model.reset()
prompt = "<|im_start|>user\nOutput arguments for the function<|im_end|>\n<|im_start|>assistant\n"
prompt = "<|im_start|>user\nOutput arguments for the function, array with 2 elements<|im_end|>\n<|im_start|>assistant\n"

def function(foo: int, bar: List[int]):
return foo + sum(bar)

rng = torch.Generator(device="cpu")
rng.manual_seed(0)
rng.manual_seed(10)
sequence = generate.json(model, function)(
prompt, max_tokens=100, temperature=0.0, rng=rng
)
assert isinstance(sequence, dict)
assert isinstance(function(**sequence), int)


def test_llamacpp_successive_choices(model):
model.model.reset()

choose = generate.regex(model, r"(one|two|three)")
assert choose("pick a numner") in ["one", "two", "three"]

cities = ["New York", "Paris", "San Francisco"]
city = generate.choice(model, cities)
assert city("pick a city") in cities

assert choose("a number") in ["one", "two", "three"]

0 comments on commit 05c1d56

Please sign in to comment.