Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Track token usage of iris requests #165

Merged
merged 20 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
fc44738
Add token usage monitoring in exercise chat and send to Artemis
alexjoham Oct 1, 2024
26e3873
Add Pipeline enum for better tracking
alexjoham Oct 11, 2024
aa50faf
Update tokens location, add token tracking to competency and chat pipe
alexjoham Oct 11, 2024
9905460
added first versions for tracking for smaller pipelines
alexjoham Oct 11, 2024
e241d45
Fix lint errors
alexjoham Oct 11, 2024
4502e30
Fix last lint error
alexjoham Oct 11, 2024
3b81a30
Fix lint errors
alexjoham Oct 11, 2024
74b1239
Merge remote-tracking branch 'origin/feature/track-usage-of-iris-requ…
alexjoham Oct 11, 2024
6bcb002
Merge branch 'main' into track-token-usage
alexjoham Oct 11, 2024
4324180
Add token cost tracking for input and output tokens
alexjoham Oct 12, 2024
c9e89be
Update token handling as proposed by CodeRabbit
alexjoham Oct 12, 2024
4c92900
Update PyrisMessage to use only TokenUsageDTO, add token count for error
alexjoham Oct 12, 2024
6bd4b33
Fix competency extraction did not save Enum
alexjoham Oct 12, 2024
c79837d
Merge branch 'main' into track-token-usage
alexjoham Oct 15, 2024
4d61c85
Update code after merge
alexjoham Oct 15, 2024
3253c46
Make -1 default value if no tokens have been received
alexjoham Oct 16, 2024
9fe9e0a
Update DTO for new Artemis table
alexjoham Oct 19, 2024
13c5db1
Change number of tokens if error to 0, as is standard by OpenAI & Ollama
alexjoham Oct 23, 2024
dd504fc
Fix token usage list append bug
bassner Oct 23, 2024
043264a
Fix formatting
bassner Oct 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions app/domain/data/token_usage_dto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from pydantic import BaseModel

from app.llm.external.PipelineEnum import PipelineEnum


class TokenUsageDTO(BaseModel):
model_info: str
num_input_tokens: int
cost_per_input_token: float
num_output_tokens: int
cost_per_output_token: float
pipeline: PipelineEnum
6 changes: 6 additions & 0 deletions app/domain/pyris_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ class IrisMessageRole(str, Enum):
class PyrisMessage(BaseModel):
model_config = ConfigDict(populate_by_name=True)

num_input_tokens: int = Field(alias="numInputTokens", default=0)
cost_per_input_token: float = Field(alias="costPerInputToken", default=0)
num_output_tokens: int = Field(alias="numOutputTokens", default=0)
cost_per_output_token: float = Field(alias="costPerOutputToken", default=0)
model_info: str = Field(alias="modelInfo", default="")
alexjoham marked this conversation as resolved.
Show resolved Hide resolved
alexjoham marked this conversation as resolved.
Show resolved Hide resolved

sent_at: datetime | None = Field(alias="sentAt", default=None)
sender: IrisMessageRole
contents: List[MessageContentDTO] = []
Expand Down
2 changes: 2 additions & 0 deletions app/domain/status/status_update_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

from pydantic import BaseModel

from ..data.token_usage_dto import TokenUsageDTO
from ...domain.status.stage_dto import StageDTO


class StatusUpdateDTO(BaseModel):
stages: List[StageDTO]
tokens: List[TokenUsageDTO] = []
33 changes: 33 additions & 0 deletions app/llm/external/LLMTokenCount.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from app.llm.external.PipelineEnum import PipelineEnum


class LLMTokenCount:
alexjoham marked this conversation as resolved.
Show resolved Hide resolved

model_info: str
num_input_tokens: int
cost_per_input_token: float
num_output_tokens: int
cost_per_output_token: float
pipeline: PipelineEnum

def __init__(
self,
model_info: str,
num_input_tokens: int,
cost_per_input_token: float,
num_output_tokens: int,
cost_per_output_token: float,
pipeline: PipelineEnum,
):
self.model_info = model_info
self.num_input_tokens = num_input_tokens
self.cost_per_input_token = cost_per_input_token
self.num_output_tokens = num_output_tokens
self.cost_per_output_token = cost_per_output_token
self.pipeline = pipeline

