Skip to content

Commit

Permalink
Merge branch 'dev' into feature/exa_web_search_driver
Browse files Browse the repository at this point in the history
  • Loading branch information
william-price01 authored Oct 2, 2024
2 parents 904addb + 65bed0e commit 7e97136
Show file tree
Hide file tree
Showing 12 changed files with 309 additions and 268 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **BREAKING**: Renamed parameter `pipe` on `HuggingFacePipelinePromptDriver` to `pipeline`.
- Several places where API clients are initialized are now lazy loaded.
- `Structure.output`'s type is now `BaseArtifact` and raises an exception if the output is `None`.
- **BREAKING**: Update `pypdf` dependency to `^5.0.1`.
- **BREAKING**: Update `redis` dependency to `^5.1.0`.
- `MarkdownifyWebScraperDriver.DEFAULT_EXCLUDE_TAGS` now includes media/blob-like HTML tags

### Fixed
- Anthropic native Tool calling

## [0.32.0] - 2024-09-17

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,5 +58,5 @@ def store(self, runs: list[Run], metadata: dict[str, Any]) -> None:
def load(self) -> tuple[list[Run], dict[str, Any]]:
memory_json = self.client.hget(self.index, self.conversation_id)
if memory_json is not None:
return self._from_params_dict(json.loads(memory_json))
return self._from_params_dict(json.loads(memory_json)) # pyright: ignore[reportArgumentType] https://github.com/redis/redis-py/issues/2399
return [], {}
9 changes: 8 additions & 1 deletion griptape/drivers/prompt/anthropic_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def __to_anthropic_role(self, message: Message) -> str:
return "user"

