-
Notifications
You must be signed in to change notification settings - Fork 896
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from andrewyng/create-basic-classes
Create initial classes
- Loading branch information
Showing
11 changed files
with
549 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
"""Provides the MultiFMClient for managing chats across many FM providers.""" | ||
|
||
from .multi_fm_client import MultiFMClient |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
"""Chat is instantiated with a client and manages completions.""" | ||
|
||
from .completions import Completions | ||
|
||
|
||
class Chat: | ||
"""Manage chat sessions with multiple providers.""" | ||
|
||
def __init__(self, topmost_instance): | ||
"""Initialize a new Chat instance. | ||
Args: | ||
---- | ||
topmost_instance: The chat session's client instance (MultiFMClient). | ||
""" | ||
self.topmost_instance = topmost_instance | ||
self.completions = Completions(topmost_instance) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
"""Completions is instantiated with a client and manages completion requests in chat sessions.""" | ||
|
||
|
||
class Completions: | ||
"""Manage completion requests in chat sessions.""" | ||
|
||
def __init__(self, topmost_instance): | ||
"""Initialize a new Completions instance. | ||
Args: | ||
---- | ||
topmost_instance: The chat session's client instance (MultiFMClient). | ||
""" | ||
self.topmost_instance = topmost_instance | ||
|
||
def create(self, model=None, temperature=0, messages=None): | ||
"""Create a completion request using a specified provider/model combination. | ||
Args: | ||
---- | ||
model (str): The model identifier with format 'provider:model'. | ||
temperature (float): The sampling temperature. | ||
messages (list): A list of previous messages. | ||
Returns: | ||
------- | ||
The resulting completion. | ||
""" | ||
interface, model_name = self.topmost_instance.get_provider_interface(model) | ||
|
||
return interface.chat_completion_create( | ||
messages, | ||
model=model_name, | ||
temperature=temperature, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
"""MultiFMClient manages a Chat across multiple provider interfaces.""" | ||
|
||
from ..providers.openai_interface import OpenAIInterface | ||
from ..providers.groq_interface import GroqInterface | ||
from .chat import Chat | ||
|
||
|
||
class MultiFMClient: | ||
"""Manages multiple provider interfaces.""" | ||
|
||
_MODEL_FORMAT_ERROR_MESSAGE_TEMPLATE = ( | ||
"Expected ':' in model identifier to specify provider:model. Got {model}." | ||
) | ||
_NO_FACTORY_ERROR_MESSAGE_TEMPLATE = ( | ||
"Could not find factory to create interface for provider '{provider}'." | ||
) | ||
|
||
def __init__(self): | ||
"""Initialize the MultiFMClient instance. | ||
Attributes | ||
---------- | ||
chat (Chat): The chat session. | ||
all_interfaces (dict): Stores interface instances by provider names. | ||
all_factories (dict): Maps provider names to their corresponding interfaces. | ||
""" | ||
self.chat = Chat(self) | ||
self.all_interfaces = {} | ||
self.all_factories = { | ||
"openai": OpenAIInterface, | ||
"groq": GroqInterface, | ||
} | ||
|
||
def get_provider_interface(self, model): | ||
"""Retrieve or create a provider interface based on a model identifier. | ||
Args: | ||
---- | ||
model (str): The model identifier in the format 'provider:model'. | ||
Raises: | ||
------ | ||
ValueError: If the model identifier does colon-separate provider and model. | ||
Exception: If no factory is found from the supplied model. | ||
Returns: | ||
------- | ||
The interface instance for the provider and the model name. | ||
""" | ||
if ":" not in model: | ||
raise ValueError( | ||
self._MODEL_FORMAT_ERROR_MESSAGE_TEMPLATE.format(model=model) | ||
) | ||
|
||
model_parts = model.split(":", maxsplit=1) | ||
provider = model_parts[0] | ||
model_name = model_parts[1] | ||
|
||
if provider in self.all_interfaces: | ||
return self.all_interfaces[provider] | ||
|
||
if provider not in self.all_factories: | ||
raise Exception( | ||
self._NO_FACTORY_ERROR_MESSAGE_TEMPLATE.format(provider=provider) | ||
) | ||
|
||
self.all_interfaces[provider] = self.all_factories[provider]() | ||
return self.all_interfaces[provider], model_name |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
"""Provides the ProviderInterface for defining the interface that all FM providers must implement.""" | ||
|
||
from .provider_interface import ProviderInterface |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
"""The shared interface for model providers.""" | ||
|
||
|
||
class ProviderInterface: | ||
"""Defines the expected behavior for provider-specific interfaces.""" | ||
|
||
def chat_completion_create(self, messages=None, model=None, temperature=0) -> None: | ||
"""Create a chat completion using the specified messages, model, and temperature. | ||
This method must be implemented by subclasses to perform completions. | ||
Args: | ||
---- | ||
messages (list): The chat history. | ||
model (str): The identifier of the model to be used in the completion. | ||
temperature (float): The temperature to use in the completion. | ||
Raises: | ||
------ | ||
NotImplementedError: If this method has not been implemented by a subclass. | ||
""" | ||
raise NotImplementedError( | ||
"Provider Interface has not implemented chat_completion_create()" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
"""Provides the individual provider interfaces for each FM provider.""" | ||
|
||
from .openai_interface import OpenAIInterface |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
"""The interface to the Groq API.""" | ||
|
||
import os | ||
|
||
from ..framework.provider_interface import ProviderInterface | ||
|
||
|
||
class GroqInterface(ProviderInterface): | ||
"""Implements the ProviderInterface for interacting with Groq's APIs.""" | ||
|
||
def __init__(self): | ||
"""Set up the Groq client using the API key obtained from the user's environment.""" | ||
import groq | ||
|
||
self.groq_client = groq.Groq(api_key=os.getenv("GROQ_API_KEY")) | ||
|
||
def chat_completion_create(self, messages=None, model=None, temperature=0): | ||
"""Request chat completions from the Groq API. | ||
Args: | ||
---- | ||
model (str): Identifies the specific provider/model to use. | ||
messages (list of dict): A list of message objects in chat history. | ||
temperature (float): The temperature to use in the completion. | ||
Returns: | ||
------- | ||
The API response with the completion result. | ||
""" | ||
return self.groq_client.chat.completions.create( | ||
model=model, | ||
messages=messages, | ||
temperature=temperature, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
"""The interface to the OpenAI API.""" | ||
|
||
import os | ||
|
||
from ..framework.provider_interface import ProviderInterface | ||
|
||
|
||
class OpenAIInterface(ProviderInterface): | ||
"""Implements the ProviderInterface for interacting with OpenAI's APIs.""" | ||
|
||
def __init__(self): | ||
"""Set up the OpenAI client using the API key obtained from the user's environment.""" | ||
import openai | ||
|
||
self.openai_client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY")) | ||
|
||
def chat_completion_create(self, messages=None, model=None, temperature=0): | ||
"""Request chat completions from the OpenAI API. | ||
Args: | ||
---- | ||
model (str): Identifies the specific provider/model to use. | ||
messages (list of dict): A list of message objects in chat history. | ||
temperature (float): The temperature to use in the completion. | ||
Returns: | ||
------- | ||
The API response with the completion result. | ||
""" | ||
return self.openai_client.chat.completions.create( | ||
model=model, | ||
messages=messages, | ||
temperature=temperature, | ||
) |
Oops, something went wrong.