Skip to content

Commit

Permalink
Ensure no implicit max_tokens in models.llamacpp
Browse files Browse the repository at this point in the history
  • Loading branch information
lapp0 committed Jun 21, 2024
1 parent be39aaa commit e66e68b
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
4 changes: 3 additions & 1 deletion outlines/models/llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,9 @@ def prepare_generation_parameters(

# Somehow `llama-cpp-python` generates `max_tokens + 1` tokens
if "max_tokens" not in llama_cpp_params:
if max_tokens is not None:
if max_tokens is None:
llama_cpp_params["max_tokens"] = int(2**30)
else:
llama_cpp_params["max_tokens"] = max_tokens - 1
else:
llama_cpp_params["max_tokens"] = llama_cpp_params["max_tokens"] - 1
Expand Down
19 changes: 19 additions & 0 deletions tests/generate/test_integration_llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,3 +356,22 @@ def test_tokenizer_vocabulary_decode_sanity():
]
)
assert decoded_nl_token == vocab_nl_token


def test_no_length_constraint_when_unset():
"""Assert that models.llamacpp doesn't have an implicit max_tokens preventing full sequence generation"""
import llama_cpp

model = models.llamacpp(
repo_id="M4-ai/TinyMistral-248M-v2-Instruct-GGUF",
filename="TinyMistral-248M-v2-Instruct.Q4_K_M.gguf",
tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained(
"Locutusque/TinyMistral-248M-Instruct"
),
)

long_pattern = "abcdefg" * 10
generator = generate.regex(model, long_pattern)

output = generator("a")
assert re.match(long_pattern, output)

0 comments on commit e66e68b

Please sign in to comment.