def __to_anthropic_tools(self, tools: list[BaseTool]) -> list[dict]:
return [
tool_schemas = [
{
"name": tool.to_native_tool_name(activity),
"description": tool.activity_description(activity),
Expand All @@ -137,6 +137,13 @@ def __to_anthropic_tools(self, tools: list[BaseTool]) -> list[dict]:
for activity in tool.activities()
]

# Anthropic doesn't support $schema and $id
for tool_schema in tool_schemas:
del tool_schema["input_schema"]["$schema"]
del tool_schema["input_schema"]["$id"]

return tool_schemas

def __to_anthropic_content(self, message: Message) -> str | list[dict]:
if message.has_all_content_type(TextMessageContent):
return message.to_text()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class MarkdownifyWebScraperDriver(BaseWebScraperDriver):
the browser has emitted the "load" event.
"""

DEFAULT_EXCLUDE_TAGS = ["script", "style", "head"]
DEFAULT_EXCLUDE_TAGS = ["script", "style", "head", "audio", "img", "picture", "source", "video"]

include_links: bool = field(default=True, kw_only=True)
exclude_tags: list[str] = field(
Expand Down
479 changes: 250 additions & 229 deletions poetry.lock

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ snowflake-sqlalchemy = { version = "^1.6.1", optional = true }
pinecone-client = { version = "^3", optional = true }
pymongo = { version = "^4.8.0", optional = true }
marqo = { version = "^3.7.0", optional = true }
redis = { version = "^4.6.0", optional = true }
redis = { version = "^5.1.0", optional = true }
opensearch-py = { version = "^2.3.1", optional = true }
pgvector = { version = ">=0.2.3,<0.4.0", optional = true }
psycopg2-binary = { version = "^2.9.9", optional = true }
Expand Down Expand Up @@ -69,7 +69,7 @@ exa-py = {version = "^1.1.4", optional = true}

# loaders
pandas = {version = "^1.3", optional = true}
pypdf = {version = "^3.9", optional = true}
pypdf = {version = "^5.0.1", optional = true}
pillow = {version = "^10.2.0", optional = true}
mail-parser = {version = "^3.15.0", optional = true}
filetype = {version = "^1.2", optional = true}
Expand Down Expand Up @@ -218,7 +218,7 @@ pytest-mock = "^3.1.4"
mongomock = "^4.1.2"

twine = "^5.1.1"
moto = {extras = ["dynamodb", "iotdata", "sqs"], version = "^4.2.13"}
moto = {extras = ["dynamodb", "iotdata", "sqs"], version = "^5.0.16"}
pytest-xdist = "^3.3.1"
pytest-cov = "^5.0.0"
pytest-env = "^1.1.1"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import boto3
import pytest
from moto import mock_sqs
from moto import mock_aws

from griptape.drivers.event_listener.amazon_sqs_event_listener_driver import AmazonSqsEventListenerDriver
from tests.mocks.mock_event import MockEvent
Expand All @@ -14,7 +14,7 @@ def _run_before_and_after_tests(self):

@pytest.fixture()
def driver(self):
mock = mock_sqs()
mock = mock_aws()
mock.start()

session = boto3.Session(region_name="us-east-1")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import boto3
import pytest
from moto import mock_iotdata
from moto import mock_aws

from griptape.drivers.event_listener.aws_iot_core_event_listener_driver import AwsIotCoreEventListenerDriver
from tests.mocks.mock_event import MockEvent
from tests.utils.aws import mock_aws_credentials


@mock_iotdata
@mock_aws
class TestAwsIotCoreEventListenerDriver:
@pytest.fixture(autouse=True)
def _run_before_and_after_tests(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import boto3
import pytest
from moto import mock_s3
from moto import mock_aws

from griptape.artifacts import InfoArtifact, ListArtifact, TextArtifact
from griptape.drivers import AmazonS3FileManagerDriver
Expand All @@ -18,7 +18,7 @@ def _set_aws_credentials(self):

@pytest.fixture()
def session(self):
mock = mock_s3()
mock = mock_aws()
mock.start()
yield boto3.Session(region_name="us-east-1")
mock.stop()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import boto3
import pytest
from moto import mock_dynamodb
from moto import mock_aws

from griptape.drivers import AmazonDynamoDbConversationMemoryDriver
from griptape.memory.structure import ConversationMemory
Expand All @@ -11,15 +11,18 @@

class TestDynamoDbConversationMemoryDriver:
DYNAMODB_TABLE_NAME = "griptape"
DYNAMODB_COMPOSITE_TABLE_NAME = "griptape_composite"
DYNAMODB_PARTITION_KEY = "entryId"
DYNAMODB_SORT_KEY = "sortKey"
AWS_REGION = "us-west-2"
VALUE_ATTRIBUTE_KEY = "foo"
PARTITION_KEY_VALUE = "bar"
SORT_KEY_VALUE = "baz"

@pytest.fixture(autouse=True)
def _run_before_and_after_tests(self):
mock_aws_credentials()
self.mock_dynamodb = mock_dynamodb()
self.mock_dynamodb = mock_aws()
self.mock_dynamodb.start()

dynamodb = boto3.Session(region_name=self.AWS_REGION).client("dynamodb")
Expand All @@ -30,9 +33,23 @@ def _run_before_and_after_tests(self):
BillingMode="PAY_PER_REQUEST",
)

dynamodb.create_table(
TableName=self.DYNAMODB_COMPOSITE_TABLE_NAME,
KeySchema=[
{"AttributeName": self.DYNAMODB_PARTITION_KEY, "KeyType": "HASH"},
{"AttributeName": self.DYNAMODB_SORT_KEY, "KeyType": "RANGE"},
],
AttributeDefinitions=[
{"AttributeName": self.DYNAMODB_PARTITION_KEY, "AttributeType": "S"},
{"AttributeName": self.DYNAMODB_SORT_KEY, "AttributeType": "S"},
],
BillingMode="PAY_PER_REQUEST",
)

yield

dynamodb.delete_table(TableName=self.DYNAMODB_TABLE_NAME)
dynamodb.delete_table(TableName=self.DYNAMODB_COMPOSITE_TABLE_NAME)
self.mock_dynamodb.stop()

def test_store(self):
Expand Down Expand Up @@ -62,27 +79,31 @@ def test_store(self):
def test_store_with_sort_key(self):
session = boto3.Session(region_name=self.AWS_REGION)
dynamodb = session.resource("dynamodb")
table = dynamodb.Table(self.DYNAMODB_TABLE_NAME)
table = dynamodb.Table(self.DYNAMODB_COMPOSITE_TABLE_NAME)
memory_driver = AmazonDynamoDbConversationMemoryDriver(
session=session,
table_name=self.DYNAMODB_TABLE_NAME,
table_name=self.DYNAMODB_COMPOSITE_TABLE_NAME,
partition_key=self.DYNAMODB_PARTITION_KEY,
value_attribute_key=self.VALUE_ATTRIBUTE_KEY,
partition_key_value=self.PARTITION_KEY_VALUE,
sort_key="sortKey",
sort_key_value="foo",
sort_key=self.DYNAMODB_SORT_KEY,
sort_key_value=self.SORT_KEY_VALUE,
)
memory = ConversationMemory(conversation_memory_driver=memory_driver)
pipeline = Pipeline(conversation_memory=memory)

pipeline.add_task(PromptTask("test"))

response = table.get_item(TableName=self.DYNAMODB_TABLE_NAME, Key={"entryId": "bar", "sortKey": "foo"})
response = table.get_item(
TableName=self.DYNAMODB_COMPOSITE_TABLE_NAME, Key={"entryId": "bar", "sortKey": "baz"}
)
assert "Item" not in response

pipeline.run()

response = table.get_item(TableName=self.DYNAMODB_TABLE_NAME, Key={"entryId": "bar", "sortKey": "foo"})
response = table.get_item(
TableName=self.DYNAMODB_COMPOSITE_TABLE_NAME, Key={"entryId": "bar", "sortKey": "baz"}
)
assert "Item" in response

def test_load(self):
Expand All @@ -109,12 +130,12 @@ def test_load(self):
def test_load_with_sort_key(self):
memory_driver = AmazonDynamoDbConversationMemoryDriver(
session=boto3.Session(region_name=self.AWS_REGION),
table_name=self.DYNAMODB_TABLE_NAME,
table_name=self.DYNAMODB_COMPOSITE_TABLE_NAME,
partition_key=self.DYNAMODB_PARTITION_KEY,
value_attribute_key=self.VALUE_ATTRIBUTE_KEY,
partition_key_value=self.PARTITION_KEY_VALUE,
sort_key="sortKey",
sort_key_value="foo",
sort_key=self.DYNAMODB_SORT_KEY,
sort_key_value=self.SORT_KEY_VALUE,
)
memory = ConversationMemory(conversation_memory_driver=memory_driver, meta={"foo": "bar"})
pipeline = Pipeline(conversation_memory=memory)
Expand Down
14 changes: 0 additions & 14 deletions tests/unit/drivers/prompt/test_anthropic_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ class TestAnthropicPromptDriver:
{
"description": "test description: foo",
"input_schema": {
"$id": "Input Schema",
"$schema": "http://json-schema.org/draft-07/schema#",
"additionalProperties": False,
"properties": {
"values": {
Expand All @@ -34,8 +32,6 @@ class TestAnthropicPromptDriver:
{
"description": "test description: foo",
"input_schema": {
"$id": "Input Schema",
"$schema": "http://json-schema.org/draft-07/schema#",
"additionalProperties": False,
"properties": {
"values": {
Expand All @@ -54,8 +50,6 @@ class TestAnthropicPromptDriver:
{
"description": "test description: foo",
"input_schema": {
"$id": "Input Schema",
"$schema": "http://json-schema.org/draft-07/schema#",
"additionalProperties": False,
"properties": {
"values": {
Expand All @@ -74,8 +68,6 @@ class TestAnthropicPromptDriver:
{
"description": "test description",
"input_schema": {
"$id": "Input Schema",
"$schema": "http://json-schema.org/draft-07/schema#",
"additionalProperties": False,
"properties": {},
"required": [],
Expand All @@ -86,8 +78,6 @@ class TestAnthropicPromptDriver:
{
"description": "test description",
"input_schema": {
"$id": "Input Schema",
"$schema": "http://json-schema.org/draft-07/schema#",
"additionalProperties": False,
"properties": {},
"required": [],
Expand All @@ -98,8 +88,6 @@ class TestAnthropicPromptDriver:
{
"description": "test description: foo",
"input_schema": {
"$id": "Input Schema",
"$schema": "http://json-schema.org/draft-07/schema#",
"additionalProperties": False,
"properties": {
"values": {
Expand All @@ -118,8 +106,6 @@ class TestAnthropicPromptDriver:
{
"description": "test description",
"input_schema": {
"$id": "Input Schema",
"$schema": "http://json-schema.org/draft-07/schema#",
"additionalProperties": False,
"properties": {
"values": {
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/loaders/test_pdf_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_load(self, loader, create_source):

artifacts = loader.load(source)

assert len(artifacts) == 151
assert len(artifacts) == 156
assert artifacts[0].value.startswith("Bitcoin: A Peer-to-Peer")
assert artifacts[-1].value.endswith('its applications," 1957.\n9')
assert artifacts[0].embedding == [0, 1]
Expand All @@ -37,7 +37,7 @@ def test_load_collection(self, loader, create_source):

for key in keys:
artifact = collection[key]
assert len(artifact) == 151
assert len(artifact) == 156
assert artifact[0].value.startswith("Bitcoin: A Peer-to-Peer")
assert artifact[-1].value.endswith('its applications," 1957.\n9')
assert artifact[0].embedding == [0, 1]

0 comments on commit 7e97136

Please sign in to comment.