diff --git a/docs/docs/examples/node_parsers/topic_parser.ipynb b/docs/docs/examples/node_parsers/topic_parser.ipynb new file mode 100644 index 0000000000000..32eb8d19adbd8 --- /dev/null +++ b/docs/docs/examples/node_parsers/topic_parser.ipynb @@ -0,0 +1,394 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "dd006f66", + "metadata": {}, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "id": "d617ade9-796f-431f-86ff-6b865e0eb007", + "metadata": {}, + "source": [ + "# TopicNodeParser\n", + "\n", + "[MedGraphRAG](https://arxiv.org/html/2408.04187) aims to improve the capabilities of LLMs in the medical domain by generating evidence-based results through a novel graph-based Retrieval-Augmented Generation framework, improving safety and reliability in handling private medical data.\n", + "\n", + "`TopicNodeParser` implements an approximate version of the chunking technique described in the paper.\n", + "\n", + "Here is the technique as outlined in the paper:\n", + "\n", + "```\n", + "Large medical documents often contain multiple themes or diverse content. To process these effectively, we first segment the document into data chunks that conform to the context limitations of Large Language Models (LLMs). Traditional methods such as chunking based on token size or fixed characters typically fail to detect subtle shifts in topics accurately. Consequently, these chunks may not fully capture the intended context, leading to a loss in the richness of meaning.\n", + "\n", + "To enhance accuracy, we adopt a mixed method of character separation coupled with topic-based segmentation. Specifically, we utilize static characters (line break symbols) to isolate individual paragraphs within the document. Following this, we apply a derived form of the text for semantic chunking. Our approach includes the use of proposition transfer, which extracts standalone statements from a raw text Chen et al. (2023). Through proposition transfer, each paragraph is transformed into self-sustaining statements. We then conduct a sequential analysis of the document to assess each proposition, deciding whether it should merge with an existing chunk or initiate a new one. This decision is made via a zero-shot approach by an LLM. To reduce noise generated by sequential processing, we implement a sliding window technique, managing five paragraphs at a time. We continuously adjust the window by removing the first paragraph and adding the next, maintaining focus on topic consistency. We set a hard threshold that the longest chunk cannot excess the context length limitation of LLM. After chunking the document, we construct graph on each individual of the data chunk.\n", + "```\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3d1c5118", + "metadata": {}, + "outputs": [], + "source": [ + "%pip install llama-index llama-index-node-parser-topic" + ] + }, + { + "cell_type": "markdown", + "id": "12dcc784-f2c6-4c37-8771-57a921ff2eab", + "metadata": {}, + "source": [ + "## Setup Data\n", + "\n", + "Here we consider a sample text.\n", + "\n", + "Note: The propositions were created by an LLM, which might lead to longer processing times when creating nodes. Exercise caution while experimenting." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7fdcd874", + "metadata": {}, + "outputs": [], + "source": [ + "text = \"\"\"In this paper, we introduce a novel graph RAG method for applying LLMs to the medical domain, which we refer to as Medical Graph RAG (MedRAG). This technique improves LLM performance in the medical domain by response queries with grounded source citations and clear interpretations of medical terminology, boosting the transparency and interpretability of the results. This approach involves a three-tier hierarchical graph construction method. Initially, we use documents provided by users as our top-level source to extract entities. These entities are then linked to a second level consisting of more basic entities previously abstracted from credible medical books and papers. Subsequently, these entities are connected to a third level—the fundamental medical dictionary graph—that provides detailed explanations of each medical term and their semantic relationships. We then construct a comprehensive graph at the highest level by linking entities based on their content and hierarchical connections. This method ensures that the knowledge can be traced back to its sources and the results are factually accurate.\n", + "\n", + "To respond to user queries, we implement a U-retrieve strategy that combines top-down retrieval with bottom-up response generation. The process begins by structuring the query using predefined medical tags and indexing them through the graphs in a top-down manner. The system then generates responses based on these queries, pulling from meta-graphs—nodes retrieved along with their TopK related nodes and relationships—and summarizing the information into a detailed response. This technique maintains a balance between global context awareness and the contextual limitations inherent in LLMs.\n", + "\n", + "Our medical graph RAG provides Intrinsic source citation can enhance LLM transparency, interpretability, and verifiability. The results provides the provenance, or source grounding information, as it generates each response, and demonstrates that an answer is grounded in the dataset. Having the cited source for each assertion readily available also enables a human user to quickly and accurately audit the LLM’s output directly against the original source material. It is super useful in the field of medicine that security is very important, and each of the reasoning should be evidence-based. By using such a method, we construct an evidence-based Medical LLM that the clinician could easiely check the source of the reasoning and calibrate the model response to ensure the safty usage of llm in the clinical senarios.\n", + "\n", + "To evaluate our medical graph RAG, we implemented the method on several popular open and closed-source LLMs, including ChatGPT OpenAI (2023a) and LLaMA Touvron et al. (2023), testing them across mainstream medical Q&A benchmarks such as PubMedQA Jin et al. (2019), MedMCQA Pal et al. (2022), and USMLE Kung et al. (2023). For the RAG process, we supplied a comprehensive medical dictionary as the foundational knowledge layer, the UMLS medical knowledge graph Lindberg et al. (1993) as the foundamental layer detailing semantic relationships, and a curated MedC-K dataset Wu et al. (2023) —comprising the latest medical papers and books—as the intermediate level of data to simulate user-provided private data. Our experiments demonstrate that our model significantly enhances the performance of general-purpose LLMs on medical questions. Remarkably, it even surpasses many fine-tuned or specially trained LLMs on medical corpora, solely using the RAG approach without additional training.\n", + "\"\"\"\n", + "\n", + "from llama_index.core import Document\n", + "\n", + "documents = [Document(text=text)]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "717bd52c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "In this paper, we introduce a novel graph RAG method for applying LLMs to the medical domain, which we refer to as Medical Graph RAG (MedRAG). This technique improves LLM performance in the medical domain by response queries with grounded source citations and clear interpretations of medical terminology, boosting the transparency and interpretability of the results. This approach involves a three-tier hierarchical graph construction method. Initially, we use documents provided by users as our top-level source to extract entities. These entities are then linked to a second level consisting of more basic entities previously abstracted from credible medical books and papers. Subsequently, these entities are connected to a third level—the fundamental medical dictionary graph—that provides detailed explanations of each medical term and their semantic relationships. We then construct a comprehensive graph at the highest level by linking entities based on their content and hierarchical connections. This method ensures that the knowledge can be traced back to its sources and the results are factually accurate.\n", + "\n", + "To respond to user queries, we implement a U-retrieve strategy that combines top-down retrieval with bottom-up response generation. The process begins by structuring the query using predefined medical tags and indexing them through the graphs in a top-down manner. The system then generates responses based on these queries, pulling from meta-graphs—nodes retrieved along with their TopK related nodes and relationships—and summarizing the information into a detailed response. This technique maintains a balance between global context awareness and the contextual limitations inherent in LLMs.\n", + "\n", + "Our medical graph RAG provides Intrinsic source citation can enhance LLM transparency, interpretability, and verifiability. The results provides the provenance, or source grounding information, as it generates each response, and demonstrates that an answer is grounded in the dataset. Having the cited source for each assertion readily available also enables a human user to quickly and accurately audit the LLM’s output directly against the original source material. It is super useful in the field of medicine that security is very important, and each of the reasoning should be evidence-based. By using such a method, we construct an evidence-based Medical LLM that the clinician could easiely check the source of the reasoning and calibrate the model response to ensure the safty usage of llm in the clinical senarios.\n", + "\n", + "To evaluate our medical graph RAG, we implemented the method on several popular open and closed-source LLMs, including ChatGPT OpenAI (2023a) and LLaMA Touvron et al. (2023), testing them across mainstream medical Q&A benchmarks such as PubMedQA Jin et al. (2019), MedMCQA Pal et al. (2022), and USMLE Kung et al. (2023). For the RAG process, we supplied a comprehensive medical dictionary as the foundational knowledge layer, the UMLS medical knowledge graph Lindberg et al. (1993) as the foundamental layer detailing semantic relationships, and a curated MedC-K dataset Wu et al. (2023) —comprising the latest medical papers and books—as the intermediate level of data to simulate user-provided private data. Our experiments demonstrate that our model significantly enhances the performance of general-purpose LLMs on medical questions. Remarkably, it even surpasses many fine-tuned or specially trained LLMs on medical corpora, solely using the RAG approach without additional training.\n", + "\n" + ] + } + ], + "source": [ + "print(documents[0].get_content())" + ] + }, + { + "cell_type": "markdown", + "id": "3b7ac8b7", + "metadata": {}, + "source": [ + "## Setup LLM And Embedding Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "886c7682", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "os.environ[\"OPENAI_API_KEY\"] = \"sk-...\" # Replace with your OpenAI API key" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3b082912", + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.embeddings.openai import OpenAIEmbedding\n", + "from llama_index.llms.openai import OpenAI\n", + "\n", + "embed_model = OpenAIEmbedding()\n", + "llm = OpenAI(model=\"gpt-4o-mini\")" + ] + }, + { + "cell_type": "markdown", + "id": "dd21470a-a6b4-43aa-94d6-503860706404", + "metadata": {}, + "source": [ + "## Define TopicNodeParser" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b204c43c-f98a-47fb-b84c-1ed5e07c7f4a", + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.node_parser.topic import TopicNodeParser" + ] + }, + { + "cell_type": "markdown", + "id": "7b4959ff", + "metadata": {}, + "source": [ + "### LLM based topic similarity." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bf398967-74b0-4fe1-a6ee-2da246d33757", + "metadata": {}, + "outputs": [], + "source": [ + "node_parser = TopicNodeParser.from_defaults(\n", + " llm=llm,\n", + " max_chunk_size=1000,\n", + " similarity_method=\"llm\", # can be \"llm\" or \"embedding\"\n", + " window_size=2, # paper suggests window_size=5\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b9cc498a-9d75-4a87-b8a8-fcc995872a4b", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "efb00dbe0b894c97bd5d33b86c2d6d45", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Parsing nodes: 0%| | 0/1 [00:00 str: + return "TopicNodeParser" + + @classmethod + def from_defaults( + cls, + callback_manager: Optional[CallbackManager] = None, + id_func: Optional[Callable[[int, Document], str]] = None, + tokenizer: Optional[Callable] = None, + max_chunk_size: int = 1000, + window_size: int = 5, + llm: Optional[LLM] = None, + embed_model: Optional[BaseEmbedding] = None, + similarity_method: str = "llm", + similarity_threshold: float = 0.8, + ) -> "TopicNodeParser": + """Initialize with parameters.""" + from llama_index.core import Settings + + callback_manager = callback_manager or CallbackManager([]) + id_func = id_func or default_id_func + tokenizer = tokenizer or get_tokenizer() + llm = llm or Settings.llm + embed_model = embed_model or Settings.embed_model + + return cls( + callback_manager=callback_manager, + id_func=id_func, + tokenizer=tokenizer, + max_chunk_size=max_chunk_size, + window_size=window_size, + llm=llm, + embed_model=embed_model, + similarity_threshold=similarity_threshold, + similarity_method=similarity_method, + ) + + def _parse_nodes( + self, + nodes: Sequence[BaseNode], + show_progress: bool = False, + **kwargs: Any, + ) -> List[BaseNode]: + """Parse document into nodes.""" + all_nodes: List[BaseNode] = [] + nodes_with_progress = get_tqdm_iterable(nodes, show_progress, "Parsing nodes") + + for node in nodes_with_progress: + nodes = self.build_topic_based_nodes_from_documents([node]) + all_nodes.extend(nodes) + + return all_nodes + + def split_into_paragraphs(self, text: str) -> List[str]: + """Split the document into paragraphs based on line breaks.""" + return re.split(r"\n\s*\n", text) + + def proposition_transfer(self, paragraph: str) -> List[str]: + """ + Convert a paragraph into a list of self-sustaining statements using LLM. + """ + messages = [ + ChatMessage(role="system", content=PROPOSITION_SYSTEM_PROMPT), + ChatMessage(role="user", content=paragraph), + ] + + response = str(self.llm.chat(messages)) + + json_start = response.find("[") + json_end = response.rfind("]") + 1 + if json_start != -1 and json_end != -1: + json_content = response[json_start:json_end] + # Parse the JSON response + try: + return json.loads(json_content) + except json.JSONDecodeError: + print(f"Failed to parse JSON: {json_content}") + return [] + else: + print(f"No valid JSON found in the response: {response}") + return [] + + def is_same_topic_llm(self, current_chunk: List[str], new_proposition: str) -> bool: + """ + Use zero-shot classification with LLM to determine if the new proposition belongs to the same topic. + """ + current_text = " ".join(current_chunk) + messages = [ + ChatMessage(role="system", content=TOPIC_CLASSIFICATION_SYSTEM_PROMPT), + ChatMessage( + role="user", + content=f"Text 1: {current_text}\n\nText 2: {new_proposition}", + ), + ] + + response = self.llm.chat(messages) + + return "same topic" in str(response).lower() + + def is_same_topic_embedding( + self, current_chunk: List[str], new_proposition: str + ) -> bool: + """ + Use embedding-based similarity to determine if the new proposition belongs to the same topic. + """ + current_text = " ".join(current_chunk) + current_text_embedding = self.embed_model.get_text_embedding(current_text) + new_proposition_embedding = self.embed_model.get_text_embedding(new_proposition) + + similarity_score = similarity(current_text_embedding, new_proposition_embedding) + return similarity_score > self.similarity_threshold + + def semantic_chunking(self, paragraphs: List[str]) -> List[str]: + """ + Perform semantic chunking on the given paragraphs. + max_chunk_size: It is based on hard threshold of 1000 characters. + As per paper the hard threshold that the longest chunk cannot excess the context length limitation of LLM. + Here we are using 1000 tokens as the threshold. + """ + chunks: List[str] = [] + current_chunk: List[str] = [] + current_chunk_size: int = 0 + half_window = self.window_size // 2 + # Cache for storing propositions + proposition_cache: Dict[int, List[str]] = {} + + for i in range(len(paragraphs)): + # Define the window range + start_idx = max(0, i - half_window) + end_idx = min(len(paragraphs), i + half_window + 1) + + # Generate and cache propositions for paragraphs in the window + window_propositions = [] + for j in range(start_idx, end_idx): + if j not in proposition_cache: + proposition_cache[j] = self.proposition_transfer(paragraphs[j]) + window_propositions.extend(proposition_cache[j]) + + for prop in window_propositions: + if current_chunk: + if self.similarity_method == "llm": + is_same_topic = self.is_same_topic_llm(current_chunk, prop) + elif self.similarity_method == "embedding": + is_same_topic = self.is_same_topic_embedding( + current_chunk, prop + ) + else: + raise ValueError( + "Invalid similarity method. Choose 'llm' or 'embedding'." + ) + else: + is_same_topic = True + + if not current_chunk or ( + is_same_topic + and current_chunk_size + len(self.tokenizer(prop)) + <= self.max_chunk_size + ): + current_chunk.append(prop) + current_chunk_size += len(prop) + else: + chunks.append(" ".join(current_chunk)) + current_chunk = [prop] + current_chunk_size = len(self.tokenizer(prop)) + + # If we've reached the max chunk size, start a new chunk + if current_chunk_size >= self.max_chunk_size: + chunks.append(" ".join(current_chunk)) + current_chunk = [] + current_chunk_size = 0 + + if current_chunk: + chunks.append(" ".join(current_chunk)) + + return chunks + + def build_topic_based_nodes_from_documents( + self, documents: Sequence[Document] + ) -> List[BaseNode]: + """Build topic based nodes from documents.""" + all_nodes: List[BaseNode] = [] + for doc in documents: + paragraphs = self.split_into_paragraphs(doc.text) + chunks = self.semantic_chunking(paragraphs) + nodes = build_nodes_from_splits( + chunks, + doc, + id_func=self.id_func, + ) + all_nodes.extend(nodes) + + return all_nodes diff --git a/llama-index-integrations/node_parser/llama-index-node-parser-topic/pyproject.toml b/llama-index-integrations/node_parser/llama-index-node-parser-topic/pyproject.toml new file mode 100644 index 0000000000000..838124f51060a --- /dev/null +++ b/llama-index-integrations/node_parser/llama-index-node-parser-topic/pyproject.toml @@ -0,0 +1,66 @@ +[build-system] +build-backend = "poetry.core.masonry.api" +requires = ["poetry-core"] + +[tool.codespell] +check-filenames = true +check-hidden = true +skip = "*.csv,*.html,*.json,*.jsonl,*.pdf,*.txt,*.ipynb" + +[tool.llamahub] +contains_example = false +import_path = "llama_index.node_parser.topic" + +[tool.llamahub.class_authors] +TopicNodeParser = "llama-index" + +[tool.mypy] +disallow_untyped_defs = true +# Remove venv skip when integrated with pre-commit +exclude = ["_static", "build", "examples", "notebooks", "venv"] +explicit_package_bases = true +ignore_missing_imports = true +namespace_packages = true +plugins = "pydantic.mypy" +python_version = "3.8" + +[tool.poetry] +authors = ["llama-index"] +description = "llama-index node_parser topic node parser integration" +exclude = ["**/BUILD"] +license = "MIT" +name = "llama-index-node-parser-topic" +readme = "README.md" +version = "0.1.0" + +[tool.poetry.dependencies] +python = ">=3.8.1,<4.0" +llama-index-core = "^0.11.0" + +[tool.poetry.group.dev.dependencies] +ipython = "8.10.0" +jupyter = "^1.0.0" +mypy = "0.991" +pre-commit = "3.2.0" +pylint = "2.15.10" +pytest = "7.2.1" +pytest-mock = "3.11.1" +ruff = "0.0.292" +tree-sitter-languages = "^1.8.0" +types-Deprecated = ">=0.1.0" +types-PyYAML = "^6.0.12.12" +types-protobuf = "^4.24.0.4" +types-redis = "4.5.5.0" +types-requests = "2.28.11.8" +types-setuptools = "67.1.0.0" + +[tool.poetry.group.dev.dependencies.black] +extras = ["jupyter"] +version = "<=23.9.1,>=23.7.0" + +[tool.poetry.group.dev.dependencies.codespell] +extras = ["toml"] +version = ">=v2.2.6" + +[[tool.poetry.packages]] +include = "llama_index/" diff --git a/llama-index-integrations/node_parser/llama-index-node-parser-topic/tests/BUILD b/llama-index-integrations/node_parser/llama-index-node-parser-topic/tests/BUILD new file mode 100644 index 0000000000000..dabf212d7e716 --- /dev/null +++ b/llama-index-integrations/node_parser/llama-index-node-parser-topic/tests/BUILD @@ -0,0 +1 @@ +python_tests() diff --git a/llama-index-integrations/node_parser/llama-index-node-parser-topic/tests/__init__.py b/llama-index-integrations/node_parser/llama-index-node-parser-topic/tests/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/llama-index-integrations/node_parser/llama-index-node-parser-topic/tests/test_node_parser_topic.py b/llama-index-integrations/node_parser/llama-index-node-parser-topic/tests/test_node_parser_topic.py new file mode 100644 index 0000000000000..a9de114a70260 --- /dev/null +++ b/llama-index-integrations/node_parser/llama-index-node-parser-topic/tests/test_node_parser_topic.py @@ -0,0 +1,27 @@ +from llama_index.core import Document, MockEmbedding +from llama_index.core.llms import MockLLM +from llama_index.node_parser.topic import TopicNodeParser + + +def test_llm_chunking(): + llm = MockLLM() + embed_model = MockEmbedding(embed_dim=3) + node_parser = TopicNodeParser.from_defaults( + llm=llm, embed_model=embed_model, similarity_method="llm" + ) + + nodes = node_parser([Document(text="Hello world!"), Document(text="Hello world!")]) + print(nodes) + assert len(nodes) == 4 + + +def test_embedding_chunking(): + llm = MockLLM() + embed_model = MockEmbedding(embed_dim=3) + node_parser = TopicNodeParser.from_defaults( + llm=llm, embed_model=embed_model, similarity_method="embedding" + ) + + nodes = node_parser([Document(text="Hello world!"), Document(text="Hello world!")]) + print(nodes) + assert len(nodes) == 4