Skip to content

Commit

Permalink
Merge pull request #13 from andrewyng/octo
Browse files Browse the repository at this point in the history
added Octo.ai provider
  • Loading branch information
ksolo authored Jul 19, 2024
2 parents 5521ac6 + a0e80ca commit d48b5da
Show file tree
Hide file tree
Showing 8 changed files with 278 additions and 4 deletions.
5 changes: 4 additions & 1 deletion .env.sample
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,7 @@ MISTRAL_API_KEY=""
OPENAI_API_KEY=""
OLLAMA_API_URL="http://localhost:11434"
REPLICATE_API_KEY=""
TOGETHER_API_KEY=""
TOGETHER_API_KEY=""
OCTO_API_KEY=""
AWS_ACCESS_KEY_ID=""
AWS_SECRET_ACCESS_KEY=""
4 changes: 4 additions & 0 deletions aimodels/client/multi_fm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
OpenAIInterface,
ReplicateInterface,
TogetherInterface,
OctoInterface,
AWSBedrockInterface,
)


Expand Down Expand Up @@ -44,6 +46,8 @@ def __init__(self):
"openai": OpenAIInterface,
"replicate": ReplicateInterface,
"together": TogetherInterface,
"octo": OctoInterface,
"aws": AWSBedrockInterface,
}

def get_provider_interface(self, model):
Expand Down
2 changes: 2 additions & 0 deletions aimodels/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@
from .openai_interface import OpenAIInterface
from .replicate_interface import ReplicateInterface
from .together_interface import TogetherInterface
from .octo_interface import OctoInterface
from .aws_bedrock_interface import AWSBedrockInterface
115 changes: 115 additions & 0 deletions aimodels/providers/aws_bedrock_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""The interface to the Together API."""

import os
from urllib.request import urlopen
import boto3
import json

from ..framework.provider_interface import ProviderInterface


def convert_messages_to_llama3_prompt(messages):
"""
Convert a list of messages to a prompt in Llama 3 instruction format.
Args:
messages (list of dict): List of messages where each message is a dictionary
with 'role' ('system', 'user', 'assistant') and 'content'.
Returns:
str: Formatted prompt for Llama 3.
"""
prompt = "<|begin_of_text|>"
for message in messages:
prompt += f"<|start_header_id|>{message['role']}<|end_header_id|>{message['content']}<|eot_id|>\n"

prompt += "<|start_header_id|>assistant<|end_header_id|>"

return prompt


class RecursiveNamespace:
"""
Convert dictionaries to objects with attribute access, including nested dictionaries.
This class is used to simulate the OpenAI chat.completions.create's return type, so
response.choices[0].message.content works consistenly for AWS Bedrock's LLM return of a string.
"""

def __init__(self, data):
for key, value in data.items():
if isinstance(value, dict):
value = RecursiveNamespace(value)
elif isinstance(value, list):
value = [
RecursiveNamespace(item) if isinstance(item, dict) else item
for item in value
]
setattr(self, key, value)

@classmethod
def from_dict(cls, data):
return cls(data)

def to_dict(self):
result = {}
for key, value in self.__dict__.items():
if isinstance(value, RecursiveNamespace):
value = value.to_dict()
elif isinstance(value, list):
value = [
item.to_dict() if isinstance(item, RecursiveNamespace) else item
for item in value
]
result[key] = value
return result


class AWSBedrockInterface(ProviderInterface):
"""Implements the ProviderInterface for interacting with AWS Bedrock's APIs."""

def __init__(self):
"""Set up the AWS Bedrock client using the AWS access key id and secret access key obtained from the user's environment."""
self.aws_bedrock_client = boto3.client(
service_name="bedrock-runtime",
region_name="us-west-2",
aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"),
aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY"),
)

def chat_completion_create(self, messages=None, model=None, temperature=0):
"""Request chat completions from the AWS Bedrock API.
Args:
----
model (str): Identifies the specific provider/model to use.
messages (list of dict): A list of message objects in chat history.
temperature (float): The temperature to use in the completion.
Returns:
-------
The API response with the completion result.
"""
body = json.dumps(
{
"prompt": convert_messages_to_llama3_prompt(messages),
"temperature": temperature,
}
)
accept = "application/json"
content_type = "application/json"
response = self.aws_bedrock_client.invoke_model(
body=body, modelId=model, accept=accept, contentType=content_type
)
response_body = json.loads(response.get("body").read())
generation = response_body.get("generation")

