From 75609fff10284380b8884d7475dfb4ab7445bce8 Mon Sep 17 00:00:00 2001 From: Jeff Tang Date: Fri, 12 Jul 2024 11:53:18 -0700 Subject: [PATCH 1/3] new provider together --- .env.sample | 1 + aimodels/client/multi_fm_client.py | 2 + aimodels/providers/__init__.py | 1 + aimodels/providers/together_interface.py | 38 ++++++++++ examples/multi_fm_client.ipynb | 91 ++++++++---------------- 5 files changed, 70 insertions(+), 63 deletions(-) create mode 100644 aimodels/providers/together_interface.py diff --git a/.env.sample b/.env.sample index ebd89f42..cfaf8080 100644 --- a/.env.sample +++ b/.env.sample @@ -5,3 +5,4 @@ MISTRAL_API_KEY="" OPENAI_API_KEY="" OLLAMA_API_URL="http://localhost:11434" REPLICATE_API_KEY="" +TOGETHER_API_KEY="" \ No newline at end of file diff --git a/aimodels/client/multi_fm_client.py b/aimodels/client/multi_fm_client.py index f3fd7521..14e9915c 100644 --- a/aimodels/client/multi_fm_client.py +++ b/aimodels/client/multi_fm_client.py @@ -9,6 +9,7 @@ OllamaInterface, OpenAIInterface, ReplicateInterface, + TogetherInterface, ) @@ -42,6 +43,7 @@ def __init__(self): "ollama": OllamaInterface, "openai": OpenAIInterface, "replicate": ReplicateInterface, + "together": TogetherInterface, } def get_provider_interface(self, model): diff --git a/aimodels/providers/__init__.py b/aimodels/providers/__init__.py index f32c433a..3cc97943 100644 --- a/aimodels/providers/__init__.py +++ b/aimodels/providers/__init__.py @@ -7,3 +7,4 @@ from .ollama_interface import OllamaInterface from .openai_interface import OpenAIInterface from .replicate_interface import ReplicateInterface +from .together_interface import TogetherInterface \ No newline at end of file diff --git a/aimodels/providers/together_interface.py b/aimodels/providers/together_interface.py new file mode 100644 index 00000000..8a590edf --- /dev/null +++ b/aimodels/providers/together_interface.py @@ -0,0 +1,38 @@ +"""The interface to the Groq API.""" + +import os + +from ..framework.provider_interface import ProviderInterface + +_TOGETHER_BASE_URL = "https://api.together.xyz/v1" + +class TogetherInterface(ProviderInterface): + """Implements the ProviderInterface for interacting with Together's APIs.""" + + def __init__(self): + """Set up the Together client using the API key obtained from the user's environment.""" + from openai import OpenAI + + self.together_client = OpenAI( + api_key=os.getenv("TOGETHER_API_KEY"), + base_url=_TOGETHER_BASE_URL, + ) + def chat_completion_create(self, messages=None, model=None, temperature=0): + """Request chat completions from the Together API. + + Args: + ---- + model (str): Identifies the specific provider/model to use. + messages (list of dict): A list of message objects in chat history. + temperature (float): The temperature to use in the completion. + + Returns: + ------- + The API response with the completion result. + + """ + return self.together_client.chat.completions.create( + model=model, + messages=messages, + temperature=temperature, + ) diff --git a/examples/multi_fm_client.ipynb b/examples/multi_fm_client.ipynb index d7f4f89a..31a6edef 100644 --- a/examples/multi_fm_client.ipynb +++ b/examples/multi_fm_client.ipynb @@ -14,7 +14,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "initial_id", "metadata": { "ExecuteTime": { @@ -22,18 +22,7 @@ "start_time": "2024-07-04T15:30:02.051986Z" } }, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 1, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "import sys\n", "sys.path.append('../aimodels')\n", @@ -45,7 +34,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "a54491b7-6aa9-4337-9aba-3a0aef263bb4", "metadata": {}, "outputs": [], @@ -54,12 +43,13 @@ "\n", "os.environ['GROQ_API_KEY'] = 'xxx' # get a free key at https://console.groq.com/keys\n", "os.environ['FIREWORKS_API_KEY'] = 'xxx' # get a free key at https://fireworks.ai/api-keys\n", - "os.environ['REPLICATE_API_KEY'] = 'xxx' # get a free key at https://replicate.com/account/api-tokens" + "os.environ['REPLICATE_API_KEY'] = 'xxx' # get a free key at https://replicate.com/account/api-tokens\n", + "os.environ['TOGETHER_API_KEY'] = 'xxx' # get a free key at https://api.together.ai/" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "id": "4de3a24f", "metadata": { "ExecuteTime": { @@ -79,6 +69,21 @@ "]" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "4b3e6c41-070d-4041-9ed9-c8977790fe18", + "metadata": {}, + "outputs": [], + "source": [ + "together_llama3_8b = \"together:meta-llama/Llama-3-8b-chat-hf\"\n", + "#together_llama3_70b = \"together:meta-llama/Llama-3-70b-chat-hf\"\n", + "\n", + "response = client.chat.completions.create(model=together_llama3_8b, messages=messages)\n", + "\n", + "print(response.choices[0].message.content)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -86,27 +91,15 @@ "metadata": {}, "outputs": [], "source": [ - "# !pip install fireworks-ai" + "#!pip install fireworks-ai" ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "id": "9900fdf3-a113-40fd-b42f-0e6d866838be", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Arrrr, listen close me hearty! Here be a joke fer ye:\n", - "\n", - "Why did the pirate quit his job?\n", - "\n", - "Because he was sick o' all the arrrr-guments! (get it? arguments, but with an \"arrr\" like a pirate says? aye, I thought it be a good one, matey!)\n" - ] - } - ], + "outputs": [], "source": [ "fireworks_llama3_8b = \"fireworks:accounts/fireworks/models/llama-v3-8b-instruct\"\n", "#fireworks_llama3_70b = \"fireworks:accounts/fireworks/models/llama-v3-70b-instruct\"\n", @@ -118,24 +111,10 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "id": "c9b2aad6-8603-4227-9566-778f714eb0b5", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Arrrr, listen close me hearty! Here be a joke fer ye:\n", - "\n", - "Why did the pirate quit his job?\n", - "\n", - "Because he were sick o' all the arrrr-guments! (get it? arguments, but with arrrr, like a pirate says \"arrgh\"! ahhahahah!)\n", - "\n", - "Yer turn, matey! Got a joke to share?\n" - ] - } - ], + "outputs": [], "source": [ "groq_llama3_8b = \"groq:llama3-8b-8192\"\n", "# groq_llama3_70b = \"groq:llama3-70b-8192\"\n", @@ -147,24 +126,10 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "id": "6baf88b8-2ecb-4bdf-9263-4af949668d16", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Arrrr, listen close me hearty! Here be a joke fer ye:\n", - "\n", - "Why did the pirate quit his job?\n", - "\n", - "Because he were sick o' all the arrrr-guments! (get it? arguments, but with arrrr, like a pirate says \"arrgh\"! ahhahahah!)\n", - "\n", - "Yer turn, matey! Got a joke to share?\n" - ] - } - ], + "outputs": [], "source": [ "replicate_llama3_8b = \"replicate:meta/meta-llama-3-8b-instruct\"\n", "#replicate_llama3_70b = \"replicate:meta/meta-llama-3-70b-instruct\"\n", From b848645d8f3d62038e4d570494646aa6a8c029a2 Mon Sep 17 00:00:00 2001 From: Jeff Tang Date: Fri, 12 Jul 2024 16:09:34 -0700 Subject: [PATCH 2/3] formatting fixes --- aimodels/providers/__init__.py | 2 +- aimodels/providers/together_interface.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/aimodels/providers/__init__.py b/aimodels/providers/__init__.py index 3cc97943..d497dbc1 100644 --- a/aimodels/providers/__init__.py +++ b/aimodels/providers/__init__.py @@ -7,4 +7,4 @@ from .ollama_interface import OllamaInterface from .openai_interface import OpenAIInterface from .replicate_interface import ReplicateInterface -from .together_interface import TogetherInterface \ No newline at end of file +from .together_interface import TogetherInterface diff --git a/aimodels/providers/together_interface.py b/aimodels/providers/together_interface.py index 8a590edf..c005cbf4 100644 --- a/aimodels/providers/together_interface.py +++ b/aimodels/providers/together_interface.py @@ -6,6 +6,7 @@ _TOGETHER_BASE_URL = "https://api.together.xyz/v1" + class TogetherInterface(ProviderInterface): """Implements the ProviderInterface for interacting with Together's APIs.""" @@ -17,6 +18,7 @@ def __init__(self): api_key=os.getenv("TOGETHER_API_KEY"), base_url=_TOGETHER_BASE_URL, ) + def chat_completion_create(self, messages=None, model=None, temperature=0): """Request chat completions from the Together API. From 93e6bff6effd5a961bf5f99b77d2be52d8e89034 Mon Sep 17 00:00:00 2001 From: Kevin Solorio <103829+ksolo@users.noreply.github.com> Date: Fri, 12 Jul 2024 18:14:33 -0500 Subject: [PATCH 3/3] Update together_interface.py --- aimodels/providers/together_interface.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aimodels/providers/together_interface.py b/aimodels/providers/together_interface.py index c005cbf4..039eae4c 100644 --- a/aimodels/providers/together_interface.py +++ b/aimodels/providers/together_interface.py @@ -1,4 +1,4 @@ -"""The interface to the Groq API.""" +"""The interface to the Together API.""" import os