Skip to content

Commit

Permalink
added node-postprocessors to retriever_tool
Browse files Browse the repository at this point in the history
  • Loading branch information
Izuki Matsuba committed Mar 29, 2024
1 parent 0cfb2c9 commit b9ba4e5
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 3 deletions.
28 changes: 25 additions & 3 deletions llama-index-core/llama_index/core/tools/retriever_tool.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
"""Retriever tool."""

from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, List, Optional


from llama_index.core.base.base_retriever import BaseRetriever

if TYPE_CHECKING:
from llama_index.core.langchain_helpers.agents.tools import LlamaIndexTool
from llama_index.core.schema import MetadataMode
from llama_index.core.schema import MetadataMode, NodeWithScore, QueryBundle
from llama_index.core.tools.types import AsyncBaseTool, ToolMetadata, ToolOutput
from llama_index.core.postprocessor.types import BaseNodePostprocessor

DEFAULT_NAME = "retriever_tool"
DEFAULT_DESCRIPTION = """Useful for running a natural language query
Expand All @@ -23,28 +25,37 @@ class RetrieverTool(AsyncBaseTool):
Args:
retriever (BaseRetriever): A retriever.
metadata (ToolMetadata): The associated metadata of the query engine.
node_postprocessors (Optional[List[BaseNodePostprocessor]]): A list of
node postprocessors.
"""

def __init__(
self,
retriever: BaseRetriever,
metadata: ToolMetadata,
node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
) -> None:
self._retriever = retriever
self._metadata = metadata
self._node_postprocessors = node_postprocessors or []

@classmethod
def from_defaults(
cls,
retriever: BaseRetriever,
node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
name: Optional[str] = None,
description: Optional[str] = None,
) -> "RetrieverTool":
name = name or DEFAULT_NAME
description = description or DEFAULT_DESCRIPTION

metadata = ToolMetadata(name=name, description=description)
return cls(retriever=retriever, metadata=metadata)
return cls(
retriever=retriever,
metadata=metadata,
node_postprocessors=node_postprocessors,
)

@property
def retriever(self) -> BaseRetriever:
Expand All @@ -66,6 +77,7 @@ def call(self, *args: Any, **kwargs: Any) -> ToolOutput:
raise ValueError("Cannot call query engine without inputs")

docs = self._retriever.retrieve(query_str)
docs = self._apply_node_postprocessors(docs, query_str)
content = ""
for doc in docs:
node_copy = doc.node.copy()
Expand All @@ -91,6 +103,7 @@ async def acall(self, *args: Any, **kwargs: Any) -> ToolOutput:
raise ValueError("Cannot call query engine without inputs")
docs = await self._retriever.aretrieve(query_str)
content = ""
docs = self._apply_node_postprocessors(docs, query_str)
for doc in docs:
node_copy = doc.node.copy()
node_copy.text_template = "{metadata_str}\n{content}"
Expand All @@ -105,3 +118,12 @@ async def acall(self, *args: Any, **kwargs: Any) -> ToolOutput:

def as_langchain_tool(self) -> "LlamaIndexTool":
raise NotImplementedError("`as_langchain_tool` not implemented here.")

def _apply_node_postprocessors(
self, nodes: List[NodeWithScore], query_bundle: QueryBundle
) -> List[NodeWithScore]:
for node_postprocessor in self._node_postprocessors:
nodes = node_postprocessor.postprocess_nodes(
nodes, query_bundle=query_bundle
)
return nodes
53 changes: 53 additions & 0 deletions llama-index-core/tests/tools/test_retriever_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""Test retriever tool."""
from typing import List, Optional

from llama_index.core.base.base_retriever import BaseRetriever
from llama_index.core.schema import NodeWithScore, TextNode, QueryBundle
from llama_index.core.tools import RetrieverTool
from llama_index.core.postprocessor.types import BaseNodePostprocessor


class MockRetriever(BaseRetriever):
"""Custom retriever for testing."""

def _retrieve(self, query_str: str) -> List[NodeWithScore]:
"""Mock retrieval."""
return [NodeWithScore(node=TextNode(text=f"mock_{query_str}"), score=0.9)]

async def _aretrieve(self, query_str: str) -> List[NodeWithScore]:
"""Mock retrieval."""
return [NodeWithScore(node=TextNode(text=f"mock_{query_str}"), score=0.9)]


class MockPostProcessor(BaseNodePostprocessor):
@classmethod
def class_name(cls) -> str:
return "CitationPostProcessor"

def _postprocess_nodes(
self,
nodes: List[NodeWithScore],
query_bundle: Optional[QueryBundle] = None,
) -> List[NodeWithScore]:
for n in nodes:
prefix = f"processed_"
n.node.text = prefix + n.node.text
return nodes


def test_retriever_tool() -> None:
"""Test retriever tool."""
# Test retrieval
retriever = MockRetriever()
retriever_tool = RetrieverTool.from_defaults(retriever=retriever)
response_nodes = retriever_tool("hello world")
assert str(response_nodes) == "mock_hello world\n\n\n\n"
assert response_nodes.raw_output[0].node.text == "mock_hello world\n\n"

# Test node_postprocessors
node_postprocessors = [MockPostProcessor()]
pr_retriever_tool = RetrieverTool.from_defaults(
retriever=retriever, node_postprocessors=node_postprocessors
)
pr_response_nodes = pr_retriever_tool("hello world")
assert str(pr_response_nodes) == "processed_mock_hello world\n\n\n\n"

0 comments on commit b9ba4e5

Please sign in to comment.