Skip to content

Commit

Permalink
add reranking
Browse files Browse the repository at this point in the history
  • Loading branch information
jpfcabral committed Jan 16, 2025
1 parent 784445c commit 99ed23d
Show file tree
Hide file tree
Showing 5 changed files with 203 additions and 0 deletions.
2 changes: 2 additions & 0 deletions libs/aws/langchain_aws/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
InMemorySemanticCache,
InMemoryVectorStore,
)
from langchain_aws.rerank.rerank import BedrockRerank

__all__ = [
"BedrockEmbeddings",
Expand All @@ -29,4 +30,5 @@
"NeptuneGraph",
"InMemoryVectorStore",
"InMemorySemanticCache",
"BedrockRerank"
]
Empty file.
126 changes: 126 additions & 0 deletions libs/aws/langchain_aws/rerank/rerank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import json
from copy import deepcopy
from typing import Any, Dict, List, Optional, Sequence, Union

import boto3
from langchain_core.callbacks.manager import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.utils import from_env
from pydantic import ConfigDict, Field, model_validator
from typing_extensions import Self


class BedrockRerank(BaseDocumentCompressor):
"""Document compressor that uses AWS Bedrock Rerank API."""

client: Any = None
"""Bedrock client to use for compressing documents."""
top_n: Optional[int] = 3
"""Number of documents to return."""
model: Optional[str] = "amazon.rerank-v1:0"
"""Model to use for reranking. Default is amazon.rerank-v1:0."""
aws_region: str = Field(
default_factory=from_env("AWS_DEFAULT_REGION", default="us-west-2")
)
"""AWS region to initialize the Bedrock client."""
aws_profile: Optional[str] = Field(
default_factory=from_env("AWS_PROFILE", default=None)
)
"""AWS profile for authentication, optional."""

model_config = ConfigDict(
extra="forbid",
arbitrary_types_allowed=True,
)

@model_validator(mode="after")
def initialize_client(self) -> Self:
"""Initialize the AWS Bedrock client."""
if not self.client:
session = (
boto3.Session(profile_name=self.aws_profile)
if self.aws_profile
else boto3.Session()
)
self.client = session.client("bedrock-runtime", region_name=self.aws_region)
return self

def rerank(
self,
documents: Sequence[Union[str, Document, dict]],
query: str,
*,
top_n: Optional[int] = None,
model: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""Returns an ordered list of documents based on their relevance to the query.
Args:
query: The query to use for reranking.
documents: A sequence of documents to rerank.
top_n: The number of top-ranked results to return. Defaults to self.top_n.
model: The model to use for reranking. Defaults to self.model.
Returns:
List[Dict[str, Any]]: A list of ranked documents with relevance scores.
"""
if len(documents) == 0:
return []

# Serialize documents for the Bedrock API
serialized_documents = [
json.dumps(doc)
if isinstance(doc, dict)
else doc.page_content
if isinstance(doc, Document)
else doc
for doc in documents
]

body = json.dumps(
{
"query": query,
"documents": serialized_documents,
"top_n": top_n or self.top_n,
}
)

response = self.client.invoke_model(
modelId=model or self.model,
accept="application/json",
contentType="application/json",
body=body,
)

response_body = json.loads(response.get("body").read())
results = [
{"index": result["index"], "relevance_score": result["relevance_score"]}
for result in response_body["results"]
]

return results

def compress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
"""
Compress documents using Bedrock's rerank API.
Args:
documents: A sequence of documents to compress.
query: The query to use for compressing the documents.
callbacks: Callbacks to run during the compression process.
Returns:
A sequence of compressed documents.
"""
compressed = []
for res in self.rerank(documents, query):
doc = documents[res["index"]]
doc_copy = Document(doc.page_content, metadata=deepcopy(doc.metadata))
doc_copy.metadata["relevance_score"] = res["relevance_score"]
compressed.append(doc_copy)
return compressed
Empty file.
75 changes: 75 additions & 0 deletions libs/aws/tests/unit_tests/rerank/test_rerank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import json
from unittest.mock import MagicMock

import pytest
from langchain_core.documents import Document

from langchain_aws import BedrockRerank


# Mock setup
@pytest.fixture
def mock_bedrock_client():
mock_client = MagicMock()
mock_client.invoke_model.return_value = {
"body": MagicMock(
read=MagicMock(
return_value=json.dumps(
{
"results": [
{"index": 0, "relevance_score": 0.95},
{"index": 1, "relevance_score": 0.90},
]
}
)
)
)
}
return mock_client


@pytest.fixture
def bedrock_rerank(mock_bedrock_client):
return BedrockRerank(client=mock_bedrock_client)


# Test initialize_client
def test_initialize_client_with_profile():
bedrock_rerank = BedrockRerank(aws_profile="default")
bedrock_rerank.initialize_client()
assert bedrock_rerank.client is not None


def test_initialize_client_without_profile():
bedrock_rerank = BedrockRerank()
bedrock_rerank.initialize_client()
assert bedrock_rerank.client is not None


# Test rerank method
def test_rerank_success(bedrock_rerank):
documents = ["doc1", "doc2", "doc3"]
query = "Test query"
results = bedrock_rerank.rerank(documents, query)
assert len(results) == 2
assert results[0]["index"] == 0
assert results[0]["relevance_score"] == 0.95


def test_rerank_empty_documents(bedrock_rerank):
results = bedrock_rerank.rerank([], "query")
assert results == []


# Test compress_documents method
def test_compress_documents(bedrock_rerank):
documents = [
Document(page_content="doc1"),
Document(page_content="doc2"),
Document(page_content="doc3"),
]
query = "Test query"
compressed = bedrock_rerank.compress_documents(documents, query)
assert len(compressed) == 2
assert compressed[0].metadata["relevance_score"] == 0.95
assert compressed[1].metadata["relevance_score"] == 0.90

0 comments on commit 99ed23d

Please sign in to comment.