diff --git a/.env.sample b/.env.sample index 906f1b0d..ebd89f42 100644 --- a/.env.sample +++ b/.env.sample @@ -1,5 +1,7 @@ ANTHROPIC_API_KEY="" +FIREWORKS_API_KEY="" GROQ_API_KEY="" MISTRAL_API_KEY="" OPENAI_API_KEY="" OLLAMA_API_URL="http://localhost:11434" +REPLICATE_API_KEY="" diff --git a/aimodels/client/multi_fm_client.py b/aimodels/client/multi_fm_client.py index f036e868..f3fd7521 100644 --- a/aimodels/client/multi_fm_client.py +++ b/aimodels/client/multi_fm_client.py @@ -3,10 +3,12 @@ from .chat import Chat from ..providers import ( AnthropicInterface, + FireworksInterface, GroqInterface, MistralInterface, OllamaInterface, OpenAIInterface, + ReplicateInterface, ) @@ -34,10 +36,12 @@ def __init__(self): self.all_interfaces = {} self.all_factories = { "anthropic": AnthropicInterface, + "fireworks": FireworksInterface, "groq": GroqInterface, "mistral": MistralInterface, "ollama": OllamaInterface, "openai": OpenAIInterface, + "replicate": ReplicateInterface, } def get_provider_interface(self, model): diff --git a/aimodels/providers/__init__.py b/aimodels/providers/__init__.py index 08aff090..f32c433a 100644 --- a/aimodels/providers/__init__.py +++ b/aimodels/providers/__init__.py @@ -1,7 +1,9 @@ """Provides the individual provider interfaces for each FM provider.""" from .anthropic_interface import AnthropicInterface +from .fireworks_interface import FireworksInterface from .groq_interface import GroqInterface from .mistral_interface import MistralInterface from .ollama_interface import OllamaInterface from .openai_interface import OpenAIInterface +from .replicate_interface import ReplicateInterface diff --git a/aimodels/providers/fireworks_interface.py b/aimodels/providers/fireworks_interface.py new file mode 100644 index 00000000..777a45cf --- /dev/null +++ b/aimodels/providers/fireworks_interface.py @@ -0,0 +1,35 @@ +"""The interface to the Fireworks API.""" + +import os + +from ..framework.provider_interface import ProviderInterface + + +class FireworksInterface(ProviderInterface): + """Implements the ProviderInterface for interacting with Fireworks's APIs.""" + + def __init__(self): + """Set up the Fireworks client using the API key obtained from the user's environment.""" + from fireworks.client import Fireworks + + self.fireworks_client = Fireworks(api_key=os.getenv("FIREWORKS_API_KEY")) + + def chat_completion_create(self, messages=None, model=None, temperature=0): + """Request chat completions from the Fireworks API. + + Args: + ---- + model (str): Identifies the specific provider/model to use. + messages (list of dict): A list of message objects in chat history. + temperature (float): The temperature to use in the completion. + + Returns: + ------- + The API response with the completion result. + + """ + return self.fireworks_client.chat.completions.create( + model=model, + messages=messages, + temperature=temperature, + ) diff --git a/aimodels/providers/replicate_interface.py b/aimodels/providers/replicate_interface.py new file mode 100644 index 00000000..99307d17 --- /dev/null +++ b/aimodels/providers/replicate_interface.py @@ -0,0 +1,40 @@ +"""The interface to the Replicate API.""" + +import os + +from ..framework.provider_interface import ProviderInterface + +_REPLICATE_BASE_URL = "https://openai-proxy.replicate.com/v1" + + +class ReplicateInterface(ProviderInterface): + """Implements the ProviderInterface for interacting with Replicate's APIs.""" + + def __init__(self): + """Set up the Replicate client using the API key obtained from the user's environment.""" + from openai import OpenAI + + self.replicate_client = OpenAI( + api_key=os.getenv("REPLICATE_API_KEY"), + base_url=_REPLICATE_BASE_URL, + ) + + def chat_completion_create(self, messages=None, model=None, temperature=0): + """Request chat completions from the Replicate API. + + Args: + ---- + model (str): Identifies the specific provider/model to use. + messages (list of dict): A list of message objects in chat history. + temperature (float): The temperature to use in the completion. + + Returns: + ------- + The API response with the completion result. + + """ + return self.replicate_client.chat.completions.create( + model=model, + messages=messages, + temperature=temperature, + ) diff --git a/examples/multi_fm_client.ipynb b/examples/multi_fm_client.ipynb index 08b3d31d..d7f4f89a 100644 --- a/examples/multi_fm_client.ipynb +++ b/examples/multi_fm_client.ipynb @@ -1,35 +1,27 @@ { "cells": [ { - "cell_type": "markdown", - "id": "60c7fb39", + "cell_type": "raw", + "id": "16c03c35-b679-43d4-971b-4ce19e619d51", "metadata": {}, "source": [ "# MultiFMClient\n", "\n", - "MultiFMClient provides a uniform interface for interacting with LLMs from various providers. It adapts the official python libraries from providers such as Mistral, OpenAI, Meta, Anthropic, etc. to conform to the OpenAI chat completion interface.\n", + "MultiFMClient provides a uniform interface for interacting with LLMs from various providers. It adapts the official python libraries from providers such as Mistral, OpenAI, Groq, Anthropic, Fireworks, Replicate, etc. to conform to the OpenAI chat completion interface.\n", "\n", "Below are some examples of how to use MultiFMClient to interact with different LLMs." ] }, { "cell_type": "code", + "execution_count": 1, "id": "initial_id", "metadata": { - "collapsed": true, "ExecuteTime": { "end_time": "2024-07-04T15:30:02.064319Z", "start_time": "2024-07-04T15:30:02.051986Z" } }, - "source": [ - "import sys\n", - "sys.path.append('../aimodels')\n", - "\n", - "from dotenv import load_dotenv, find_dotenv\n", - "\n", - "load_dotenv(find_dotenv())" - ], "outputs": [ { "data": { @@ -42,10 +34,32 @@ "output_type": "execute_result" } ], - "execution_count": 1 + "source": [ + "import sys\n", + "sys.path.append('../aimodels')\n", + "\n", + "from dotenv import load_dotenv, find_dotenv\n", + "\n", + "load_dotenv(find_dotenv())" + ] }, { "cell_type": "code", + "execution_count": 4, + "id": "a54491b7-6aa9-4337-9aba-3a0aef263bb4", + "metadata": {}, + "outputs": [], + "source": [ + "import os \n", + "\n", + "os.environ['GROQ_API_KEY'] = 'xxx' # get a free key at https://console.groq.com/keys\n", + "os.environ['FIREWORKS_API_KEY'] = 'xxx' # get a free key at https://fireworks.ai/api-keys\n", + "os.environ['REPLICATE_API_KEY'] = 'xxx' # get a free key at https://replicate.com/account/api-tokens" + ] + }, + { + "cell_type": "code", + "execution_count": 8, "id": "4de3a24f", "metadata": { "ExecuteTime": { @@ -53,6 +67,7 @@ "start_time": "2024-07-04T15:31:12.796445Z" } }, + "outputs": [], "source": [ "from aimodels.client import MultiFMClient\n", "\n", @@ -62,59 +77,128 @@ " {\"role\": \"system\", \"content\": \"Respond in Pirate English.\"},\n", " {\"role\": \"user\", \"content\": \"Tell me a joke\"},\n", "]" - ], + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "668a6cfa-9011-480a-ae1b-6dbd6a51e716", + "metadata": {}, "outputs": [], - "execution_count": 3 + "source": [ + "# !pip install fireworks-ai" + ] }, { "cell_type": "code", - "id": "adebd2f0b578a909", - "metadata": { - "ExecuteTime": { - "end_time": "2024-07-04T15:31:25.060689Z", - "start_time": "2024-07-04T15:31:16.131205Z" + "execution_count": 13, + "id": "9900fdf3-a113-40fd-b42f-0e6d866838be", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Arrrr, listen close me hearty! Here be a joke fer ye:\n", + "\n", + "Why did the pirate quit his job?\n", + "\n", + "Because he was sick o' all the arrrr-guments! (get it? arguments, but with an \"arrr\" like a pirate says? aye, I thought it be a good one, matey!)\n" + ] } - }, + ], "source": [ - "anthropic_claude_3_opus = \"anthropic:claude-3-opus-20240229\"\n", + "fireworks_llama3_8b = \"fireworks:accounts/fireworks/models/llama-v3-8b-instruct\"\n", + "#fireworks_llama3_70b = \"fireworks:accounts/fireworks/models/llama-v3-70b-instruct\"\n", "\n", - "response = client.chat.completions.create(model=anthropic_claude_3_opus, messages=messages)\n", + "response = client.chat.completions.create(model=fireworks_llama3_8b, messages=messages)\n", "\n", "print(response.choices[0].message.content)" - ], + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "c9b2aad6-8603-4227-9566-778f714eb0b5", + "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Arrr, me bucko, 'ere be a jolly jest fer ye!\n", + "Arrrr, listen close me hearty! Here be a joke fer ye:\n", + "\n", + "Why did the pirate quit his job?\n", "\n", - "What did th' pirate say on 'is 80th birthday? \"Aye matey!\"\n", + "Because he were sick o' all the arrrr-guments! (get it? arguments, but with arrrr, like a pirate says \"arrgh\"! ahhahahah!)\n", "\n", - "Ye see, it be a play on words, as \"Aye matey\" sounds like \"I'm eighty\". Har har har! 'Tis a clever bit o' pirate humor, if I do say so meself. Now, 'ow about ye fetch me a mug o' grog while I spin ye another yarn?\n" + "Yer turn, matey! Got a joke to share?\n" ] } ], - "execution_count": 4 + "source": [ + "groq_llama3_8b = \"groq:llama3-8b-8192\"\n", + "# groq_llama3_70b = \"groq:llama3-70b-8192\"\n", + "\n", + "response = client.chat.completions.create(model=groq_llama3_8b, messages=messages)\n", + "\n", + "print(response.choices[0].message.content)" + ] }, { "cell_type": "code", - "execution_count": 4, - "id": "6819ac17", + "execution_count": 12, + "id": "6baf88b8-2ecb-4bdf-9263-4af949668d16", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Arrrr, here be a joke fer ye!\n", + "Arrrr, listen close me hearty! Here be a joke fer ye:\n", + "\n", + "Why did the pirate quit his job?\n", "\n", - "Why did the pirate take a parrot on his ship?\n", + "Because he were sick o' all the arrrr-guments! (get it? arguments, but with arrrr, like a pirate says \"arrgh\"! ahhahahah!)\n", "\n", - "Because it were a hootin' good bird to have around, savvy? Aye, and it kept 'im company while he were swabbin' the decks! Arrrgh, I hope that made ye laugh, matey!\n" + "Yer turn, matey! Got a joke to share?\n" ] } ], + "source": [ + "replicate_llama3_8b = \"replicate:meta/meta-llama-3-8b-instruct\"\n", + "#replicate_llama3_70b = \"replicate:meta/meta-llama-3-70b-instruct\"\n", + "\n", + "response = client.chat.completions.create(model=replicate_llama3_8b, messages=messages)\n", + "\n", + "print(response.choices[0].message.content)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "adebd2f0b578a909", + "metadata": { + "ExecuteTime": { + "end_time": "2024-07-04T15:31:25.060689Z", + "start_time": "2024-07-04T15:31:16.131205Z" + } + }, + "outputs": [], + "source": [ + "anthropic_claude_3_opus = \"anthropic:claude-3-opus-20240229\"\n", + "\n", + "response = client.chat.completions.create(model=anthropic_claude_3_opus, messages=messages)\n", + "\n", + "print(response.choices[0].message.content)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6819ac17", + "metadata": {}, + "outputs": [], "source": [ "ollama_llama3 = \"ollama:llama3\"\n", "\n", @@ -124,44 +208,36 @@ ] }, { + "cell_type": "code", + "execution_count": null, + "id": "4a94961b2bddedbb", "metadata": { "ExecuteTime": { "end_time": "2024-07-04T15:31:39.472675Z", "start_time": "2024-07-04T15:31:38.283368Z" } }, - "cell_type": "code", + "outputs": [], "source": [ "mistral_7b = \"mistral:open-mistral-7b\"\n", "\n", "response = client.chat.completions.create(model=mistral_7b, messages=messages, temperature=0.2)\n", "\n", "print(response.choices[0].message.content)" - ], - "id": "4a94961b2bddedbb", - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Arr matey, I've got a jest fer ye, if ye be ready for a laugh! Why did the pirate bring a clock to the island? Because he wanted to catch the time! Aye, that be a good one, I be thinkin'. Arrr!\n" - ] - } - ], - "execution_count": 5 + ] }, { - "metadata": {}, "cell_type": "code", - "outputs": [], "execution_count": null, - "source": "", - "id": "611210a4dc92845f" + "id": "611210a4dc92845f", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -175,7 +251,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.3" + "version": "3.10.14" } }, "nbformat": 4,