diff --git a/examples/advanced_rag/dynamic_section_retrieval.ipynb b/examples/advanced_rag/dynamic_section_retrieval.ipynb new file mode 100644 index 0000000..0728390 --- /dev/null +++ b/examples/advanced_rag/dynamic_section_retrieval.ipynb @@ -0,0 +1,1161 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "f4b2c37d-3b5a-47aa-95b9-d28e0bc83f77", + "metadata": {}, + "source": [ + "# Dynamic Section Retrieval with LlamaParse\n", + "\n", + "\"Open\n", + "\n", + "This notebook showcases a concept called \"dynamic section retrieval\".\n", + "\n", + "A common problem with naive RAG approaches is that each document is hierarchically organized by section, but standard chunking/retrieval searches for chunks that can be fragments of the entire section and miss out on relevant context.\n", + "\n", + "Dynamic section retrieval takes into account entire contiguous sections as metadata during retrieval, avoiding the problem of retrieving section fragments. \n", + "1. First, tag chunks of a long document with the sections they correspond to, through structured extraction.\n", + "2. Do two-pass retrieval. After initial semantic search, dynamically pull in the entire section through metadata filtering.\n", + "\n", + "![](dynamic_section_retrieval_img.png)\n", + "\n", + "This helps provide a solution to the common chunking problem of retrieving chunks that are only subsets of the entire section you're meant to retrieve." + ] + }, + { + "cell_type": "markdown", + "id": "2e4f707a-c7b5-473f-b4a6-881e2245e82d", + "metadata": {}, + "source": [ + "## Setup\n", + "\n", + "Install core packages and download relevant files. Here we load some popular ICLR 2024 papers." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "71bd0714-324f-48b3-8a93-72c6c3a10b53", + "metadata": {}, + "outputs": [], + "source": [ + "import nest_asyncio\n", + "\n", + "nest_asyncio.apply()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9aa458bc-bc8d-46fe-9a57-021dd8d9e525", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install llama-index\n", + "!pip install llama-index-core\n", + "!pip install llama-parse" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "79821400-caaf-42f1-99d8-74c184c19e29", + "metadata": {}, + "outputs": [], + "source": [ + "# NOTE: uncomment more papers if you want to do research over a larger subset of docs\n", + "\n", + "urls = [\n", + " # \"https://openreview.net/pdf?id=VtmBAGCN7o\",\n", + " # \"https://openreview.net/pdf?id=6PmJoRfdaK\",\n", + " # \"https://openreview.net/pdf?id=LzPWWPAdY4\",\n", + " \"https://openreview.net/pdf?id=VTF8yNQM66\",\n", + " \"https://openreview.net/pdf?id=hSyW5go0v8\",\n", + " # \"https://openreview.net/pdf?id=9WD9KwssyT\",\n", + " # \"https://openreview.net/pdf?id=yV6fD7LYkF\",\n", + " # \"https://openreview.net/pdf?id=hnrB5YHoYu\",\n", + " # \"https://openreview.net/pdf?id=WbWtOYIzIK\",\n", + " \"https://openreview.net/pdf?id=c5pwL0Soay\",\n", + " # \"https://openreview.net/pdf?id=TpD2aG1h0D\",\n", + "]\n", + "\n", + "papers = [\n", + " # \"metagpt.pdf\",\n", + " # \"longlora.pdf\",\n", + " # \"loftq.pdf\",\n", + " \"swebench.pdf\",\n", + " \"selfrag.pdf\",\n", + " # \"zipformer.pdf\",\n", + " # \"values.pdf\",\n", + " # \"finetune_fair_diffusion.pdf\",\n", + " # \"knowledge_card.pdf\",\n", + " \"metra.pdf\",\n", + " # \"vr_mcl.pdf\",\n", + "]\n", + "\n", + "data_dir = \"iclr_docs\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "80137d15-f22b-47eb-adce-ac295ced7e71", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "mkdir: iclr_docs: File exists\n", + "--2024-11-10 16:18:56-- https://openreview.net/pdf?id=VTF8yNQM66\n", + "Resolving openreview.net (openreview.net)... 35.184.86.251\n", + "Connecting to openreview.net (openreview.net)|35.184.86.251|:443... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 2680380 (2.6M) [application/pdf]\n", + "Saving to: ‘iclr_docs/swebench.pdf’\n", + "\n", + "iclr_docs/swebench. 100%[===================>] 2.56M 7.22MB/s in 0.4s \n", + "\n", + "2024-11-10 16:18:57 (7.22 MB/s) - ‘iclr_docs/swebench.pdf’ saved [2680380/2680380]\n", + "\n", + "--2024-11-10 16:18:57-- https://openreview.net/pdf?id=hSyW5go0v8\n", + "Resolving openreview.net (openreview.net)... 35.184.86.251\n", + "Connecting to openreview.net (openreview.net)|35.184.86.251|:443... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 1244749 (1.2M) [application/pdf]\n", + "Saving to: ‘iclr_docs/selfrag.pdf’\n", + "\n", + "iclr_docs/selfrag.p 100%[===================>] 1.19M 4.21MB/s in 0.3s \n", + "\n", + "2024-11-10 16:18:58 (4.21 MB/s) - ‘iclr_docs/selfrag.pdf’ saved [1244749/1244749]\n", + "\n", + "--2024-11-10 16:18:58-- https://openreview.net/pdf?id=c5pwL0Soay\n", + "Resolving openreview.net (openreview.net)... 35.184.86.251\n", + "Connecting to openreview.net (openreview.net)|35.184.86.251|:443... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 4775879 (4.6M) [application/pdf]\n", + "Saving to: ‘iclr_docs/metra.pdf’\n", + "\n", + "iclr_docs/metra.pdf 100%[===================>] 4.55M 4.06MB/s in 1.1s \n", + "\n", + "2024-11-10 16:19:00 (4.06 MB/s) - ‘iclr_docs/metra.pdf’ saved [4775879/4775879]\n", + "\n" + ] + } + ], + "source": [ + "!mkdir \"{data_dir}\"\n", + "for url, paper in zip(urls, papers):\n", + " !wget \"{url}\" -O \"{data_dir}/{paper}\"" + ] + }, + { + "cell_type": "markdown", + "id": "974ce0a5-931a-4c1f-b8f3-af670c08eb0f", + "metadata": {}, + "source": [ + "#### Define LLM and Embedding Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "75a05e99-56e2-4db9-baae-f9401100dcc3", + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.core import Settings\n", + "from llama_index.llms.openai import OpenAI\n", + "from llama_index.embeddings.openai import OpenAIEmbedding\n", + "\n", + "embed_model = OpenAIEmbedding(model=\"text-embedding-3-large\")\n", + "llm = OpenAI(model=\"gpt-4o\")\n", + "\n", + "Settings.embed_model = embed_model\n", + "Settings.llm = llm" + ] + }, + { + "cell_type": "markdown", + "id": "2f16859f-c69e-4edf-acb6-0a5a0784275a", + "metadata": {}, + "source": [ + "#### Parse Documents" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d6cd2cc9-673f-4f53-81fb-cc990950d345", + "metadata": {}, + "outputs": [], + "source": [ + "from llama_parse import LlamaParse\n", + "\n", + "parser = LlamaParse(result_type=\"markdown\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f9d6f0e8-323e-4786-a4a8-e393441ecd61", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Started parsing the file under job_id 827f328d-b72e-4b70-8b4b-47dbba859d69\n", + "Started parsing the file under job_id d3104cd5-731e-4def-bdbc-889e8731989c\n", + "Started parsing the file under job_id 6046274e-e522-46af-9185-3c036e9c3ad6\n" + ] + } + ], + "source": [ + "from pathlib import Path\n", + "\n", + "paper_dicts = {}\n", + "\n", + "for paper_path in papers:\n", + " paper_base = Path(paper_path).stem\n", + " full_paper_path = str(Path(data_dir) / paper_path)\n", + " md_json_objs = parser.get_json_result(full_paper_path)\n", + " json_dicts = md_json_objs[0][\"pages\"]\n", + " paper_dicts[paper_path] = {\n", + " \"paper_path\": full_paper_path,\n", + " \"json_dicts\": json_dicts,\n", + " }" + ] + }, + { + "cell_type": "markdown", + "id": "2d52878b-aabf-418e-a4c7-9903a77dd8c8", + "metadata": {}, + "source": [ + "#### Get Text Nodes\n", + "\n", + "Convert the dictionary above into TextNode objects that we can put into a vector store." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "18c24174-05ce-417f-8dd2-79c3f375db03", + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.core.schema import TextNode\n", + "from typing import Optional" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8e331dfe-a627-4e23-8c57-70ab1d9342e4", + "metadata": {}, + "outputs": [], + "source": [ + "# NOTE: these are utility functions to sort the dumped images by the page number\n", + "# (they are formatted like \"{uuid}-{page_num}.jpg\"\n", + "import re\n", + "\n", + "\n", + "def get_page_number(file_name):\n", + " match = re.search(r\"-page-(\\d+)\\.jpg$\", str(file_name))\n", + " if match:\n", + " return int(match.group(1))\n", + " return 0\n", + "\n", + "\n", + "def _get_sorted_image_files(image_dir):\n", + " \"\"\"Get image files sorted by page.\"\"\"\n", + " raw_files = [f for f in list(Path(image_dir).iterdir()) if f.is_file()]\n", + " sorted_files = sorted(raw_files, key=get_page_number)\n", + " return sorted_files" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "346fe5ef-171e-4a54-9084-7a7805103a13", + "metadata": {}, + "outputs": [], + "source": [ + "from copy import deepcopy\n", + "from pathlib import Path\n", + "\n", + "\n", + "# attach image metadata to the text nodes\n", + "def get_text_nodes(json_dicts, paper_path):\n", + " \"\"\"Split docs into nodes, by separator.\"\"\"\n", + " nodes = []\n", + "\n", + " md_texts = [d[\"md\"] for d in json_dicts]\n", + "\n", + " for idx, md_text in enumerate(md_texts):\n", + " chunk_metadata = {\n", + " \"page_num\": idx + 1,\n", + " \"paper_path\": paper_path,\n", + " }\n", + " node = TextNode(\n", + " text=md_text,\n", + " metadata=chunk_metadata,\n", + " )\n", + " nodes.append(node)\n", + "\n", + " return nodes" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f591669c-5a8e-491d-9cef-0b754abbf26f", + "metadata": {}, + "outputs": [], + "source": [ + "# this will combine all nodes from all papers into a single list\n", + "all_text_nodes = []\n", + "text_nodes_dict = {}\n", + "for paper_path, paper_dict in paper_dicts.items():\n", + " json_dicts = paper_dict[\"json_dicts\"]\n", + " text_nodes = get_text_nodes(json_dicts, paper_dict[\"paper_path\"])\n", + " all_text_nodes.extend(text_nodes)\n", + " text_nodes_dict[paper_path] = text_nodes" + ] + }, + { + "cell_type": "markdown", + "id": "3b25f253-3aa0-4689-be6e-d0c722b8b48c", + "metadata": {}, + "source": [ + "## Add Section Metadata\n", + "\n", + "The first step is to extract out a map of all sections from the text of each document. We create a workflow that extracts out if a section heading exists on each page, and merges it together into a combined list. We then run a reflection step to review/correct the extracted sections to make sure everything is correct.\n", + "\n", + "Once we have a map of all the sections and the page numbers they start at, we can add the appropriate section ID as metadata to each chunk." + ] + }, + { + "cell_type": "markdown", + "id": "d8fdb689-cc94-4da2-ba12-9267e8ee8623", + "metadata": {}, + "source": [ + "#### Define Section Schema to Extract Into\n", + "\n", + "Here we define the output schema which allows us to extract out the section metadata from each section of the document. This will give us a full table of contents of each section." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "66358783-1d7f-489d-a85b-35bcb9620912", + "metadata": {}, + "outputs": [], + "source": [ + "from pydantic import BaseModel, Field\n", + "from typing import List, Optional\n", + "\n", + "\n", + "class SectionOutput(BaseModel):\n", + " \"\"\"The metadata for a given section. Includes the section name, title, page that it starts on, and more.\"\"\"\n", + "\n", + " section_name: str = Field(\n", + " ..., description=\"The current section number (e.g. section_name='3.2')\"\n", + " )\n", + " section_title: str = Field(\n", + " ...,\n", + " description=\"The current section title associated with the number (e.g. section_title='Experimental Results')\",\n", + " )\n", + "\n", + " start_page_number: int = Field(..., description=\"The start page number.\")\n", + " is_subsection: bool = Field(\n", + " ...,\n", + " description=\"True if it's a subsection (e.g. Section 3.2). False if it's not a subsection (e.g. Section 3)\",\n", + " )\n", + " description: Optional[str] = Field(\n", + " None,\n", + " description=\"The extracted line from the source text that indicates this is a relevant section.\",\n", + " )\n", + "\n", + " def get_section_id(self):\n", + " \"\"\"Get section id.\"\"\"\n", + " return f\"{self.section_name}: {self.section_title}\"\n", + "\n", + "\n", + "class SectionsOutput(BaseModel):\n", + " \"\"\"A list of all sections.\"\"\"\n", + "\n", + " sections: List[SectionOutput]\n", + "\n", + "\n", + "class ValidSections(BaseModel):\n", + " \"\"\"A list of indexes, each corresponding to a valid section.\"\"\"\n", + "\n", + " valid_indexes: List[int] = Field(\n", + " \"List of valid section indexes. Do NOT include sections to remove.\"\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "bff90c77-f92e-4c5e-a441-70f81adb68fb", + "metadata": {}, + "source": [ + "#### Extract into Section Outputs\n", + "\n", + "Use LlamaIndex structured output capabilities to iterate through each page and extract out relevant section metadata. Note: some pages may contain no section metadata (there are no sections that begin on that page)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dcfcd3a6-4739-4624-a6ed-678e41119575", + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.llms.openai import OpenAI\n", + "from llama_index.core.prompts import ChatPromptTemplate, ChatMessage\n", + "from llama_index.core.llms import LLM\n", + "from llama_index.core.async_utils import run_jobs, asyncio_run\n", + "import json\n", + "\n", + "\n", + "async def aget_sections(\n", + " doc_text: str, llm: Optional[LLM] = None\n", + ") -> List[SectionOutput]:\n", + " \"\"\"Get extracted sections from a provided text.\"\"\"\n", + "\n", + " system_prompt = \"\"\"\\\n", + " You are an AI document assistant tasked with extracting out section metadata from a document text. \n", + " \n", + "- You should ONLY extract out metadata if the document text contains the beginning of a section.\n", + "- The metadata schema is listed below - you should extract out the section_name, section_title, start page number, description.\n", + "- A valid section MUST begin with a hashtag (#) and have a number (e.g. \"1 Introduction\" or \"Section 1 Introduction\"). \\\n", + "Note: Not all hashtag (#) lines are valid sections. \n", + "\n", + "- You can extract out multiple section metadata if there are multiple sections on the page. \n", + "- If there are no sections that begin in this document text, do NOT extract out any sections. \n", + "- A valid section MUST be clearly delineated in the document text. Do NOT extract out a section if it is mentioned, \\\n", + "but is not actually the start of a section in the document text.\n", + "- A Figure or Table does NOT count as a section.\n", + " \n", + " The user will give the document text below.\n", + " \n", + " \"\"\"\n", + " llm = llm or OpenAI(model=\"gpt-4o\")\n", + "\n", + " chat_template = ChatPromptTemplate(\n", + " [\n", + " ChatMessage.from_str(system_prompt, \"system\"),\n", + " ChatMessage.from_str(\"Document text: {doc_text}\", \"user\"),\n", + " ]\n", + " )\n", + " result = await llm.astructured_predict(\n", + " SectionsOutput, chat_template, doc_text=doc_text\n", + " )\n", + " return result.sections\n", + "\n", + "\n", + "async def arefine_sections(\n", + " sections: List[SectionOutput], llm: Optional[LLM] = None\n", + ") -> List[SectionOutput]:\n", + " \"\"\"Refine sections based on extracted text.\"\"\"\n", + "\n", + " system_prompt = \"\"\"\\\n", + " You are an AI review assistant tasked with reviewing and correcting another agent's work in extracting sections from a document.\n", + "\n", + " Below is the list of sections with indexes. The sections may be incorrect in the following manner:\n", + " - There may be false positive sections - some sections may be wrongly extracted - you can tell by the sequential order of the rest of the sections\n", + " - Some sections may be incorrectly marked as subsections and vice-versa\n", + " - You can use the description which contains extracted text from the source document to see if it actually qualifies as a section.\n", + "\n", + " Given this, return the list of indexes that are valid. Do NOT include the indexes to be removed.\n", + " \n", + " \"\"\"\n", + " llm = llm or OpenAI(model=\"gpt-4o\")\n", + "\n", + " chat_template = ChatPromptTemplate(\n", + " [\n", + " ChatMessage.from_str(system_prompt, \"system\"),\n", + " ChatMessage.from_str(\"Sections in text:\\n\\n{sections}\", \"user\"),\n", + " ]\n", + " )\n", + "\n", + " section_texts = \"\\n\".join(\n", + " [f\"{idx}: {json.dumps(s.dict())}\" for idx, s in enumerate(sections)]\n", + " )\n", + "\n", + " result = await llm.astructured_predict(\n", + " ValidSections, chat_template, sections=section_texts\n", + " )\n", + " valid_indexes = result.valid_indexes\n", + "\n", + " new_sections = [s for idx, s in enumerate(sections) if idx in valid_indexes]\n", + " return new_sections\n", + "\n", + "\n", + "async def acreate_sections(text_nodes_dict):\n", + " sections_dict = {}\n", + " for paper_path, text_nodes in text_nodes_dict.items():\n", + " all_sections = []\n", + "\n", + " tasks = [aget_sections(n.get_content(metadata_mode=\"all\")) for n in text_nodes]\n", + "\n", + " async_results = await run_jobs(tasks, workers=8, show_progress=True)\n", + " all_sections = [s for r in async_results for s in r]\n", + "\n", + " all_sections = await arefine_sections(all_sections)\n", + " sections_dict[paper_path] = all_sections\n", + " return sections_dict" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6e360a5c-29bd-4d86-9a21-f46013bab39a", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████████████████████████████████████████████████████████████████| 51/51 [00:11<00:00, 4.35it/s]\n", + "100%|██████████████████████████████████████████████████████████████████████| 30/30 [00:09<00:00, 3.05it/s]\n", + "100%|██████████████████████████████████████████████████████████████████████| 25/25 [00:07<00:00, 3.22it/s]\n" + ] + } + ], + "source": [ + "sections_dict = asyncio_run(acreate_sections(text_nodes_dict))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d930f0e5-5295-46b0-b54b-e2da4fb25fe5", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[SectionOutput(section_name='1', section_title='INTRODUCTION', start_page_number=1, is_subsection=False, description='# 1 INTRODUCTION'),\n", + " SectionOutput(section_name='2', section_title='BENCHMARK CONSTRUCTION', start_page_number=2, is_subsection=False, description='# BENCHMARK CONSTRUCTION'),\n", + " SectionOutput(section_name='2.2', section_title='TASK FORMULATION', start_page_number=3, is_subsection=True, description='# 2.2 TASK FORMULATION'),\n", + " SectionOutput(section_name='2.3', section_title='FEATURES OF SWE-BENCH', start_page_number=3, is_subsection=True, description='# 2.3 FEATURES OF SWE-BENCH'),\n", + " SectionOutput(section_name='3', section_title='SWE-LLAMA: FINE-TUNING CODELLAMA FOR SWE-BENCH', start_page_number=3, is_subsection=False, description='# 3 SWE-LLAMA: FINE-TUNING CODELLAMA FOR SWE-BENCH'),\n", + " SectionOutput(section_name='4', section_title='EXPERIMENTAL SETUP', start_page_number=4, is_subsection=False, description='# 4 EXPERIMENTAL SETUP'),\n", + " SectionOutput(section_name='4.1', section_title='RETRIEVAL-BASED APPROACH', start_page_number=4, is_subsection=True, description='# 4.1 RETRIEVAL-BASED APPROACH'),\n", + " SectionOutput(section_name='4.2', section_title='INPUT FORMAT', start_page_number=5, is_subsection=True, description='# 4.2 INPUT FORMAT'),\n", + " SectionOutput(section_name='4.3', section_title='MODELS', start_page_number=5, is_subsection=True, description='# 4.3 MODELS'),\n", + " SectionOutput(section_name='5', section_title='RESULTS', start_page_number=5, is_subsection=False, description='# 5 RESULTS'),\n", + " SectionOutput(section_name='5.1', section_title='A QUALITATIVE ANALYSIS OF SWE-LLAMA GENERATIONS', start_page_number=8, is_subsection=True, description='# 5.1 A QUALITATIVE ANALYSIS OF SWE-LLAMA GENERATIONS'),\n", + " SectionOutput(section_name='6', section_title='RELATED WORK', start_page_number=8, is_subsection=False, description='# 6 RELATED WORK'),\n", + " SectionOutput(section_name='7', section_title='DISCUSSION', start_page_number=9, is_subsection=False, description='# 7 DISCUSSION'),\n", + " SectionOutput(section_name='8', section_title='ETHICS STATEMENT', start_page_number=10, is_subsection=False, description='# 8 ETHICS STATEMENT'),\n", + " SectionOutput(section_name='9', section_title='REPRODUCIBILITY STATEMENT', start_page_number=10, is_subsection=False, description='# 9 REPRODUCIBILITY STATEMENT'),\n", + " SectionOutput(section_name='10', section_title='ACKNOWLEDGEMENTS', start_page_number=10, is_subsection=False, description='# 10 ACKNOWLEDGEMENTS'),\n", + " SectionOutput(section_name='A', section_title='BENCHMARK DETAILS', start_page_number=15, is_subsection=False, description='# A BENCHMARK DETAILS'),\n", + " SectionOutput(section_name='A.1', section_title='HIGH LEVEL OVERVIEW', start_page_number=15, is_subsection=True, description='# A.1 HIGH LEVEL OVERVIEW'),\n", + " SectionOutput(section_name='A.2', section_title='CONSTRUCTION PROCESS', start_page_number=16, is_subsection=True, description='# A.2 CONSTRUCTION PROCESS'),\n", + " SectionOutput(section_name='A.3', section_title='Execution-Based Validation', start_page_number=18, is_subsection=True, description='# A.3 EXECUTION-BASED VALIDATION'),\n", + " SectionOutput(section_name='A.5', section_title='Evaluation Test Set Characterization', start_page_number=20, is_subsection=True, description='# A.5 EVALUATION TEST SET CHARACTERIZATION'),\n", + " SectionOutput(section_name='A.6', section_title='DEVELOPMENT SET CHARACTERIZATION', start_page_number=23, is_subsection=True, description='# A.6 DEVELOPMENT SET CHARACTERIZATION'),\n", + " SectionOutput(section_name='B', section_title='ADDITIONAL DETAILS ON TRAINING SWE-LLAMA', start_page_number=24, is_subsection=False, description='# B ADDITIONAL DETAILS ON TRAINING SWE-LLAMA'),\n", + " SectionOutput(section_name='B.1', section_title='TRAINING DETAILS', start_page_number=24, is_subsection=True, description='# B.1 TRAINING DETAILS'),\n", + " SectionOutput(section_name='D', section_title='ADDITIONAL EXPERIMENTAL DETAILS', start_page_number=28, is_subsection=False, description='# D ADDITIONAL EXPERIMENTAL DETAILS'),\n", + " SectionOutput(section_name='D.1', section_title='RETRIEVAL DETAILS', start_page_number=28, is_subsection=True, description='# D.1 RETRIEVAL DETAILS'),\n", + " SectionOutput(section_name='D.2', section_title='INFERENCE SETTINGS', start_page_number=29, is_subsection=True, description='# D.2 INFERENCE SETTINGS'),\n", + " SectionOutput(section_name='D.3', section_title='PROMPT TEMPLATE EXAMPLE', start_page_number=29, is_subsection=True, description='# D.3 PROMPT TEMPLATE EXAMPLE'),\n", + " SectionOutput(section_name='E', section_title='Societal Impact', start_page_number=31, is_subsection=False, description='# E SOCIETAL IMPACT'),\n", + " SectionOutput(section_name='F', section_title='In-Depth Analysis of SWE-Llama Generations', start_page_number=31, is_subsection=False, description='# F IN-DEPTH ANALYSIS OF SWE-LLAMA GENERATIONS')]" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sections_dict[\"swebench.pdf\"]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8c6f237b-df2a-4d9e-91bf-b0bbb88ef183", + "metadata": {}, + "outputs": [], + "source": [ + "# [Optional] SAVE\n", + "import pickle\n", + "\n", + "pickle.dump(sections_dict, open(\"sections_dict.pkl\", \"wb\"))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7497b614-250e-4f3e-8940-b361996a00b6", + "metadata": {}, + "outputs": [], + "source": [ + "# [Optional] LOAD\n", + "sections_dict = pickle.load(open(\"sections_dict.pkl\", \"rb\"))" + ] + }, + { + "cell_type": "markdown", + "id": "28b01141-d9c1-424c-937a-8707867180b1", + "metadata": {}, + "source": [ + "#### Annotate each chunk with the section metadata\n", + "\n", + "In the section above we've extracted out a TOC of all sections/subsections and their page numbers. Given this we can just do one forward pass through all the chunks, and annotate them with the section they correspond to (e.g. the section/subsection with the highest page number less than the page number of the chunk). " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "38133abe-7800-424a-9259-71df6d154d31", + "metadata": {}, + "outputs": [], + "source": [ + "def annotate_chunks_with_sections(chunks, sections):\n", + " main_sections = [s for s in sections if not s.is_subsection]\n", + " # subsections include the main sections too (some sections have no subsections etc.)\n", + " sub_sections = sections\n", + "\n", + " main_section_idx, sub_section_idx = 0, 0\n", + " for idx, c in enumerate(chunks):\n", + " cur_page = c.metadata[\"page_num\"]\n", + " while (\n", + " main_section_idx + 1 < len(main_sections)\n", + " and main_sections[main_section_idx + 1].start_page_number <= cur_page\n", + " ):\n", + " main_section_idx += 1\n", + " while (\n", + " sub_section_idx + 1 < len(sub_sections)\n", + " and sub_sections[sub_section_idx + 1].start_page_number <= cur_page\n", + " ):\n", + " sub_section_idx += 1\n", + "\n", + " cur_main_section = main_sections[main_section_idx]\n", + " cur_sub_section = sub_sections[sub_section_idx]\n", + "\n", + " c.metadata[\"section_id\"] = cur_main_section.get_section_id()\n", + " c.metadata[\"sub_section_id\"] = cur_sub_section.get_section_id()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d125b0c0-acb0-4f56-9ef3-f06d452ae3cd", + "metadata": {}, + "outputs": [], + "source": [ + "for paper_path, text_nodes in text_nodes_dict.items():\n", + " sections = sections_dict[paper_path]\n", + " annotate_chunks_with_sections(text_nodes, sections)" + ] + }, + { + "cell_type": "markdown", + "id": "b1ab80d2-cac4-417d-aaac-7ea9dfed49f7", + "metadata": {}, + "source": [ + "You can choose to save these nodes if you'd like." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2272ae05-89f6-46a9-9b9f-915e15908128", + "metadata": {}, + "outputs": [], + "source": [ + "# SAVE\n", + "import pickle\n", + "\n", + "pickle.dump(text_nodes_dict, open(\"iclr_text_nodes.pkl\", \"wb\"))" + ] + }, + { + "cell_type": "markdown", + "id": "8ebf0173-af45-4fae-aca4-2ceb266f8357", + "metadata": {}, + "source": [ + "**LOAD**: If you've already saved nodes, run the below cell to load from an existing file." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c1e5425b-4872-47b3-86f5-f6a068788a2b", + "metadata": {}, + "outputs": [], + "source": [ + "# LOAD\n", + "import pickle\n", + "\n", + "text_nodes_dict = pickle.load(open(\"iclr_text_nodes.pkl\", \"rb\"))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "642e90f1-1d32-4925-b37d-2af8a0ca9712", + "metadata": {}, + "outputs": [], + "source": [ + "all_text_nodes = []\n", + "for paper_path, text_nodes in text_nodes_dict.items():\n", + " all_text_nodes.extend(text_nodes)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8d7b566b-5ec1-4e49-b4d4-e863af2aabc6", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "106" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(all_text_nodes)" + ] + }, + { + "cell_type": "markdown", + "id": "d03a4de3-39ce-40c3-b37b-b6bbc597ddb1", + "metadata": {}, + "source": [ + "### Build Indexes\n", + "\n", + "Once the text nodes are ready, we feed into our vector store index abstraction, which will index these nodes into a simple in-memory vector store (of course, you should definitely check out our 40+ vector store integrations!)\n", + "\n", + "Besides vector indexing, we **also** store a mapping of paper path to the summary index. This allows us to perform document-level retrieval - retrieve all chunks relevant to a given document." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "add64e3e-12df-4d5a-beba-b3018325e15b", + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.vector_stores.chroma import ChromaVectorStore\n", + "from llama_index.core import VectorStoreIndex\n", + "\n", + "persist_dir = \"storage_chroma\"\n", + "\n", + "vector_store = ChromaVectorStore.from_params(\n", + " collection_name=\"text_nodes\", persist_dir=persist_dir\n", + ")\n", + "index = VectorStoreIndex.from_vector_store(vector_store)" + ] + }, + { + "cell_type": "markdown", + "id": "1e46583a-6c6b-4a5e-b78a-d06721ae7d1c", + "metadata": {}, + "source": [ + "**NOTE**: Don't run the block below if you've already inserted the nodes. Only run if it's your first time!!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9777f302-699a-4417-99b8-2be4e7cd60f5", + "metadata": {}, + "outputs": [], + "source": [ + "index.insert_nodes(all_text_nodes)" + ] + }, + { + "cell_type": "markdown", + "id": "d46f14ff-45b1-41f4-84e2-a6e5d6637809", + "metadata": {}, + "source": [ + "## Setup Dynamic, Section-Level Retrieval\n", + "\n", + "We now setup a retriever that will allow us to retrieve an entire contiguous section in a document, instead of a chunk of it. This is useful for preserving the entire context within a doc.\n", + "\n", + "- Step 1: Do chunk-level retrieval to find the relevant chunks.\n", + "- Step 2: For each chunk, identify the section that it corresponds to.\n", + "- Step 3: Do a second retrieval pass using metadata filters to find the entire contiguous section that matches the chunk, and return that as a continguous node.\n", + "- Step 4: Feed the contiguous sections into the LLM." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "652cb067-da39-42cb-a303-faa346f72e13", + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.llms.openai import OpenAI\n", + "\n", + "llm = OpenAI(model=\"gpt-4o\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "253f0c57-f5b4-4dbd-a0a0-62a42bd5bbdc", + "metadata": {}, + "outputs": [], + "source": [ + "chunk_retriever = index.as_retriever(similarity_top_k=3)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b0a564cb-bfdb-48a5-9d67-10390c3a6c28", + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.core.vector_stores.types import (\n", + " VectorStoreInfo,\n", + " VectorStoreQuerySpec,\n", + " MetadataInfo,\n", + " MetadataFilters,\n", + " FilterCondition,\n", + ")\n", + "from llama_index.core.schema import NodeWithScore\n", + "\n", + "\n", + "def section_retrieve(query: str, verbose: bool = False) -> List[NodeWithScore]:\n", + " \"\"\"Retrieve sections.\"\"\"\n", + " if verbose:\n", + " print(f\">> Identifying the right sections to retrieve\")\n", + " chunk_nodes = chunk_retriever.retrieve(query)\n", + "\n", + " all_section_nodes = {}\n", + " for node in chunk_nodes:\n", + " section_id = node.node.metadata[\"section_id\"]\n", + " if verbose:\n", + " print(f\">> Retrieving section: {section_id}\")\n", + " filters = MetadataFilters.from_dicts(\n", + " [\n", + " {\"key\": \"section_id\", \"value\": section_id, \"operator\": \"==\"},\n", + " {\n", + " \"key\": \"paper_path\",\n", + " \"value\": node.node.metadata[\"paper_path\"],\n", + " \"operator\": \"==\",\n", + " },\n", + " ],\n", + " condition=FilterCondition.AND,\n", + " )\n", + "\n", + " # TODO: make node_ids not positional\n", + " section_nodes_raw = index.vector_store.get_nodes(node_ids=None, filters=filters)\n", + " section_nodes = [NodeWithScore(node=n) for n in section_nodes_raw]\n", + " # order and consolidate nodes\n", + " section_nodes_sorted = sorted(\n", + " section_nodes, key=lambda x: x.metadata[\"page_num\"]\n", + " )\n", + "\n", + " all_section_nodes.update({n.id_: n for n in section_nodes_sorted})\n", + " return all_section_nodes.values()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f721e770-ce4c-4511-96d5-8a89d16c7281", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + ">> Identifying the right sections to retrieve\n", + ">> Retrieving section: A: BENCHMARK DETAILS\n", + ">> Retrieving section: 2: BENCHMARK CONSTRUCTION\n", + ">> Retrieving section: A: BENCHMARK DETAILS\n" + ] + } + ], + "source": [ + "nodes = section_retrieve(\n", + " \"Give me a full overview of the benchmark details in SWE Bench\", verbose=True\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e99eaa71-7d93-40c0-bba0-a9c983a6cbd3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'page_num': 15, 'paper_path': 'iclr_docs/swebench.pdf', 'section_id': 'A: BENCHMARK DETAILS', 'sub_section_id': 'A.1: HIGH LEVEL OVERVIEW'}\n", + "{'page_num': 16, 'paper_path': 'iclr_docs/swebench.pdf', 'section_id': 'A: BENCHMARK DETAILS', 'sub_section_id': 'A.2: CONSTRUCTION PROCESS'}\n", + "{'page_num': 17, 'paper_path': 'iclr_docs/swebench.pdf', 'section_id': 'A: BENCHMARK DETAILS', 'sub_section_id': 'A.2: CONSTRUCTION PROCESS'}\n", + "{'page_num': 18, 'paper_path': 'iclr_docs/swebench.pdf', 'section_id': 'A: BENCHMARK DETAILS', 'sub_section_id': 'A.3: Execution-Based Validation'}\n", + "{'page_num': 19, 'paper_path': 'iclr_docs/swebench.pdf', 'section_id': 'A: BENCHMARK DETAILS', 'sub_section_id': 'A.3: Execution-Based Validation'}\n", + "{'page_num': 20, 'paper_path': 'iclr_docs/swebench.pdf', 'section_id': 'A: BENCHMARK DETAILS', 'sub_section_id': 'A.5: Evaluation Test Set Characterization'}\n", + "{'page_num': 21, 'paper_path': 'iclr_docs/swebench.pdf', 'section_id': 'A: BENCHMARK DETAILS', 'sub_section_id': 'A.5: Evaluation Test Set Characterization'}\n", + "{'page_num': 22, 'paper_path': 'iclr_docs/swebench.pdf', 'section_id': 'A: BENCHMARK DETAILS', 'sub_section_id': 'A.5: Evaluation Test Set Characterization'}\n", + "{'page_num': 23, 'paper_path': 'iclr_docs/swebench.pdf', 'section_id': 'A: BENCHMARK DETAILS', 'sub_section_id': 'A.6: DEVELOPMENT SET CHARACTERIZATION'}\n", + "{'page_num': 2, 'paper_path': 'iclr_docs/swebench.pdf', 'section_id': '2: BENCHMARK CONSTRUCTION', 'sub_section_id': '2: BENCHMARK CONSTRUCTION'}\n" + ] + } + ], + "source": [ + "for n in nodes:\n", + " print(n.node.metadata)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "509de5ae-4d51-4b39-b67e-698cb84acd73", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + ">> Identifying the right sections to retrieve\n", + ">> Retrieving section: F: ADDITIONAL RESULTS\n", + ">> Retrieving section: 5: EXPERIMENTS\n", + ">> Retrieving section: F: ADDITIONAL RESULTS\n" + ] + } + ], + "source": [ + "nodes = section_retrieve(\n", + " \"Give me details of all additional experimental results in the Metra paper\",\n", + " verbose=True,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "db64a838-5f19-46e0-b874-859a125f8dcd", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'page_num': 21, 'paper_path': 'iclr_docs/metra.pdf', 'section_id': 'F: ADDITIONAL RESULTS', 'sub_section_id': 'F.1: FULL QUALITATIVE RESULTS'}\n", + "{'page_num': 22, 'paper_path': 'iclr_docs/metra.pdf', 'section_id': 'F: ADDITIONAL RESULTS', 'sub_section_id': 'F.4: Additional Baselines'}\n", + "{'page_num': 6, 'paper_path': 'iclr_docs/metra.pdf', 'section_id': '5: EXPERIMENTS', 'sub_section_id': '5: EXPERIMENTS'}\n", + "{'page_num': 7, 'paper_path': 'iclr_docs/metra.pdf', 'section_id': '5: EXPERIMENTS', 'sub_section_id': '5.2: QUALITATIVE COMPARISON'}\n", + "{'page_num': 8, 'paper_path': 'iclr_docs/metra.pdf', 'section_id': '5: EXPERIMENTS', 'sub_section_id': '5.3: QUANTITATIVE COMPARISON'}\n" + ] + } + ], + "source": [ + "for n in nodes:\n", + " print(n.node.metadata)" + ] + }, + { + "cell_type": "markdown", + "id": "d67303e6-ec65-499b-85bb-8189d220b466", + "metadata": {}, + "source": [ + "### Try out Section-Level Retrieval as a Full RAG Pipeline\n", + "\n", + "Now that we've defined the retriever, we can plug the retrieved results into an LLM to create a full RAG pipeline! \n", + "\n", + "Our response synthesizers help handle dumping context into the LLM prompt window while accounting for context window limitations." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bb382809-d38e-4f03-bf26-6e1bf0d98df6", + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.core.query_engine import CustomQueryEngine\n", + "from llama_index.core.response_synthesizers import TreeSummarize, BaseSynthesizer\n", + "\n", + "\n", + "class SectionRetrieverRAGEngine(CustomQueryEngine):\n", + " \"\"\"RAG Query Engine.\"\"\"\n", + "\n", + " synthesizer: BaseSynthesizer\n", + " verbose: bool = True\n", + "\n", + " def __init__(self, *args, **kwargs):\n", + " super().__init__(synthesizer=TreeSummarize(llm=llm))\n", + "\n", + " def custom_query(self, query_str: str):\n", + " nodes = section_retrieve(query_str, verbose=self.verbose)\n", + " response_obj = self.synthesizer.synthesize(query_str, nodes)\n", + " return response_obj" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "426d9426-a145-4f50-ad37-4dd82b5c7ae8", + "metadata": {}, + "outputs": [], + "source": [ + "query_engine = SectionRetrieverRAGEngine()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d1ec3f98-7181-4850-8b37-1e0aa751bf54", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + ">> Identifying the right sections to retrieve\n", + ">> Retrieving section: A: BENCHMARK DETAILS\n", + ">> Retrieving section: 5: RESULTS\n", + ">> Retrieving section: A: BENCHMARK DETAILS\n", + "In SWEBench, difficulty correlates with context length in a way that as the total context length increases, model performance tends to drop. This is observed across various models, including Claude 2, which shows a significant decrease in performance with longer context lengths. The models often struggle to localize the problematic code that needs updating when presented with a lot of code that may not be directly related to the issue at hand. This suggests that models can become distracted by additional context, which aligns with findings from other studies indicating that models may be sensitive to the relative location of target sequences. Even when increasing the maximum context size improves recall with respect to the oracle files, performance still drops, indicating that models are ineffective at localizing the necessary code changes.\n" + ] + } + ], + "source": [ + "response = query_engine.query(\n", + " \"Tell me more about how difficulty correlates with context length in SWEBench\"\n", + ")\n", + "print(str(response))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "483f5615-ab58-4bc7-968b-7a9e116756e1", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + ">> Identifying the right sections to retrieve\n", + ">> Retrieving section: A: BENCHMARK DETAILS\n", + ">> Retrieving section: 2: BENCHMARK CONSTRUCTION\n", + ">> Retrieving section: A: BENCHMARK DETAILS\n", + "SWE-bench is a benchmark designed to evaluate language models in a realistic software engineering setting by using GitHub issues and pull requests from popular repositories. The benchmark involves generating a pull request that addresses a given issue and passes related tests. The construction of SWE-bench involves a three-stage pipeline:\n", + "\n", + "1. **Repo Selection and Data Scraping**: Pull requests are collected from 12 popular open-source Python repositories on GitHub, resulting in approximately 90,000 PRs. These repositories are chosen for their better maintenance, clear contributor guidelines, and comprehensive test coverage.\n", + "\n", + "2. **Attribute-Based Filtering**: Candidate tasks are created by selecting merged PRs that resolve a GitHub issue and contribute tests. This indicates that the user likely added tests to verify the resolution of the issue.\n", + "\n", + "3. **Execution-Based Filtering**: For each candidate task, the PR's test content is applied, and test results are logged before and after applying the PR's other content. Tasks are filtered out if they do not have at least one test that changes from fail to pass or if they result in installation or runtime errors.\n", + "\n", + "The benchmark is designed to be extensible, allowing for updates with new task instances as new language models are released. It includes a robust framework for execution-based evaluation, ensuring that generated solutions can be verified by running unit tests. SWE-bench also provides a training dataset, SWE-bench-train, and fine-tuned models like SWE-Llama 7b and 13b, which are based on the CodeLlama model. These models are evaluated on their ability to resolve issues, with SWE-Llama 13b showing competitive performance in some settings.\n" + ] + } + ], + "source": [ + "response = query_engine.query(\n", + " \"Give me a full overview of the benchmark details in SWE Bench\"\n", + ")\n", + "print(str(response))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6d747bf8-0ed2-4c10-8108-9d0e8d53a4fb", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'page_num': 15, 'paper_path': 'iclr_docs/swebench.pdf', 'section_id': 'A: BENCHMARK DETAILS', 'sub_section_id': 'A.1: HIGH LEVEL OVERVIEW'}\n", + "{'page_num': 16, 'paper_path': 'iclr_docs/swebench.pdf', 'section_id': 'A: BENCHMARK DETAILS', 'sub_section_id': 'A.2: CONSTRUCTION PROCESS'}\n", + "{'page_num': 17, 'paper_path': 'iclr_docs/swebench.pdf', 'section_id': 'A: BENCHMARK DETAILS', 'sub_section_id': 'A.2: CONSTRUCTION PROCESS'}\n", + "{'page_num': 18, 'paper_path': 'iclr_docs/swebench.pdf', 'section_id': 'A: BENCHMARK DETAILS', 'sub_section_id': 'A.3: Execution-Based Validation'}\n", + "{'page_num': 19, 'paper_path': 'iclr_docs/swebench.pdf', 'section_id': 'A: BENCHMARK DETAILS', 'sub_section_id': 'A.3: Execution-Based Validation'}\n", + "{'page_num': 20, 'paper_path': 'iclr_docs/swebench.pdf', 'section_id': 'A: BENCHMARK DETAILS', 'sub_section_id': 'A.5: Evaluation Test Set Characterization'}\n", + "{'page_num': 21, 'paper_path': 'iclr_docs/swebench.pdf', 'section_id': 'A: BENCHMARK DETAILS', 'sub_section_id': 'A.5: Evaluation Test Set Characterization'}\n", + "{'page_num': 22, 'paper_path': 'iclr_docs/swebench.pdf', 'section_id': 'A: BENCHMARK DETAILS', 'sub_section_id': 'A.5: Evaluation Test Set Characterization'}\n", + "{'page_num': 23, 'paper_path': 'iclr_docs/swebench.pdf', 'section_id': 'A: BENCHMARK DETAILS', 'sub_section_id': 'A.6: DEVELOPMENT SET CHARACTERIZATION'}\n", + "{'page_num': 2, 'paper_path': 'iclr_docs/swebench.pdf', 'section_id': '2: BENCHMARK CONSTRUCTION', 'sub_section_id': '2: BENCHMARK CONSTRUCTION'}\n" + ] + } + ], + "source": [ + "for n in response.source_nodes:\n", + " print(n.metadata)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "62b11a23-df6a-4d83-b35c-691bb4d125c0", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + ">> Identifying the right sections to retrieve\n", + ">> Retrieving section: F: ADDITIONAL RESULTS\n", + ">> Retrieving section: 5: EXPERIMENTS\n", + ">> Retrieving section: F: ADDITIONAL RESULTS\n", + "The additional experimental results in the METRA paper include several key findings:\n", + "\n", + "1. **Full Qualitative Results**: METRA discovers diverse locomotion behaviors across different environments, including state-based Ant and HalfCheetah, and pixel-based Quadruped and Humanoid. The results are consistent across multiple random seeds, indicating robustness in behavior discovery.\n", + "\n", + "2. **Latent Space Visualization**: METRA effectively captures the most temporally spread-out dimensions in the state space, such as x-y coordinates, in its latent space. This is demonstrated in both state-based and pixel-based environments, with higher-dimensional latent spaces capturing more diverse behaviors.\n", + "\n", + "3. **Ablation Study of Latent Space Sizes**: The study shows that increasing the size of the latent space generally enhances the diversity of skills learned by METRA. Different dimensions of continuous and discrete skills were tested on Ant and HalfCheetah.\n", + "\n", + "4. **Comparison with Additional Baselines**: METRA was compared with DGPO, a method focused on finding diverse behaviors that maximize task rewards. The comparison was conducted in a controlled Markov process setting without external rewards, using only intrinsic rewards.\n", + "\n", + "These results highlight METRA's ability to discover diverse and meaningful behaviors in various environments, its effective use of latent spaces, and its performance relative to other methods.\n" + ] + } + ], + "source": [ + "response = query_engine.query(\n", + " \"Give me details of all additional experimental results in the Metra paper\"\n", + ")\n", + "print(str(response))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "llama_index_v3", + "language": "python", + "name": "llama_index_v3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/advanced_rag/dynamic_section_retrieval_img.png b/examples/advanced_rag/dynamic_section_retrieval_img.png new file mode 100644 index 0000000..9e662a6 Binary files /dev/null and b/examples/advanced_rag/dynamic_section_retrieval_img.png differ