diff --git a/docs/targets/aws/bedrock_knowledgebases.md b/docs/targets/aws/bedrock_knowledgebases.md new file mode 100644 index 0000000..e192a04 --- /dev/null +++ b/docs/targets/aws/bedrock_knowledgebases.md @@ -0,0 +1,30 @@ +# Knowledge bases for Amazon Bedrock + +Knowledge bases for Amazon Bedrock provides you the capability of amassing data sources into a repository of information. With knowledge bases, you can easily build an application that takes advantage of retrieval augmented generation (RAG), a technique in which the retrieval of information from data sources augments the generation of model responses. For more information, visit the AWS documentation [here](https://docs.aws.amazon.com/bedrock/latest/userguide/knowledge-base.html). + +## Prerequisites + +The principal must have the following permissions: + +- [RetrieveAndGenerate](https://docs.aws.amazon.com/bedrock/latest/APIReference/API_agent-runtime_RetrieveAndGenerate.html) + +## Configurations + +```yaml +target: + type: bedrock-knowledgebase + model_id: + knowledge_base_id: +``` + +`model_id` *(string)* + +The unique identifier of the foundation model used to generate a response. + +--- + +`knowledge_base_id` *(string)* + +The unique identifier of the knowledge base that is queried and the foundation model used for generation. + +--- \ No newline at end of file diff --git a/docs/targets/index.md b/docs/targets/index.md index a627564..a0decf4 100644 --- a/docs/targets/index.md +++ b/docs/targets/index.md @@ -45,6 +45,7 @@ The maximum number of retry attempts. The default is `10`. - [Agents for Amazon Bedrock](./aws/bedrock_agents.md) - [Amazon Q for Business](./aws/q_business.md) - [Amazon SageMaker endpoints](./aws/sagemaker_endpoints.md) +- [Knowledge bases for Amazon Bedrock](./aws/bedrock_knowledgebases.md) --- diff --git a/src/agenteval/plan.py b/src/agenteval/plan.py index a809188..20980d3 100644 --- a/src/agenteval/plan.py +++ b/src/agenteval/plan.py @@ -17,6 +17,7 @@ from agenteval.targets import BaseTarget from agenteval.targets.aws import ( BedrockAgentTarget, + BedrockKnowledgebaseTarget, QBusinessTarget, SageMakerEndpointTarget, ) @@ -47,6 +48,7 @@ "bedrock-agent": BedrockAgentTarget, "q-business": QBusinessTarget, "sagemaker-endpoint": SageMakerEndpointTarget, + "bedrock-knowledgebase": BedrockKnowledgebaseTarget, } sys.path.append(".") diff --git a/src/agenteval/targets/aws/__init__.py b/src/agenteval/targets/aws/__init__.py index 58d95a0..88a9fe7 100644 --- a/src/agenteval/targets/aws/__init__.py +++ b/src/agenteval/targets/aws/__init__.py @@ -1,5 +1,6 @@ from .aws_target import AWSTarget from .bedrock_agent_target import BedrockAgentTarget +from .bedrock_knowledgebase_target import BedrockKnowledgebaseTarget from .q_business_target import QBusinessTarget from .sagemaker_endpoint_target import SageMakerEndpointTarget @@ -8,4 +9,5 @@ "BedrockAgentTarget", "QBusinessTarget", "SageMakerEndpointTarget", + "BedrockKnowledgebaseTarget", ] diff --git a/src/agenteval/targets/aws/bedrock_knowledgebase_target.py b/src/agenteval/targets/aws/bedrock_knowledgebase_target.py new file mode 100644 index 0000000..7b96369 --- /dev/null +++ b/src/agenteval/targets/aws/bedrock_knowledgebase_target.py @@ -0,0 +1,38 @@ +from agenteval.targets import TargetResponse +from agenteval.targets.aws import AWSTarget + +_SERVICE_NAME = "bedrock-agent-runtime" + + +class BedrockKnowledgebaseTarget(AWSTarget): + def __init__(self, knowledge_base_id: str, model_id: str, **kwargs): + super().__init__(boto3_service_name=_SERVICE_NAME, **kwargs) + aws_region = self.boto3_client.meta.region_name + self._knowledge_base_id = knowledge_base_id + self._model_arn = f"arn:aws:bedrock:{aws_region}::foundation-model/{model_id}" + self._session_id: str = None + + def invoke(self, prompt: str) -> TargetResponse: + args = { + "input": { + "text": prompt, + }, + "retrieveAndGenerateConfiguration": { + "type": "KNOWLEDGE_BASE", + "knowledgeBaseConfiguration": { + "knowledgeBaseId": self._knowledge_base_id, + "modelArn": self._model_arn, + }, + }, + } + if self._session_id: + args["sessionId"] = self._session_id + + response = self.boto3_client.retrieve_and_generate(**args) + generated_text = response["output"]["text"] + citations = response["citations"] + self._session_id = response["sessionId"] + + return TargetResponse( + response=generated_text, data={"bedrock_knowledgebase_citations": citations} + ) diff --git a/tests/src/agenteval/targets/aws/test_bedrock_knowlegebase_target.py b/tests/src/agenteval/targets/aws/test_bedrock_knowlegebase_target.py new file mode 100644 index 0000000..b6a520f --- /dev/null +++ b/tests/src/agenteval/targets/aws/test_bedrock_knowlegebase_target.py @@ -0,0 +1,73 @@ +import pytest + +from src.agenteval.targets.aws import bedrock_knowledgebase_target +from src.agenteval.utils import aws + + +@pytest.fixture +def bedrock_knowledgebase_fixture(mocker): + mock_session = mocker.patch.object(aws.boto3, "Session") + mocker.patch.object(mock_session.return_value, "client") + + fixture = bedrock_knowledgebase_target.BedrockKnowledgebaseTarget( + model_id="bedrock-model-id", + knowledge_base_id="bedrock-knowledge-base-id", + aws_profile="test-profile", + aws_region="us-west-2", + ) + + return fixture + + +class TestBedrockKnowledgebaseTarget: + + _GENERATED_TEXT = "generated text" + _SESSION_ID = "session-id" + + def test_invoke(self, mocker, bedrock_knowledgebase_fixture): + mock_knowlegebase_retrieve_and_generate = mocker.patch.object( + bedrock_knowledgebase_fixture.boto3_client, "retrieve_and_generate" + ) + + mock_knowlegebase_retrieve_and_generate.return_value = { + "citations": [ + { + "generatedResponsePart": { + "textResponsePart": { + "span": {"end": 10, "start": 0}, + "text": "generated text from citation 1", + } + }, + "retrievedReferences": [ + { + "content": {"text": "referenced text"}, + "location": {"s3Location": {"uri": "s3://"}, "type": "s3"}, + "metadata": {"string": None}, + } + ], + }, + { + "generatedResponsePart": { + "textResponsePart": { + "span": {"end": 20, "start": 10}, + "text": "generated text from citation 2", + } + }, + "retrievedReferences": [ + { + "content": {"text": "referenced text"}, + "location": {"s3Location": {"uri": "s3://"}, "type": "s3"}, + "metadata": {"string": None}, + } + ], + }, + ], + "output": {"text": self._GENERATED_TEXT}, + "sessionId": self._SESSION_ID, + } + + response = bedrock_knowledgebase_fixture.invoke("test prompt") + + assert response.response == self._GENERATED_TEXT + assert bedrock_knowledgebase_fixture._session_id == self._SESSION_ID + assert len(response.data.get("bedrock_knowledgebase_citations")) == 2