Skip to content

Commit

Permalink
Merge pull request #1 from andrewyng/create-basic-classes
Browse files Browse the repository at this point in the history
Create initial classes
  • Loading branch information
standsleeping authored Jul 2, 2024
2 parents b7922f2 + c416d3f commit a8e6ab0
Show file tree
Hide file tree
Showing 11 changed files with 549 additions and 1 deletion.
3 changes: 3 additions & 0 deletions aimodels/client/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""Provides the MultiFMClient for managing chats across many FM providers."""

from .multi_fm_client import MultiFMClient
18 changes: 18 additions & 0 deletions aimodels/client/chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""Chat is instantiated with a client and manages completions."""

from .completions import Completions


class Chat:
"""Manage chat sessions with multiple providers."""

def __init__(self, topmost_instance):
"""Initialize a new Chat instance.
Args:
----
topmost_instance: The chat session's client instance (MultiFMClient).
"""
self.topmost_instance = topmost_instance
self.completions = Completions(topmost_instance)
37 changes: 37 additions & 0 deletions aimodels/client/completions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""Completions is instantiated with a client and manages completion requests in chat sessions."""


class Completions:
"""Manage completion requests in chat sessions."""

def __init__(self, topmost_instance):
"""Initialize a new Completions instance.
Args:
----
topmost_instance: The chat session's client instance (MultiFMClient).
"""
self.topmost_instance = topmost_instance

def create(self, model=None, temperature=0, messages=None):
"""Create a completion request using a specified provider/model combination.
Args:
----
model (str): The model identifier with format 'provider:model'.
temperature (float): The sampling temperature.
messages (list): A list of previous messages.
Returns:
-------
The resulting completion.
"""
interface, model_name = self.topmost_instance.get_provider_interface(model)

return interface.chat_completion_create(
messages,
model=model_name,
temperature=temperature,
)
70 changes: 70 additions & 0 deletions aimodels/client/multi_fm_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""MultiFMClient manages a Chat across multiple provider interfaces."""

from ..providers.openai_interface import OpenAIInterface
from ..providers.groq_interface import GroqInterface
from .chat import Chat


class MultiFMClient:
"""Manages multiple provider interfaces."""

_MODEL_FORMAT_ERROR_MESSAGE_TEMPLATE = (
"Expected ':' in model identifier to specify provider:model. Got {model}."
)
_NO_FACTORY_ERROR_MESSAGE_TEMPLATE = (
"Could not find factory to create interface for provider '{provider}'."
)

def __init__(self):
"""Initialize the MultiFMClient instance.
Attributes
----------
chat (Chat): The chat session.
all_interfaces (dict): Stores interface instances by provider names.
all_factories (dict): Maps provider names to their corresponding interfaces.
"""
self.chat = Chat(self)
self.all_interfaces = {}
self.all_factories = {
"openai": OpenAIInterface,
"groq": GroqInterface,
}

def get_provider_interface(self, model):
"""Retrieve or create a provider interface based on a model identifier.
Args:
----
model (str): The model identifier in the format 'provider:model'.
Raises:
------
ValueError: If the model identifier does colon-separate provider and model.
Exception: If no factory is found from the supplied model.
Returns:
-------
The interface instance for the provider and the model name.
"""
if ":" not in model:
raise ValueError(
self._MODEL_FORMAT_ERROR_MESSAGE_TEMPLATE.format(model=model)
)

model_parts = model.split(":", maxsplit=1)
provider = model_parts[0]
model_name = model_parts[1]

if provider in self.all_interfaces:
return self.all_interfaces[provider]

if provider not in self.all_factories:
raise Exception(
self._NO_FACTORY_ERROR_MESSAGE_TEMPLATE.format(provider=provider)
)

self.all_interfaces[provider] = self.all_factories[provider]()
return self.all_interfaces[provider], model_name
3 changes: 3 additions & 0 deletions aimodels/framework/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""Provides the ProviderInterface for defining the interface that all FM providers must implement."""

from .provider_interface import ProviderInterface
25 changes: 25 additions & 0 deletions aimodels/framework/provider_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""The shared interface for model providers."""


class ProviderInterface:
"""Defines the expected behavior for provider-specific interfaces."""

def chat_completion_create(self, messages=None, model=None, temperature=0) -> None:
"""Create a chat completion using the specified messages, model, and temperature.
This method must be implemented by subclasses to perform completions.
Args:
----
messages (list): The chat history.
model (str): The identifier of the model to be used in the completion.
temperature (float): The temperature to use in the completion.
Raises:
------
NotImplementedError: If this method has not been implemented by a subclass.
"""
raise NotImplementedError(
"Provider Interface has not implemented chat_completion_create()"
)
3 changes: 3 additions & 0 deletions aimodels/providers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""Provides the individual provider interfaces for each FM provider."""

from .openai_interface import OpenAIInterface
35 changes: 35 additions & 0 deletions aimodels/providers/groq_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""The interface to the Groq API."""

import os

from ..framework.provider_interface import ProviderInterface


class GroqInterface(ProviderInterface):
"""Implements the ProviderInterface for interacting with Groq's APIs."""

def __init__(self):
"""Set up the Groq client using the API key obtained from the user's environment."""
import groq

self.groq_client = groq.Groq(api_key=os.getenv("GROQ_API_KEY"))

def chat_completion_create(self, messages=None, model=None, temperature=0):
"""Request chat completions from the Groq API.
Args:
----
model (str): Identifies the specific provider/model to use.
messages (list of dict): A list of message objects in chat history.
temperature (float): The temperature to use in the completion.
Returns:
-------
The API response with the completion result.
"""
return self.groq_client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
)
35 changes: 35 additions & 0 deletions aimodels/providers/openai_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""The interface to the OpenAI API."""

import os

from ..framework.provider_interface import ProviderInterface


class OpenAIInterface(ProviderInterface):
"""Implements the ProviderInterface for interacting with OpenAI's APIs."""

def __init__(self):
"""Set up the OpenAI client using the API key obtained from the user's environment."""
import openai

self.openai_client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

def chat_completion_create(self, messages=None, model=None, temperature=0):
"""Request chat completions from the OpenAI API.
Args:
----
model (str): Identifies the specific provider/model to use.
messages (list of dict): A list of message objects in chat history.
temperature (float): The temperature to use in the completion.
Returns:
-------
The API response with the completion result.
"""
return self.openai_client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
)
Loading

0 comments on commit a8e6ab0

Please sign in to comment.