Skip to content

Commit

Permalink
Merge pull request #7 from andrewyng/fireworks_replicate
Browse files Browse the repository at this point in the history
added providers fireworks and replicate with test code and outputs
  • Loading branch information
ksolo authored Jul 10, 2024
2 parents 2f89367 + 3bdd7fe commit 3d31443
Show file tree
Hide file tree
Showing 6 changed files with 211 additions and 52 deletions.
2 changes: 2 additions & 0 deletions .env.sample
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
ANTHROPIC_API_KEY=""
FIREWORKS_API_KEY=""
GROQ_API_KEY=""
MISTRAL_API_KEY=""
OPENAI_API_KEY=""
OLLAMA_API_URL="http://localhost:11434"
REPLICATE_API_KEY=""
4 changes: 4 additions & 0 deletions aimodels/client/multi_fm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
from .chat import Chat
from ..providers import (
AnthropicInterface,
FireworksInterface,
GroqInterface,
MistralInterface,
OllamaInterface,
OpenAIInterface,
ReplicateInterface,
)


Expand Down Expand Up @@ -34,10 +36,12 @@ def __init__(self):
self.all_interfaces = {}
self.all_factories = {
"anthropic": AnthropicInterface,
"fireworks": FireworksInterface,
"groq": GroqInterface,
"mistral": MistralInterface,
"ollama": OllamaInterface,
"openai": OpenAIInterface,
"replicate": ReplicateInterface,
}

def get_provider_interface(self, model):
Expand Down
2 changes: 2 additions & 0 deletions aimodels/providers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Provides the individual provider interfaces for each FM provider."""

from .anthropic_interface import AnthropicInterface
from .fireworks_interface import FireworksInterface
from .groq_interface import GroqInterface
from .mistral_interface import MistralInterface
from .ollama_interface import OllamaInterface
from .openai_interface import OpenAIInterface
from .replicate_interface import ReplicateInterface
35 changes: 35 additions & 0 deletions aimodels/providers/fireworks_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""The interface to the Fireworks API."""

import os

from ..framework.provider_interface import ProviderInterface


class FireworksInterface(ProviderInterface):
"""Implements the ProviderInterface for interacting with Fireworks's APIs."""

def __init__(self):
"""Set up the Fireworks client using the API key obtained from the user's environment."""
from fireworks.client import Fireworks

self.fireworks_client = Fireworks(api_key=os.getenv("FIREWORKS_API_KEY"))

def chat_completion_create(self, messages=None, model=None, temperature=0):
"""Request chat completions from the Fireworks API.
Args:
----
model (str): Identifies the specific provider/model to use.
messages (list of dict): A list of message objects in chat history.
temperature (float): The temperature to use in the completion.
Returns:
-------
The API response with the completion result.
"""
return self.fireworks_client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
)
40 changes: 40 additions & 0 deletions aimodels/providers/replicate_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""The interface to the Replicate API."""

import os

from ..framework.provider_interface import ProviderInterface

_REPLICATE_BASE_URL = "https://openai-proxy.replicate.com/v1"


class ReplicateInterface(ProviderInterface):
"""Implements the ProviderInterface for interacting with Replicate's APIs."""

def __init__(self):
"""Set up the Replicate client using the API key obtained from the user's environment."""
from openai import OpenAI

self.replicate_client = OpenAI(
api_key=os.getenv("REPLICATE_API_KEY"),
base_url=_REPLICATE_BASE_URL,
)

def chat_completion_create(self, messages=None, model=None, temperature=0):
"""Request chat completions from the Replicate API.
Args:
----
model (str): Identifies the specific provider/model to use.
messages (list of dict): A list of message objects in chat history.
temperature (float): The temperature to use in the completion.
Returns:
-------
The API response with the completion result.
"""
return self.replicate_client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
)
Loading

0 comments on commit 3d31443

Please sign in to comment.