Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Vector Search [qa]; make vectordb search configurable #24

Merged
merged 9 commits into from
Dec 22, 2023
Merged
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ ignore_missing_imports = True
[mypy-gradio.*]
ignore_missing_imports = True

[mypy-gradio.langchain.*]
ignore_missing_imports = True

;; gradio is not PEP 561 compliant (no py.typed) yet
[mypy-team_red.frontends.*]
disallow_any_unimported = False
6 changes: 5 additions & 1 deletion team_red/backends/bridge.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import logging
from typing import Dict, Optional
from typing import Dict, List, Optional

from team_red.config import CONFIG
from team_red.gen import GenerationService
from team_red.qa import QAService

from ..transport import (
DocumentSource,
GenResponse,
PromptConfig,
QAAnswer,
Expand Down Expand Up @@ -39,6 +40,9 @@ def gen(self) -> GenerationService:
def qa_query(self, question: QAQuestion) -> QAAnswer:
return self.qa.query(question)

def db_query(self, question: QAQuestion) -> List[DocumentSource]:
return self.qa.db_query(question)

def add_file(self, file: QAFileUpload) -> QAAnswer:
return self.qa.add_file(file)

Expand Down
65 changes: 53 additions & 12 deletions team_red/frontends/qa_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,28 @@
logging.basicConfig(level=logging.DEBUG)


def query(question: str) -> str:
res = TRANSPORTER.qa_query(QAQuestion(question=question))
if res.status != 200:
msg = f"Query was unsuccessful: {res.error_msg} (Error Code {res.status})"
def query(question: str, search_type: str, k_source: int, search_strategy: str) -> str:
q = QAQuestion(
question=question, search_strategy=search_strategy, max_sources=k_source
)
if search_type == "LLM":
qa_res = TRANSPORTER.qa_query(q)
if qa_res.status != 200:
msg = (
f"Query was unsuccessful: {qa_res.error_msg}"
f" (Error Code {qa_res.status})"
)
raise gr.Error(msg)
return qa_res.answer
db_res = TRANSPORTER.db_query(q)
if not db_res:
msg = f"Database query returned empty!"
raise gr.Error(msg)
return res.answer
output = ""
for doc in db_res:
output += f"{doc.content}\n"
output += f"({doc.name} / {doc.page})\n----------\n\n"
return output


def upload(file_path: str, progress: Optional[gr.Progress] = None) -> None:
Expand Down Expand Up @@ -61,20 +77,45 @@ def set_prompt(prompt: str, progress: Optional[gr.Progress] = None) -> None:
with demo:
gr.Markdown("# Entlassbriefe QA")
with gr.Row():
file_upload = gr.File(file_count="single", file_types=[".txt"])
with gr.Column():
prompt = gr.TextArea(
value=TRANSPORTER.get_qa_prompt().text, interactive=True, label="Prompt"
with gr.Column(scale=1):
file_upload = gr.File(file_count="single", file_types=[".txt"])
with gr.Column(scale=1):
type_radio = gr.Radio(
choices=["LLM", "VectorDB"],
value="LLM",
label="Suchmodus",
interactive=True,
)
k_slider = gr.Slider(
minimum=1,
maximum=10,
step=1,
value=3,
interactive=True,
label="Quellenanzahl",
)
prompt_submit = gr.Button("Aktualisiere Prompt")
strategy_dropdown = gr.Dropdown(
choices=["similarity", "mmr"],
value="similarity",
interactive=True,
label="Suchmodus",
)
prompt = gr.TextArea(
value=TRANSPORTER.get_qa_prompt().text, interactive=True, label="Prompt"
)
prompt_submit = gr.Button("Aktualisiere Prompt")
inp = gr.Textbox(
label="Stellen Sie eine Frage:", placeholder="Wie heißt der Patient?"
)
out = gr.Textbox(label="Antwort")
file_upload.change(fn=upload, inputs=file_upload, outputs=out)
btn = gr.Button("Frage stellen")
btn.click(fn=query, inputs=inp, outputs=out)
inp.submit(fn=query, inputs=inp, outputs=out)
btn.click(
fn=query, inputs=[inp, type_radio, k_slider, strategy_dropdown], outputs=out
)
inp.submit(
fn=query, inputs=[inp, type_radio, k_slider, strategy_dropdown], outputs=out
)
prompt_submit.click(fn=set_prompt, inputs=prompt, outputs=out)

if __name__ == "__main__":
Expand Down
21 changes: 20 additions & 1 deletion team_red/qa/qa_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import TYPE_CHECKING, Iterable, List, Optional, Protocol

from langchain.chains.retrieval_qa.base import BaseRetrievalQA
from langchain.docstore.document import Document
from langchain.document_loaders import PyPDFLoader, TextLoader
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.prompts import PromptTemplate
Expand All @@ -25,7 +26,6 @@
from team_red.utils import build_retrieval_qa

if TYPE_CHECKING:
from langchain.docstore.document import Document
from langchain.document_loaders.base import BaseLoader


Expand All @@ -46,6 +46,9 @@ def add_texts(
) -> List[str]:
pass

def search(self, query: str, search_type: str, k: int) -> List[Document]:
pass


class QAService:
def __init__(self, config: QAConfig) -> None:
Expand All @@ -69,6 +72,22 @@ def __init__(self, config: QAConfig) -> None:
config.embedding.db_path, self._embeddings
)

def db_query(self, question: QAQuestion) -> List[DocumentSource]:
if not self._vectorstore:
return []
return [
DocumentSource(
content=doc.page_content,
name=doc.metadata.get("source", "unknown"),
page=doc.metadata.get("page", 1),
)
for doc in self._vectorstore.search(
question.question,
search_type=question.search_strategy,
k=question.max_sources,
)
]

def query(self, question: QAQuestion) -> QAAnswer:
if not self._database:
if not self._vectorstore:
Expand Down
5 changes: 5 additions & 0 deletions team_red/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ class GenResponse(BaseModel):

class QAQuestion(BaseModel):
question: str
search_strategy: str = "similarity"
max_sources: int = 3


class DocumentSource(BaseModel):
Expand Down Expand Up @@ -43,6 +45,9 @@ class Transport(Protocol):
def qa_query(self, query: QAQuestion) -> QAAnswer:
pass

def db_query(self, question: QAQuestion) -> List[DocumentSource]:
pass

def add_file(self, file: QAFileUpload) -> QAAnswer:
pass

Expand Down
12 changes: 12 additions & 0 deletions tests/qa/test_qa_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,15 @@ def test_query(qa_service_cajal: QAService) -> None:
res = qa_service_cajal.query(QAQuestion(question="Wer ist der Patient?"))
assert res.status == 200
assert res.answer


def test_db_query(qa_service_cajal: QAService) -> None:
q = QAQuestion(question="Wer ist der Patient?")
res = qa_service_cajal.db_query(q)
assert len(res) == q.max_sources
assert res[0].name == "Cajal.txt"
assert len(res[0].content) <= CONFIG.qa.embedding.chunk_size
q = QAQuestion(question="Wie heißt das Krankenhaus", max_sources=1)
res = qa_service_cajal.db_query(q)
assert len(res) == q.max_sources
assert "Diakonissenkrankenhaus Berlin" in res[0].content