Skip to content

Commit

Permalink
openai assistant refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
ashpreetbedi committed Jan 9, 2024
1 parent 2d9da05 commit 18d2bef
Show file tree
Hide file tree
Showing 15 changed files with 59 additions and 59 deletions.
1 change: 0 additions & 1 deletion phi/assistant/__init__.py

This file was deleted.

1 change: 0 additions & 1 deletion phi/assistant/file/__init__.py

This file was deleted.

1 change: 1 addition & 0 deletions phi/assistant/openai/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from phi.assistant.openai.assistant import OpenAiAssistant
56 changes: 28 additions & 28 deletions phi/assistant/assistant.py → phi/assistant/openai/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

from pydantic import BaseModel, ConfigDict, field_validator, model_validator

from phi.assistant.file import File
from phi.assistant.exceptions import AssistantIdNotSet
from phi.assistant.openai.file import File
from phi.assistant.openai.exceptions import AssistantIdNotSet
from phi.tools import Tool, ToolRegistry
from phi.tools.function import Function
from phi.utils.log import logger, set_log_level_to_debug
Expand All @@ -18,13 +18,13 @@
raise


class Assistant(BaseModel):
class OpenAiAssistant(BaseModel):
# -*- LLM settings
model: str = "gpt-4-1106-preview"
openai: Optional[OpenAI] = None

# -*- Assistant settings
# Assistant id which can be referenced in API endpoints.
# -*- OpenAiAssistant settings
# OpenAiAssistant id which can be referenced in API endpoints.
id: Optional[str] = None
# The object type, populated by the API. Always assistant.
object: Optional[str] = None
Expand All @@ -35,30 +35,30 @@ class Assistant(BaseModel):
# The system instructions that the assistant uses. The maximum length is 32768 characters.
instructions: Optional[str] = None

# -*- Assistant Tools
# -*- OpenAiAssistant Tools
# A list of tools provided to the assistant. There can be a maximum of 128 tools per assistant.
# Tools can be of types code_interpreter, retrieval, or function.
tools: Optional[List[Union[Tool, ToolRegistry, Callable, Dict]]] = None
# -*- Functions available to the Assistant to call
# -*- Functions available to the OpenAiAssistant to call
# Functions extracted from the tools which can be executed locally by the assistant.
functions: Optional[Dict[str, Function]] = None

# -*- Assistant Files
# -*- OpenAiAssistant Files
# A list of file IDs attached to this assistant.
# There can be a maximum of 20 files attached to the assistant.
# Files are ordered by their creation date in ascending order.
file_ids: Optional[List[str]] = None
# Files attached to this assistant.
files: Optional[List[File]] = None

# -*- Assistant Storage
# -*- OpenAiAssistant Storage
# storage: Optional[AssistantStorage] = None
# Create table if it doesn't exist
# create_storage: bool = True
# AssistantRow from the database: DO NOT SET THIS MANUALLY
# database_row: Optional[AssistantRow] = None

# -*- Assistant Knowledge Base
# -*- OpenAiAssistant Knowledge Base
# knowledge_base: Optional[KnowledgeBase] = None

# Set of 16 key-value pairs that can be attached to an object.
Expand Down Expand Up @@ -92,18 +92,18 @@ def client(self) -> OpenAI:
return self.openai or OpenAI()

@model_validator(mode="after")
def extract_functions_from_tools(self) -> "Assistant":
def extract_functions_from_tools(self) -> "OpenAiAssistant":
if self.tools is not None:
for tool in self.tools:
if self.functions is None:
self.functions = {}
if isinstance(tool, ToolRegistry):
self.functions.update(tool.functions)
logger.debug(f"Functions from {tool.name} added to Assistant.")
logger.debug(f"Functions from {tool.name} added to OpenAiAssistant.")
elif callable(tool):
f = Function.from_callable(tool)
self.functions[f.name] = f
logger.debug(f"Function {f.name} added to Assistant")
logger.debug(f"Function {f.name} added to OpenAiAssistant")
return self

def __enter__(self):
Expand Down Expand Up @@ -137,7 +137,7 @@ def get_tools_for_api(self) -> Optional[List[Dict[str, Any]]]:
tools_for_api.append({"type": "function", "function": _f.to_dict()})
return tools_for_api

def create(self) -> "Assistant":
def create(self) -> "OpenAiAssistant":
request_body: Dict[str, Any] = {}
if self.name is not None:
request_body["name"] = self.name
Expand All @@ -163,7 +163,7 @@ def create(self) -> "Assistant":
**request_body,
)
self.load_from_openai(self.openai_assistant)
logger.debug(f"Assistant created: {self.id}")
logger.debug(f"OpenAiAssistant created: {self.id}")
return self

