-
Notifications
You must be signed in to change notification settings - Fork 188
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor/tokenizers #840
Refactor/tokenizers #840
Conversation
74437d1
to
b00d200
Compare
eb39a02
to
9e1533e
Compare
Codecov ReportAttention: Patch coverage is 📢 Thoughts on this report? Let us know! |
11cdd03
to
4377389
Compare
4377389
to
fb52f7b
Compare
fb52f7b
to
ac2b5da
Compare
def prompt_stack_input_to_message(self, prompt_input: PromptStack.Input) -> dict: | ||
"""Converts a PromptStack Input to a ChatML-style message dictionary for token counting or model input. | ||
|
||
Args: | ||
prompt_input: The PromptStack Input to convert. | ||
|
||
Returns: | ||
A dictionary with the role and content of the input. | ||
""" | ||
content = prompt_input.content | ||
|
||
if prompt_input.is_system(): | ||
return {"role": "system", "content": content} | ||
elif prompt_input.is_assistant(): | ||
return {"role": "assistant", "content": content} | ||
else: | ||
return {"role": "user", "content": content} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be a method on PromptInput
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so, each Tokenizer will have a slightly different implementation of this. In an upcoming PR i'm changing this to an abstract method.
|
||
def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str: | ||
"""Converts a Prompt Stack to a string for token counting or model input. | ||
This base implementation will not be very accurate, and should be overridden by subclasses with model-specific tokens. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Originally same question here, but I sense that the model-specific method means this should be owned by the tokenizer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Honestly we could probably get rid of this method. All of our Prompt Drivers now support message-style APIs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah but now I remember why I didn't remove in this PR: I didn't want to remove the token counts in the events which relies on this method. My next PR changes this functionality so I will remove this method.
|
||
def _default_max_input_tokens(self) -> int: | ||
tokens = next((v for k, v in self.MODEL_PREFIXES_TO_MAX_INPUT_TOKENS.items() if self.model.startswith(k)), None) | ||
|
||
if tokens is None: | ||
raise ValueError(f"Unknown model default max input tokens: {self.model}") | ||
return self.DEFAULT_MAX_INPUT_TOKENS |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems reasonable to fall back to the default here. Do you think it'd be worth it to throw out a warning if this happens? It seems like a big enough difference in expectations between the known values defined in the mapping and a default 'best guess' value, a behavior we seem to have avoided until now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah warning is a good idea.
62b4f47
to
bc12fa5
Compare
649dcda
to
3e0a24a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a couple of questions
|
||
### Simple | ||
Not all LLM providers have a public tokenizer API. In this case, you can use the `SimpleTokenizer` to count tokens based on a simple heuristic. | ||
|
||
```python | ||
from griptape.tokenizers import SimpleTokenizer | ||
|
||
tokenizer = SimpleTokenizer(max_input_tokens=1024, max_output_tokens=1024, characters_per_token=6) | ||
tokenizer = SimpleTokenizer(model="any-model", max_input_tokens=1024, max_output_tokens=1024, characters_per_token=6) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious, why can't model
be optional?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work!
Added
BaseTokenizer.prompt_stack_to_string()
to convert a Prompt Stack to a string.BaseTokenizer.prompt_stack_input_to_string()
to convert a Prompt Stack Input to a ChatML-style message dictionary.Changed
BasePromptDriver.count_tokens()
.BasePromptDriver.max_output_tokens()
.BasePromptDriver.prompt_stack_to_string()
toBaseTokenizer
.PromptStack.add_to_conversation_memory
toBaseConversationMemory.add_to_prompt_stack
.BaseTokenizer.count_tokens()
can now approximately token counts given a Prompt Stack.BasePromptDriver.max_tokens
instead of usingBasePromptDriver.max_output_tokens()
.griptape.constants.RESPONSE_STOP_SEQUENCE
toToolkitTask
.ToolkitTask.RESPONSE_STOP_SEQUENCE
is now only added when usingToolkitTask
.