Skip to content

Commit

Permalink
added validation using tonic validate
Browse files Browse the repository at this point in the history
  • Loading branch information
varchanaiyer committed May 20, 2024
1 parent 6258bb1 commit af7bc23
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 3 deletions.
4 changes: 2 additions & 2 deletions api/Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ name = "pypi"
[packages]
stampy-chat = {editable = true, path = "."}
# ---- <flask stuff> ----
flask = "==1.1.2"
flask = ">=2.2.2"
gunicorn = "==20.0.4"
jinja2 = "==2.11.3"
markupsafe = "==1.1.1"
Expand All @@ -25,7 +25,7 @@ requests = "*"
alembic = "*"
sqlalchemy = "*"
mysql-connector-python = "*"
langchain = "*"
langchain = "==0.1.20"
transformers = "*"
langchain-anthropic = "*"

Expand Down
35 changes: 34 additions & 1 deletion api/src/stampy_chat/chat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from typing import Any, Callable, Dict, List

from tonic_validate import BenchmarkItem, ValidateScorer, LLMResponse
from tonic_validate.metrics import AnswerConsistencyMetric, RetrievalPrecisionMetric, AugmentationAccuracyMetric

from langchain.chains import LLMChain, OpenAIModerationChain
from langchain_community.chat_models import ChatOpenAI
from langchain_anthropic import ChatAnthropic
Expand All @@ -17,7 +20,7 @@
from stampy_chat.settings import Settings, MODELS, OPENAI, ANTRHROPIC
from stampy_chat.callbacks import StampyCallbackHandler, BroadcastCallbackHandler, LoggerCallbackHandler
from stampy_chat.followups import StampyChain
from stampy_chat.citations import make_example_selector
from stampy_chat.citations import make_example_selector, get_top_k_blocks

from langsmith import Client

Expand Down Expand Up @@ -322,6 +325,7 @@ def run_query(session_id: str, query: str, history: List[Dict], settings: Settin
:param Callable[[Any], None] callback: an optional callback that will be called at various key parts of the chain
:returns: the result of the chain
"""

callbacks = [LoggerCallbackHandler(session_id=session_id, query=query, history=history)]
if callback:
callbacks += [BroadcastCallbackHandler(callback)]
Expand All @@ -340,9 +344,38 @@ def run_query(session_id: str, query: str, history: List[Dict], settings: Settin
prompt=make_prompt(settings, chat_model, callbacks),
memory=make_memory(settings, history, callbacks)
)

if followups:
chain = chain | StampyChain(callbacks=callbacks)

result = chain.invoke({"query": query, 'history': history}, {'callbacks': []})

#Validate results
contexts=[c['text'] for c in get_top_k_blocks(query, 5)]
rag_response={
"llm_answer": result['text'],
"llm_context_list": contexts
}

benchmark = BenchmarkItem(question=query)

llm_response = LLMResponse(
llm_answer=result['text'],
llm_context_list=contexts,
benchmark_item=benchmark
)

# Score the responses
scorer = ValidateScorer([
AnswerConsistencyMetric(),
AugmentationAccuracyMetric(),
RetrievalPrecisionMetric()
])
run = scorer.score_responses([llm_response])
print(run.overall_scores)



if callback:
callback({'state': 'done'})
callback(None) # make sure the callback handler know that things have ended
Expand Down

0 comments on commit af7bc23

Please sign in to comment.