Skip to content

Commit

Permalink
add bedrock knowledgebase as target
Browse files Browse the repository at this point in the history
  • Loading branch information
Sharon Li committed Apr 18, 2024
1 parent 60c43f8 commit 6d03b0d
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 0 deletions.
30 changes: 30 additions & 0 deletions docs/targets/aws/bedrock_knowledgebases.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Knowledge bases for Amazon Bedrock

Knowledge bases for Amazon Bedrock provides you the capability of amassing data sources into a repository of information. With knowledge bases, you can easily build an application that takes advantage of retrieval augmented generation (RAG), a technique in which the retrieval of information from data sources augments the generation of model responses. For more information, visit the AWS documentation [here](https://docs.aws.amazon.com/bedrock/latest/userguide/knowledge-base.html).

## Prerequisites

The principal must have the following permissions:

- [RetrieveAndGenerate](https://docs.aws.amazon.com/bedrock/latest/APIReference/API_agent-runtime_RetrieveAndGenerate.html)

## Configurations

```yaml
target:
type: bedrock-knowledgebase
model_id:
knowledge_base_id:
```
`model_id` *(string)*

The unique identifier of the foundation model used to generate a response.

---

`knowledge_base_id` *(string)*

The unique identifier of the knowledge base that is queried and the foundation model used for generation.

---
1 change: 1 addition & 0 deletions docs/targets/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ The maximum number of retry attempts. The default is `10`.
- [Agents for Amazon Bedrock](./aws/bedrock_agents.md)
- [Amazon Q for Business](./aws/q_business.md)
- [Amazon SageMaker endpoints](./aws/sagemaker_endpoints.md)
- [Knowledge bases for Amazon Bedrock](./aws/bedrock_knowledgebases.md)

---

Expand Down
2 changes: 2 additions & 0 deletions src/agenteval/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from agenteval.targets import BaseTarget
from agenteval.targets.aws import (
BedrockAgentTarget,
BedrockKnowledgebaseTarget,
QBusinessTarget,
SageMakerEndpointTarget,
)
Expand Down Expand Up @@ -47,6 +48,7 @@
"bedrock-agent": BedrockAgentTarget,
"q-business": QBusinessTarget,
"sagemaker-endpoint": SageMakerEndpointTarget,
"bedrock-knowledgebase": BedrockKnowledgebaseTarget,
}

sys.path.append(".")
Expand Down
2 changes: 2 additions & 0 deletions src/agenteval/targets/aws/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .aws_target import AWSTarget
from .bedrock_agent_target import BedrockAgentTarget
from .bedrock_knowledgebase_target import BedrockKnowledgebaseTarget
from .q_business_target import QBusinessTarget
from .sagemaker_endpoint_target import SageMakerEndpointTarget

Expand All @@ -8,4 +9,5 @@
"BedrockAgentTarget",
"QBusinessTarget",
"SageMakerEndpointTarget",
"BedrockKnowledgebaseTarget",
]
38 changes: 38 additions & 0 deletions src/agenteval/targets/aws/bedrock_knowledgebase_target.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from agenteval.targets import TargetResponse
from agenteval.targets.aws import AWSTarget

_SERVICE_NAME = "bedrock-agent-runtime"


class BedrockKnowledgebaseTarget(AWSTarget):
def __init__(self, knowledge_base_id: str, model_id: str, **kwargs):
super().__init__(boto3_service_name=_SERVICE_NAME, **kwargs)
aws_region = self.boto3_client.meta.region_name
self._knowledge_base_id = knowledge_base_id
self._model_arn = f"arn:aws:bedrock:{aws_region}::foundation-model/{model_id}"
self._session_id: str = None

def invoke(self, prompt: str) -> TargetResponse:
args = {
"input": {
"text": prompt,
},
"retrieveAndGenerateConfiguration": {
"type": "KNOWLEDGE_BASE",
"knowledgeBaseConfiguration": {
"knowledgeBaseId": self._knowledge_base_id,
"modelArn": self._model_arn,
},
},
}
if self._session_id:
args["sessionId"] = self._session_id

response = self.boto3_client.retrieve_and_generate(**args)
generated_text = response["output"]["text"]
citations = response["citations"]
self._session_id = response["sessionId"]

return TargetResponse(
response=generated_text, data={"bedrock_knowledgebase_citations": citations}
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import pytest

from src.agenteval.targets.aws import bedrock_knowledgebase_target
from src.agenteval.utils import aws


@pytest.fixture
def bedrock_knowledgebase_fixture(mocker):
mock_session = mocker.patch.object(aws.boto3, "Session")
mocker.patch.object(mock_session.return_value, "client")

fixture = bedrock_knowledgebase_target.BedrockKnowledgebaseTarget(
model_id="bedrock-model-id",
knowledge_base_id="bedrock-knowledge-base-id",
aws_profile="test-profile",
aws_region="us-west-2",
)

return fixture


class TestBedrockKnowledgebaseTarget:

_GENERATED_TEXT = "generated text"
_SESSION_ID = "session-id"

def test_invoke(self, mocker, bedrock_knowledgebase_fixture):
mock_knowlegebase_retrieve_and_generate = mocker.patch.object(
bedrock_knowledgebase_fixture.boto3_client, "retrieve_and_generate"
)

mock_knowlegebase_retrieve_and_generate.return_value = {
"citations": [
{
"generatedResponsePart": {
"textResponsePart": {
"span": {"end": 10, "start": 0},
"text": "generated text from citation 1",
}
},
"retrievedReferences": [
{
"content": {"text": "referenced text"},
"location": {"s3Location": {"uri": "s3://"}, "type": "s3"},
"metadata": {"string": None},
}
],
},
{
"generatedResponsePart": {
"textResponsePart": {
"span": {"end": 20, "start": 10},
"text": "generated text from citation 2",
}
},
"retrievedReferences": [
{
"content": {"text": "referenced text"},
"location": {"s3Location": {"uri": "s3://"}, "type": "s3"},
"metadata": {"string": None},
}
],
},
],
"output": {"text": self._GENERATED_TEXT},
"sessionId": self._SESSION_ID,
}

response = bedrock_knowledgebase_fixture.invoke("test prompt")

assert response.response == self._GENERATED_TEXT
assert bedrock_knowledgebase_fixture._session_id == self._SESSION_ID
assert len(response.data.get("bedrock_knowledgebase_citations")) == 2

0 comments on commit 6d03b0d

Please sign in to comment.