Skip to content

Commit

Permalink
Improve token counting for messages with package (Azure-Samples#1577)
Browse files Browse the repository at this point in the history
* Disable openai key access

* Use message token helper instead

* Update to latest package

* Revert launch change

* Improve typing
  • Loading branch information
pamelafox authored May 21, 2024
1 parent e6fa39f commit 77303da
Show file tree
Hide file tree
Showing 15 changed files with 109 additions and 650 deletions.
7 changes: 6 additions & 1 deletion app/backend/approaches/approach.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
VectorQuery,
)
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletionMessageParam

from core.authentication import AuthenticationHelper
from text import nonewlines
Expand Down Expand Up @@ -254,6 +255,10 @@ async def compute_image_embedding(self, q: str):
return VectorizedQuery(vector=image_query_vector, k_nearest_neighbors=50, fields="imageEmbedding")

async def run(
self, messages: list[dict], stream: bool = False, session_state: Any = None, context: dict[str, Any] = {}
self,
messages: list[ChatCompletionMessageParam],
stream: bool = False,
session_state: Any = None,
context: dict[str, Any] = {},
) -> Union[dict[str, Any], AsyncGenerator[dict[str, Any], None]]:
raise NotImplementedError
76 changes: 18 additions & 58 deletions app/backend/approaches/chatapproach.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,19 @@
import json
import logging
import re
from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator, Optional, Union

from openai.types.chat import (
ChatCompletion,
ChatCompletionContentPartParam,
ChatCompletionMessageParam,
)
from openai.types.chat import ChatCompletion, ChatCompletionMessageParam

from approaches.approach import Approach
from core.messagebuilder import MessageBuilder


class ChatApproach(Approach, ABC):
# Chat roles
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"

