diff --git a/mypy.ini b/mypy.ini index a1c36cd..24bfd81 100644 --- a/mypy.ini +++ b/mypy.ini @@ -17,6 +17,9 @@ ignore_missing_imports = True [mypy-gradio.*] ignore_missing_imports = True +[mypy-gradio.langchain.*] +ignore_missing_imports = True + ;; gradio is not PEP 561 compliant (no py.typed) yet [mypy-team_red.frontends.*] disallow_any_unimported = False \ No newline at end of file diff --git a/team_red/backends/bridge.py b/team_red/backends/bridge.py index 682cf2a..2ded033 100644 --- a/team_red/backends/bridge.py +++ b/team_red/backends/bridge.py @@ -1,11 +1,12 @@ import logging -from typing import Dict, Optional +from typing import Dict, List, Optional from team_red.config import CONFIG from team_red.gen import GenerationService from team_red.qa import QAService from ..transport import ( + DocumentSource, GenResponse, PromptConfig, QAAnswer, @@ -39,6 +40,9 @@ def gen(self) -> GenerationService: def qa_query(self, question: QAQuestion) -> QAAnswer: return self.qa.query(question) + def db_query(self, question: QAQuestion) -> List[DocumentSource]: + return self.qa.db_query(question) + def add_file(self, file: QAFileUpload) -> QAAnswer: return self.qa.add_file(file) diff --git a/team_red/frontends/qa_frontend.py b/team_red/frontends/qa_frontend.py index ba7db5e..b3bcfda 100644 --- a/team_red/frontends/qa_frontend.py +++ b/team_red/frontends/qa_frontend.py @@ -13,12 +13,28 @@ logging.basicConfig(level=logging.DEBUG) -def query(question: str) -> str: - res = TRANSPORTER.qa_query(QAQuestion(question=question)) - if res.status != 200: - msg = f"Query was unsuccessful: {res.error_msg} (Error Code {res.status})" +def query(question: str, search_type: str, k_source: int, search_strategy: str) -> str: + q = QAQuestion( + question=question, search_strategy=search_strategy, max_sources=k_source + ) + if search_type == "LLM": + qa_res = TRANSPORTER.qa_query(q) + if qa_res.status != 200: + msg = ( + f"Query was unsuccessful: {qa_res.error_msg}" + f" (Error Code {qa_res.status})" + ) + raise gr.Error(msg) + return qa_res.answer + db_res = TRANSPORTER.db_query(q) + if not db_res: + msg = f"Database query returned empty!" raise gr.Error(msg) - return res.answer + output = "" + for doc in db_res: + output += f"{doc.content}\n" + output += f"({doc.name} / {doc.page})\n----------\n\n" + return output def upload(file_path: str, progress: Optional[gr.Progress] = None) -> None: @@ -61,20 +77,45 @@ def set_prompt(prompt: str, progress: Optional[gr.Progress] = None) -> None: with demo: gr.Markdown("# Entlassbriefe QA") with gr.Row(): - file_upload = gr.File(file_count="single", file_types=[".txt"]) - with gr.Column(): - prompt = gr.TextArea( - value=TRANSPORTER.get_qa_prompt().text, interactive=True, label="Prompt" + with gr.Column(scale=1): + file_upload = gr.File(file_count="single", file_types=[".txt"]) + with gr.Column(scale=1): + type_radio = gr.Radio( + choices=["LLM", "VectorDB"], + value="LLM", + label="Suchmodus", + interactive=True, + ) + k_slider = gr.Slider( + minimum=1, + maximum=10, + step=1, + value=3, + interactive=True, + label="Quellenanzahl", ) - prompt_submit = gr.Button("Aktualisiere Prompt") + strategy_dropdown = gr.Dropdown( + choices=["similarity", "mmr"], + value="similarity", + interactive=True, + label="Suchmodus", + ) + prompt = gr.TextArea( + value=TRANSPORTER.get_qa_prompt().text, interactive=True, label="Prompt" + ) + prompt_submit = gr.Button("Aktualisiere Prompt") inp = gr.Textbox( label="Stellen Sie eine Frage:", placeholder="Wie heißt der Patient?" ) out = gr.Textbox(label="Antwort") file_upload.change(fn=upload, inputs=file_upload, outputs=out) btn = gr.Button("Frage stellen") - btn.click(fn=query, inputs=inp, outputs=out) - inp.submit(fn=query, inputs=inp, outputs=out) + btn.click( + fn=query, inputs=[inp, type_radio, k_slider, strategy_dropdown], outputs=out + ) + inp.submit( + fn=query, inputs=[inp, type_radio, k_slider, strategy_dropdown], outputs=out + ) prompt_submit.click(fn=set_prompt, inputs=prompt, outputs=out) if __name__ == "__main__": diff --git a/team_red/qa/qa_service.py b/team_red/qa/qa_service.py index 4a2733b..868432b 100644 --- a/team_red/qa/qa_service.py +++ b/team_red/qa/qa_service.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Iterable, List, Optional, Protocol from langchain.chains.retrieval_qa.base import BaseRetrievalQA +from langchain.docstore.document import Document from langchain.document_loaders import PyPDFLoader, TextLoader from langchain.embeddings import HuggingFaceEmbeddings from langchain.prompts import PromptTemplate @@ -25,7 +26,6 @@ from team_red.utils import build_retrieval_qa if TYPE_CHECKING: - from langchain.docstore.document import Document from langchain.document_loaders.base import BaseLoader @@ -46,6 +46,9 @@ def add_texts( ) -> List[str]: pass + def search(self, query: str, search_type: str, k: int) -> List[Document]: + pass + class QAService: def __init__(self, config: QAConfig) -> None: @@ -69,6 +72,22 @@ def __init__(self, config: QAConfig) -> None: config.embedding.db_path, self._embeddings ) + def db_query(self, question: QAQuestion) -> List[DocumentSource]: + if not self._vectorstore: + return [] + return [ + DocumentSource( + content=doc.page_content, + name=doc.metadata.get("source", "unknown"), + page=doc.metadata.get("page", 1), + ) + for doc in self._vectorstore.search( + question.question, + search_type=question.search_strategy, + k=question.max_sources, + ) + ] + def query(self, question: QAQuestion) -> QAAnswer: if not self._database: if not self._vectorstore: diff --git a/team_red/transport.py b/team_red/transport.py index d454eaf..a698d06 100644 --- a/team_red/transport.py +++ b/team_red/transport.py @@ -14,6 +14,8 @@ class GenResponse(BaseModel): class QAQuestion(BaseModel): question: str + search_strategy: str = "similarity" + max_sources: int = 3 class DocumentSource(BaseModel): @@ -43,6 +45,9 @@ class Transport(Protocol): def qa_query(self, query: QAQuestion) -> QAAnswer: pass + def db_query(self, question: QAQuestion) -> List[DocumentSource]: + pass + def add_file(self, file: QAFileUpload) -> QAAnswer: pass diff --git a/tests/qa/test_qa_service.py b/tests/qa/test_qa_service.py index e7e36b7..7f0f4fb 100644 --- a/tests/qa/test_qa_service.py +++ b/tests/qa/test_qa_service.py @@ -36,3 +36,15 @@ def test_query(qa_service_cajal: QAService) -> None: res = qa_service_cajal.query(QAQuestion(question="Wer ist der Patient?")) assert res.status == 200 assert res.answer + + +def test_db_query(qa_service_cajal: QAService) -> None: + q = QAQuestion(question="Wer ist der Patient?") + res = qa_service_cajal.db_query(q) + assert len(res) == q.max_sources + assert res[0].name == "Cajal.txt" + assert len(res[0].content) <= CONFIG.qa.embedding.chunk_size + q = QAQuestion(question="Wie heißt das Krankenhaus", max_sources=1) + res = qa_service_cajal.db_query(q) + assert len(res) == q.max_sources + assert "Diakonissenkrankenhaus Berlin" in res[0].content