Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Push MLflow experiment tracking to a reconciler to avoid race conditions #135

Merged
merged 20 commits into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion llm-service/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -166,4 +166,6 @@ cython_debug/

# Mlflow

mlruns/
mlruns/

reconciler/data/
42 changes: 15 additions & 27 deletions llm-service/app/routers/index/data_source/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,10 @@
from http import HTTPStatus
from typing import Any, Dict, Optional

import mlflow
from fastapi import APIRouter, Depends, HTTPException
from fastapi_utils.cbv import cbv
from llama_index.core.llms import LLM
from llama_index.core.node_parser import SentenceSplitter
from mlflow.entities import Experiment
from pydantic import BaseModel

from .... import exceptions
Expand All @@ -49,6 +47,7 @@
from ....services import document_storage, models
from ....services.metadata_apis import data_sources_metadata_api
from ....services.metadata_apis.data_sources_metadata_api import RagDataSource
from ....services.mlflow import write_mlflow_run_json

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -142,16 +141,8 @@ def download_and_index(
request: RagIndexDocumentRequest,
) -> None:
datasource = data_sources_metadata_api.get_metadata(data_source_id)
mlflow.llama_index.autolog()
experiment: Experiment = mlflow.set_experiment(
experiment_name=f"datasource_{datasource.name}_{data_source_id}"
)
with mlflow.start_run(
experiment_id=experiment.experiment_id, run_name=f"doc_{doc_id}"
):
self._download_and_index(datasource, doc_id, request)
self._download_and_index(datasource, doc_id, request)

@mlflow.trace(name="download_and_index")
def _download_and_index(
self, datasource: RagDataSource, doc_id: str, request: RagIndexDocumentRequest
) -> None:
Expand Down Expand Up @@ -181,23 +172,20 @@ def _download_and_index(
llm=llm,
chunks_vector_store=self.chunks_vector_store,
)

mlflow.log_metrics(
{
"file_size_bytes": file_path.stat().st_size,
}
)

mlflow.log_params(
write_mlflow_run_json(
f"datasource_{datasource.name}_{datasource.id}",
f"doc_{doc_id}",
{
"data_source_id": datasource.id,
"embedding_model": datasource.embedding_model,
"summarization_model": datasource.summarization_model,
"chunk_size": request.configuration.chunk_size,
"chunk_overlap": request.configuration.chunk_overlap,
"file_name": request.original_filename,
"file_size_bytes": file_path.stat().st_size,
}
"params": {
"data_source_id": str(datasource.id),
"embedding_model": datasource.embedding_model,
"summarization_model": datasource.summarization_model,
"chunk_size": str(request.configuration.chunk_size),
"chunk_overlap": str(request.configuration.chunk_overlap),
"file_name": request.original_filename,
"file_size_bytes": str(file_path.stat().st_size),
}
},
)

# Delete to avoid duplicates
Expand Down
34 changes: 3 additions & 31 deletions llm-service/app/routers/index/sessions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,14 @@
import logging
from typing import Annotated

import mlflow
from fastapi import APIRouter, Cookie
from mlflow.entities import Experiment, Run
from pydantic import BaseModel

from .... import exceptions
from ....rag_types import RagPredictConfiguration
from ....services.chat import generate_suggested_questions, v2_chat, direct_llm_chat
from ....services.chat_store import ChatHistoryManager, RagStudioChatMessage
from ....services.metadata_apis import session_metadata_api
from ....services.mlflow import rating_mlflow_log_metric, feedback_mlflow_log_table

logger = logging.getLogger(__name__)
router = APIRouter(prefix="/sessions/{session_id}", tags=["Sessions"])
Expand Down Expand Up @@ -93,18 +91,7 @@ def rating(
response_id: str,
request: ChatResponseRating,
) -> ChatResponseRating:
session = session_metadata_api.get_session(session_id)
experiment: Experiment = mlflow.set_experiment(
experiment_name=f"session_{session.name}_{session.id}"
)
runs: list[Run] = mlflow.search_runs(
[experiment.experiment_id],
filter_string=f"tags.response_id='{response_id}'",
output_format="list",
)
for run in runs:
value: int = 1 if request.rating else -1
mlflow.log_metric("rating", value, run_id=run.info.run_id)
rating_mlflow_log_metric(request.rating, response_id, session_id)
return ChatResponseRating(rating=request.rating)