query_prompt_few_shots = [
{"role": USER, "content": "How did crypto do last year?"},
{"role": ASSISTANT, "content": "Summarize Cryptocurrency Market Dynamics from last year"},
{"role": USER, "content": "What are my health plans?"},
{"role": ASSISTANT, "content": "Show available health plans"},
query_prompt_few_shots: list[ChatCompletionMessageParam] = [
{"role": "user", "content": "How did crypto do last year?"},
{"role": "assistant", "content": "Summarize Cryptocurrency Market Dynamics from last year"},
{"role": "user", "content": "What are my health plans?"},
{"role": "assistant", "content": "Show available health plans"},
]
NO_RESPONSE = "0"

Expand Down Expand Up @@ -53,7 +42,7 @@ def system_message_chat_conversation(self) -> str:
pass

@abstractmethod
async def run_until_final_call(self, history, overrides, auth_claims, should_stream) -> tuple:
async def run_until_final_call(self, messages, overrides, auth_claims, should_stream) -> tuple:
pass

def get_system_prompt(self, override_prompt: Optional[str], follow_up_questions_prompt: str) -> str:
Expand Down Expand Up @@ -89,48 +78,15 @@ def get_search_query(self, chat_completion: ChatCompletion, user_query: str):
def extract_followup_questions(self, content: str):
return content.split("<<")[0], re.findall(r"<<([^>>]+)>>", content)

def get_messages_from_history(
self,
system_prompt: str,
model_id: str,
history: list[dict[str, str]],
user_content: Union[str, list[ChatCompletionContentPartParam]],
max_tokens: int,
few_shots=[],
) -> list[ChatCompletionMessageParam]:
message_builder = MessageBuilder(system_prompt, model_id)

# Add examples to show the chat what responses we want. It will try to mimic any responses and make sure they match the rules laid out in the system message.
for shot in reversed(few_shots):
message_builder.insert_message(shot.get("role"), shot.get("content"))

append_index = len(few_shots) + 1

message_builder.insert_message(self.USER, user_content, index=append_index)

total_token_count = 0
for existing_message in message_builder.messages:
total_token_count += message_builder.count_tokens_for_message(existing_message)

newest_to_oldest = list(reversed(history[:-1]))
for message in newest_to_oldest:
potential_message_count = message_builder.count_tokens_for_message(message)
if (total_token_count + potential_message_count) > max_tokens:
logging.info("Reached max tokens of %d, history will be truncated", max_tokens)
break
message_builder.insert_message(message["role"], message["content"], index=append_index)
total_token_count += potential_message_count
return message_builder.messages

async def run_without_streaming(
self,
history: list[dict[str, str]],
messages: list[ChatCompletionMessageParam],
overrides: dict[str, Any],
auth_claims: dict[str, Any],
session_state: Any = None,
) -> dict[str, Any]:
extra_info, chat_coroutine = await self.run_until_final_call(
history, overrides, auth_claims, should_stream=False
messages, overrides, auth_claims, should_stream=False
)
chat_completion_response: ChatCompletion = await chat_coroutine
chat_resp = chat_completion_response.model_dump() # Convert to dict to make it JSON serializable
Expand All @@ -144,18 +100,18 @@ async def run_without_streaming(

async def run_with_streaming(
self,
history: list[dict[str, str]],
messages: list[ChatCompletionMessageParam],
overrides: dict[str, Any],
auth_claims: dict[str, Any],
session_state: Any = None,
) -> AsyncGenerator[dict, None]:
extra_info, chat_coroutine = await self.run_until_final_call(
history, overrides, auth_claims, should_stream=True
messages, overrides, auth_claims, should_stream=True
)
yield {
"choices": [
{
"delta": {"role": self.ASSISTANT},
"delta": {"role": "assistant"},
"context": extra_info,
"session_state": session_state,
"finish_reason": None,
Expand Down Expand Up @@ -190,7 +146,7 @@ async def run_with_streaming(
yield {
"choices": [
{
"delta": {"role": self.ASSISTANT},
"delta": {"role": "assistant"},
"context": {"followup_questions": followup_questions},
"finish_reason": None,
"index": 0,
Expand All @@ -200,7 +156,11 @@ async def run_with_streaming(
}

async def run(
self, messages: list[dict], stream: bool = False, session_state: Any = None, context: dict[str, Any] = {}
self,
messages: list[ChatCompletionMessageParam],
stream: bool = False,
session_state: Any = None,
context: dict[str, Any] = {},
) -> Union[dict[str, Any], AsyncGenerator[dict[str, Any], None]]:
overrides = context.get("overrides", {})
auth_claims = context.get("auth_claims", {})
Expand Down
39 changes: 21 additions & 18 deletions app/backend/approaches/chatreadretrieveread.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
from openai.types.chat import (
ChatCompletion,
ChatCompletionChunk,
ChatCompletionMessageParam,
ChatCompletionToolParam,
)
from openai_messages_token_helper import build_messages, get_token_limit

from approaches.approach import ThoughtStep
from approaches.chatapproach import ChatApproach
from core.authentication import AuthenticationHelper
from core.modelhelper import get_token_limit


class ChatReadRetrieveReadApproach(ChatApproach):
Expand Down Expand Up @@ -65,7 +66,7 @@ def system_message_chat_conversation(self):
@overload
async def run_until_final_call(
self,
history: list[dict[str, str]],
messages: list[ChatCompletionMessageParam],
overrides: dict[str, Any],
auth_claims: dict[str, Any],
should_stream: Literal[False],
Expand All @@ -74,15 +75,15 @@ async def run_until_final_call(
@overload
async def run_until_final_call(
self,
history: list[dict[str, str]],
messages: list[ChatCompletionMessageParam],
overrides: dict[str, Any],
auth_claims: dict[str, Any],
should_stream: Literal[True],
) -> tuple[dict[str, Any], Coroutine[Any, Any, AsyncStream[ChatCompletionChunk]]]: ...

async def run_until_final_call(
self,
history: list[dict[str, str]],
messages: list[ChatCompletionMessageParam],
overrides: dict[str, Any],
auth_claims: dict[str, Any],
should_stream: bool = False,
Expand All @@ -97,7 +98,9 @@ async def run_until_final_call(
filter = self.build_filter(overrides, auth_claims)
use_semantic_ranker = True if overrides.get("semantic_ranker") and has_text else False

original_user_query = history[-1]["content"]
original_user_query = messages[-1]["content"]
if not isinstance(original_user_query, str):
raise ValueError("The most recent message content must be a string.")
user_query_request = "Generate search query for: " + original_user_query

tools: List[ChatCompletionToolParam] = [
Expand All @@ -121,24 +124,25 @@ async def run_until_final_call(
]

# STEP 1: Generate an optimized keyword search query based on the chat history and the last question
query_messages = self.get_messages_from_history(
query_response_token_limit = 100
query_messages = build_messages(
model=self.chatgpt_model,
system_prompt=self.query_prompt_template,
model_id=self.chatgpt_model,
history=history,
user_content=user_query_request,
max_tokens=self.chatgpt_token_limit - len(user_query_request),
tools=tools,
few_shots=self.query_prompt_few_shots,
past_messages=messages[:-1],
new_user_content=user_query_request,
max_tokens=self.chatgpt_token_limit - query_response_token_limit,
)

chat_completion: ChatCompletion = await self.openai_client.chat.completions.create(
messages=query_messages, # type: ignore
# Azure OpenAI takes the deployment name as the model name
model=self.chatgpt_deployment if self.chatgpt_deployment else self.chatgpt_model,
temperature=0.0, # Minimize creativity for search query generation
max_tokens=100, # Setting too low risks malformed JSON, setting too high may affect performance
max_tokens=query_response_token_limit, # Setting too low risks malformed JSON, setting too high may affect performance
n=1,
tools=tools,
tool_choice="auto",
)

query_text = self.get_search_query(chat_completion, original_user_query)
Expand Down Expand Up @@ -177,14 +181,13 @@ async def run_until_final_call(
)

response_token_limit = 1024
messages_token_limit = self.chatgpt_token_limit - response_token_limit
messages = self.get_messages_from_history(
messages = build_messages(
model=self.chatgpt_model,
system_prompt=system_message,
model_id=self.chatgpt_model,
history=history,
past_messages=messages[:-1],
# Model does not handle lengthy system messages well. Moving sources to latest user conversation to solve follow up questions prompt.
user_content=original_user_query + "\n\nSources:\n" + content,
max_tokens=messages_token_limit,
new_user_content=original_user_query + "\n\nSources:\n" + content,
max_tokens=self.chatgpt_token_limit - response_token_limit,
)

data_points = {"text": sources_content}
Expand Down
37 changes: 20 additions & 17 deletions app/backend/approaches/chatreadretrievereadvision.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
ChatCompletionChunk,
ChatCompletionContentPartImageParam,
ChatCompletionContentPartParam,
ChatCompletionMessageParam,
)
from openai_messages_token_helper import build_messages, get_token_limit

from approaches.approach import ThoughtStep
from approaches.chatapproach import ChatApproach
from core.authentication import AuthenticationHelper
from core.imageshelper import fetch_image
from core.modelhelper import get_token_limit


class ChatReadRetrieveReadVisionApproach(ChatApproach):
Expand Down Expand Up @@ -79,7 +80,7 @@ def system_message_chat_conversation(self):

async def run_until_final_call(
self,
history: list[dict[str, str]],
messages: list[ChatCompletionMessageParam],
overrides: dict[str, Any],
auth_claims: dict[str, Any],
should_stream: bool = False,
Expand All @@ -97,25 +98,29 @@ async def run_until_final_call(
include_gtpV_text = overrides.get("gpt4v_input") in ["textAndImages", "texts", None]
include_gtpV_images = overrides.get("gpt4v_input") in ["textAndImages", "images", None]

original_user_query = history[-1]["content"]
original_user_query = messages[-1]["content"]
if not isinstance(original_user_query, str):
raise ValueError("The most recent message content must be a string.")
past_messages: list[ChatCompletionMessageParam] = messages[:-1]

# STEP 1: Generate an optimized keyword search query based on the chat history and the last question
user_query_request = "Generate search query for: " + original_user_query

query_messages = self.get_messages_from_history(
query_response_token_limit = 100
query_messages = build_messages(
model=self.gpt4v_model,
system_prompt=self.query_prompt_template,
model_id=self.gpt4v_model,
history=history,
user_content=user_query_request,
max_tokens=self.chatgpt_token_limit - len(" ".join(user_query_request)),
few_shots=self.query_prompt_few_shots,
past_messages=past_messages,
new_user_content=user_query_request,
max_tokens=self.chatgpt_token_limit - query_response_token_limit,
)

chat_completion: ChatCompletion = await self.openai_client.chat.completions.create(
model=self.gpt4v_deployment if self.gpt4v_deployment else self.gpt4v_model,
messages=query_messages,
temperature=0.0, # Minimize creativity for search query generation
max_tokens=100,
max_tokens=query_response_token_limit,
n=1,
)

Expand Down Expand Up @@ -159,9 +164,6 @@ async def run_until_final_call(
self.follow_up_questions_prompt_content if overrides.get("suggest_followup_questions") else "",
)

response_token_limit = 1024
messages_token_limit = self.chatgpt_token_limit - response_token_limit

user_content: list[ChatCompletionContentPartParam] = [{"text": original_user_query, "type": "text"}]
image_list: list[ChatCompletionContentPartImageParam] = []

Expand All @@ -174,12 +176,13 @@ async def run_until_final_call(
image_list.append({"image_url": url, "type": "image_url"})
user_content.extend(image_list)

messages = self.get_messages_from_history(
response_token_limit = 1024
messages = build_messages(
model=self.gpt4v_model,
system_prompt=system_message,
model_id=self.gpt4v_model,
history=history,
user_content=user_content,
max_tokens=messages_token_limit,
past_messages=messages[:-1],
new_user_content=user_content,
max_tokens=self.chatgpt_token_limit - response_token_limit,
)

data_points = {
Expand Down
Loading

0 comments on commit 77303da

Please sign in to comment.