From d26b09b39b360a0d2b0f304035a11e1597ce3d52 Mon Sep 17 00:00:00 2001 From: Alexander Neumann Date: Wed, 20 Dec 2023 16:05:38 +0100 Subject: [PATCH 1/9] add search_type and max_sources as QAQuestion parameters --- team_red/transport.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/team_red/transport.py b/team_red/transport.py index d454eaf..a05ea25 100644 --- a/team_red/transport.py +++ b/team_red/transport.py @@ -14,6 +14,8 @@ class GenResponse(BaseModel): class QAQuestion(BaseModel): question: str + search_type: str = "similarity" + max_sources: int = 3 class DocumentSource(BaseModel): From dda4de127b9d45d4fbe33c1f49ca436ff2821f78 Mon Sep 17 00:00:00 2001 From: Alexander Neumann Date: Wed, 20 Dec 2023 16:06:10 +0100 Subject: [PATCH 2/9] implement db_query for qna --- team_red/backends/bridge.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/team_red/backends/bridge.py b/team_red/backends/bridge.py index 682cf2a..2ded033 100644 --- a/team_red/backends/bridge.py +++ b/team_red/backends/bridge.py @@ -1,11 +1,12 @@ import logging -from typing import Dict, Optional +from typing import Dict, List, Optional from team_red.config import CONFIG from team_red.gen import GenerationService from team_red.qa import QAService from ..transport import ( + DocumentSource, GenResponse, PromptConfig, QAAnswer, @@ -39,6 +40,9 @@ def gen(self) -> GenerationService: def qa_query(self, question: QAQuestion) -> QAAnswer: return self.qa.query(question) + def db_query(self, question: QAQuestion) -> List[DocumentSource]: + return self.qa.db_query(question) + def add_file(self, file: QAFileUpload) -> QAAnswer: return self.qa.add_file(file) From dd2d451228cd3d8eb738f8d53d59346d42a3acd7 Mon Sep 17 00:00:00 2001 From: Alexander Neumann Date: Wed, 20 Dec 2023 16:06:36 +0100 Subject: [PATCH 3/9] implement db_query [qa] --- team_red/qa/qa_service.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/team_red/qa/qa_service.py b/team_red/qa/qa_service.py index 4a2733b..465f1b8 100644 --- a/team_red/qa/qa_service.py +++ b/team_red/qa/qa_service.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Iterable, List, Optional, Protocol from langchain.chains.retrieval_qa.base import BaseRetrievalQA +from langchain.docstore.document import Document from langchain.document_loaders import PyPDFLoader, TextLoader from langchain.embeddings import HuggingFaceEmbeddings from langchain.prompts import PromptTemplate @@ -25,7 +26,6 @@ from team_red.utils import build_retrieval_qa if TYPE_CHECKING: - from langchain.docstore.document import Document from langchain.document_loaders.base import BaseLoader @@ -46,6 +46,9 @@ def add_texts( ) -> List[str]: pass + def search(self, query: str, search_type: str, k: int) -> List[Document]: + pass + class QAService: def __init__(self, config: QAConfig) -> None: @@ -69,6 +72,27 @@ def __init__(self, config: QAConfig) -> None: config.embedding.db_path, self._embeddings ) + def db_query(self, question: QAQuestion) -> List[DocumentSource]: + if not self._vectorstore: + return [] + return [ + DocumentSource( + content=doc.page_content, + name=doc.metadata.get("source", "unknown"), + page=doc.metadata.get("page", 1), + ) + for doc in self._vectorstore.search( + question.question, + search_type=question.search_type, + k=question.max_sources, + ) + ] + + @classmethod + def is_lc_serializable(cls) -> bool: + """Return whether this class is serializable.""" + return True + def query(self, question: QAQuestion) -> QAAnswer: if not self._database: if not self._vectorstore: From 6ae0d40e3977fac83fdbd00c66e5f7bdfcd1edb2 Mon Sep 17 00:00:00 2001 From: Alexander Neumann Date: Wed, 20 Dec 2023 16:06:53 +0100 Subject: [PATCH 4/9] add test for db query --- tests/qa/test_qa_service.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/qa/test_qa_service.py b/tests/qa/test_qa_service.py index e7e36b7..7f0f4fb 100644 --- a/tests/qa/test_qa_service.py +++ b/tests/qa/test_qa_service.py @@ -36,3 +36,15 @@ def test_query(qa_service_cajal: QAService) -> None: res = qa_service_cajal.query(QAQuestion(question="Wer ist der Patient?")) assert res.status == 200 assert res.answer + + +def test_db_query(qa_service_cajal: QAService) -> None: + q = QAQuestion(question="Wer ist der Patient?") + res = qa_service_cajal.db_query(q) + assert len(res) == q.max_sources + assert res[0].name == "Cajal.txt" + assert len(res[0].content) <= CONFIG.qa.embedding.chunk_size + q = QAQuestion(question="Wie heißt das Krankenhaus", max_sources=1) + res = qa_service_cajal.db_query(q) + assert len(res) == q.max_sources + assert "Diakonissenkrankenhaus Berlin" in res[0].content From ae8f1666c8df6235e105310ccbaafb0ca4a56de2 Mon Sep 17 00:00:00 2001 From: Alexander Neumann Date: Wed, 20 Dec 2023 16:49:58 +0100 Subject: [PATCH 5/9] rename qa field to search_strategy --- team_red/transport.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/team_red/transport.py b/team_red/transport.py index a05ea25..a698d06 100644 --- a/team_red/transport.py +++ b/team_red/transport.py @@ -14,7 +14,7 @@ class GenResponse(BaseModel): class QAQuestion(BaseModel): question: str - search_type: str = "similarity" + search_strategy: str = "similarity" max_sources: int = 3 @@ -45,6 +45,9 @@ class Transport(Protocol): def qa_query(self, query: QAQuestion) -> QAAnswer: pass + def db_query(self, question: QAQuestion) -> List[DocumentSource]: + pass + def add_file(self, file: QAFileUpload) -> QAAnswer: pass From 03be710f500ac253df044bd15545f578460274d8 Mon Sep 17 00:00:00 2001 From: Alexander Neumann Date: Wed, 20 Dec 2023 16:50:24 +0100 Subject: [PATCH 6/9] extend qa prompt with search mode --- team_red/frontends/qa_frontend.py | 62 +++++++++++++++++++++++++------ 1 file changed, 50 insertions(+), 12 deletions(-) diff --git a/team_red/frontends/qa_frontend.py b/team_red/frontends/qa_frontend.py index ba7db5e..176e42e 100644 --- a/team_red/frontends/qa_frontend.py +++ b/team_red/frontends/qa_frontend.py @@ -13,12 +13,25 @@ logging.basicConfig(level=logging.DEBUG) -def query(question: str) -> str: - res = TRANSPORTER.qa_query(QAQuestion(question=question)) - if res.status != 200: - msg = f"Query was unsuccessful: {res.error_msg} (Error Code {res.status})" +def query(question: str, search_type: str, k_source: int, search_strategy: str) -> str: + q = QAQuestion( + question=question, search_strategy=search_strategy, max_sources=k_source + ) + if search_type == "LLM": + res = TRANSPORTER.qa_query(q) + if res.status != 200: + msg = f"Query was unsuccessful: {res.error_msg} (Error Code {res.status})" + raise gr.Error(msg) + return res.answer + res = TRANSPORTER.db_query(q) + if not res: + msg = f"Database query returned empty!" raise gr.Error(msg) - return res.answer + output = "" + for doc in res: + output += f"{doc.content}\n" + output += f"({doc.name} / {doc.page})\n----------\n\n" + return output def upload(file_path: str, progress: Optional[gr.Progress] = None) -> None: @@ -61,20 +74,45 @@ def set_prompt(prompt: str, progress: Optional[gr.Progress] = None) -> None: with demo: gr.Markdown("# Entlassbriefe QA") with gr.Row(): - file_upload = gr.File(file_count="single", file_types=[".txt"]) - with gr.Column(): - prompt = gr.TextArea( - value=TRANSPORTER.get_qa_prompt().text, interactive=True, label="Prompt" + with gr.Column(scale=1): + file_upload = gr.File(file_count="single", file_types=[".txt"]) + with gr.Column(scale=1): + type_radio = gr.Radio( + choices=["LLM", "VectorDB"], + value="LLM", + label="Suchmodus", + interactive=True, + ) + k_slider = gr.Slider( + minimum=1, + maximum=10, + step=1, + value=3, + interactive=True, + label="Quellenanzahl", + ) + strategy_dropdown = gr.Dropdown( + choices=["similarity", "mmr"], + value="similarity", + interactive=True, + label="Suchmodus", ) - prompt_submit = gr.Button("Aktualisiere Prompt") + prompt = gr.TextArea( + value=TRANSPORTER.get_qa_prompt().text, interactive=True, label="Prompt" + ) + prompt_submit = gr.Button("Aktualisiere Prompt") inp = gr.Textbox( label="Stellen Sie eine Frage:", placeholder="Wie heißt der Patient?" ) out = gr.Textbox(label="Antwort") file_upload.change(fn=upload, inputs=file_upload, outputs=out) btn = gr.Button("Frage stellen") - btn.click(fn=query, inputs=inp, outputs=out) - inp.submit(fn=query, inputs=inp, outputs=out) + btn.click( + fn=query, inputs=[inp, type_radio, k_slider, strategy_dropdown], outputs=out + ) + inp.submit( + fn=query, inputs=[inp, type_radio, k_slider, strategy_dropdown], outputs=out + ) prompt_submit.click(fn=set_prompt, inputs=prompt, outputs=out) if __name__ == "__main__": From b6704b338e036dd0dba78e899a6e733d0f43f64a Mon Sep 17 00:00:00 2001 From: Alexander Neumann Date: Wed, 20 Dec 2023 16:50:56 +0100 Subject: [PATCH 7/9] adjust qarequest param name in qa service --- team_red/qa/qa_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/team_red/qa/qa_service.py b/team_red/qa/qa_service.py index 465f1b8..21d0f45 100644 --- a/team_red/qa/qa_service.py +++ b/team_red/qa/qa_service.py @@ -83,7 +83,7 @@ def db_query(self, question: QAQuestion) -> List[DocumentSource]: ) for doc in self._vectorstore.search( question.question, - search_type=question.search_type, + search_type=question.search_strategy, k=question.max_sources, ) ] From ee1484ddb0c28ecaed2477a2c78f786bb4badb15 Mon Sep 17 00:00:00 2001 From: Alexander Neumann Date: Wed, 20 Dec 2023 16:56:38 +0100 Subject: [PATCH 8/9] fix mypy issues, ignore langchain missing imports --- mypy.ini | 3 +++ team_red/frontends/qa_frontend.py | 17 ++++++++++------- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/mypy.ini b/mypy.ini index a1c36cd..24bfd81 100644 --- a/mypy.ini +++ b/mypy.ini @@ -17,6 +17,9 @@ ignore_missing_imports = True [mypy-gradio.*] ignore_missing_imports = True +[mypy-gradio.langchain.*] +ignore_missing_imports = True + ;; gradio is not PEP 561 compliant (no py.typed) yet [mypy-team_red.frontends.*] disallow_any_unimported = False \ No newline at end of file diff --git a/team_red/frontends/qa_frontend.py b/team_red/frontends/qa_frontend.py index 176e42e..b3bcfda 100644 --- a/team_red/frontends/qa_frontend.py +++ b/team_red/frontends/qa_frontend.py @@ -18,17 +18,20 @@ def query(question: str, search_type: str, k_source: int, search_strategy: str) question=question, search_strategy=search_strategy, max_sources=k_source ) if search_type == "LLM": - res = TRANSPORTER.qa_query(q) - if res.status != 200: - msg = f"Query was unsuccessful: {res.error_msg} (Error Code {res.status})" + qa_res = TRANSPORTER.qa_query(q) + if qa_res.status != 200: + msg = ( + f"Query was unsuccessful: {qa_res.error_msg}" + f" (Error Code {qa_res.status})" + ) raise gr.Error(msg) - return res.answer - res = TRANSPORTER.db_query(q) - if not res: + return qa_res.answer + db_res = TRANSPORTER.db_query(q) + if not db_res: msg = f"Database query returned empty!" raise gr.Error(msg) output = "" - for doc in res: + for doc in db_res: output += f"{doc.content}\n" output += f"({doc.name} / {doc.page})\n----------\n\n" return output From f41896f33ee73639248cf36a6a2f813709a42b37 Mon Sep 17 00:00:00 2001 From: Alexander Neumann Date: Wed, 20 Dec 2023 17:00:47 +0100 Subject: [PATCH 9/9] remove is_lc_serializable got no idea where this came from... --- team_red/qa/qa_service.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/team_red/qa/qa_service.py b/team_red/qa/qa_service.py index 21d0f45..868432b 100644 --- a/team_red/qa/qa_service.py +++ b/team_red/qa/qa_service.py @@ -88,11 +88,6 @@ def db_query(self, question: QAQuestion) -> List[DocumentSource]: ) ] - @classmethod - def is_lc_serializable(cls) -> bool: - """Return whether this class is serializable.""" - return True - def query(self, question: QAQuestion) -> QAAnswer: if not self._database: if not self._vectorstore: