Skip to content

Commit

Permalink
Track token usage of iris requests (#165)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexjoham authored Oct 23, 2024
1 parent 7b901fe commit 3bf8510
Show file tree
Hide file tree
Showing 31 changed files with 246 additions and 49 deletions.
16 changes: 16 additions & 0 deletions app/common/PipelineEnum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from enum import Enum


class PipelineEnum(str, Enum):
IRIS_CODE_FEEDBACK = "IRIS_CODE_FEEDBACK"
IRIS_CHAT_COURSE_MESSAGE = "IRIS_CHAT_COURSE_MESSAGE"
IRIS_CHAT_EXERCISE_MESSAGE = "IRIS_CHAT_EXERCISE_MESSAGE"
IRIS_INTERACTION_SUGGESTION = "IRIS_INTERACTION_SUGGESTION"
IRIS_CHAT_LECTURE_MESSAGE = "IRIS_CHAT_LECTURE_MESSAGE"
IRIS_COMPETENCY_GENERATION = "IRIS_COMPETENCY_GENERATION"
IRIS_CITATION_PIPELINE = "IRIS_CITATION_PIPELINE"
IRIS_RERANKER_PIPELINE = "IRIS_RERANKER_PIPELINE"
IRIS_SUMMARY_PIPELINE = "IRIS_SUMMARY_PIPELINE"
IRIS_LECTURE_RETRIEVAL_PIPELINE = "IRIS_LECTURE_RETRIEVAL_PIPELINE"
IRIS_LECTURE_INGESTION = "IRIS_LECTURE_INGESTION"
NOT_SET = "NOT_SET"
2 changes: 1 addition & 1 deletion app/common/message_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage

from app.domain.data.text_message_content_dto import TextMessageContentDTO
from app.domain.pyris_message import PyrisMessage, IrisMessageRole
from app.common.pyris_message import PyrisMessage, IrisMessageRole


def convert_iris_message_to_langchain_message(
Expand Down
3 changes: 3 additions & 0 deletions app/domain/pyris_message.py → app/common/pyris_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pydantic import BaseModel, ConfigDict, Field

from app.domain.data.message_content_dto import MessageContentDTO
from app.common.token_usage_dto import TokenUsageDTO


class IrisMessageRole(str, Enum):
Expand All @@ -16,6 +17,8 @@ class IrisMessageRole(str, Enum):
class PyrisMessage(BaseModel):
model_config = ConfigDict(populate_by_name=True)

token_usage: TokenUsageDTO = Field(default_factory=TokenUsageDTO)

sent_at: datetime | None = Field(alias="sentAt", default=None)
sender: IrisMessageRole
contents: List[MessageContentDTO] = []
Expand Down
18 changes: 18 additions & 0 deletions app/common/token_usage_dto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from pydantic import BaseModel, Field

from app.common.PipelineEnum import PipelineEnum


class TokenUsageDTO(BaseModel):
model_info: str = Field(alias="model", default="")
num_input_tokens: int = Field(alias="numInputTokens", default=0)
cost_per_input_token: float = Field(alias="costPerMillionInputToken", default=0)
num_output_tokens: int = Field(alias="numOutputTokens", default=0)
cost_per_output_token: float = Field(alias="costPerMillionOutputToken", default=0)
pipeline: PipelineEnum = Field(alias="pipelineId", default=PipelineEnum.NOT_SET)

def __str__(self):
return (
f"{self.model_info}: {self.num_input_tokens} input cost: {self.cost_per_input_token},"
f" {self.num_output_tokens} output cost: {self.cost_per_output_token}, pipeline: {self.pipeline} "
)
1 change: 0 additions & 1 deletion app/domain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,5 @@
from app.domain.chat.course_chat.course_chat_pipeline_execution_dto import (
CourseChatPipelineExecutionDTO,
)
from .pyris_message import PyrisMessage, IrisMessageRole
from app.domain.data import image_message_content_dto
from app.domain.feature_dto import FeatureDTO
2 changes: 1 addition & 1 deletion app/domain/chat/chat_pipeline_execution_base_data_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pydantic import Field, BaseModel

from app.domain import PipelineExecutionSettingsDTO
from app.domain.pyris_message import PyrisMessage
from app.common.pyris_message import PyrisMessage
from app.domain.data.user_dto import UserDTO
from app.domain.status.stage_dto import StageDTO

Expand Down
2 changes: 1 addition & 1 deletion app/domain/chat/chat_pipeline_execution_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pydantic import Field

from app.domain import PipelineExecutionDTO
from app.domain.pyris_message import PyrisMessage
from app.common.pyris_message import PyrisMessage
from app.domain.data.user_dto import UserDTO


Expand Down
2 changes: 1 addition & 1 deletion app/domain/chat/interaction_suggestion_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from pydantic import Field, BaseModel

from app.domain import PyrisMessage
from app.common.pyris_message import PyrisMessage


class InteractionSuggestionPipelineExecutionDTO(BaseModel):
Expand Down
2 changes: 2 additions & 0 deletions app/domain/status/status_update_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

from pydantic import BaseModel

from app.common.token_usage_dto import TokenUsageDTO
from ...domain.status.stage_dto import StageDTO


class StatusUpdateDTO(BaseModel):
stages: List[StageDTO]
tokens: List[TokenUsageDTO] = []
3 changes: 2 additions & 1 deletion app/domain/text_exercise_chat_pipeline_execution_dto.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pydantic import BaseModel, Field

from app.domain import PipelineExecutionDTO, PyrisMessage
from app.common.pyris_message import PyrisMessage
from app.domain import PipelineExecutionDTO
from app.domain.data.text_exercise_dto import TextExerciseDTO


Expand Down
2 changes: 1 addition & 1 deletion app/llm/external/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABCMeta, abstractmethod
from pydantic import BaseModel

from ...domain import PyrisMessage
from ...common.pyris_message import PyrisMessage
from ...llm import CompletionArguments
from ...llm.capability import CapabilityList

Expand Down
22 changes: 18 additions & 4 deletions app/llm/external/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
from ollama import Client, Message

from ...common.message_converters import map_role_to_str, map_str_to_role
from ...common.pyris_message import PyrisMessage
from ...common.token_usage_dto import TokenUsageDTO
from ...domain.data.json_message_content_dto import JsonMessageContentDTO
from ...domain.data.text_message_content_dto import TextMessageContentDTO
from ...domain.data.image_message_content_dto import ImageMessageContentDTO
from ...domain import PyrisMessage
from ...llm import CompletionArguments
from ...llm.external.model import ChatModel, CompletionModel, EmbeddingModel

Expand Down Expand Up @@ -57,15 +58,23 @@ def convert_to_ollama_messages(messages: list[PyrisMessage]) -> list[Message]:
return messages_to_return


def convert_to_iris_message(message: Message) -> PyrisMessage:
def convert_to_iris_message(
message: Message, num_input_tokens: int, num_output_tokens: int, model: str
) -> PyrisMessage:
"""
Convert a Message to a PyrisMessage
"""
contents = [TextMessageContentDTO(text_content=message["content"])]
tokens = TokenUsageDTO(
numInputTokens=num_input_tokens,
numOutputTokens=num_output_tokens,
model=model,
)
return PyrisMessage(
sender=map_str_to_role(message["role"]),
contents=contents,
send_at=datetime.now(),
sentAt=datetime.now(),
token_usage=tokens,
)


Expand Down Expand Up @@ -108,7 +117,12 @@ def chat(
format="json" if arguments.response_format == "JSON" else "",
options=self.options,
)
return convert_to_iris_message(response["message"])
return convert_to_iris_message(
response.get("message"),
response.get("prompt_eval_count", 0),
response.get("eval_count", 0),
response.get("model", self.model),
)

def embed(self, text: str) -> list[float]:
response = self._client.embeddings(
Expand Down
29 changes: 23 additions & 6 deletions app/llm/external/openai_chat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import time
from datetime import datetime
from typing import Literal, Any
from typing import Literal, Any, Optional

from openai import (
OpenAI,
Expand All @@ -12,12 +12,14 @@
ContentFilterFinishReasonError,
)
from openai.lib.azure import AzureOpenAI
from openai.types import CompletionUsage
from openai.types.chat import ChatCompletionMessage, ChatCompletionMessageParam
from openai.types.shared_params import ResponseFormatJSONObject

from ...common.message_converters import map_str_to_role, map_role_to_str
from app.domain.data.text_message_content_dto import TextMessageContentDTO
from ...domain import PyrisMessage
from ...common.pyris_message import PyrisMessage
from ...common.token_usage_dto import TokenUsageDTO
from ...domain.data.image_message_content_dto import ImageMessageContentDTO
from ...domain.data.json_message_content_dto import JsonMessageContentDTO
from ...llm import CompletionArguments
Expand Down Expand Up @@ -67,15 +69,28 @@ def convert_to_open_ai_messages(
return openai_messages


def convert_to_iris_message(message: ChatCompletionMessage) -> PyrisMessage:
def convert_to_iris_message(
message: ChatCompletionMessage, usage: Optional[CompletionUsage], model: str
) -> PyrisMessage:
"""
Convert a ChatCompletionMessage to a PyrisMessage
"""
return PyrisMessage(
num_input_tokens = getattr(usage, "prompt_tokens", 0)
num_output_tokens = getattr(usage, "completion_tokens", 0)

tokens = TokenUsageDTO(
model=model,
numInputTokens=num_input_tokens,
numOutputTokens=num_output_tokens,
)

message = PyrisMessage(
sender=map_str_to_role(message.role),
contents=[TextMessageContentDTO(textContent=message.content)],
send_at=datetime.now(),
sentAt=datetime.now(),
token_usage=tokens,
)
return message


class OpenAIChatModel(ChatModel):
Expand Down Expand Up @@ -113,13 +128,15 @@ def chat(
max_tokens=arguments.max_tokens,
)
choice = response.choices[0]
usage = response.usage
model = response.model
if choice.finish_reason == "content_filter":
# I figured that an openai error would be automatically raised if the content filter activated,
# but it seems that that is not the case.
# We don't want to retry because the same message will likely be rejected again.
# Raise an exception to trigger the global error handler and report a fatal error to the client.
raise ContentFilterFinishReasonError()
return convert_to_iris_message(choice.message)
return convert_to_iris_message(choice.message, usage, model)
except (
APIError,
APITimeoutError,
Expand Down
16 changes: 14 additions & 2 deletions app/llm/langchain/iris_langchain_chat_model.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
import logging
from typing import List, Optional, Any

from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.chat_models import (
BaseChatModel,
)
from langchain_core.messages import BaseMessage
from langchain_core.outputs import ChatResult
from langchain_core.outputs.chat_generation import ChatGeneration
from langchain_core.outputs import ChatResult, ChatGeneration

from app.common.PipelineEnum import PipelineEnum
from ...common import (
convert_iris_message_to_langchain_message,
convert_langchain_message_to_iris_message,
)
from app.common.token_usage_dto import TokenUsageDTO
from ...llm import RequestHandler, CompletionArguments


Expand All @@ -20,6 +22,8 @@ class IrisLangchainChatModel(BaseChatModel):

request_handler: RequestHandler
completion_args: CompletionArguments
tokens: TokenUsageDTO = None
logger = logging.getLogger(__name__)

def __init__(
self,
Expand All @@ -43,6 +47,14 @@ def _generate(
iris_message = self.request_handler.chat(iris_messages, self.completion_args)
base_message = convert_iris_message_to_langchain_message(iris_message)
chat_generation = ChatGeneration(message=base_message)
self.tokens = TokenUsageDTO(
model=iris_message.token_usage.model_info,
numInputTokens=iris_message.token_usage.num_input_tokens,
costPerMillionInputToken=iris_message.token_usage.cost_per_input_token,
numOutputTokens=iris_message.token_usage.num_output_tokens,
costPerMillionOutputToken=iris_message.token_usage.cost_per_output_token,
pipeline=PipelineEnum.NOT_SET,
)
return ChatResult(generations=[chat_generation])

@property
Expand Down
2 changes: 1 addition & 1 deletion app/llm/request_handler/basic_request_handler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Optional

from app.domain import PyrisMessage
from app.common.pyris_message import PyrisMessage
from app.domain.data.image_message_content_dto import ImageMessageContentDTO
from app.llm.request_handler import RequestHandler
from app.llm.completion_arguments import CompletionArguments
Expand Down
7 changes: 5 additions & 2 deletions app/llm/request_handler/capability_request_handler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from enum import Enum

from app.domain import PyrisMessage
from app.common.pyris_message import PyrisMessage
from app.llm.capability import RequirementList
from app.llm.external.model import (
ChatModel,
Expand Down Expand Up @@ -44,7 +44,10 @@ def chat(
self, messages: list[PyrisMessage], arguments: CompletionArguments
) -> PyrisMessage:
llm = self._select_model(ChatModel)
return llm.chat(messages, arguments)
message = llm.chat(messages, arguments)
message.token_usage.cost_per_input_token = llm.capabilities.input_cost.value
message.token_usage.cost_per_output_token = llm.capabilities.output_cost.value
return message

def embed(self, text: str) -> list[float]:
llm = self._select_model(EmbeddingModel)
Expand Down
2 changes: 1 addition & 1 deletion app/llm/request_handler/request_handler_interface.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABCMeta, abstractmethod
from typing import Optional

from ...domain import PyrisMessage
from ...common.pyris_message import PyrisMessage
from ...domain.data.image_message_content_dto import ImageMessageContentDTO
from ...llm import CompletionArguments

Expand Down
8 changes: 7 additions & 1 deletion app/pipeline/chat/code_feedback_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
from langsmith import traceable
from pydantic import BaseModel

from ...domain import PyrisMessage
from ...common.pyris_message import PyrisMessage
from ...domain.data.build_log_entry import BuildLogEntryDTO
from ...domain.data.feedback_dto import FeedbackDTO
from app.common.token_usage_dto import TokenUsageDTO
from ...llm import CapabilityRequestHandler, RequirementList
from ...llm import CompletionArguments
from app.common.PipelineEnum import PipelineEnum
from ...llm.langchain import IrisLangchainChatModel
from ...pipeline import Pipeline
from ...web.status.status_update import StatusCallback
Expand Down Expand Up @@ -40,6 +42,7 @@ class CodeFeedbackPipeline(Pipeline):
callback: StatusCallback
default_prompt: PromptTemplate
output_parser: StrOutputParser
tokens: TokenUsageDTO

def __init__(self, callback: Optional[StatusCallback] = None):
super().__init__(implementation_id="code_feedback_pipeline_reference_impl")
Expand Down Expand Up @@ -141,4 +144,7 @@ def __call__(
}
)
)
token_usage = self.llm.tokens
token_usage.pipeline = PipelineEnum.IRIS_CODE_FEEDBACK
self.tokens = token_usage
return response.replace("{", "{{").replace("}", "}}")
Loading

0 comments on commit 3bf8510

Please sign in to comment.