Skip to content

Commit

Permalink
added providers fireworks and replciate with test code and outputs in…
Browse files Browse the repository at this point in the history
… multi_fm_client.ipynb
  • Loading branch information
jeffxtang committed Jul 6, 2024
1 parent 43f8a8a commit 221b04b
Show file tree
Hide file tree
Showing 5 changed files with 202 additions and 52 deletions.
4 changes: 4 additions & 0 deletions aimodels/client/multi_fm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
MistralInterface,
OllamaInterface,
OpenAIInterface,
FireworksInterface,
ReplicateInterface,
)


Expand Down Expand Up @@ -38,6 +40,8 @@ def __init__(self):
"mistral": MistralInterface,
"ollama": OllamaInterface,
"openai": OpenAIInterface,
"fireworks": FireworksInterface,
"replicate": ReplicateInterface,
}

def get_provider_interface(self, model):
Expand Down
2 changes: 2 additions & 0 deletions aimodels/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@
from .mistral_interface import MistralInterface
from .ollama_interface import OllamaInterface
from .openai_interface import OpenAIInterface
from .fireworks_interface import FireworksInterface
from .replicate_interface import ReplicateInterface
34 changes: 34 additions & 0 deletions aimodels/providers/fireworks_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""The interface to the Fireworks API."""

import os

from ..framework.provider_interface import ProviderInterface

class FireworksInterface(ProviderInterface):
"""Implements the ProviderInterface for interacting with Fireworks's APIs."""

def __init__(self):
"""Set up the Fireworks client using the API key obtained from the user's environment."""
from fireworks.client import Fireworks

self.fireworks_client = Fireworks(api_key=os.getenv("FIREWORKS_API_KEY"))

def chat_completion_create(self, messages=None, model=None, temperature=0):
"""Request chat completions from the Fireworks API.
Args:
----
model (str): Identifies the specific provider/model to use.
messages (list of dict): A list of message objects in chat history.
temperature (float): The temperature to use in the completion.
Returns:
-------
The API response with the completion result.
"""
return self.fireworks_client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
)
34 changes: 34 additions & 0 deletions aimodels/providers/replicate_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""The interface to the Replicate API."""

import os

from ..framework.provider_interface import ProviderInterface

class ReplicateInterface(ProviderInterface):
"""Implements the ProviderInterface for interacting with Replicate's APIs."""

def __init__(self):
"""Set up the Replicate client using the API key obtained from the user's environment."""
from openai import OpenAI

self.replicate_client = OpenAI(api_key=os.getenv("REPLICATE_API_KEY"), base_url="https://openai-proxy.replicate.com/v1")

