Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

new provider together #10

Merged
merged 3 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .env.sample
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ MISTRAL_API_KEY=""
OPENAI_API_KEY=""
OLLAMA_API_URL="http://localhost:11434"
REPLICATE_API_KEY=""
TOGETHER_API_KEY=""
2 changes: 2 additions & 0 deletions aimodels/client/multi_fm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
OllamaInterface,
OpenAIInterface,
ReplicateInterface,
TogetherInterface,
)


Expand Down Expand Up @@ -42,6 +43,7 @@ def __init__(self):
"ollama": OllamaInterface,
"openai": OpenAIInterface,
"replicate": ReplicateInterface,
"together": TogetherInterface,
}

def get_provider_interface(self, model):
Expand Down
1 change: 1 addition & 0 deletions aimodels/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
from .ollama_interface import OllamaInterface
from .openai_interface import OpenAIInterface
from .replicate_interface import ReplicateInterface
from .together_interface import TogetherInterface
40 changes: 40 additions & 0 deletions aimodels/providers/together_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""The interface to the Together API."""

import os

from ..framework.provider_interface import ProviderInterface

_TOGETHER_BASE_URL = "https://api.together.xyz/v1"


class TogetherInterface(ProviderInterface):
"""Implements the ProviderInterface for interacting with Together's APIs."""

def __init__(self):
"""Set up the Together client using the API key obtained from the user's environment."""
from openai import OpenAI

self.together_client = OpenAI(
api_key=os.getenv("TOGETHER_API_KEY"),
base_url=_TOGETHER_BASE_URL,
)

def chat_completion_create(self, messages=None, model=None, temperature=0):
"""Request chat completions from the Together API.

Args:
----
model (str): Identifies the specific provider/model to use.
messages (list of dict): A list of message objects in chat history.
temperature (float): The temperature to use in the completion.

Returns:
-------
The API response with the completion result.

"""
return self.together_client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
)
91 changes: 28 additions & 63 deletions examples/multi_fm_client.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,15 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"id": "initial_id",
"metadata": {
"ExecuteTime": {
"end_time": "2024-07-04T15:30:02.064319Z",
"start_time": "2024-07-04T15:30:02.051986Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"import sys\n",
"sys.path.append('../aimodels')\n",
Expand All @@ -45,7 +34,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"id": "a54491b7-6aa9-4337-9aba-3a0aef263bb4",
"metadata": {},
"outputs": [],
Expand All @@ -54,12 +43,13 @@
"\n",
"os.environ['GROQ_API_KEY'] = 'xxx' # get a free key at https://console.groq.com/keys\n",
"os.environ['FIREWORKS_API_KEY'] = 'xxx' # get a free key at https://fireworks.ai/api-keys\n",
"os.environ['REPLICATE_API_KEY'] = 'xxx' # get a free key at https://replicate.com/account/api-tokens"
"os.environ['REPLICATE_API_KEY'] = 'xxx' # get a free key at https://replicate.com/account/api-tokens\n",
"os.environ['TOGETHER_API_KEY'] = 'xxx' # get a free key at https://api.together.ai/"
]
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"id": "4de3a24f",
"metadata": {
"ExecuteTime": {
Expand All @@ -79,34 +69,37 @@
"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4b3e6c41-070d-4041-9ed9-c8977790fe18",
"metadata": {},
"outputs": [],
"source": [
"together_llama3_8b = \"together:meta-llama/Llama-3-8b-chat-hf\"\n",
"#together_llama3_70b = \"together:meta-llama/Llama-3-70b-chat-hf\"\n",
"\n",
"response = client.chat.completions.create(model=together_llama3_8b, messages=messages)\n",
"\n",
"print(response.choices[0].message.content)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "668a6cfa-9011-480a-ae1b-6dbd6a51e716",
"metadata": {},
"outputs": [],
"source": [
"# !pip install fireworks-ai"
"#!pip install fireworks-ai"
]
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": null,
"id": "9900fdf3-a113-40fd-b42f-0e6d866838be",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Arrrr, listen close me hearty! Here be a joke fer ye:\n",
"\n",
"Why did the pirate quit his job?\n",
"\n",
"Because he was sick o' all the arrrr-guments! (get it? arguments, but with an \"arrr\" like a pirate says? aye, I thought it be a good one, matey!)\n"
]
}
],
"outputs": [],
"source": [
"fireworks_llama3_8b = \"fireworks:accounts/fireworks/models/llama-v3-8b-instruct\"\n",
"#fireworks_llama3_70b = \"fireworks:accounts/fireworks/models/llama-v3-70b-instruct\"\n",
Expand All @@ -118,24 +111,10 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": null,
"id": "c9b2aad6-8603-4227-9566-778f714eb0b5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Arrrr, listen close me hearty! Here be a joke fer ye:\n",
"\n",
"Why did the pirate quit his job?\n",
"\n",
"Because he were sick o' all the arrrr-guments! (get it? arguments, but with arrrr, like a pirate says \"arrgh\"! ahhahahah!)\n",
"\n",
"Yer turn, matey! Got a joke to share?\n"
]
}
],
"outputs": [],
"source": [
"groq_llama3_8b = \"groq:llama3-8b-8192\"\n",
"# groq_llama3_70b = \"groq:llama3-70b-8192\"\n",
Expand All @@ -147,24 +126,10 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": null,
"id": "6baf88b8-2ecb-4bdf-9263-4af949668d16",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Arrrr, listen close me hearty! Here be a joke fer ye:\n",
"\n",
"Why did the pirate quit his job?\n",
"\n",
"Because he were sick o' all the arrrr-guments! (get it? arguments, but with arrrr, like a pirate says \"arrgh\"! ahhahahah!)\n",
"\n",
"Yer turn, matey! Got a joke to share?\n"
]
}
],
"outputs": [],
"source": [
"replicate_llama3_8b = \"replicate:meta/meta-llama-3-8b-instruct\"\n",
"#replicate_llama3_70b = \"replicate:meta/meta-llama-3-70b-instruct\"\n",
Expand Down
Loading