def get_id(self) -> Optional[str]:
Expand All @@ -172,28 +172,28 @@ def get_id(self) -> Optional[str]:
def get_from_openai(self) -> OpenAIAssistantType:
_assistant_id = self.get_id()
if _assistant_id is None:
raise AssistantIdNotSet("Assistant.id not set")
raise AssistantIdNotSet("OpenAiAssistant.id not set")

self.openai_assistant = self.client.beta.assistants.retrieve(
assistant_id=_assistant_id,
)
self.load_from_openai(self.openai_assistant)
return self.openai_assistant

def get(self, use_cache: bool = True) -> "Assistant":
def get(self, use_cache: bool = True) -> "OpenAiAssistant":
if self.openai_assistant is not None and use_cache:
return self

self.get_from_openai()
return self

def get_or_create(self, use_cache: bool = True) -> "Assistant":
def get_or_create(self, use_cache: bool = True) -> "OpenAiAssistant":
try:
return self.get(use_cache=use_cache)
except AssistantIdNotSet:
return self.create()

def update(self) -> "Assistant":
def update(self) -> "OpenAiAssistant":
try:
assistant_to_update = self.get_from_openai()
if assistant_to_update is not None:
Expand Down Expand Up @@ -227,11 +227,11 @@ def update(self) -> "Assistant":
**request_body,
)
self.load_from_openai(self.openai_assistant)
logger.debug(f"Assistant updated: {self.id}")
logger.debug(f"OpenAiAssistant updated: {self.id}")
return self
raise ValueError("Assistant not available")
raise ValueError("OpenAiAssistant not available")
except AssistantIdNotSet:
logger.warning("Assistant not available")
logger.warning("OpenAiAssistant not available")
raise

def delete(self) -> OpenAIAssistantDeleted:
Expand All @@ -241,10 +241,10 @@ def delete(self) -> OpenAIAssistantDeleted:
deletion_status = self.client.beta.assistants.delete(
assistant_id=assistant_to_delete.id,
)
logger.debug(f"Assistant deleted: {deletion_status.id}")
logger.debug(f"OpenAiAssistant deleted: {deletion_status.id}")
return deletion_status
except AssistantIdNotSet:
logger.warning("Assistant not available")
logger.warning("OpenAiAssistant not available")
raise