Expand All @@ -121,21 +108,7 @@ def feedback(
response_id: str,
request: ChatResponseFeedback,
) -> ChatResponseFeedback:
session = session_metadata_api.get_session(session_id)
experiment: Experiment = mlflow.set_experiment(
experiment_name=f"session_{session.name}_{session.id}"
)
runs: list[Run] = mlflow.search_runs(
[experiment.experiment_id],
filter_string=f"tags.response_id='{response_id}'",
output_format="list",
)
for run in runs:
mlflow.log_table(
data={"feedback": request.feedback},
artifact_file="feedback.json",
run_id=run.info.run_id,
)
feedback_mlflow_log_table(request.feedback, response_id, session_id)
return ChatResponseFeedback(feedback=request.feedback)


Expand Down Expand Up @@ -168,7 +141,6 @@ def chat(
_basusertoken: Annotated[str | None, Cookie()] = None,
) -> RagStudioChatMessage:
user_name = parse_jwt_cookie(_basusertoken)
mlflow.llama_index.autolog()

configuration = request.configuration or RagPredictConfiguration()
if configuration.exclude_knowledge_base:
Expand Down
136 changes: 26 additions & 110 deletions llm-service/app/services/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,13 @@
# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF
# DATA.
# ##############################################################################
import asyncio
import re
import time
import uuid
from typing import List, Iterable

import mlflow
from fastapi import HTTPException
from llama_index.core.base.llms.types import MessageRole
from llama_index.core.chat_engine.types import AgentChatResponse
from mlflow.entities import Experiment

