From af7bc232897d09bad00397998b1324495a3342d1 Mon Sep 17 00:00:00 2001 From: varchanaiyer Date: Mon, 20 May 2024 19:17:02 +0800 Subject: [PATCH] added validation using tonic validate --- api/Pipfile | 4 ++-- api/src/stampy_chat/chat.py | 35 ++++++++++++++++++++++++++++++++++- 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/api/Pipfile b/api/Pipfile index afb2857..8eb28d8 100644 --- a/api/Pipfile +++ b/api/Pipfile @@ -6,7 +6,7 @@ name = "pypi" [packages] stampy-chat = {editable = true, path = "."} # ---- ---- -flask = "==1.1.2" +flask = ">=2.2.2" gunicorn = "==20.0.4" jinja2 = "==2.11.3" markupsafe = "==1.1.1" @@ -25,7 +25,7 @@ requests = "*" alembic = "*" sqlalchemy = "*" mysql-connector-python = "*" -langchain = "*" +langchain = "==0.1.20" transformers = "*" langchain-anthropic = "*" diff --git a/api/src/stampy_chat/chat.py b/api/src/stampy_chat/chat.py index b6e42d9..9099ac9 100644 --- a/api/src/stampy_chat/chat.py +++ b/api/src/stampy_chat/chat.py @@ -1,5 +1,8 @@ from typing import Any, Callable, Dict, List +from tonic_validate import BenchmarkItem, ValidateScorer, LLMResponse +from tonic_validate.metrics import AnswerConsistencyMetric, RetrievalPrecisionMetric, AugmentationAccuracyMetric + from langchain.chains import LLMChain, OpenAIModerationChain from langchain_community.chat_models import ChatOpenAI from langchain_anthropic import ChatAnthropic @@ -17,7 +20,7 @@ from stampy_chat.settings import Settings, MODELS, OPENAI, ANTRHROPIC from stampy_chat.callbacks import StampyCallbackHandler, BroadcastCallbackHandler, LoggerCallbackHandler from stampy_chat.followups import StampyChain -from stampy_chat.citations import make_example_selector +from stampy_chat.citations import make_example_selector, get_top_k_blocks from langsmith import Client @@ -322,6 +325,7 @@ def run_query(session_id: str, query: str, history: List[Dict], settings: Settin :param Callable[[Any], None] callback: an optional callback that will be called at various key parts of the chain :returns: the result of the chain """ + callbacks = [LoggerCallbackHandler(session_id=session_id, query=query, history=history)] if callback: callbacks += [BroadcastCallbackHandler(callback)] @@ -340,9 +344,38 @@ def run_query(session_id: str, query: str, history: List[Dict], settings: Settin prompt=make_prompt(settings, chat_model, callbacks), memory=make_memory(settings, history, callbacks) ) + if followups: chain = chain | StampyChain(callbacks=callbacks) + result = chain.invoke({"query": query, 'history': history}, {'callbacks': []}) + + #Validate results + contexts=[c['text'] for c in get_top_k_blocks(query, 5)] + rag_response={ + "llm_answer": result['text'], + "llm_context_list": contexts + } + + benchmark = BenchmarkItem(question=query) + + llm_response = LLMResponse( + llm_answer=result['text'], + llm_context_list=contexts, + benchmark_item=benchmark + ) + + # Score the responses + scorer = ValidateScorer([ + AnswerConsistencyMetric(), + AugmentationAccuracyMetric(), + RetrievalPrecisionMetric() + ]) + run = scorer.score_responses([llm_response]) + print(run.overall_scores) + + + if callback: callback({'state': 'done'}) callback(None) # make sure the callback handler know that things have ended