diff --git a/CHANGELOG.md b/CHANGELOG.md index 3a8cb3296..542b6ca14 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Support for `BranchTask` in `StructureVisualizer`. - `EvalEngine` for evaluating the performance of an LLM's output against a given input. - `BaseFileLoader.save()` method for saving an Artifact to a destination. +- Support for `GenericMessageContent` in `AnthropicPromptDriver` and `AmazonBedrockPromptDriver`. ### Changed diff --git a/docs/examples/src/talk_to_a_document.py b/docs/examples/src/talk_to_a_document.py new file mode 100644 index 000000000..108055df3 --- /dev/null +++ b/docs/examples/src/talk_to_a_document.py @@ -0,0 +1,42 @@ +import base64 + +import requests +from attrs import define + +from griptape.artifacts import GenericArtifact, TextArtifact +from griptape.drivers import AnthropicPromptDriver +from griptape.structures import Agent + + +@define +class DocumentArtifact(GenericArtifact): + """Artifact for storing a document. + + Subclassing `GenericArtifact` to avoid printing out `self.value` as `GenericArtifact` does. + """ + + def to_text(self) -> str: + return f"Document: {self.value['source']['media_type']}" + + +doc_bytes = requests.get("https://arxiv.org/pdf/1706.03762.pdf").content + + +agent = Agent( + prompt_driver=AnthropicPromptDriver(model="claude-3-5-sonnet-20240620", max_attempts=0), + input=[ + DocumentArtifact( + { + "type": "document", + "source": { + "type": "base64", + "media_type": "application/pdf", + "data": base64.b64encode(doc_bytes).decode("utf-8"), + }, + } + ), + TextArtifact("{{ args[0] }}"), + ], +) + +agent.run("What is the title and who are the authors of this paper?") diff --git a/docs/examples/talk-to-a-document.md b/docs/examples/talk-to-a-document.md new file mode 100644 index 000000000..8f4eb070e --- /dev/null +++ b/docs/examples/talk-to-a-document.md @@ -0,0 +1,22 @@ +Some LLM providers such as [Anthropic](https://docs.anthropic.com/en/api/messages#body-messages-content) and [Amazon Bedrock](https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_DocumentBlock.html) offer the ability to pass documents directly to the LLM. + +In this example, we create a custom `DocumentArtifact` to pass a PDF document to the Agent. The Agent then uses the document to answer questions about the paper. + +```python +--8<-- "docs/examples/src/talk_to_a_document.py" +``` + +``` +[12/23/24 09:37:47] INFO PromptTask cc77e4c193c84a5986a4e02e56614d6b + Input: Document: application/pdf + + What is the title and who are the authors of this paper? +[12/23/24 09:37:57] INFO PromptTask cc77e4c193c84a5986a4e02e56614d6b + Output: The title of this paper is "Attention Is All You Need" and the authors are: + + Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Ɓukasz Kaiser, and Illia + Polosukhin. + + The paper is from Google Brain, Google Research, and the University of Toronto. It introduces the Transformer model + architecture for sequence transduction tasks like machine translation. +``` diff --git a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py index b108180d2..54278c895 100644 --- a/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py +++ b/griptape/drivers/prompt/amazon_bedrock_prompt_driver.py @@ -188,7 +188,7 @@ def __to_bedrock_message_content(self, content: BaseMessageContent) -> dict: }, } else: - raise ValueError(f"Unsupported content type: {type(content)}") + return content.artifact.value def __to_bedrock_tool_use_content(self, artifact: BaseArtifact) -> dict: if isinstance(artifact, ImageArtifact): diff --git a/griptape/drivers/prompt/anthropic_prompt_driver.py b/griptape/drivers/prompt/anthropic_prompt_driver.py index 3341006a1..060b8151d 100644 --- a/griptape/drivers/prompt/anthropic_prompt_driver.py +++ b/griptape/drivers/prompt/anthropic_prompt_driver.py @@ -192,7 +192,7 @@ def __to_anthropic_message_content(self, content: BaseMessageContent) -> dict: "is_error": isinstance(artifact, ErrorArtifact), } else: - raise ValueError(f"Unsupported prompt content type: {type(content)}") + return content.artifact.value def __to_anthropic_tool_result_content(self, artifact: BaseArtifact) -> dict: if isinstance(artifact, ImageArtifact): diff --git a/mkdocs.yml b/mkdocs.yml index d48b3622c..8f510064b 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -176,6 +176,7 @@ nav: - Talk to a PDF: "examples/talk-to-a-pdf.md" - Talk to a Video: "examples/talk-to-a-video.md" - Talk to an Image: "examples/talk-to-an-image.md" + - Talk to a Document: "examples/talk-to-a-document.md" - Multi Agent Workflows: "examples/multi-agent-workflow.md" - Shared Memory Between Agents: "examples/multiple-agent-shared-memory.md" - Chat Sessions with Amazon DynamoDB: "examples/amazon-dynamodb-sessions.md" diff --git a/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py b/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py index ada192aae..939b86c5e 100644 --- a/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_amazon_bedrock_prompt_driver.py @@ -1,6 +1,6 @@ import pytest -from griptape.artifacts import ActionArtifact, ErrorArtifact, ImageArtifact, ListArtifact, TextArtifact +from griptape.artifacts import ActionArtifact, ErrorArtifact, GenericArtifact, ImageArtifact, ListArtifact, TextArtifact from griptape.common import ActionCallDeltaMessageContent, PromptStack, TextDeltaMessageContent, ToolAction from griptape.drivers import AmazonBedrockPromptDriver from tests.mocks.mock_tool.tool import MockTool @@ -299,6 +299,7 @@ def prompt_stack(self, request): ] ) ) + prompt_stack.add_user_message(GenericArtifact("video-file")) return prompt_stack @@ -354,6 +355,7 @@ def messages(self): ], "role": "user", }, + {"content": ["video-file"], "role": "user"}, ] @pytest.mark.parametrize("use_native_tools", [True, False]) diff --git a/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py b/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py index 2b84b5a17..b611b5e1c 100644 --- a/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_anthropic_prompt_driver.py @@ -2,7 +2,7 @@ import pytest -from griptape.artifacts import ActionArtifact, ImageArtifact, ListArtifact, TextArtifact +from griptape.artifacts import ActionArtifact, GenericArtifact, ImageArtifact, ListArtifact, TextArtifact from griptape.artifacts.error_artifact import ErrorArtifact from griptape.common import ActionCallDeltaMessageContent, PromptStack, TextDeltaMessageContent, ToolAction from griptape.drivers import AnthropicPromptDriver @@ -270,6 +270,7 @@ def prompt_stack(self, request): ] ) ) + prompt_stack.add_user_message(GenericArtifact({"type": "document"})) return prompt_stack @@ -337,6 +338,12 @@ def messages(self): ], "role": "user", }, + { + "content": [ + {"type": "document"}, + ], + "role": "user", + }, ] def test_init(self):