from . import evaluators, llm_completion
from .chat_store import (
Expand All @@ -56,8 +52,9 @@
RagStudioChatMessage,
RagMessage,
)
from .metadata_apis import session_metadata_api, data_sources_metadata_api
from .metadata_apis import session_metadata_api
from .metadata_apis.session_metadata_api import Session
from .mlflow import record_rag_mlflow_run, record_direct_llm_mlflow_run
from .query import querier
from .query.query_configuration import QueryConfiguration
from ..ai.vector_stores.qdrant import QdrantVectorStore
Expand All @@ -78,32 +75,22 @@ def v2_chat(
use_summary_filter=session.query_configuration.enable_summary_filter,
)
response_id = str(uuid.uuid4())
experiment: Experiment = mlflow.set_experiment(
experiment_name=f"session_{session.name}_{session.id}"

new_chat_message: RagStudioChatMessage = _run_chat(
session, response_id, query, query_configuration, user_name
)
# mlflow.set_experiment_tag("session_id", session.id)
with mlflow.start_run(
experiment_id=experiment.experiment_id, run_name=f"{response_id}"
):
new_chat_message: RagStudioChatMessage = _run_chat(
session, response_id, query, query_configuration, user_name
)

ChatHistoryManager().append_to_history(session_id, [new_chat_message])
return new_chat_message


@mlflow.trace(name="v2_chat")
def _run_chat(
session: Session,
response_id: str,
query: str,
query_configuration: QueryConfiguration,
user_name: str,
) -> RagStudioChatMessage:
asyncio.run(log_ml_flow_params(session, query_configuration, user_name))
mlflow.set_tag("response_id", response_id)

if len(session.data_source_ids) != 1:
raise HTTPException(
status_code=400, detail="Only one datasource is supported for chat."
Expand Down Expand Up @@ -150,66 +137,11 @@ def _run_chat(
timestamp=time.time(),
condensed_question=condensed_question,
)
log_ml_flow_metrics(session, new_chat_message)
return new_chat_message


def log_ml_flow_metrics(session: Session, message: RagStudioChatMessage) -> None:
source_nodes: list[RagPredictSourceNode] = message.source_nodes
query = message.rag_message.user
response = message.rag_message.assistant
for evaluation in message.evaluations:
mlflow.log_metric(evaluation.name, evaluation.value)

mlflow.log_metrics(
{
"source_nodes_count": len(source_nodes),
"max_score": (source_nodes[0].score if source_nodes else 0.0),
"input_word_count": len(re.findall(r"\w+", query)),
"output_word_count": len(re.findall(r"\w+", response)),
}
)

flattened_nodes = [node.model_dump() for node in source_nodes]
mlflow.log_table(
{
"response_id": message.id,
"node_id": map(lambda x: x.get("node_id"), flattened_nodes),
"doc_id": map(lambda x: x.get("doc_id"), flattened_nodes),
"source_file_name": map(
lambda x: x.get("source_file_name"), flattened_nodes
),
"score": map(lambda x: x.get("score"), flattened_nodes),
"query": query,
"response": response,
"condensed_question": message.condensed_question,
},
artifact_file="response_details.json",
)


async def log_ml_flow_params(
session: Session, query_configuration: QueryConfiguration, user_name: str
) -> None:
data_source_metadata = data_sources_metadata_api.get_metadata(session.data_source_ids[0])
mlflow.log_params(
{
"top_k": query_configuration.top_k,
"inference_model": query_configuration.model_name,
"rerank_model_name": query_configuration.rerank_model_name,
"exclude_knowledge_base": query_configuration.exclude_knowledge_base,
"use_question_condensing": query_configuration.use_question_condensing,
"use_hyde": query_configuration.use_hyde,
"use_summary_filter": query_configuration.use_summary_filter,
"session_id": session.id,
"data_source_ids": session.data_source_ids,
"user_name": user_name,
"embedding_model": data_source_metadata.embedding_model,
"chunk_size": data_source_metadata.chunk_size,
"summarization_model": data_source_metadata.summarization_model,
"chunk_overlap_percent": data_source_metadata.chunk_overlap_percent,
}
record_rag_mlflow_run(
new_chat_message, query_configuration, response_id, session, user_name
)
return new_chat_message


def retrieve_chat_history(session_id: int) -> List[RagContext]:
Expand Down Expand Up @@ -328,39 +260,23 @@ def direct_llm_chat(
session_id: int, query: str, user_name: str
) -> RagStudioChatMessage:
session = session_metadata_api.get_session(session_id)
experiment = mlflow.set_experiment(
experiment_name=f"session_{session.name}_{session.id}"
)
response_id = str(uuid.uuid4())
with mlflow.start_run(
experiment_id=experiment.experiment_id, run_name=f"{response_id}"
):
mlflow.set_tag("response_id", response_id)
mlflow.set_tag("direct_llm", True)
mlflow.log_params(
{
"inference_model": session.inference_model,
"exclude_knowledge_base": True,
"session_id": session.id,
"data_source_ids": session.data_source_ids,
"user_name": user_name,
}
)
record_direct_llm_mlflow_run(response_id, session, user_name)

chat_response = llm_completion.completion(
session_id, query, session.inference_model
)
new_chat_message = RagStudioChatMessage(
id=response_id,
source_nodes=[],
inference_model=session.inference_model,
evaluations=[],
rag_message=RagMessage(
user=query,
assistant=str(chat_response.message.content),
),
timestamp=time.time(),
condensed_question=None,
)
ChatHistoryManager().append_to_history(session_id, [new_chat_message])
return new_chat_message
chat_response = llm_completion.completion(
session_id, query, session.inference_model
)
new_chat_message = RagStudioChatMessage(
id=response_id,
source_nodes=[],
inference_model=session.inference_model,
evaluations=[],
rag_message=RagMessage(
user=query,
assistant=str(chat_response.message.content),
),
timestamp=time.time(),
condensed_question=None,
)
ChatHistoryManager().append_to_history(session_id, [new_chat_message])
return new_chat_message
Loading