-
Notifications
You must be signed in to change notification settings - Fork 485
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add llama logits processor #556
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,3 +4,4 @@ __pycache__ | |
docs/build | ||
.coverage | ||
.idea/ | ||
*.gguf |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from enum import Enum | ||
|
||
from llama_cpp import Llama, LogitsProcessorList | ||
from pydantic import BaseModel, constr | ||
|
||
from outlines.generate.processors import JSONLogitsProcessor | ||
from outlines.models.llamacpp import LlamaCppTokenizer | ||
|
||
|
||
class Weapon(str, Enum): | ||
sword = "sword" | ||
axe = "axe" | ||
mace = "mace" | ||
spear = "spear" | ||
bow = "bow" | ||
crossbow = "crossbow" | ||
|
||
|
||
class Armor(str, Enum): | ||
leather = "leather" | ||
chainmail = "chainmail" | ||
plate = "plate" | ||
|
||
|
||
class Character(BaseModel): | ||
name: constr(max_length=10) | ||
age: int | ||
armor: Armor | ||
weapon: Weapon | ||
strength: int | ||
|
||
|
||
if __name__ == "__main__": | ||
llama = Llama("./phi-2.Q4_K_M.gguf") | ||
tokenizer = LlamaCppTokenizer(llama) | ||
|
||
prompt = "Instruct: You are a leading role play gamer. You have seen thousands of different characters and their attributes.\nPlease return a JSON object with common attributes of an RPG character. Give me a character description\nOutput:" | ||
|
||
logits_processor = JSONLogitsProcessor(Character, tokenizer) | ||
|
||
json_str = llama.create_completion( | ||
prompt, | ||
top_k=40, | ||
top_p=0.95, | ||
temperature=0.7, | ||
max_tokens=100, | ||
logits_processor=LogitsProcessorList([logits_processor]), | ||
)["choices"][0]["text"] | ||
|
||
print(json_str) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ | |
from outlines.fsm.fsm import RegexFSM | ||
from outlines.generate.api import SequenceGenerator | ||
from outlines.models import OpenAI | ||
from outlines.models.llamacpp import LlamaCpp, RegexLogitsProcessor | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should import within There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We register on the |
||
from outlines.samplers import Sampler, multinomial | ||
|
||
|
||
|
@@ -35,8 +36,30 @@ def regex(model, regex_str: str, sampler: Sampler = multinomial()): | |
return generator | ||
|
||
|
||
@regex.register(LlamaCpp) | ||
def regex_llamacpp( | ||
model: LlamaCpp, | ||
regex_str: str, | ||
sampler: Sampler = multinomial(), | ||
): | ||
if not isinstance(sampler, multinomial): | ||
raise NotImplementedError( | ||
r"The llama.cpp integration does not currently support any other sampling algorithm " | ||
+ "than the multinomial sampler." | ||
) | ||
|
||
logits_processor = RegexLogitsProcessor(regex_str, model.tokenizer) | ||
model.logits_processor = logits_processor | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it make sense to bind the logits processor to the model? Won't this have side effects for other generations? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But a new generation would bind a different logit processor? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Except for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's fine for now given how |
||
|
||
return model | ||
|
||
|
||
@regex.register(OpenAI) | ||
def regex_openai(model, regex_str: str, sampler: Sampler = multinomial()): | ||
def regex_openai( | ||
model: OpenAI, | ||
regex_str: str, | ||
sampler: Sampler = multinomial(), | ||
): | ||
raise NotImplementedError( | ||
"Cannot use regex-structured generation with an OpenAI model" | ||
+ "due to the limitations of the OpenAI API." | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line is unnecessary, it's what the default
format_sequence
already does.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's what I thought as well. However monkey patching has the downside that when working in some context the patched method persists unless you reinitialize. So without the trivial lambda the integration tests will fail because it uses the format_sequence method from a prior test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense. Considering this is minor and your change alleviates the problems of 5 known users, IMHO this should be tackled separately #652