From a0e80ca95c10232f67f760155febb0d49586deef Mon Sep 17 00:00:00 2001 From: Jeff Tang Date: Wed, 17 Jul 2024 16:02:58 -0700 Subject: [PATCH] code formatting fixes --- aimodels/client/multi_fm_client.py | 2 +- aimodels/providers/aws_bedrock_interface.py | 50 +++++++++++++-------- 2 files changed, 33 insertions(+), 19 deletions(-) diff --git a/aimodels/client/multi_fm_client.py b/aimodels/client/multi_fm_client.py index 820aab7d..939c9086 100644 --- a/aimodels/client/multi_fm_client.py +++ b/aimodels/client/multi_fm_client.py @@ -47,7 +47,7 @@ def __init__(self): "replicate": ReplicateInterface, "together": TogetherInterface, "octo": OctoInterface, - "aws": AWSBedrockInterface + "aws": AWSBedrockInterface, } def get_provider_interface(self, model): diff --git a/aimodels/providers/aws_bedrock_interface.py b/aimodels/providers/aws_bedrock_interface.py index 6fe01352..24fa5ad1 100644 --- a/aimodels/providers/aws_bedrock_interface.py +++ b/aimodels/providers/aws_bedrock_interface.py @@ -7,14 +7,15 @@ from ..framework.provider_interface import ProviderInterface + def convert_messages_to_llama3_prompt(messages): """ Convert a list of messages to a prompt in Llama 3 instruction format. - + Args: - messages (list of dict): List of messages where each message is a dictionary + messages (list of dict): List of messages where each message is a dictionary with 'role' ('system', 'user', 'assistant') and 'content'. - + Returns: str: Formatted prompt for Llama 3. """ @@ -23,8 +24,9 @@ def convert_messages_to_llama3_prompt(messages): prompt += f"<|start_header_id|>{message['role']}<|end_header_id|>{message['content']}<|eot_id|>\n" prompt += "<|start_header_id|>assistant<|end_header_id|>" - - return prompt + + return prompt + class RecursiveNamespace: """ @@ -32,12 +34,16 @@ class RecursiveNamespace: This class is used to simulate the OpenAI chat.completions.create's return type, so response.choices[0].message.content works consistenly for AWS Bedrock's LLM return of a string. """ + def __init__(self, data): for key, value in data.items(): if isinstance(value, dict): value = RecursiveNamespace(value) elif isinstance(value, list): - value = [RecursiveNamespace(item) if isinstance(item, dict) else item for item in value] + value = [ + RecursiveNamespace(item) if isinstance(item, dict) else item + for item in value + ] setattr(self, key, value) @classmethod @@ -50,10 +56,14 @@ def to_dict(self): if isinstance(value, RecursiveNamespace): value = value.to_dict() elif isinstance(value, list): - value = [item.to_dict() if isinstance(item, RecursiveNamespace) else item for item in value] + value = [ + item.to_dict() if isinstance(item, RecursiveNamespace) else item + for item in value + ] result[key] = value return result + class AWSBedrockInterface(ProviderInterface): """Implements the ProviderInterface for interacting with AWS Bedrock's APIs.""" @@ -64,7 +74,7 @@ def __init__(self): region_name="us-west-2", aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"), aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY"), - ) + ) def chat_completion_create(self, messages=None, model=None, temperature=0): """Request chat completions from the AWS Bedrock API. @@ -80,15 +90,19 @@ def chat_completion_create(self, messages=None, model=None, temperature=0): The API response with the completion result. """ - body = json.dumps({ - "prompt": convert_messages_to_llama3_prompt(messages), - "temperature": temperature - }) - accept = 'application/json' - content_type = 'application/json' - response = self.aws_bedrock_client.invoke_model(body=body, modelId=model, accept=accept, contentType=content_type) - response_body = json.loads(response.get('body').read()) - generation = response_body.get('generation') + body = json.dumps( + { + "prompt": convert_messages_to_llama3_prompt(messages), + "temperature": temperature, + } + ) + accept = "application/json" + content_type = "application/json" + response = self.aws_bedrock_client.invoke_model( + body=body, modelId=model, accept=accept, contentType=content_type + ) + response_body = json.loads(response.get("body").read()) + generation = response_body.get("generation") response_data = { "choices": [ @@ -96,6 +110,6 @@ def chat_completion_create(self, messages=None, model=None, temperature=0): "message": {"content": generation}, } ], - } + } return RecursiveNamespace.from_dict(response_data)