Skip to content

Commit

Permalink
Add another provider
Browse files Browse the repository at this point in the history
standsleeping committed Jul 2, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 93576df commit 7980dde
Showing 5 changed files with 60 additions and 3 deletions.
2 changes: 2 additions & 0 deletions aimodels/client/multi_fm_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""MultiFMClient manages a Chat across multiple provider interfaces."""

from ..providers.openai_interface import OpenAIInterface
from ..providers.groq_interface import GroqInterface
from .chat import Chat


@@ -28,6 +29,7 @@ def __init__(self):
self.all_interfaces = {}
self.all_factories = {
"openai": OpenAIInterface,
"groq": GroqInterface,
}

def get_provider_interface(self, model):
35 changes: 35 additions & 0 deletions aimodels/providers/groq_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""The interface to the Groq API."""

import os

from ..framework.provider_interface import ProviderInterface


class GroqInterface(ProviderInterface):
"""Implements the ProviderInterface for interacting with Groq's APIs."""

def __init__(self):
"""Set up the Groq client using the API key obtained from the user's environment."""
import groq

self.groq_client = groq.Groq(api_key=os.getenv("GROQ_API_KEY"))

def chat_completion_create(self, messages=None, model=None, temperature=0):
"""Request chat completions from the Groq API.
Args:
----
model (str): Identifies the specific provider/model to use.
messages (list of dict): A list of message objects in chat history.
temperature (float): The temperature to use in the completion.
Returns:
-------
The API response with the completion result.
"""
return self.groq_client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
)
4 changes: 2 additions & 2 deletions aimodels/providers/openai_interface.py
Original file line number Diff line number Diff line change
@@ -19,8 +19,8 @@ def chat_completion_create(self, messages=None, model=None, temperature=0):
Args:
----
messages (list of dict): A list of message objects in chat history.
model (str): Identifies the specific provider/model to use.
messages (list of dict): A list of message objects in chat history.
temperature (float): The temperature to use in the completion.
Returns:
@@ -30,6 +30,6 @@ def chat_completion_create(self, messages=None, model=None, temperature=0):
"""
return self.openai_client.chat.completions.create(
model=model,
temperature=temperature,
messages=messages,
temperature=temperature,
)
21 changes: 20 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -9,6 +9,7 @@ readme = "README.md"
python = "^3.10"
python-dotenv = "^1.0.1"
openai = "^1.35.7"
groq = "^0.9.0"


[tool.poetry.group.dev.dependencies]

0 comments on commit 7980dde

Please sign in to comment.