def chat_completion_create(self, messages=None, model=None, temperature=0):
"""Request chat completions from the Replicate API.
Args:
----
model (str): Identifies the specific provider/model to use.
messages (list of dict): A list of message objects in chat history.
temperature (float): The temperature to use in the completion.
Returns:
-------
The API response with the completion result.
"""
return self.replicate_client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
)
180 changes: 128 additions & 52 deletions examples/multi_fm_client.ipynb
Original file line number Diff line number Diff line change
@@ -1,35 +1,27 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "60c7fb39",
"cell_type": "raw",
"id": "16c03c35-b679-43d4-971b-4ce19e619d51",
"metadata": {},
"source": [
"# MultiFMClient\n",
"\n",
"MultiFMClient provides a uniform interface for interacting with LLMs from various providers. It adapts the official python libraries from providers such as Mistral, OpenAI, Meta, Anthropic, etc. to conform to the OpenAI chat completion interface.\n",
"MultiFMClient provides a uniform interface for interacting with LLMs from various providers. It adapts the official python libraries from providers such as Mistral, OpenAI, Groq, Anthropic, Fireworks, Replicate, etc. to conform to the OpenAI chat completion interface.\n",
"\n",
"Below are some examples of how to use MultiFMClient to interact with different LLMs."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "initial_id",
"metadata": {
"collapsed": true,
"ExecuteTime": {
"end_time": "2024-07-04T15:30:02.064319Z",
"start_time": "2024-07-04T15:30:02.051986Z"
}
},
"source": [
"import sys\n",
"sys.path.append('../aimodels')\n",
"\n",
"from dotenv import load_dotenv, find_dotenv\n",
"\n",
"load_dotenv(find_dotenv())"
],
"outputs": [
{
"data": {
Expand All @@ -42,17 +34,40 @@
"output_type": "execute_result"
}
],
"execution_count": 1
"source": [
"import sys\n",
"sys.path.append('../aimodels')\n",
"\n",
"from dotenv import load_dotenv, find_dotenv\n",
"\n",
"load_dotenv(find_dotenv())"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "a54491b7-6aa9-4337-9aba-3a0aef263bb4",
"metadata": {},
"outputs": [],
"source": [
"import os \n",
"\n",
"os.environ['GROQ_API_KEY'] = 'xxx' # get a free key at https://console.groq.com/keys\n",
"os.environ['FIREWORKS_API_KEY'] = 'xxx' # get a free key at https://fireworks.ai/api-keys\n",
"os.environ['REPLICATE_API_KEY'] = 'xxx' # get a free key at https://replicate.com/account/api-tokens"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "4de3a24f",
"metadata": {
"ExecuteTime": {
"end_time": "2024-07-04T15:31:12.914321Z",
"start_time": "2024-07-04T15:31:12.796445Z"
}
},
"outputs": [],
"source": [
"from aimodels.client import MultiFMClient\n",
"\n",
Expand All @@ -62,59 +77,128 @@
" {\"role\": \"system\", \"content\": \"Respond in Pirate English.\"},\n",
" {\"role\": \"user\", \"content\": \"Tell me a joke\"},\n",
"]"
],
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "668a6cfa-9011-480a-ae1b-6dbd6a51e716",
"metadata": {},
"outputs": [],
"execution_count": 3
"source": [
"# !pip install fireworks-ai"
]
},
{
"cell_type": "code",
"id": "adebd2f0b578a909",
"metadata": {
"ExecuteTime": {
"end_time": "2024-07-04T15:31:25.060689Z",
"start_time": "2024-07-04T15:31:16.131205Z"
"execution_count": 13,
"id": "9900fdf3-a113-40fd-b42f-0e6d866838be",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Arrrr, listen close me hearty! Here be a joke fer ye:\n",
"\n",
"Why did the pirate quit his job?\n",
"\n",
"Because he was sick o' all the arrrr-guments! (get it? arguments, but with an \"arrr\" like a pirate says? aye, I thought it be a good one, matey!)\n"
]
}
},
],
"source": [
"anthropic_claude_3_opus = \"anthropic:claude-3-opus-20240229\"\n",
"fireworks_llama3_8b = \"fireworks:accounts/fireworks/models/llama-v3-8b-instruct\"\n",
"#fireworks_llama3_70b = \"fireworks:accounts/fireworks/models/llama-v3-70b-instruct\"\n",
"\n",
"response = client.chat.completions.create(model=anthropic_claude_3_opus, messages=messages)\n",
"response = client.chat.completions.create(model=fireworks_llama3_8b, messages=messages)\n",
"\n",
"print(response.choices[0].message.content)"
],
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "c9b2aad6-8603-4227-9566-778f714eb0b5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Arrr, me bucko, 'ere be a jolly jest fer ye!\n",
"Arrrr, listen close me hearty! Here be a joke fer ye:\n",
"\n",
"Why did the pirate quit his job?\n",
"\n",
"What did th' pirate say on 'is 80th birthday? \"Aye matey!\"\n",
"Because he were sick o' all the arrrr-guments! (get it? arguments, but with arrrr, like a pirate says \"arrgh\"! ahhahahah!)\n",
"\n",
"Ye see, it be a play on words, as \"Aye matey\" sounds like \"I'm eighty\". Har har har! 'Tis a clever bit o' pirate humor, if I do say so meself. Now, 'ow about ye fetch me a mug o' grog while I spin ye another yarn?\n"
"Yer turn, matey! Got a joke to share?\n"
]
}
],
"execution_count": 4
"source": [
"groq_llama3_8b = \"groq:llama3-8b-8192\"\n",
"# groq_llama3_70b = \"groq:llama3-70b-8192\"\n",
"\n",
"response = client.chat.completions.create(model=groq_llama3_8b, messages=messages)\n",
"\n",
"print(response.choices[0].message.content)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "6819ac17",
"execution_count": 12,
"id": "6baf88b8-2ecb-4bdf-9263-4af949668d16",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Arrrr, here be a joke fer ye!\n",
"Arrrr, listen close me hearty! Here be a joke fer ye:\n",
"\n",
"Why did the pirate quit his job?\n",
"\n",
"Why did the pirate take a parrot on his ship?\n",
"Because he were sick o' all the arrrr-guments! (get it? arguments, but with arrrr, like a pirate says \"arrgh\"! ahhahahah!)\n",
"\n",
"Because it were a hootin' good bird to have around, savvy? Aye, and it kept 'im company while he were swabbin' the decks! Arrrgh, I hope that made ye laugh, matey!\n"
"Yer turn, matey! Got a joke to share?\n"
]
}
],
"source": [
"replicate_llama3_8b = \"replicate:meta/meta-llama-3-8b-instruct\"\n",
"#replicate_llama3_70b = \"replicate:meta/meta-llama-3-70b-instruct\"\n",
"\n",
"response = client.chat.completions.create(model=replicate_llama3_8b, messages=messages)\n",
"\n",
"print(response.choices[0].message.content)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "adebd2f0b578a909",
"metadata": {
"ExecuteTime": {
"end_time": "2024-07-04T15:31:25.060689Z",
"start_time": "2024-07-04T15:31:16.131205Z"
}
},
"outputs": [],
"source": [
"anthropic_claude_3_opus = \"anthropic:claude-3-opus-20240229\"\n",
"\n",
"response = client.chat.completions.create(model=anthropic_claude_3_opus, messages=messages)\n",
"\n",
"print(response.choices[0].message.content)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6819ac17",
"metadata": {},
"outputs": [],
"source": [
"ollama_llama3 = \"ollama:llama3\"\n",
"\n",
Expand All @@ -124,44 +208,36 @@
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4a94961b2bddedbb",
"metadata": {
"ExecuteTime": {
"end_time": "2024-07-04T15:31:39.472675Z",
"start_time": "2024-07-04T15:31:38.283368Z"
}
},
"cell_type": "code",
"outputs": [],
"source": [
"mistral_7b = \"mistral:open-mistral-7b\"\n",
"\n",
"response = client.chat.completions.create(model=mistral_7b, messages=messages, temperature=0.2)\n",
"\n",
"print(response.choices[0].message.content)"
],
"id": "4a94961b2bddedbb",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Arr matey, I've got a jest fer ye, if ye be ready for a laugh! Why did the pirate bring a clock to the island? Because he wanted to catch the time! Aye, that be a good one, I be thinkin'. Arrr!\n"
]
}
],
"execution_count": 5
]
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"source": "",
"id": "611210a4dc92845f"
"id": "611210a4dc92845f",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -175,7 +251,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
"version": "3.10.14"
}
},
"nbformat": 4,
Expand Down

0 comments on commit 221b04b

Please sign in to comment.