From e6f5b01eb801aad789c63535d2b927b28e1d524f Mon Sep 17 00:00:00 2001 From: watany <76135106+watany-dev@users.noreply.github.com> Date: Fri, 15 Nov 2024 06:13:44 +0000 Subject: [PATCH] fix(bedrock): Correct Cross-Region Inference Identifiers --- libs/aws/langchain_aws/llms/bedrock.py | 2 +- .../unit_tests/chat_models/test_bedrock.py | 18 ++++++++++++++---- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/libs/aws/langchain_aws/llms/bedrock.py b/libs/aws/langchain_aws/llms/bedrock.py index 1348e7f6..beac164f 100644 --- a/libs/aws/langchain_aws/llms/bedrock.py +++ b/libs/aws/langchain_aws/llms/bedrock.py @@ -728,7 +728,7 @@ def _get_provider(self) -> str: parts = self.model_id.split(".", maxsplit=2) return ( parts[1] - if (len(parts) > 1 and parts[0].lower() in {"eu", "us", "ap", "sa"}) + if (len(parts) > 1 and parts[0].lower() in {"eu", "us", "apac", "sa"}) else parts[0] ) diff --git a/libs/aws/tests/unit_tests/chat_models/test_bedrock.py b/libs/aws/tests/unit_tests/chat_models/test_bedrock.py index 2a869f36..bf445707 100644 --- a/libs/aws/tests/unit_tests/chat_models/test_bedrock.py +++ b/libs/aws/tests/unit_tests/chat_models/test_bedrock.py @@ -429,30 +429,40 @@ def test_standard_tracing_params() -> None: @pytest.mark.parametrize( - "model_id, provider, expected_provider, expectation", + "model_id, provider, expected_provider, expectation, region_name", [ ( "eu.anthropic.claude-3-haiku-20240307-v1:0", None, "anthropic", nullcontext(), + "us-west-2", ), - ("meta.llama3-1-405b-instruct-v1:0", None, "meta", nullcontext()), + ( + "apac.anthropic.claude-3-5-sonnet-20240620-v1:0", + None, + "anthropic", + nullcontext(), + "ap-northeast-1", + ), + ("meta.llama3-1-405b-instruct-v1:0", None, "meta", nullcontext(), "us-west-2"), ( "arn:aws:bedrock:us-east-1::custom-model/cohere.command-r-v1:0/MyCustomModel2", "cohere", "cohere", nullcontext(), + "us-west-2", ), ( "arn:aws:bedrock:us-east-1::custom-model/cohere.command-r-v1:0/MyCustomModel2", None, "cohere", pytest.raises(ValueError), + "us-west-2", ), ], ) -def test__get_provider(model_id, provider, expected_provider, expectation) -> None: - llm = ChatBedrock(model_id=model_id, provider=provider, region_name="us-west-2") +def test__get_provider(model_id, provider, expected_provider, expectation, region_name) -> None: + llm = ChatBedrock(model_id=model_id, provider=provider, region_name=region_name) with expectation: assert llm._get_provider() == expected_provider