Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Max tokens per question and context #54

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

reavessm
Copy link

Closes: RHCLOUD-34879

Copy link
Collaborator

@bsquizz bsquizz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, you're on the right track here. I looked into the models supported by tiktoken out of the box and it doesn't support the LLM we've been using so far (mistral-7b-instruct). However, it looks like mistral has some python libraries that could calculate the number of tokens: https://docs.mistral.ai/guides/tokenization/

The only thing is it looks like we'll have to convert the langchain message history into the "mistral-common" equivalent classes. So you'd have to copy the msg_list but use the different classes for each message.

  • HumanMessage becomes UserMessage
  • SystemMessage is also named SystemMessage
  • I think AIMessage becomes AssistantMessage

Or I think you could use ChatMessage for all of them and set the role to be the appropriate role.

See:

Lastly, I think that when we trim the messages we need to make sure we keep the initial SystemMessage in the list because that is instructing the model on how to behave

@@ -44,15 +45,31 @@ def ask(self, system_prompt, previous_messages, question, agent_id, stream):
prompt_params = {"context": context_text, "question": question}
log.debug("search result: %s", context_text)

# If tiktoken doesn't support our model, default to gpt2
try:
text_splitter = tiktoken.encoding_for_model(cfg.LLM_MODEL_NAME)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The LLM_MODEL_NAME can actually be an arbitrary name. It's an identifier used on the server side. What I mean is, the model might be accessed using the name mistral-7b-instruct on the requests made to the hosting server, but the model is actually Mistral-7B-Instruct-v0.3

So I think you'll want another config variable like cfg.TOKENIZER_MODEL_NAME

Copy link
Author

@reavessm reavessm Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when we trim the messages we need to make sure we keep the initial SystemMessage in the list because that is instructing the model on how to behave

We're not really "trimming" as much as we're "selectively not adding". This nuance matters because the selection (token length calculation) happens after we add the default prompt here: https://github.com/RedHatInsights/tangerine-backend/pull/54/files#diff-abbf9cb2997932bbf240cd1e9f186f47e5c4c6cc15305f52aece62baa3e0fed1R56 .

So if there's anything else we want to make sure we keep, we can add it before the previous_messages loop as well

Stephen Reaves added 2 commits September 25, 2024 14:21
Signed-off-by: Stephen Reaves <[email protected]>
Signed-off-by: Stephen Reaves <[email protected]>
Copy link
Collaborator

@bsquizz bsquizz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Getting closer here I think :) But I have several questions/comments

msg_list.append(AIMessage(content=f"{msg['text']}</s>"))
# The tokenizer requires that every request begins with a
# SystemMessage or a UserMessage, so we tokenize the AI
# response as a UserMessage, but append to the list as an
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would the AssistantMessage be the right one to use here?

if msg["sender"] == "human":
msg_list.append(HumanMessage(content=f"[INST] {msg['text']} [/INST]"))
token_list = len(text_splitter.encode_chat_completion(ChatCompletionRequest(messages=[UserMessage(content=msg["text"])])).tokens)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be OK if we change the name of this variable? I was expecting token_list to be a list type ... but it is just an integer right?

# SystemMessage or a UserMessage, so we tokenize the AI
# response as a UserMessage, but append to the list as an
# AIMessage.
token_list = len(text_splitter.encode_chat_completion(ChatCompletionRequest(messages=[UserMessage(content=f"{msg['text']}</s>")])).tokens)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same note about the name of token_list as my prior comment


prompt = ChatPromptTemplate.from_template(cfg.USER_PROMPT_TEMPLATE)
prompt_params = {"context": context_text, "question": question}
log.debug("search result: %s", context_text)

text_splitter = MistralTokenizer.v3(is_tekken=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this variable be named tokenizer so that this doesn't become confused with other text splitters in the code base ?

# Tokenizer doesn't like including the first two tokens when
# decoding...
if len(tokens_question) > MAX_TOKENS_QUESTION+2:
log.debug("Question too big, truncating...")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if we ever want to truncate the asked question. Or at least hopefully, we never would need to unless the user asked a ridiculously long question. I think we should only look at the total of current question+message history and start to drop previous_messages if the number of tokens becomes too large. I think all the token counting and truncation could happen within the ask function, let me know if this idea is off base.


total_tokens += token_list
if token_list + total_tokens >= cfg.MAX_TOKENS_CONTEXT:
print()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this print in there for debugging? Should it be removed?

@@ -38,28 +43,55 @@ def ask(self, system_prompt, previous_messages, question, agent_id, stream):
if "title" in metadata:
title = metadata["title"]
context_text += f", document title: '{title}'"
context_text += ">>\n\n" f"{page_content}\n\n" f"<<Search result {i+1} END>>\n"
context_text += (">>\n\n" f"{page_content}\n\n" f"<<Search result {i+1} END>>\n")

prompt = ChatPromptTemplate.from_template(cfg.USER_PROMPT_TEMPLATE)
prompt_params = {"context": context_text, "question": question}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's where we become aware of the question content. We need to somehow count the number of tokens for the question somewhere around here and add it to the running total right?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants