Skip to content

Commit

Permalink
[GEN AI] Improving Langfuse's traces
Browse files Browse the repository at this point in the history
  • Loading branch information
assouktim committed Dec 4, 2024
1 parent 8cd1c92 commit 6072ada
Show file tree
Hide file tree
Showing 11 changed files with 134 additions and 56 deletions.
4 changes: 4 additions & 0 deletions bot/engine/src/main/kotlin/engine/action/ActionMetadata.kt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package ai.tock.bot.engine.action

import ai.tock.genai.orchestratorclient.responses.ObservabilityInfo

data class ActionMetadata(
/** Is it the last answer of the bot. */
var lastAnswer: Boolean = false,
Expand All @@ -41,5 +43,7 @@ data class ActionMetadata(
var sourceWithContent: Boolean = false,
/** is Gen AI RAG's answer? **/
var isGenAiRagAnswer: Boolean = false,
/** ObservabilityInfo **/
val observabilityInfo: ObservabilityInfo? = null,
)

26 changes: 17 additions & 9 deletions bot/engine/src/main/kotlin/engine/config/RAGAnswerHandler.kt
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package ai.tock.bot.engine.config

import ai.tock.bot.admin.bot.rag.BotRAGConfiguration
import ai.tock.bot.admin.bot.vectorstore.BotVectorStoreConfiguration
import ai.tock.bot.admin.indicators.IndicatorValues
import ai.tock.bot.admin.indicators.Indicators
import ai.tock.bot.admin.indicators.metric.MetricType
Expand All @@ -31,15 +30,16 @@ import ai.tock.bot.engine.action.SendSentence
import ai.tock.bot.engine.action.SendSentenceWithFootnotes
import ai.tock.bot.engine.dialog.Dialog
import ai.tock.bot.engine.user.PlayerType
import ai.tock.genai.orchestratorclient.requests.*
import ai.tock.genai.orchestratorclient.requests.ChatMessage
import ai.tock.genai.orchestratorclient.requests.ChatMessageType
import ai.tock.genai.orchestratorclient.requests.DialogDetails
import ai.tock.genai.orchestratorclient.requests.RAGQuery
import ai.tock.genai.orchestratorclient.responses.ObservabilityInfo
import ai.tock.genai.orchestratorclient.responses.RAGResponse
import ai.tock.genai.orchestratorclient.responses.TextWithFootnotes
import ai.tock.genai.orchestratorclient.retrofit.GenAIOrchestratorBusinessError
import ai.tock.genai.orchestratorclient.retrofit.GenAIOrchestratorValidationError
import ai.tock.genai.orchestratorclient.services.RAGService
import ai.tock.genai.orchestratorcore.models.vectorstore.*
import ai.tock.genai.orchestratorcore.utils.OpenSearchUtils
import ai.tock.genai.orchestratorcore.utils.PGVectorUtils
import ai.tock.genai.orchestratorcore.utils.VectorStoreUtils
import ai.tock.shared.*
import engine.config.AbstractProactiveAnswerHandler
Expand Down Expand Up @@ -68,7 +68,7 @@ object RAGAnswerHandler : AbstractProactiveAnswerHandler {
BotRepository.saveMetric(createMetric(MetricType.STORY_HANDLED))

// Call RAG Api - Gen AI Orchestrator
val (answer, debug, noAnswerStory) = rag(this)
val (answer, debug, noAnswerStory, observabilityInfo) = rag(this)

// Add debug data if available and if debugging is enabled
if (debug != null && (action.metadata.debugEnabled || ragDebugEnabled)) {
Expand All @@ -87,7 +87,7 @@ object RAGAnswerHandler : AbstractProactiveAnswerHandler {
if(action.metadata.sourceWithContent) it.content else null
)
}.toMutableList(),
metadata = ActionMetadata(isGenAiRagAnswer = true)
metadata = ActionMetadata(isGenAiRagAnswer = true, observabilityInfo = observabilityInfo)
)
)
} else {
Expand Down Expand Up @@ -179,7 +179,14 @@ object RAGAnswerHandler : AbstractProactiveAnswerHandler {
try {
val response = ragService.rag(
query = RAGQuery(
history = getDialogHistory(dialog),
dialog = DialogDetails(
dialogId = dialog.id.toString(),
userId = dialog.playerIds.firstOrNull { PlayerType.user == it.type }?.id,
history = getDialogHistory(dialog),
tags = listOf(
"connector:${underlyingConnector.connectorType.id}"
)
),
questionAnsweringLlmSetting = ragConfiguration.llmSetting,
questionAnsweringPromptInputs = mapOf(
"question" to action.toString(),
Expand All @@ -195,7 +202,7 @@ object RAGAnswerHandler : AbstractProactiveAnswerHandler {
)

// Handle RAG response
return RAGResult(response?.answer, response?.debug, ragStoryRedirection(this, response))
return RAGResult(response?.answer, response?.debug, ragStoryRedirection(this, response), response?.observabilityInfo)
} catch (exc: Exception) {
logger.error { exc }
// Save failure metric
Expand Down Expand Up @@ -259,6 +266,7 @@ data class RAGResult(
val answer: TextWithFootnotes? = null,
val debug: Any? = null,
val noAnswerStory: StoryDefinition? = null,
val observabilityInfo: ObservabilityInfo? = null,
)

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import ai.tock.genai.orchestratorcore.models.vectorstore.VectorStoreSetting
data class RAGQuery(
// val condenseQuestionLlmSetting: LLMSetting,
// val condenseQuestionPromptInputs: Map<String, String>,
val history: List<ChatMessage> = emptyList(),
val dialog: DialogDetails?,
val questionAnsweringLlmSetting: LLMSetting,
val questionAnsweringPromptInputs: Map<String, String>,
val embeddingQuestionEmSetting: EMSetting,
Expand All @@ -34,6 +34,13 @@ data class RAGQuery(
val observabilitySetting: ObservabilitySetting?
)

data class DialogDetails(
val dialogId: String? = null,
val userId: String? = null,
val history: List<ChatMessage> = emptyList(),
val tags: List<String> = emptyList(),
)

data class ChatMessage(
val text: String,
val type: ChatMessageType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ package ai.tock.genai.orchestratorclient.responses

data class RAGResponse(
val answer: TextWithFootnotes,
val debug: Any? = null
val debug: Any? = null,
val observabilityInfo: ObservabilityInfo? = null,
)

data class TextWithFootnotes(
Expand All @@ -31,4 +32,10 @@ data class Footnote(
val title: String,
val url: String? = null,
val content: String?,
)

data class ObservabilityInfo(
val traceId: String,
val traceName: String,
val traceUrl: String,
)
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __eq__(self, other):
)

def __hash__(self):
return hash((self.title, self.url, self.content))
return hash((self.title, str(self.url or ''), self.content))


class Footnote(Source):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,16 +133,30 @@ class VectorStoreProviderSettingStatusQuery(BaseModel):
default=None,
)

class DialogDetails(BaseModel):
"""The dialog details model"""

dialog_id: Optional[str] = Field(
description="The dialog/session ID, attached to the observability traces if "
"the observability provider support it.",
default=None, examples=["uuid-0123"])
user_id: Optional[str] = Field(
description="The user ID, attached to the observability traces if the observability provider support it",
default=None, examples=["[email protected]"])
history: list[ChatMessage] = Field(
description="Conversation history, used to reformulate the user's question.")
tags: list[str] = Field(
description='List of tags, attached to the observability trace, if the observability provider support it.',
examples=[["my-Tag"]])


class RagQuery(BaseQuery):
"""The RAG query model"""

history: list[ChatMessage] = Field(
description="Conversation history, used to reformulate the user's question."
)
dialog: Optional[DialogDetails] = Field(description='The user dialog details.')
question_answering_prompt_inputs: Any = Field(
description='Key-value inputs for the llm prompt when used as a template. Please note that the '
'chat_history field must not be specified here, it will be override by the history field',
'chat_history field must not be specified here, it will be override by the dialog.history field',
)
# condense_question_llm_setting: LLMSetting =
# Field(description="LLM setting, used to condense the user's question.")
Expand All @@ -156,7 +170,7 @@ class RagQuery(BaseQuery):
)
question_answering_prompt_inputs: Any = Field(
description='Key-value inputs for the llm prompt when used as a template. Please note that the '
'chat_history field must not be specified here, it will be override by the history field',
'chat_history field must not be specified here, it will be override by the dialog.history field',
)
embedding_question_em_setting: EMSetting = Field(
description="Embedding model setting, used to calculate the user's question vector."
Expand All @@ -182,13 +196,15 @@ class RagQuery(BaseQuery):
'json_schema_extra': {
'examples': [
{
'history': [
{'text': 'Hello, how can I do this?', 'type': 'HUMAN'},
{
'text': 'you can do this with the following method ....',
'type': 'AI',
},
],
'dialog' : {
'history': [
{'text': 'Hello, how can I do this?', 'type': 'HUMAN'},
{
'text': 'you can do this with the following method ....',
'type': 'AI',
},
]
},
'question_answering_llm_setting': {
'provider': 'OpenAI',
'api_key': {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,19 @@ class EMProviderResponse(BaseModel):
)


class ObservabilityInfo(BaseModel):
"""The Observability Info model"""

trace_id: str = Field(
description='The observability trace id.'
)
trace_name: str = Field(
description='The observability trace name.'
)
trace_url: str = Field(
description='The observability trace url.'
)

class RagResponse(BaseModel):
"""The RAG response model"""

Expand All @@ -108,7 +121,10 @@ class RagResponse(BaseModel):
examples=[{'action': 'retrieve', 'result': 'OK', 'errors': []}],
default=None,
)

observability_info: Optional[ObservabilityInfo] = Field(
description='The observability info.',
default=None
)

class QAResponse(BaseModel):
"""The QA response model"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ async def generate_and_split_sentences(
config = {"callbacks": [
create_observability_callback_handler(
observability_setting=query.observability_setting,
trace_name=ObservabilityTrace.SENTENCE_GENERATION
trace_name=ObservabilityTrace.SENTENCE_GENERATION.value
)]}

sentences = await chain.ainvoke(query.prompt.inputs, config=config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"""

import logging
from typing import Optional
from typing import Optional, Any

from langchain_core.embeddings import Embeddings
from langfuse.callback import CallbackHandler as LangfuseCallbackHandler
Expand Down Expand Up @@ -338,22 +338,21 @@ def get_callback_handler_factory(

def create_observability_callback_handler(
observability_setting: Optional[ObservabilitySetting],
trace_name: ObservabilityTrace,
**kwargs: Any
) -> Optional[LangfuseCallbackHandler]:
"""
Create the Observability Callback Handler
Args:
observability_setting: The Observability Settings
trace_name: The trace name
Returns:
The Observability Callback Handler
"""
if observability_setting is not None:
return get_callback_handler_factory(
setting=observability_setting
).get_callback_handler(trace_name=trace_name.value)
setting=observability_setting,
).get_callback_handler(**kwargs)

return None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from langchain_core.documents import Document
from langchain_core.prompts import PromptTemplate
from langchain_core.vectorstores import VectorStoreRetriever
from langfuse.callback import CallbackHandler as LangfuseCallbackHandler

from gen_ai_orchestrator.errors.exceptions.exceptions import (
GenAIGuardCheckException,
Expand All @@ -59,7 +60,7 @@
TextWithFootnotes,
)
from gen_ai_orchestrator.routers.requests.requests import RagQuery
from gen_ai_orchestrator.routers.responses.responses import RagResponse
from gen_ai_orchestrator.routers.responses.responses import RagResponse, ObservabilityInfo
from gen_ai_orchestrator.services.langchain.callbacks.retriever_json_callback_handler import (
RetrieverJsonCallbackHandler,
)
Expand Down Expand Up @@ -93,15 +94,23 @@ async def execute_qa_chain(query: RagQuery, debug: bool) -> RagResponse:

conversational_retrieval_chain = create_rag_chain(query=query)

message_history = ChatMessageHistory()
session_id = None
user_id = None
tags = []
if query.dialog:
for msg in query.dialog.history:
if ChatMessageType.HUMAN == msg.type:
message_history.add_user_message(msg.text)
else:
message_history.add_ai_message(msg.text)
session_id = query.dialog.dialog_id,
user_id = query.dialog.user_id,
tags = query.dialog.tags,

logger.debug(
'RAG chain - Use chat history: %s', 'Yes' if len(query.history) > 0 else 'No'
'RAG chain - Use chat history: %s', 'Yes' if len(message_history.messages) > 0 else 'No'
)
message_history = ChatMessageHistory()
for msg in query.history:
if ChatMessageType.HUMAN == msg.type:
message_history.add_user_message(msg.text)
else:
message_history.add_ai_message(msg.text)

inputs = {
**query.question_answering_prompt_inputs,
Expand All @@ -115,17 +124,20 @@ async def execute_qa_chain(query: RagQuery, debug: bool) -> RagResponse:

callback_handlers = []
records_callback_handler = RetrieverJsonCallbackHandler()
observability_handler = None
if debug:
# Debug callback handler
callback_handlers.append(records_callback_handler)
if query.observability_setting is not None:
# Langfuse callback handler
callback_handlers.append(
create_observability_callback_handler(
observability_setting=query.observability_setting,
trace_name=ObservabilityTrace.RAG,
)
observability_handler = create_observability_callback_handler(
observability_setting=query.observability_setting,
trace_name=ObservabilityTrace.RAG.value,
session_id=session_id,
user_id=user_id,
tags=tags,
)
callback_handlers.append(observability_handler)

response = await conversational_retrieval_chain.ainvoke(
input=inputs,
Expand Down Expand Up @@ -161,6 +173,7 @@ async def execute_qa_chain(query: RagQuery, debug: bool) -> RagResponse:
)
),
),
observability_info=get_observability_info(observability_handler),
debug=get_rag_debug_data(
query, response, records_callback_handler, rag_duration
)
Expand All @@ -169,6 +182,17 @@ async def execute_qa_chain(query: RagQuery, debug: bool) -> RagResponse:
)


def get_observability_info(observability_handler) -> Optional[ObservabilityInfo]:
"""Get the observability Information"""
if isinstance(observability_handler, LangfuseCallbackHandler):
return ObservabilityInfo(
trace_id=observability_handler.trace.id,
trace_name=observability_handler.trace_name,
trace_url=observability_handler.get_trace_url()
)
else:
return None

def get_source_content(doc: Document) -> str:
"""
Find and delete the title followed by two line breaks
Expand Down
Loading

0 comments on commit 6072ada

Please sign in to comment.