Skip to content

Commit

Permalink
update bedrock tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tcapelle committed Dec 20, 2024
1 parent 9cd9b61 commit 6f205bc
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 211 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ openai = ["openai>=1.0.0"]
pandas-test = ["pandas>=2.2.3"]
modal = ["modal", "python-dotenv"]
vertexai = ["vertexai>=1.70.0"]
bedrock = ["boto3"]
bedrock = ["boto3", "moto[bedrock]>=5.0.24"]
test = [
"nox",
"pytest>=8.2.0",
Expand Down
262 changes: 157 additions & 105 deletions tests/integrations/bedrock/bedrock_test.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import json

import boto3
import botocore
import pytest
from unittest.mock import patch

import weave
from weave.integrations.bedrock import patch_client

model_id = "anthropic.claude-3-5-sonnet-20240620-v1:0"


system_message = "You are an expert software engineer that knows a lot of programming. You prefer short answers."
messages = [
{
Expand All @@ -24,124 +23,177 @@
}
]

# Mock responses
MOCK_CONVERSE_RESPONSE = {
'ResponseMetadata': {
'RequestId': '917ceb8d-3a0a-4649-b3bb-527494c17a69',
'HTTPStatusCode': 200,
'HTTPHeaders': {
'date': 'Fri, 20 Dec 2024 16:44:08 GMT',
'content-type': 'application/json',
'content-length': '323',
'connection': 'keep-alive',
'x-amzn-requestid': '917ceb8d-3a0a-4649-b3bb-527494c17a69'
},
'RetryAttempts': 0
},
'output': {
'message': {
'role': 'assistant',
'content': [
{
'text': 'To list all text files in the current directory (excluding subdirectories) '
'that have been modified in the last month using Bash, you can use'
}
]
}
},
'stopReason': 'max_tokens',
'usage': {'inputTokens': 40, 'outputTokens': 30, 'totalTokens': 70},
'metrics': {'latencyMs': 838}
}

MOCK_STREAM_EVENTS = [
{'messageStart': {'role': 'assistant'}},
{'contentBlockDelta': {'delta': {'text': 'To'}, 'contentBlockIndex': 0}},
{'contentBlockDelta': {'delta': {'text': ' list all text files'}, 'contentBlockIndex': 0}},
{'contentBlockDelta': {'delta': {'text': ' in the current directory'}, 'contentBlockIndex': 0}},
{'contentBlockDelta': {'delta': {'text': ' modifie'}, 'contentBlockIndex': 0}},
{'contentBlockDelta': {'delta': {'text': 'd in the last month'}, 'contentBlockIndex': 0}},
{'contentBlockDelta': {'delta': {'text': ', use'}, 'contentBlockIndex': 0}},
{'contentBlockDelta': {'delta': {'text': ':'}, 'contentBlockIndex': 0}},
{'contentBlockDelta': {'delta': {'text': '\n\n```bash'}, 'contentBlockIndex': 0}},
{'contentBlockDelta': {'delta': {'text': '\nfind . -max'}, 'contentBlockIndex': 0}},
{'contentBlockDelta': {'delta': {'text': 'depth '}, 'contentBlockIndex': 0}},
{'contentBlockDelta': {'delta': {'text': '1'}, 'contentBlockIndex': 0}},
{'contentBlockStop': {'contentBlockIndex': 0}},
{'messageStop': {'stopReason': 'max_tokens'}},
{
'metadata': {
'usage': {
'inputTokens': 55,
'outputTokens': 30,
'totalTokens': 85
},
'metrics': {
'latencyMs': 926
}
}
},
]

def _remove_body_from_response(response):
if "body" in response:
response["body"] = None # Remove the body content
return response
# Original botocore _make_api_call function
orig = botocore.client.BaseClient._make_api_call

def mock_make_api_call(self, operation_name, kwarg):
if operation_name == 'Converse':
return MOCK_CONVERSE_RESPONSE
elif operation_name == 'ConverseStream':
class MockStream:
def __iter__(self):
for event in MOCK_STREAM_EVENTS:
yield event

@pytest.mark.skip_clickhouse_client
@pytest.mark.vcr(
filter_headers=[
"authorization",
"content-type",
"user-agent",
"x-amz-date",
"x-amz-security-token",
"x-amz-sso_bearer_token",
"amz-sdk-invocation-id",
],
before_record_response=_remove_body_from_response,
allowed_hosts=[
"api.wandb.ai",
"localhost",
],
)
def test_bedrock_converse(client: weave.trace.weave_client.WeaveClient) -> None:
bedrock_client = boto3.client("bedrock-runtime")
patch_client(bedrock_client)

response = bedrock_client.converse(
modelId=model_id,
system=[{"text": system_message}], # it needs a list for some reason
messages=messages,
inferenceConfig={"maxTokens": 30},
)

# Verify the response structure
assert response is not None
assert "output" in response
assert "message" in response["output"]
assert "content" in response["output"]["message"]
return {'stream': MockStream()}
return orig(self, operation_name, kwarg)


@pytest.mark.skip_clickhouse_client
@pytest.mark.vcr(
filter_headers=[
"authorization",
"content-type",
"user-agent",
"x-amz-date",
"x-amz-security-token",
"x-amz-sso_bearer_token",
"amz-sdk-invocation-id",
],
allowed_hosts=["api.wandb.ai", "localhost"],
)
def test_bedrock_invoke(client: weave.trace.weave_client.WeaveClient) -> None:
def test_bedrock_converse(client: weave.trace.weave_client.WeaveClient) -> None:
bedrock_client = boto3.client("bedrock-runtime")
patch_client(bedrock_client)

body = json.dumps(
{
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": 30,
"temperature": 0.7,
"messages": [{"role": "user", "content": "What is the capital of France?"}],
}
)

response = bedrock_client.invoke_model(
modelId=model_id,
body=body,
contentType="application/json",
accept="application/json",
with patch('botocore.client.BaseClient._make_api_call', new=mock_make_api_call):
response = bedrock_client.converse(
modelId=model_id,
system=[{"text": system_message}],
messages=messages,
inferenceConfig={"maxTokens": 30},
)

# Existing checks
assert response is not None
assert "output" in response
assert "message" in response["output"]
assert "content" in response["output"]["message"]

# Now verify that a trace was captured.
calls = list(client.calls())
assert len(calls) == 1, "Expected exactly one trace call"
call = calls[0]

assert call.exception is None
assert call.ended_at is not None

# Inspect the captured output if desired
output = call.output

# Confirm we see the same text as in the mock response
assert (
output["output"]["message"]["content"][0]["text"]
== "To list all text files in the current directory (excluding subdirectories) that have been modified in the last month using Bash, you can use"
)

invoke_output = json.loads(response.get("body").read())

# Verify the response structure
assert invoke_output is not None
assert "content" in invoke_output
# Check usage in a style similar to mistral tests
summary = call.summary
assert summary is not None, "Summary should not be None"
# We'll reference usage by the model_id
model_usage = summary["usage"][model_id]
assert model_usage["requests"] == 1, "Expected exactly one request increment"
# Map the tokens to pydantic usage fields
# "inputTokens" -> prompt_tokens, "outputTokens" -> completion_tokens
assert output["usage"]["inputTokens"] == model_usage["prompt_tokens"] == 40
assert output["usage"]["outputTokens"] == model_usage["completion_tokens"] == 30
assert output["usage"]["totalTokens"] == model_usage["total_tokens"] == 70


@pytest.mark.skip_clickhouse_client
@pytest.mark.vcr(
filter_headers=[
"authorization",
"content-type",
"user-agent",
"x-amz-date",
"x-amz-security-token",
"x-amz-sso_bearer_token",
"amz-sdk-invocation-id",
],
allowed_hosts=[
"api.wandb.ai",
"localhost",
],
)
def test_bedrock_converse_stream(client: weave.trace.weave_client.WeaveClient) -> None:
bedrock_client = boto3.client("bedrock-runtime")
patch_client(bedrock_client)

response = bedrock_client.converse_stream(
modelId=model_id,
system=[{"text": system_message}],
messages=messages,
inferenceConfig={"maxTokens": 30},
with patch('botocore.client.BaseClient._make_api_call', new=mock_make_api_call):
response = bedrock_client.converse_stream(
modelId=model_id,
system=[{"text": system_message}],
messages=messages,
inferenceConfig={"maxTokens": 30},
)

# Existing checks
stream = response.get('stream')
assert stream is not None, "Stream not found in response"

# Accumulate the streamed response
final_response = ""
for event in stream:
if 'contentBlockDelta' in event:
final_response += event['contentBlockDelta']['delta']['text']

assert final_response is not None

# Now verify that a trace was captured.
calls = list(client.calls())
assert len(calls) == 1, "Expected exactly one trace call for the stream test"
call = calls[0]

assert call.exception is None
assert call.ended_at is not None

output = call.output
# For a streaming response, you might confirm final partial text is present
# in the final output or usage data is recorded
print(output)

assert (
"To list all text files" in output["content"]
)
# Access the stream from the response
stream = response.get('stream')
assert stream is not None, "Stream not found in response"

# Accumulate the streamed response
final_response = None
for event in stream:
# Process each event (the accumulator handles this internally)
final_response = event # The final event after accumulation

assert final_response is not None
# assert "content" in final_response
# assert len(final_response["content"]) > 0

# Check usage in a style similar to mistral tests
summary = call.summary
assert summary is not None, "Summary should not be None"
model_usage = summary["usage"][model_id]
assert model_usage["requests"] == 1
assert output["usage"]["inputTokens"] == model_usage["prompt_tokens"] == 55
assert output["usage"]["outputTokens"] == model_usage["completion_tokens"] == 30
assert output["usage"]["totalTokens"] == model_usage["total_tokens"] == 85

This file was deleted.

Loading

0 comments on commit 6f205bc

Please sign in to comment.