def to_dict(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -275,18 +275,18 @@ def __str__(self) -> str:
return json.dumps(self.to_dict(), indent=4)

def __repr__(self) -> str:
return f"<Assistant name={self.name} id={self.id}>"
return f"<OpenAiAssistant name={self.name} id={self.id}>"

#
# def run(self, thread: Optional["Thread"]) -> "Thread":
# from phi.assistant.thread import Thread
# from phi.assistant.openai.thread import Thread
#
# return Thread(assistant=self, thread=thread).run()

def print_response(self, message: str, markdown: bool = False) -> None:
"""Print a response from the assistant"""

from phi.assistant.thread import Thread
from phi.assistant.openai.thread import Thread

thread = Thread()
thread.print_response(message=message, assistant=self, markdown=markdown)
Expand All @@ -300,7 +300,7 @@ def cli_app(
exit_on: Tuple[str, ...] = ("exit", "bye"),
) -> None:
from rich.prompt import Prompt
from phi.assistant.thread import Thread
from phi.assistant.openai.thread import Thread

thread = Thread()
while True:
Expand Down
File renamed without changes.
1 change: 1 addition & 0 deletions phi/assistant/openai/file/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from phi.assistant.openai.file.file import File
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from pydantic import BaseModel, ConfigDict

from phi.assistant.exceptions import FileIdNotSet
from phi.assistant.openai.exceptions import FileIdNotSet
from phi.utils.log import logger

try:
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pathlib import Path
from typing import Any, Optional

from phi.assistant.file import File
from phi.assistant.openai.file import File
from phi.utils.log import logger


Expand Down
4 changes: 2 additions & 2 deletions phi/assistant/message.py → phi/assistant/openai/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

from pydantic import BaseModel, ConfigDict

from phi.assistant.file import File
from phi.assistant.exceptions import ThreadIdNotSet, MessageIdNotSet
from phi.assistant.openai.file import File
from phi.assistant.openai.exceptions import ThreadIdNotSet, MessageIdNotSet
from phi.utils.log import logger

try:
Expand Down
8 changes: 4 additions & 4 deletions phi/assistant/row.py → phi/assistant/openai/row.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@


class AssistantRow(BaseModel):
"""Interface between Assistant class and the database"""
"""Interface between OpenAiAssistant class and the database"""

# Assistant id which can be referenced in API endpoints.
# OpenAiAssistant id which can be referenced in API endpoints.
id: str
# The object type, which is always assistant.
object: str
Expand All @@ -18,13 +18,13 @@ class AssistantRow(BaseModel):
instructions: Optional[str] = None
# LLM data (name, model, etc.)
llm: Optional[Dict[str, Any]] = None
# Assistant Tools
# OpenAiAssistant Tools
tools: Optional[List[Dict[str, Any]]] = None
# Files attached to this assistant.
files: Optional[List[Dict[str, Any]]] = None
# Metadata attached to this assistant.
metadata: Optional[Dict[str, Any]] = None
# Assistant Memory
# OpenAiAssistant Memory
memory: Optional[Dict[str, Any]] = None
# True if this assistant is active
is_active: Optional[bool] = None
Expand Down
22 changes: 11 additions & 11 deletions phi/assistant/run.py → phi/assistant/openai/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

from pydantic import BaseModel, ConfigDict, model_validator

from phi.assistant.assistant import Assistant
from phi.assistant.exceptions import ThreadIdNotSet, AssistantIdNotSet, RunIdNotSet
from phi.assistant.openai.assistant import OpenAiAssistant
from phi.assistant.openai.exceptions import ThreadIdNotSet, AssistantIdNotSet, RunIdNotSet
from phi.tools import Tool, ToolRegistry
from phi.tools.function import Function
from phi.utils.functions import get_function_call
Expand Down Expand Up @@ -33,8 +33,8 @@ class Run(BaseModel):

# The ID of the thread that was executed on as a part of this run.
thread_id: Optional[str] = None
# Assistant used for this run
assistant: Optional[Assistant] = None
# OpenAiAssistant used for this run
assistant: Optional[OpenAiAssistant] = None
# The ID of the assistant used for execution of this run.
assistant_id: Optional[str] = None

Expand Down Expand Up @@ -106,11 +106,11 @@ def extract_functions_from_tools(self) -> "Run":
self.functions = {}
if isinstance(tool, ToolRegistry):
self.functions.update(tool.functions)
logger.debug(f"Functions from {tool.name} added to Assistant.")
logger.debug(f"Functions from {tool.name} added to OpenAiAssistant.")
elif callable(tool):
f = Function.from_callable(tool)
self.functions[f.name] = f
logger.debug(f"Function {f.name} added to Assistant")
logger.debug(f"Function {f.name} added to OpenAiAssistant")
return self

def load_from_openai(self, openai_run: OpenAIRun):
Expand Down Expand Up @@ -147,7 +147,7 @@ def get_tools_for_api(self) -> Optional[List[Dict[str, Any]]]:
return tools_for_api

def create(
self, thread_id: Optional[str] = None, assistant: Optional[Assistant] = None, assistant_id: Optional[str] = None
self, thread_id: Optional[str] = None, assistant: Optional[OpenAiAssistant] = None, assistant_id: Optional[str] = None
) -> "Run":
_thread_id = thread_id or self.thread_id
if _thread_id is None:
Expand All @@ -157,7 +157,7 @@ def create(
if _assistant_id is None:
_assistant_id = self.assistant.get_id() if self.assistant is not None else self.assistant_id
if _assistant_id is None:
raise AssistantIdNotSet("Assistant.id not set")
raise AssistantIdNotSet("OpenAiAssistant.id not set")

request_body: Dict[str, Any] = {}
if self.model is not None:
Expand Down Expand Up @@ -206,7 +206,7 @@ def get_or_create(
self,
use_cache: bool = True,
thread_id: Optional[str] = None,
assistant: Optional[Assistant] = None,
assistant: Optional[OpenAiAssistant] = None,
assistant_id: Optional[str] = None,
) -> "Run":
try:
Expand Down Expand Up @@ -264,7 +264,7 @@ def wait(
def run(
self,
thread_id: Optional[str] = None,
assistant: Optional[Assistant] = None,
assistant: Optional[OpenAiAssistant] = None,
assistant_id: Optional[str] = None,
wait: bool = True,
callback: Optional[Callable[[OpenAIRun], None]] = None,
Expand All @@ -284,7 +284,7 @@ def run(
# -*- Check if run requires action
if self.status == "requires_action":
if self.assistant is None:
logger.warning("Assistant not available to complete required_action")
logger.warning("OpenAiAssistant not available to complete required_action")
return self
if self.required_action is not None:
if self.required_action.type == "submit_tool_outputs":
Expand Down
Loading

0 comments on commit 18d2bef

Please sign in to comment.