def __str__(self):
return (
f"{self.model_info}: {self.num_input_tokens} in, {self.cost_per_input_token} cost in,"
f" {self.num_output_tokens} out, {self.cost_per_output_token} cost out, {self.pipeline} pipeline"
)
15 changes: 15 additions & 0 deletions app/llm/external/PipelineEnum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from enum import Enum


class PipelineEnum(str, Enum):
IRIS_CODE_FEEDBACK = "IRIS_CODE_FEEDBACK"
IRIS_CHAT_COURSE_MESSAGE = "IRIS_CHAT_COURSE_MESSAGE"
IRIS_CHAT_EXERCISE_MESSAGE = "IRIS_CHAT_EXERCISE_MESSAGE"
IRIS_INTERACTION_SUGGESTION = "IRIS_INTERACTION_SUGGESTION"
IRIS_CHAT_LECTURE_MESSAGE = "IRIS_CHAT_LECTURE_MESSAGE"
IRIS_COMPETENCY_GENERATION = "IRIS_COMPETENCY_GENERATION"
IRIS_CITATION_PIPELINE = "IRIS_CITATION_PIPELINE"
IRIS_RERANKER_PIPELINE = "IRIS_RERANKER_PIPELINE"
IRIS_SUMMARY_PIPELINE = "IRIS_SUMMARY_PIPELINE"
IRIS_LECTURE_RETRIEVAL_PIPELINE = "IRIS_LECTURE_RETRIEVAL_PIPELINE"
NOT_SET = "NOT_SET"
14 changes: 12 additions & 2 deletions app/llm/external/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ def convert_to_ollama_messages(messages: list[PyrisMessage]) -> list[Message]:
return messages_to_return


def convert_to_iris_message(message: Message) -> PyrisMessage:
def convert_to_iris_message(
message: Message, num_input_tokens: int, num_output_tokens: int, model: str
) -> PyrisMessage:
"""
Convert a Message to a PyrisMessage
"""
Expand All @@ -66,6 +68,9 @@ def convert_to_iris_message(message: Message) -> PyrisMessage:
sender=map_str_to_role(message["role"]),
contents=contents,
send_at=datetime.now(),
num_input_tokens=num_input_tokens,
num_output_tokens=num_output_tokens,
model_info=model,
)


Expand Down Expand Up @@ -108,7 +113,12 @@ def chat(
format="json" if arguments.response_format == "JSON" else "",
options=self.options,
)
return convert_to_iris_message(response["message"])
return convert_to_iris_message(
response.get("message"),
response.get("prompt_eval_count", 0),
alexjoham marked this conversation as resolved.
Show resolved Hide resolved
response.get("eval_count", 0),
alexjoham marked this conversation as resolved.
Show resolved Hide resolved
response.get("model", self.model),
)
Comment on lines +120 to +125
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Update default token counts to -1 for consistency.

As discussed in previous comments and confirmed by you, the default token count values should be -1 to maintain consistency with the OpenAI implementation.

     return convert_to_iris_message(
         response.get("message"),
-        response.get("prompt_eval_count", 0),
-        response.get("eval_count", 0),
+        response.get("prompt_eval_count", -1),
+        response.get("eval_count", -1),
         response.get("model", self.model),
     )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
return convert_to_iris_message(
response.get("message"),
response.get("prompt_eval_count", 0),
response.get("eval_count", 0),
response.get("model", self.model),
)
return convert_to_iris_message(
response.get("message"),
response.get("prompt_eval_count", -1),
response.get("eval_count", -1),
response.get("model", self.model),
)


