diff --git a/pyproject.toml b/pyproject.toml index c86380c6ddc..0fef7c1aac8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,7 +84,7 @@ openai = ["openai>=1.0.0"] pandas-test = ["pandas>=2.2.3"] modal = ["modal", "python-dotenv"] vertexai = ["vertexai>=1.70.0"] -bedrock = ["boto3"] +bedrock = ["boto3", "moto[bedrock]>=5.0.24"] test = [ "nox", "pytest>=8.2.0", diff --git a/tests/integrations/bedrock/bedrock_test.py b/tests/integrations/bedrock/bedrock_test.py index 4e7328e373a..27965df1ca8 100644 --- a/tests/integrations/bedrock/bedrock_test.py +++ b/tests/integrations/bedrock/bedrock_test.py @@ -1,14 +1,13 @@ import json - import boto3 +import botocore import pytest +from unittest.mock import patch import weave from weave.integrations.bedrock import patch_client model_id = "anthropic.claude-3-5-sonnet-20240620-v1:0" - - system_message = "You are an expert software engineer that knows a lot of programming. You prefer short answers." messages = [ { @@ -24,124 +23,177 @@ } ] +# Mock responses +MOCK_CONVERSE_RESPONSE = { + 'ResponseMetadata': { + 'RequestId': '917ceb8d-3a0a-4649-b3bb-527494c17a69', + 'HTTPStatusCode': 200, + 'HTTPHeaders': { + 'date': 'Fri, 20 Dec 2024 16:44:08 GMT', + 'content-type': 'application/json', + 'content-length': '323', + 'connection': 'keep-alive', + 'x-amzn-requestid': '917ceb8d-3a0a-4649-b3bb-527494c17a69' + }, + 'RetryAttempts': 0 + }, + 'output': { + 'message': { + 'role': 'assistant', + 'content': [ + { + 'text': 'To list all text files in the current directory (excluding subdirectories) ' + 'that have been modified in the last month using Bash, you can use' + } + ] + } + }, + 'stopReason': 'max_tokens', + 'usage': {'inputTokens': 40, 'outputTokens': 30, 'totalTokens': 70}, + 'metrics': {'latencyMs': 838} +} + +MOCK_STREAM_EVENTS = [ + {'messageStart': {'role': 'assistant'}}, + {'contentBlockDelta': {'delta': {'text': 'To'}, 'contentBlockIndex': 0}}, + {'contentBlockDelta': {'delta': {'text': ' list all text files'}, 'contentBlockIndex': 0}}, + {'contentBlockDelta': {'delta': {'text': ' in the current directory'}, 'contentBlockIndex': 0}}, + {'contentBlockDelta': {'delta': {'text': ' modifie'}, 'contentBlockIndex': 0}}, + {'contentBlockDelta': {'delta': {'text': 'd in the last month'}, 'contentBlockIndex': 0}}, + {'contentBlockDelta': {'delta': {'text': ', use'}, 'contentBlockIndex': 0}}, + {'contentBlockDelta': {'delta': {'text': ':'}, 'contentBlockIndex': 0}}, + {'contentBlockDelta': {'delta': {'text': '\n\n```bash'}, 'contentBlockIndex': 0}}, + {'contentBlockDelta': {'delta': {'text': '\nfind . -max'}, 'contentBlockIndex': 0}}, + {'contentBlockDelta': {'delta': {'text': 'depth '}, 'contentBlockIndex': 0}}, + {'contentBlockDelta': {'delta': {'text': '1'}, 'contentBlockIndex': 0}}, + {'contentBlockStop': {'contentBlockIndex': 0}}, + {'messageStop': {'stopReason': 'max_tokens'}}, + { + 'metadata': { + 'usage': { + 'inputTokens': 55, + 'outputTokens': 30, + 'totalTokens': 85 + }, + 'metrics': { + 'latencyMs': 926 + } + } + }, +] -def _remove_body_from_response(response): - if "body" in response: - response["body"] = None # Remove the body content - return response +# Original botocore _make_api_call function +orig = botocore.client.BaseClient._make_api_call +def mock_make_api_call(self, operation_name, kwarg): + if operation_name == 'Converse': + return MOCK_CONVERSE_RESPONSE + elif operation_name == 'ConverseStream': + class MockStream: + def __iter__(self): + for event in MOCK_STREAM_EVENTS: + yield event -@pytest.mark.skip_clickhouse_client -@pytest.mark.vcr( - filter_headers=[ - "authorization", - "content-type", - "user-agent", - "x-amz-date", - "x-amz-security-token", - "x-amz-sso_bearer_token", - "amz-sdk-invocation-id", - ], - before_record_response=_remove_body_from_response, - allowed_hosts=[ - "api.wandb.ai", - "localhost", - ], -) -def test_bedrock_converse(client: weave.trace.weave_client.WeaveClient) -> None: - bedrock_client = boto3.client("bedrock-runtime") - patch_client(bedrock_client) - - response = bedrock_client.converse( - modelId=model_id, - system=[{"text": system_message}], # it needs a list for some reason - messages=messages, - inferenceConfig={"maxTokens": 30}, - ) - - # Verify the response structure - assert response is not None - assert "output" in response - assert "message" in response["output"] - assert "content" in response["output"]["message"] + return {'stream': MockStream()} + return orig(self, operation_name, kwarg) @pytest.mark.skip_clickhouse_client -@pytest.mark.vcr( - filter_headers=[ - "authorization", - "content-type", - "user-agent", - "x-amz-date", - "x-amz-security-token", - "x-amz-sso_bearer_token", - "amz-sdk-invocation-id", - ], - allowed_hosts=["api.wandb.ai", "localhost"], -) -def test_bedrock_invoke(client: weave.trace.weave_client.WeaveClient) -> None: +def test_bedrock_converse(client: weave.trace.weave_client.WeaveClient) -> None: bedrock_client = boto3.client("bedrock-runtime") patch_client(bedrock_client) - body = json.dumps( - { - "anthropic_version": "bedrock-2023-05-31", - "max_tokens": 30, - "temperature": 0.7, - "messages": [{"role": "user", "content": "What is the capital of France?"}], - } - ) - - response = bedrock_client.invoke_model( - modelId=model_id, - body=body, - contentType="application/json", - accept="application/json", + with patch('botocore.client.BaseClient._make_api_call', new=mock_make_api_call): + response = bedrock_client.converse( + modelId=model_id, + system=[{"text": system_message}], + messages=messages, + inferenceConfig={"maxTokens": 30}, + ) + + # Existing checks + assert response is not None + assert "output" in response + assert "message" in response["output"] + assert "content" in response["output"]["message"] + + # Now verify that a trace was captured. + calls = list(client.calls()) + assert len(calls) == 1, "Expected exactly one trace call" + call = calls[0] + + assert call.exception is None + assert call.ended_at is not None + + # Inspect the captured output if desired + output = call.output + + # Confirm we see the same text as in the mock response + assert ( + output["output"]["message"]["content"][0]["text"] + == "To list all text files in the current directory (excluding subdirectories) that have been modified in the last month using Bash, you can use" ) - invoke_output = json.loads(response.get("body").read()) - - # Verify the response structure - assert invoke_output is not None - assert "content" in invoke_output + # Check usage in a style similar to mistral tests + summary = call.summary + assert summary is not None, "Summary should not be None" + # We'll reference usage by the model_id + model_usage = summary["usage"][model_id] + assert model_usage["requests"] == 1, "Expected exactly one request increment" + # Map the tokens to pydantic usage fields + # "inputTokens" -> prompt_tokens, "outputTokens" -> completion_tokens + assert output["usage"]["inputTokens"] == model_usage["prompt_tokens"] == 40 + assert output["usage"]["outputTokens"] == model_usage["completion_tokens"] == 30 + assert output["usage"]["totalTokens"] == model_usage["total_tokens"] == 70 @pytest.mark.skip_clickhouse_client -@pytest.mark.vcr( - filter_headers=[ - "authorization", - "content-type", - "user-agent", - "x-amz-date", - "x-amz-security-token", - "x-amz-sso_bearer_token", - "amz-sdk-invocation-id", - ], - allowed_hosts=[ - "api.wandb.ai", - "localhost", - ], -) def test_bedrock_converse_stream(client: weave.trace.weave_client.WeaveClient) -> None: bedrock_client = boto3.client("bedrock-runtime") patch_client(bedrock_client) - response = bedrock_client.converse_stream( - modelId=model_id, - system=[{"text": system_message}], - messages=messages, - inferenceConfig={"maxTokens": 30}, + with patch('botocore.client.BaseClient._make_api_call', new=mock_make_api_call): + response = bedrock_client.converse_stream( + modelId=model_id, + system=[{"text": system_message}], + messages=messages, + inferenceConfig={"maxTokens": 30}, + ) + + # Existing checks + stream = response.get('stream') + assert stream is not None, "Stream not found in response" + + # Accumulate the streamed response + final_response = "" + for event in stream: + if 'contentBlockDelta' in event: + final_response += event['contentBlockDelta']['delta']['text'] + + assert final_response is not None + + # Now verify that a trace was captured. + calls = list(client.calls()) + assert len(calls) == 1, "Expected exactly one trace call for the stream test" + call = calls[0] + + assert call.exception is None + assert call.ended_at is not None + + output = call.output + # For a streaming response, you might confirm final partial text is present + # in the final output or usage data is recorded + print(output) + + assert ( + "To list all text files" in output["content"] ) - # Access the stream from the response - stream = response.get('stream') - assert stream is not None, "Stream not found in response" - - # Accumulate the streamed response - final_response = None - for event in stream: - # Process each event (the accumulator handles this internally) - final_response = event # The final event after accumulation - - assert final_response is not None - # assert "content" in final_response - # assert len(final_response["content"]) > 0 \ No newline at end of file + + # Check usage in a style similar to mistral tests + summary = call.summary + assert summary is not None, "Summary should not be None" + model_usage = summary["usage"][model_id] + assert model_usage["requests"] == 1 + assert output["usage"]["inputTokens"] == model_usage["prompt_tokens"] == 55 + assert output["usage"]["outputTokens"] == model_usage["completion_tokens"] == 30 + assert output["usage"]["totalTokens"] == model_usage["total_tokens"] == 85 diff --git a/tests/integrations/bedrock/cassettes/bedrock_test/test_bedrock_converse.yaml b/tests/integrations/bedrock/cassettes/bedrock_test/test_bedrock_converse.yaml deleted file mode 100644 index 0d89c6e5be3..00000000000 --- a/tests/integrations/bedrock/cassettes/bedrock_test/test_bedrock_converse.yaml +++ /dev/null @@ -1,65 +0,0 @@ -interactions: -- request: - body: null - headers: - amz-sdk-request: - - !!binary | - YXR0ZW1wdD0x - method: GET - uri: https://portal.sso.us-east-2.amazonaws.com/federation/credentials?role_name=SageMaker&account_id=372108735839 - response: - body: null - headers: - Access-Control-Expose-Headers: - - RequestId - - x-amzn-RequestId - Cache-Control: - - no-cache - Connection: - - keep-alive - Content-Length: - - '1064' - Content-Type: - - application/json - Date: - - Thu, 05 Dec 2024 09:53:43 GMT - RequestId: - - 82774d9d-05c6-4030-bebc-165f1e1b7b73 - Server: - - AWS SSO - x-amzn-RequestId: - - 82774d9d-05c6-4030-bebc-165f1e1b7b73 - status: - code: 200 - message: OK -- request: - body: '{"system": [{"text": "You are an expert software engineer that knows a - lot of programming. You prefer short answers."}], "messages": [{"role": "user", - "content": [{"text": "In Bash, how do I list all text files in the current directory - (excluding subdirectories) that have been modified in the last month?"}]}], - "inferenceConfig": {"maxTokens": 30}}' - headers: - Content-Length: - - '349' - amz-sdk-request: - - !!binary | - YXR0ZW1wdD0x - method: POST - uri: https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-5-sonnet-20240620-v1%3A0/converse - response: - body: null - headers: - Connection: - - keep-alive - Content-Length: - - '319' - Content-Type: - - application/json - Date: - - Thu, 05 Dec 2024 09:53:45 GMT - x-amzn-RequestId: - - 776196f3-ea9a-455a-b64b-53d171e5a866 - status: - code: 200 - message: OK -version: 1 diff --git a/tests/integrations/bedrock/cassettes/bedrock_test/test_bedrock_invoke.yaml b/tests/integrations/bedrock/cassettes/bedrock_test/test_bedrock_invoke.yaml deleted file mode 100644 index c41d34fb8b7..00000000000 --- a/tests/integrations/bedrock/cassettes/bedrock_test/test_bedrock_invoke.yaml +++ /dev/null @@ -1,40 +0,0 @@ -interactions: -- request: - body: '{"anthropic_version": "bedrock-2023-05-31", "max_tokens": 30, "temperature": - 0.7, "messages": [{"role": "user", "content": "What is the capital of France?"}]}' - headers: - Accept: - - !!binary | - YXBwbGljYXRpb24vanNvbg== - Content-Length: - - '158' - amz-sdk-request: - - !!binary | - YXR0ZW1wdD0x - method: POST - uri: https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-5-sonnet-20240620-v1%3A0/invoke - response: - body: - string: '{"id":"msg_bdrk_01BgG5ZL1hCouMqNyv7ZmLHn","type":"message","role":"assistant","model":"claude-3-5-sonnet-20240620","content":[{"type":"text","text":"The - capital of France is Paris."}],"stop_reason":"end_turn","stop_sequence":null,"usage":{"input_tokens":14,"output_tokens":10}}' - headers: - Connection: - - keep-alive - Content-Length: - - '277' - Content-Type: - - application/json - Date: - - Thu, 05 Dec 2024 09:53:45 GMT - X-Amzn-Bedrock-Input-Token-Count: - - '14' - X-Amzn-Bedrock-Invocation-Latency: - - '373' - X-Amzn-Bedrock-Output-Token-Count: - - '10' - x-amzn-RequestId: - - 8264bbb0-26d8-4c90-bc4d-db3de3de6d80 - status: - code: 200 - message: OK -version: 1