response_data = {
"choices": [
{
"message": {"content": generation},
}
],
}

return RecursiveNamespace.from_dict(response_data)
40 changes: 40 additions & 0 deletions aimodels/providers/octo_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""The interface to the Octo API."""

import os

from ..framework.provider_interface import ProviderInterface

_OCTO_BASE_URL = "https://text.octoai.run/v1"


class OctoInterface(ProviderInterface):
"""Implements the ProviderInterface for interacting with Octo's APIs."""

def __init__(self):
"""Set up the Octo client using the API key obtained from the user's environment."""
from openai import OpenAI

self.octo_client = OpenAI(
api_key=os.getenv("OCTO_API_KEY"),
base_url=_OCTO_BASE_URL,
)

def chat_completion_create(self, messages=None, model=None, temperature=0):
"""Request chat completions from the Together API.
Args:
----
model (str): Identifies the specific provider/model to use.
messages (list of dict): A list of message objects in chat history.
temperature (float): The temperature to use in the completion.
Returns:
-------
The API response with the completion result.
"""
return self.octo_client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
)
47 changes: 45 additions & 2 deletions examples/multi_fm_client.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"outputs": [],
"source": [
"import sys\n",
"sys.path.append('../aimodels')\n",
"sys.path.append('../../aimodels')\n",
"\n",
"from dotenv import load_dotenv, find_dotenv\n",
"\n",
Expand All @@ -44,7 +44,10 @@
"os.environ['GROQ_API_KEY'] = 'xxx' # get a free key at https://console.groq.com/keys\n",
"os.environ['FIREWORKS_API_KEY'] = 'xxx' # get a free key at https://fireworks.ai/api-keys\n",
"os.environ['REPLICATE_API_KEY'] = 'xxx' # get a free key at https://replicate.com/account/api-tokens\n",
"os.environ['TOGETHER_API_KEY'] = 'xxx' # get a free key at https://api.together.ai/"
"os.environ['TOGETHER_API_KEY'] = 'xxx' # get a free key at https://api.together.ai\n",
"os.environ['OCTO_API_KEY'] = 'xxx' # get a free key at https://octoai.cloud/settings\n",
"os.environ['AWS_ACCESS_KEY_ID'] = 'xxx' # get or create at https://console.aws.amazon.com/iam/home\n",
"os.environ['AWS_SECRET_ACCESS_KEY'] = 'xxx' # get or create at https://console.aws.amazon.com/iam/home"
]
},
{
Expand All @@ -69,6 +72,46 @@
"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1ffe9a49-638e-4304-b9de-49ee21d9ac8d",
"metadata": {},
"outputs": [],
"source": [
"#!pip install boto3"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9893c7e4-799a-42c9-84de-f9e643044462",
"metadata": {},
"outputs": [],
"source": [
"aws_bedrock_llama3_8b = \"aws:meta.llama3-8b-instruct-v1:0\"\n",
"#aws_bedrock_llama3_8b = \"aws:meta.llama3-70b-instruct-v1:0\"\n",
"\n",
"response = client.chat.completions.create(model=aws_bedrock_llama3_8b, messages=messages)\n",
"\n",
"print(response.choices[0].message.content)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5388efc4-3fd2-4dc6-ab58-7b179ce07943",
"metadata": {},
"outputs": [],
"source": [
"octo_llama3_8b = \"octo:meta-llama-3-8b-instruct\"\n",
"#octo_llama3_70b = \"octo:meta-llama-3-70b-instruct\"\n",
"\n",
"response = client.chat.completions.create(model=octo_llama3_8b, messages=messages)\n",
"\n",
"print(response.choices[0].message.content)"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
68 changes: 67 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ anthropic = "^0.30.1"
notebook = "^7.2.1"
ollama = "^0.2.1"
mistralai = "^0.4.2"
boto3 = "^1.34.144"

[build-system]
requires = ["poetry-core"]
Expand Down

0 comments on commit d48b5da

Please sign in to comment.