def embed(self, text: str) -> list[float]:
response = self._client.embeddings(
Expand Down
20 changes: 16 additions & 4 deletions app/llm/external/openai_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
import time
import traceback
from datetime import datetime
from typing import Literal, Any
from typing import Literal, Any, Optional

from openai import OpenAI
from openai.lib.azure import AzureOpenAI
from openai.types import CompletionUsage
from openai.types.chat import ChatCompletionMessage, ChatCompletionMessageParam
from openai.types.shared_params import ResponseFormatJSONObject

Expand Down Expand Up @@ -61,15 +62,24 @@ def convert_to_open_ai_messages(
return openai_messages


def convert_to_iris_message(message: ChatCompletionMessage) -> PyrisMessage:
def convert_to_iris_message(
message: ChatCompletionMessage, usage: Optional[CompletionUsage], model: str
) -> PyrisMessage:
Comment on lines +72 to +74
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codebase verification

Calls to convert_to_iris_message in app/llm/external/ollama.py do not match the updated function signature. Please update these calls to include all required parameters: message, num_input_tokens, num_output_tokens, and model.

  • app/llm/external/ollama.py:
    • Line where convert_to_iris_message is called with insufficient arguments.
🔗 Analysis chain

Ensure all calls to convert_to_iris_message are updated with new parameters

Since the function convert_to_iris_message now includes additional parameters usage: Optional[CompletionUsage] and model: str, please verify that all calls to this function across the codebase have been updated to match the new signature to prevent any potential errors.

You can run the following script to identify any calls that may need to be updated:

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Find all usages of `convert_to_iris_message` and check if they match the new signature.

# Expected: All calls should include three parameters: message, usage, and model.

rg --type python $'convert_to_iris_message\s*\(' -A 1

Length of output: 85


Script:

#!/bin/bash
# Description: Find all usages of `convert_to_iris_message` in Python files and check if they match the new signature.

# Expected: All calls should include three parameters: message, usage, and model.

rg --glob "*.py" 'convert_to_iris_message\s*\(' -A 1

Length of output: 723

"""
Convert a ChatCompletionMessage to a PyrisMessage
"""
return PyrisMessage(
num_input_tokens = getattr(usage, "prompt_tokens", -1)
alexjoham marked this conversation as resolved.
Show resolved Hide resolved
num_output_tokens = getattr(usage, "completion_tokens", -1)
alexjoham marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Consider using 0 as the default value for token counts

The current implementation uses -1 as the default value for num_input_tokens and num_output_tokens. However, this might be confusing as it's not a valid token count. For consistency with other parts of the codebase (e.g., ollama) and improved clarity, consider using 0 as the default value.

Suggested change:

- num_input_tokens = getattr(usage, "prompt_tokens", -1)
- num_output_tokens = getattr(usage, "completion_tokens", -1)
+ num_input_tokens = getattr(usage, "prompt_tokens", 0)
+ num_output_tokens = getattr(usage, "completion_tokens", 0)

This change would make the default values more intuitive and consistent with other parts of the codebase.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
num_input_tokens = getattr(usage, "prompt_tokens", -1)
num_output_tokens = getattr(usage, "completion_tokens", -1)
num_input_tokens = getattr(usage, "prompt_tokens", 0)
num_output_tokens = getattr(usage, "completion_tokens", 0)


message = PyrisMessage(
sender=map_str_to_role(message.role),
contents=[TextMessageContentDTO(textContent=message.content)],
send_at=datetime.now(),
num_input_tokens=num_input_tokens,
num_output_tokens=num_output_tokens,
model_info=model,
)
return message


class OpenAIChatModel(ChatModel):
Expand Down Expand Up @@ -103,7 +113,9 @@ def chat(
temperature=arguments.temperature,
max_tokens=arguments.max_tokens,
)
return convert_to_iris_message(response.choices[0].message)
return convert_to_iris_message(
response.choices[0].message, response.usage, response.model
)
except Exception as e:
wait_time = initial_delay * (backoff_factor**attempt)
logging.warning(f"Exception on attempt {attempt + 1}: {e}")
Expand Down
14 changes: 12 additions & 2 deletions app/llm/langchain/iris_langchain_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
BaseChatModel,
)
from langchain_core.messages import BaseMessage
from langchain_core.outputs import ChatResult
from langchain_core.outputs.chat_generation import ChatGeneration
from langchain_core.outputs import ChatResult, ChatGeneration

from ..external.LLMTokenCount import LLMTokenCount
from ..external.PipelineEnum import PipelineEnum
from ...common import (
convert_iris_message_to_langchain_message,
convert_langchain_message_to_iris_message,
Expand All @@ -20,6 +21,7 @@ class IrisLangchainChatModel(BaseChatModel):

request_handler: RequestHandler
completion_args: CompletionArguments
tokens: LLMTokenCount = None

def __init__(
self,
Expand All @@ -43,6 +45,14 @@ def _generate(
iris_message = self.request_handler.chat(iris_messages, self.completion_args)
base_message = convert_iris_message_to_langchain_message(iris_message)
chat_generation = ChatGeneration(message=base_message)
self.tokens = LLMTokenCount(
model_info=iris_message.model_info,
num_input_tokens=iris_message.num_input_tokens,
cost_per_input_token=iris_message.cost_per_input_token,
num_output_tokens=iris_message.num_output_tokens,
cost_per_output_token=iris_message.cost_per_output_token,
pipeline=PipelineEnum.NOT_SET,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't this overwrite existing any token counts when this wrapper is used a second time?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The token information is saved in every pipeline as soon as the LLM call is done, so the token counts can be overwritten without any loss

return ChatResult(generations=[chat_generation])

@property
Expand Down
5 changes: 4 additions & 1 deletion app/llm/request_handler/capability_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ def chat(
self, messages: list[PyrisMessage], arguments: CompletionArguments
) -> PyrisMessage:
llm = self._select_model(ChatModel)
return llm.chat(messages, arguments)
message = llm.chat(messages, arguments)
message.cost_per_input_token = llm.capabilities.input_cost.value
message.cost_per_output_token = llm.capabilities.output_cost.value
return message

def embed(self, text: str) -> list[float]:
llm = self._select_model(EmbeddingModel)
Expand Down
6 changes: 6 additions & 0 deletions app/pipeline/chat/code_feedback_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from ...domain.data.feedback_dto import FeedbackDTO
from ...llm import CapabilityRequestHandler, RequirementList
from ...llm import CompletionArguments
from ...llm.external.LLMTokenCount import LLMTokenCount
from ...llm.external.PipelineEnum import PipelineEnum
from ...llm.langchain import IrisLangchainChatModel
from ...pipeline import Pipeline
from ...web.status.status_update import StatusCallback
Expand Down Expand Up @@ -40,6 +42,7 @@ class CodeFeedbackPipeline(Pipeline):
callback: StatusCallback
default_prompt: PromptTemplate
output_parser: StrOutputParser
tokens: LLMTokenCount

def __init__(self, callback: Optional[StatusCallback] = None):
super().__init__(implementation_id="code_feedback_pipeline_reference_impl")
Expand Down Expand Up @@ -141,4 +144,7 @@ def __call__(
}
)
)
token_usage = self.llm.tokens
token_usage.pipeline = PipelineEnum.IRIS_CODE_FEEDBACK
self.tokens = token_usage
return response.replace("{", "{{").replace("}", "}}")
8 changes: 7 additions & 1 deletion app/pipeline/chat/course_chat_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
elicit_begin_agent_jol_prompt,
)
from ...domain import CourseChatPipelineExecutionDTO
from ...llm.external.PipelineEnum import PipelineEnum
from ...retrieval.lecture_retrieval import LectureRetrieval
from ...vector_database.database import VectorDatabase
from ...vector_database.lecture_schema import LectureSchema
Expand Down Expand Up @@ -107,6 +108,7 @@ def __init__(self, callback: CourseChatStatusCallback, variant: str = "default")

# Create the pipeline
self.pipeline = self.llm | StrOutputParser()
self.tokens = []

def __repr__(self):
return f"{self.__class__.__name__}(llm={self.llm})"
Expand Down Expand Up @@ -406,14 +408,18 @@ def lecture_content_retrieval() -> str:
self.callback.in_progress()
for step in agent_executor.iter(params):
print("STEP:", step)
token_count = self.llm.tokens
token_count.pipeline = PipelineEnum.IRIS_CHAT_COURSE_MESSAGE
self.tokens.append(token_count)
if step.get("output", None):
out = step["output"]

if self.retrieved_paragraphs:
self.callback.in_progress("Augmenting response ...")
out = self.citation_pipeline(self.retrieved_paragraphs, out)
self.tokens.extend(self.citation_pipeline.tokens)

self.callback.done("Response created", final_result=out)
self.callback.done("Response created", final_result=out, tokens=self.tokens)

# try:
# # if out:
Expand Down
27 changes: 25 additions & 2 deletions app/pipeline/chat/exercise_chat_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from ...domain.data.programming_submission_dto import ProgrammingSubmissionDTO
from ...llm import CapabilityRequestHandler, RequirementList
from ...llm import CompletionArguments
from ...llm.external.PipelineEnum import PipelineEnum
from ...llm.langchain import IrisLangchainChatModel
from ...retrieval.lecture_retrieval import LectureRetrieval
from ...vector_database.database import VectorDatabase
Expand Down Expand Up @@ -78,6 +79,7 @@ def __init__(self, callback: ExerciseChatStatusCallback):
self.code_feedback_pipeline = CodeFeedbackPipeline()
self.pipeline = self.llm | StrOutputParser()
self.citation_pipeline = CitationPipeline()
self.tokens = []

def __repr__(self):
return f"{self.__class__.__name__}(llm={self.llm})"
Expand All @@ -98,7 +100,9 @@ def __call__(self, dto: ExerciseChatPipelineExecutionDTO):
)
self._run_exercise_chat_pipeline(dto, should_execute_lecture_pipeline),
self.callback.done(
"Generated response", final_result=self.exercise_chat_response
"Generated response",
final_result=self.exercise_chat_response,
tokens=self.tokens,
)

try:
Expand All @@ -112,7 +116,15 @@ def __call__(self, dto: ExerciseChatPipelineExecutionDTO):
suggestion_dto.last_message = self.exercise_chat_response
suggestion_dto.problem_statement = dto.exercise.problem_statement
suggestions = self.suggestion_pipeline(suggestion_dto)
self.callback.done(final_result=None, suggestions=suggestions)
if self.suggestion_pipeline.tokens is not None:
tokens = [self.suggestion_pipeline.tokens]
else:
tokens = []
self.callback.done(
final_result=None,
suggestions=suggestions,
tokens=tokens,
)
else:
# This should never happen but whatever
self.callback.skip(
Expand Down Expand Up @@ -200,6 +212,8 @@ def _run_exercise_chat_pipeline(
if submission:
try:
feedback = future_feedback.result()
if self.code_feedback_pipeline.tokens is not None:
self.tokens.append(self.code_feedback_pipeline.tokens)
self.prompt += SystemMessagePromptTemplate.from_template(
"Another AI has checked the code of the student and has found the following issues. "
"Use this information to help the student. "
Expand All @@ -220,6 +234,8 @@ def _run_exercise_chat_pipeline(
if should_execute_lecture_pipeline:
try:
self.retrieved_lecture_chunks = future_lecture.result()
if self.retriever.tokens is not None:
self.tokens.append(self.retriever.tokens)
if len(self.retrieved_lecture_chunks) > 0:
self._add_relevant_chunks_to_prompt(
self.retrieved_lecture_chunks
Expand Down Expand Up @@ -252,6 +268,7 @@ def _run_exercise_chat_pipeline(
.with_config({"run_name": "Response Drafting"})
.invoke({})
)
self._collect_llm_tokens()
self.callback.done()
self.prompt = ChatPromptTemplate.from_messages(
[
Expand All @@ -266,6 +283,7 @@ def _run_exercise_chat_pipeline(
.with_config({"run_name": "Response Refining"})
.invoke({})
)
self._collect_llm_tokens()

if "!ok!" in guide_response:
print("Response is ok and not rewritten!!!")
Expand Down Expand Up @@ -367,3 +385,8 @@ def should_execute_lecture_pipeline(self, course_id: int) -> bool:
)
return len(result.objects) > 0
return False

def _collect_llm_tokens(self):
if self.llm.tokens is not None:
self.llm.tokens.pipeline = PipelineEnum.IRIS_CHAT_EXERCISE_MESSAGE
self.tokens.append(self.llm.tokens)
5 changes: 5 additions & 0 deletions app/pipeline/chat/interaction_suggestion_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
)

from ...llm import CompletionArguments
from ...llm.external.LLMTokenCount import LLMTokenCount
from ...llm.external.PipelineEnum import PipelineEnum
from ...llm.langchain import IrisLangchainChatModel

from ..pipeline import Pipeline
Expand All @@ -52,6 +54,7 @@ class InteractionSuggestionPipeline(Pipeline):
pipeline: Runnable
prompt: ChatPromptTemplate
variant: str
tokens: LLMTokenCount

def __init__(self, variant: str = "default"):
super().__init__(implementation_id="interaction_suggestion_pipeline")
Expand Down Expand Up @@ -164,6 +167,8 @@ def __call__(
self.prompt = ChatPromptTemplate.from_messages(prompt_val)

response: dict = (self.prompt | self.pipeline).invoke({})
self.tokens = self.llm.tokens
self.tokens.pipeline = PipelineEnum.IRIS_INTERACTION_SUGGESTION
return response["questions"]
except Exception as e:
logger.error(
Expand Down
Loading