-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Max tokens per question and context #54
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, you're on the right track here. I looked into the models supported by tiktoken out of the box and it doesn't support the LLM we've been using so far (mistral-7b-instruct). However, it looks like mistral has some python libraries that could calculate the number of tokens: https://docs.mistral.ai/guides/tokenization/
The only thing is it looks like we'll have to convert the langchain message history into the "mistral-common" equivalent classes. So you'd have to copy the msg_list but use the different classes for each message.
HumanMessage
becomesUserMessage
SystemMessage
is also namedSystemMessage
- I think
AIMessage
becomesAssistantMessage
Or I think you could use ChatMessage
for all of them and set the role
to be the appropriate role.
See:
- https://github.com/mistralai/mistral-common/blob/main/src/mistral_common/protocol/instruct/messages.py#L98
- https://github.com/mistralai/mistral-common/blob/main/src/mistral_common/protocol/instruct/messages.py#L58-L61
Lastly, I think that when we trim the messages we need to make sure we keep the initial SystemMessage in the list because that is instructing the model on how to behave
connectors/llm/interface.py
Outdated
@@ -44,15 +45,31 @@ def ask(self, system_prompt, previous_messages, question, agent_id, stream): | |||
prompt_params = {"context": context_text, "question": question} | |||
log.debug("search result: %s", context_text) | |||
|
|||
# If tiktoken doesn't support our model, default to gpt2 | |||
try: | |||
text_splitter = tiktoken.encoding_for_model(cfg.LLM_MODEL_NAME) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The LLM_MODEL_NAME can actually be an arbitrary name. It's an identifier used on the server side. What I mean is, the model might be accessed using the name mistral-7b-instruct
on the requests made to the hosting server, but the model is actually Mistral-7B-Instruct-v0.3
So I think you'll want another config variable like cfg.TOKENIZER_MODEL_NAME
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when we trim the messages we need to make sure we keep the initial SystemMessage in the list because that is instructing the model on how to behave
We're not really "trimming" as much as we're "selectively not adding". This nuance matters because the selection (token length calculation) happens after we add the default prompt here: https://github.com/RedHatInsights/tangerine-backend/pull/54/files#diff-abbf9cb2997932bbf240cd1e9f186f47e5c4c6cc15305f52aece62baa3e0fed1R56 .
So if there's anything else we want to make sure we keep, we can add it before the previous_messages
loop as well
Signed-off-by: Stephen Reaves <[email protected]>
Signed-off-by: Stephen Reaves <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Getting closer here I think :) But I have several questions/comments
msg_list.append(AIMessage(content=f"{msg['text']}</s>")) | ||
# The tokenizer requires that every request begins with a | ||
# SystemMessage or a UserMessage, so we tokenize the AI | ||
# response as a UserMessage, but append to the list as an |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would the AssistantMessage be the right one to use here?
if msg["sender"] == "human": | ||
msg_list.append(HumanMessage(content=f"[INST] {msg['text']} [/INST]")) | ||
token_list = len(text_splitter.encode_chat_completion(ChatCompletionRequest(messages=[UserMessage(content=msg["text"])])).tokens) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be OK if we change the name of this variable? I was expecting token_list
to be a list type ... but it is just an integer right?
# SystemMessage or a UserMessage, so we tokenize the AI | ||
# response as a UserMessage, but append to the list as an | ||
# AIMessage. | ||
token_list = len(text_splitter.encode_chat_completion(ChatCompletionRequest(messages=[UserMessage(content=f"{msg['text']}</s>")])).tokens) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same note about the name of token_list
as my prior comment
|
||
prompt = ChatPromptTemplate.from_template(cfg.USER_PROMPT_TEMPLATE) | ||
prompt_params = {"context": context_text, "question": question} | ||
log.debug("search result: %s", context_text) | ||
|
||
text_splitter = MistralTokenizer.v3(is_tekken=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this variable be named tokenizer
so that this doesn't become confused with other text splitters in the code base ?
# Tokenizer doesn't like including the first two tokens when | ||
# decoding... | ||
if len(tokens_question) > MAX_TOKENS_QUESTION+2: | ||
log.debug("Question too big, truncating...") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if we ever want to truncate the asked question. Or at least hopefully, we never would need to unless the user asked a ridiculously long question. I think we should only look at the total of current question+message history and start to drop previous_messages
if the number of tokens becomes too large. I think all the token counting and truncation could happen within the ask
function, let me know if this idea is off base.
|
||
total_tokens += token_list | ||
if token_list + total_tokens >= cfg.MAX_TOKENS_CONTEXT: | ||
print() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was this print
in there for debugging? Should it be removed?
@@ -38,28 +43,55 @@ def ask(self, system_prompt, previous_messages, question, agent_id, stream): | |||
if "title" in metadata: | |||
title = metadata["title"] | |||
context_text += f", document title: '{title}'" | |||
context_text += ">>\n\n" f"{page_content}\n\n" f"<<Search result {i+1} END>>\n" | |||
context_text += (">>\n\n" f"{page_content}\n\n" f"<<Search result {i+1} END>>\n") | |||
|
|||
prompt = ChatPromptTemplate.from_template(cfg.USER_PROMPT_TEMPLATE) | |||
prompt_params = {"context": context_text, "question": question} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here's where we become aware of the question content. We need to somehow count the number of tokens for the question somewhere around here and add it to the running total right?
Closes: RHCLOUD-34879