diff --git a/phi/assistant/assistant.py b/phi/assistant/assistant.py index 4e7db1c3e..ba9672e1c 100644 --- a/phi/assistant/assistant.py +++ b/phi/assistant/assistant.py @@ -1,10 +1,12 @@ -from typing import List, Any, Optional, Dict +from typing import List, Any, Optional, Dict, Union, Callable -from pydantic import BaseModel, ConfigDict, field_validator +from pydantic import BaseModel, ConfigDict, field_validator, model_validator from phi.assistant.file import File from phi.assistant.tool import Tool +from phi.assistant.tool.registry import ToolRegistry from phi.assistant.row import AssistantRow +from phi.assistant.function import Function, FunctionCall from phi.assistant.storage import AssistantStorage from phi.assistant.exceptions import AssistantIdNotSet from phi.knowledge.base import KnowledgeBase @@ -40,7 +42,9 @@ class Assistant(BaseModel): # -*- Assistant Tools # A list of tool enabled on the assistant. There can be a maximum of 128 tools per assistant. # Tools can be of types code_interpreter, retrieval, or function. - tools: Optional[List[Tool | Dict]] = None + tools: Optional[List[Union[Tool, Dict, Callable, ToolRegistry]]] = None + # Functions the Assistant may call. + _function_map: Optional[Dict[str, Function]] = None # -*- Assistant Files # A list of file IDs attached to this assistant. @@ -90,6 +94,26 @@ def set_log_level(cls, v: bool) -> bool: def client(self) -> OpenAI: return self.openai or OpenAI() + def add_function(self, f: Function) -> None: + if self._function_map is None: + self._function_map = {} + self._function_map[f.name] = f + logger.debug(f"Added function {f.name} to Assistant") + + @model_validator(mode="after") + def add_functions_to_assistant(self) -> "Assistant": + if self.tools is not None: + for tool in self.tools: + if callable(tool): + f = Function.from_callable(tool) + self.add_function(f) + elif isinstance(tool, ToolRegistry): + if self._function_map is None: + self._function_map = {} + self._function_map.update(tool.functions) + logger.debug(f"Tools from {tool.name} added to Assistant.") + return self + def load_from_storage(self): pass @@ -98,6 +122,7 @@ def load_from_openai(self, openai_assistant: OpenAIAssistant): self.object = openai_assistant.object self.created_at = openai_assistant.created_at self.file_ids = openai_assistant.file_ids + self.openai_assistant = openai_assistant def create(self) -> "Assistant": request_body: Dict[str, Any] = {} @@ -112,8 +137,14 @@ def create(self) -> "Assistant": for _tool in self.tools: if isinstance(_tool, Tool): _tools.append(_tool.to_dict()) - else: + elif isinstance(_tool, dict): _tools.append(_tool) + elif callable(_tool): + func = Function.from_callable(_tool) + _tools.append({"type": "function", "function": func.to_dict()}) + elif isinstance(_tool, ToolRegistry): + for _f in _tool.functions.values(): + _tools.append({"type": "function", "function": _f.to_dict()}) request_body["tools"] = _tools if self.file_ids is not None or self.files is not None: _file_ids = self.file_ids or [] @@ -139,10 +170,7 @@ def get_id(self) -> Optional[str]: _id = self.id return _id - def get(self, use_cache: bool = True) -> "Assistant": - if self.openai_assistant is not None and use_cache: - return self - + def get_from_openai(self) -> OpenAIAssistant: _assistant_id = self.get_id() if _assistant_id is None: raise AssistantIdNotSet("Assistant.id not set") @@ -151,6 +179,13 @@ def get(self, use_cache: bool = True) -> "Assistant": assistant_id=_assistant_id, ) self.load_from_openai(self.openai_assistant) + return self.openai_assistant + + def get(self, use_cache: bool = True) -> "Assistant": + if self.openai_assistant is not None and use_cache: + return self + + self.get_from_openai() return self def get_or_create(self, use_cache: bool = True) -> "Assistant": @@ -161,7 +196,7 @@ def get_or_create(self, use_cache: bool = True) -> "Assistant": def update(self) -> "Assistant": try: - assistant_to_update = self.get() + assistant_to_update = self.get_from_openai() if assistant_to_update is not None: request_body: Dict[str, Any] = {} if self.name is not None: @@ -175,8 +210,14 @@ def update(self) -> "Assistant": for _tool in self.tools: if isinstance(_tool, Tool): _tools.append(_tool.to_dict()) - else: + elif isinstance(_tool, dict): _tools.append(_tool) + elif callable(_tool): + func = Function.from_callable(_tool) + _tools.append({"type": "function", "function": func.to_dict()}) + elif isinstance(_tool, ToolRegistry): + for _f in _tool.functions.values(): + _tools.append({"type": "function", "function": _f.to_dict()}) request_body["tools"] = _tools if self.file_ids is not None or self.files is not None: _file_ids = self.file_ids or [] @@ -201,7 +242,7 @@ def update(self) -> "Assistant": def delete(self) -> OpenAIAssistantDeleted: try: - assistant_to_delete = self.get() + assistant_to_delete = self.get_from_openai() if assistant_to_delete is not None: deletion_status = self.client.beta.assistants.delete( assistant_id=assistant_to_delete.id, @@ -239,3 +280,58 @@ def __str__(self) -> str: import json return json.dumps(self.to_dict(), indent=4) + + def get_function_call(self, name: str, arguments: Optional[str] = None) -> Optional[FunctionCall]: + import json + + logger.debug(f"Getting function {name}. Args: {arguments}") + if self._function_map is None: + return None + + function_to_call: Optional[Function] = None + if name in self._function_map: + function_to_call = self._function_map[name] + if function_to_call is None: + logger.error(f"Function {name} not found") + return None + + function_call = FunctionCall(function=function_to_call) + if arguments is not None and arguments != "": + try: + if "None" in arguments: + arguments = arguments.replace("None", "null") + if "True" in arguments: + arguments = arguments.replace("True", "true") + if "False" in arguments: + arguments = arguments.replace("False", "false") + _arguments = json.loads(arguments) + except Exception as e: + logger.error(f"Unable to decode function arguments {arguments}: {e}") + return None + + if not isinstance(_arguments, dict): + logger.error(f"Function arguments {arguments} is not a valid JSON object") + return None + + try: + clean_arguments: Dict[str, Any] = {} + for k, v in _arguments.items(): + if isinstance(v, str): + _v = v.strip().lower() + if _v in ("none", "null"): + clean_arguments[k] = None + elif _v == "true": + clean_arguments[k] = True + elif _v == "false": + clean_arguments[k] = False + else: + clean_arguments[k] = v.strip() + else: + clean_arguments[k] = v + + function_call.arguments = clean_arguments + except Exception as e: + logger.error(f"Unable to parse function arguments {arguments}: {e}") + return None + + return function_call diff --git a/phi/assistant/file/__init__.py b/phi/assistant/file/__init__.py new file mode 100644 index 000000000..68f5c3916 --- /dev/null +++ b/phi/assistant/file/__init__.py @@ -0,0 +1 @@ +from phi.assistant.file.file import File diff --git a/phi/assistant/file.py b/phi/assistant/file/file.py similarity index 100% rename from phi/assistant/file.py rename to phi/assistant/file/file.py diff --git a/phi/assistant/function.py b/phi/assistant/function.py new file mode 100644 index 000000000..ffb90a4e8 --- /dev/null +++ b/phi/assistant/function.py @@ -0,0 +1,97 @@ +from typing import Any, Dict, Optional, Callable, get_type_hints +from pydantic import BaseModel, validate_call + +from phi.utils.log import logger + + +class Tool(BaseModel): + """Model for Assistant Tools""" + + # The type of tool + type: str + function: Optional[Dict[str, Any]] = None + + def to_dict(self) -> Dict[str, Any]: + return self.model_dump(exclude_none=True) + + +class Function(BaseModel): + """Model for Assistant functions""" + + # The name of the function to be called. + # Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64. + name: str + # A description of what the function does, used by the model to choose when and how to call the function. + description: Optional[str] = None + # The parameters the functions accepts, described as a JSON Schema object. + # To describe a function that accepts no parameters, provide the value {"type": "object", "properties": {}}. + parameters: Dict[str, Any] = {"type": "object", "properties": {}} + entrypoint: Optional[Callable] = None + + def to_dict(self) -> Dict[str, Any]: + return self.model_dump(exclude_none=True, exclude={"entrypoint"}) + + @classmethod + def from_callable(cls, c: Callable) -> "Function": + from inspect import getdoc + from phi.utils.json_schema import get_json_schema + + parameters = {"type": "object", "properties": {}} + try: + type_hints = get_type_hints(c) + parameters = get_json_schema(type_hints) + # logger.debug(f"Type hints for {c.__name__}: {type_hints}") + except Exception as e: + logger.warning(f"Could not parse args for {c.__name__}: {e}") + + return cls( + name=c.__name__, + description=getdoc(c), + parameters=parameters, + entrypoint=validate_call(c), + ) + + +class FunctionCall(BaseModel): + """Model for Assistant function calls""" + + # The function to be called. + function: Function + # The arguments to call the function with. + arguments: Optional[Dict[str, Any]] = None + # The result of the function call. + result: Optional[Any] = None + + def get_call_str(self) -> str: + """Returns a string representation of the function call.""" + if self.arguments is None: + return f"{self.function.name}()" + return f"{self.function.name}({', '.join([f'{k}={v}' for k, v in self.arguments.items()])})" + + def execute(self) -> bool: + """Runs the function call. + + @return: True if the function call was successful, False otherwise. + """ + if self.function.entrypoint is None: + return False + + logger.debug(f"Running: {self.get_call_str()}") + + # Call the function with no arguments if none are provided. + if self.arguments is None: + try: + self.result = self.function.entrypoint() + return True + except Exception as e: + logger.warning(f"Could not run function {self.get_call_str()}") + logger.error(e) + return False + + try: + self.result = self.function.entrypoint(**self.arguments) + return True + except Exception as e: + logger.warning(f"Could not run function {self.get_call_str()}") + logger.error(e) + return False diff --git a/phi/assistant/message.py b/phi/assistant/message.py index 5b984a992..9ac32d2f7 100644 --- a/phi/assistant/message.py +++ b/phi/assistant/message.py @@ -1,4 +1,5 @@ -from typing import List, Any, Optional, Dict +from typing import List, Any, Optional, Dict, Union +from typing_extensions import Literal from pydantic import BaseModel, ConfigDict @@ -22,9 +23,9 @@ class Message(BaseModel): object: Optional[str] = None # The entity that produced the message. One of user or assistant. - role: Optional[str] = None + role: Optional[Literal["user", "assistant"]] = None # The content of the message in array of text and/or images. - content: List[Any | Content] | str + content: Optional[Union[List[Content], str]] = None # The thread ID that this message belongs to. # Required to create/get a message. @@ -57,17 +58,27 @@ class Message(BaseModel): def client(self) -> OpenAI: return self.openai or OpenAI() + @classmethod + def from_openai(cls, message: OpenAIThreadMessage) -> "Message": + _message = cls() + _message.load_from_openai(message) + return _message + def load_from_openai(self, openai_message: OpenAIThreadMessage): self.id = openai_message.id - self.object = openai_message.object - self.role = openai_message.role + self.assistant_id = openai_message.assistant_id self.content = openai_message.content self.created_at = openai_message.created_at + self.file_ids = openai_message.file_ids + self.object = openai_message.object + self.role = openai_message.role self.run_id = openai_message.run_id self.thread_id = openai_message.thread_id + self.openai_message = openai_message def create(self, thread_id: Optional[str] = None) -> "Message": - if thread_id is None and self.thread_id is None: + _thread_id = thread_id or self.thread_id + if _thread_id is None: raise ThreadIdNotSet("Thread.id not set") request_body: Dict[str, Any] = {} @@ -84,7 +95,7 @@ def create(self, thread_id: Optional[str] = None) -> "Message": raise TypeError("Message.content must be a string for create()") self.openai_message = self.client.beta.threads.messages.create( - thread_id=self.thread_id, role="user", content=self.content, **request_body + thread_id=_thread_id, role="user", content=self.content, **request_body ) self.load_from_openai(self.openai_message) logger.debug(f"Message created: {self.id}") @@ -93,11 +104,9 @@ def create(self, thread_id: Optional[str] = None) -> "Message": def get_id(self) -> Optional[str]: return self.id or self.openai_message.id if self.openai_message else None - def get(self, use_cache: bool = True, thread_id: Optional[str] = None) -> "Message": - if self.openai_message is not None and use_cache: - return self - - if thread_id is None and self.thread_id is None: + def get_from_openai(self, thread_id: Optional[str] = None) -> OpenAIThreadMessage: + _thread_id = thread_id or self.thread_id + if _thread_id is None: raise ThreadIdNotSet("Thread.id not set") _message_id = self.get_id() @@ -105,10 +114,17 @@ def get(self, use_cache: bool = True, thread_id: Optional[str] = None) -> "Messa raise MessageIdNotSet("Message.id not set") self.openai_message = self.client.beta.threads.messages.retrieve( - thread_id=self.thread_id, + thread_id=_thread_id, message_id=_message_id, ) self.load_from_openai(self.openai_message) + return self.openai_message + + def get(self, use_cache: bool = True, thread_id: Optional[str] = None) -> "Message": + if self.openai_message is not None and use_cache: + return self + + self.get_from_openai(thread_id=thread_id) return self def get_or_create(self, use_cache: bool = True, thread_id: Optional[str] = None) -> "Message": @@ -119,12 +135,18 @@ def get_or_create(self, use_cache: bool = True, thread_id: Optional[str] = None) def update(self, thread_id: Optional[str] = None) -> "Message": try: - message_to_update = self.get(thread_id=thread_id) + message_to_update = self.get_from_openai(thread_id=thread_id) if message_to_update is not None: request_body: Dict[str, Any] = {} if self.metadata is not None: request_body["metadata"] = self.metadata + if message_to_update.id is None: + raise MessageIdNotSet("Message.id not set") + + if message_to_update.thread_id is None: + raise ThreadIdNotSet("Thread.id not set") + self.openai_message = self.client.beta.threads.messages.update( thread_id=message_to_update.thread_id, message_id=message_to_update.id, @@ -139,7 +161,20 @@ def update(self, thread_id: Optional[str] = None) -> "Message": def to_dict(self) -> Dict[str, Any]: return self.model_dump( - exclude_none=True, include={"id", "object", "role", "content", "file_ids", "files", "metadata"} + exclude_none=True, + include={ + "id", + "object", + "role", + "content", + "file_ids", + "files", + "metadata", + "created_at", + "thread_id", + "assistant_id", + "run_id", + }, ) def pprint(self): diff --git a/phi/assistant/run.py b/phi/assistant/run.py index ff915408e..3138dc922 100644 --- a/phi/assistant/run.py +++ b/phi/assistant/run.py @@ -1,16 +1,24 @@ -from typing import Any, Optional, Dict, List +from typing import Any, Optional, Dict, List, Union, Callable, cast from typing_extensions import Literal -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, model_validator from phi.assistant.tool import Tool +from phi.assistant.tool.registry import ToolRegistry +from phi.assistant.function import Function from phi.assistant.assistant import Assistant from phi.assistant.exceptions import ThreadIdNotSet, AssistantIdNotSet, RunIdNotSet from phi.utils.log import logger try: from openai import OpenAI - from openai.types.beta.threads.run import Run as OpenAIRun + from openai.types.beta.threads.run import ( + Run as OpenAIRun, + RequiredAction, + LastError, + ) + from openai.types.beta.threads.required_action_function_tool_call import RequiredActionFunctionToolCall + from openai.types.beta.threads.run_submit_tool_outputs_params import ToolOutput except ImportError: logger.error("`openai` not installed") raise @@ -33,17 +41,12 @@ class Run(BaseModel): # The status of the run, which can be either # queued, in_progress, requires_action, cancelling, cancelled, failed, completed, or expired. status: Optional[ - str - | Literal[ - "queued", "in_progress", "requires_action", "cancelling", "cancelled", "failed", "completed", "expired" - ] + Literal["queued", "in_progress", "requires_action", "cancelling", "cancelled", "failed", "completed", "expired"] ] = None # Details on the action required to continue the run. Will be null if no action is required. - required_action: Optional[Dict[str, Any]] = None + required_action: Optional[RequiredAction] = None - # True if this run is active - is_active: bool = True # The Unix timestamp (in seconds) for when the run was created. created_at: Optional[int] = None # The Unix timestamp (in seconds) for when the run was started. @@ -69,10 +72,12 @@ class Run(BaseModel): instructions: Optional[str] = None # Override the tools the assistant can use for this run. # This is useful for modifying the behavior on a per-run basis. - tools: Optional[List[Tool | Dict]] = None + tools: Optional[List[Union[Tool, Dict, Callable, ToolRegistry]]] = None + # Functions the Run may call. + _function_map: Optional[Dict[str, Function]] = None # The last error associated with this run. Will be null if there are no errors. - last_error: Optional[Dict[str, Any]] = None + last_error: Optional[LastError] = None # Set of 16 key-value pairs that can be attached to an object. # This can be useful for storing additional information about the object in a structured format. @@ -93,6 +98,26 @@ class Run(BaseModel): def client(self) -> OpenAI: return self.openai or OpenAI() + def add_function(self, f: Function) -> None: + if self._function_map is None: + self._function_map = {} + self._function_map[f.name] = f + logger.debug(f"Added function {f.name} to Run") + + @model_validator(mode="after") + def add_functions_to_assistant(self) -> "Run": + if self.tools is not None: + for tool in self.tools: + if callable(tool): + f = Function.from_callable(tool) + self.add_function(f) + elif isinstance(tool, ToolRegistry): + if self._function_map is None: + self._function_map = {} + self._function_map.update(tool.functions) + logger.debug(f"Tools from {tool.name} added to Assistant.") + return self + def load_from_storage(self): pass @@ -102,7 +127,6 @@ def load_from_openai(self, openai_run: OpenAIRun): self.status = openai_run.status self.required_action = openai_run.required_action self.last_error = openai_run.last_error - self.is_active = openai_run.is_active self.created_at = openai_run.created_at self.started_at = openai_run.started_at self.expires_at = openai_run.expires_at @@ -110,14 +134,19 @@ def load_from_openai(self, openai_run: OpenAIRun): self.failed_at = openai_run.failed_at self.completed_at = openai_run.completed_at self.file_ids = openai_run.file_ids + self.openai_run = openai_run def create( self, thread_id: Optional[str] = None, assistant: Optional[Assistant] = None, assistant_id: Optional[str] = None ) -> "Run": - if thread_id is None and self.thread_id is None: + _thread_id = thread_id or self.thread_id + if _thread_id is None: raise ThreadIdNotSet("Thread.id not set") - if (assistant is None or assistant.id is None) and assistant_id is None and self.assistant_id is None: + _assistant_id = assistant.get_id() if assistant is not None else assistant_id + if _assistant_id is None: + _assistant_id = self.assistant.get_id() if self.assistant is not None else self.assistant_id + if _assistant_id is None: raise AssistantIdNotSet("Assistant.id not set") request_body: Dict[str, Any] = {} @@ -130,14 +159,20 @@ def create( for _tool in self.tools: if isinstance(_tool, Tool): _tools.append(_tool.to_dict()) - else: + elif isinstance(_tool, dict): _tools.append(_tool) + elif callable(_tool): + func = Function.from_callable(_tool) + _tools.append({"type": "function", "function": func.to_dict()}) + elif isinstance(_tool, ToolRegistry): + for _f in _tool.functions.values(): + _tools.append({"type": "function", "function": _f.to_dict()}) request_body["tools"] = _tools if self.metadata is not None: request_body["metadata"] = self.metadata self.openai_run = self.client.beta.threads.runs.create( - thread_id=self.thread_id, assistant_id=self.assistant_id, **request_body + thread_id=_thread_id, assistant_id=_assistant_id, **request_body ) self.load_from_openai(self.openai_run) logger.debug(f"Run created: {self.id}") @@ -150,11 +185,9 @@ def get_id(self) -> Optional[str]: _id = self.id return _id - def get(self, use_cache: bool = True, thread_id: Optional[str] = None) -> "Run": - if self.openai_run is not None and use_cache: - return self - - if thread_id is None and self.thread_id is None: + def get_from_openai(self, thread_id: Optional[str] = None) -> OpenAIRun: + _thread_id = thread_id or self.thread_id + if _thread_id is None: raise ThreadIdNotSet("Thread.id not set") _run_id = self.get_id() @@ -162,10 +195,17 @@ def get(self, use_cache: bool = True, thread_id: Optional[str] = None) -> "Run": raise RunIdNotSet("Run.id not set") self.openai_run = self.client.beta.threads.runs.retrieve( - thread_id=self.thread_id, + thread_id=_thread_id, run_id=_run_id, ) self.load_from_openai(self.openai_run) + return self.openai_run + + def get(self, use_cache: bool = True, thread_id: Optional[str] = None) -> "Run": + if self.openai_run is not None and use_cache: + return self + + self.get_from_openai(thread_id=thread_id) return self def get_or_create( @@ -182,7 +222,7 @@ def get_or_create( def update(self, thread_id: Optional[str] = None) -> "Run": try: - run_to_update = self.get(thread_id=thread_id) + run_to_update = self.get_from_openai(thread_id=thread_id) if run_to_update is not None: request_body: Dict[str, Any] = {} if self.metadata is not None: @@ -200,20 +240,94 @@ def update(self, thread_id: Optional[str] = None) -> "Run": logger.warning("Message not available") raise - def wait_for_completion(self, timeout: Optional[int] = None) -> OpenAIRun: + def wait( + self, + interval: int = 1, + timeout: Optional[int] = None, + thread_id: Optional[str] = None, + status: Optional[List[str]] = None, + callback: Optional[Callable[[OpenAIRun], None]] = None, + ) -> bool: import time + status_to_wait = status or ["requires_action", "cancelling", "cancelled", "failed", "completed", "expired"] start_time = time.time() while True: logger.debug(f"Waiting for run {self.id} to complete") - run = self.get(use_cache=False) - logger.debug(f"Run {run.id}: {run}") - logger.debug(f"Run {run.id} status: {run.status}") - if run.status == "completed": - return run + run = self.get_from_openai(thread_id=thread_id) + logger.debug(f"Run {run.id} {run.status}") + if callback is not None: + callback(run) + if run.status in status_to_wait: + return True if timeout is not None and time.time() - start_time > timeout: - raise TimeoutError(f"Run {run.id} did not complete within {timeout} seconds") - time.sleep(1) + logger.error(f"Run {run.id} did not complete within {timeout} seconds") + return False + # raise TimeoutError(f"Run {run.id} did not complete within {timeout} seconds") + time.sleep(interval) + + def run( + self, + thread_id: Optional[str] = None, + assistant: Optional[Assistant] = None, + assistant_id: Optional[str] = None, + wait: bool = True, + callback: Optional[Callable[[OpenAIRun], None]] = None, + ) -> "Run": + # Update Run with new values + self.thread_id = thread_id or self.thread_id + self.assistant = assistant or self.assistant + self.assistant_id = assistant_id or self.assistant_id + + # Create Run + self.create() + + run_completed = not wait + while not run_completed: + self.wait(callback=callback) + + # -*- Check if run requires action + if self.status == "requires_action": + if self.assistant is None: + logger.warning("Assistant not available to complete required_action") + return self + if self.required_action is not None: + if self.required_action.type == "submit_tool_outputs": + tool_calls: List[ + RequiredActionFunctionToolCall + ] = self.required_action.submit_tool_outputs.tool_calls + + tool_outputs = [] + for tool_call in tool_calls: + if tool_call.type == "function": + function_call = self.assistant.get_function_call( + name=tool_call.function.name, arguments=tool_call.function.arguments + ) + if function_call is None: + logger.error(f"Function {tool_call.function.name} not found") + continue + + # -*- Run function call + success = function_call.execute() + if not success: + logger.error(f"Function {tool_call.function.name} failed") + continue + + output = str(function_call.result) if function_call.result is not None else "" + tool_outputs.append(ToolOutput(tool_call_id=tool_call.id, output=output)) + + # -*- Submit tool outputs + _oai_run = cast(OpenAIRun, self.openai_run) + self.openai_run = self.client.beta.threads.runs.submit_tool_outputs( + thread_id=_oai_run.thread_id, + run_id=_oai_run.id, + tool_outputs=tool_outputs, + ) + + self.load_from_openai(self.openai_run) + else: + run_completed = True + return self def to_dict(self) -> Dict[str, Any]: return self.model_dump( diff --git a/phi/assistant/thread.py b/phi/assistant/thread.py index 6233029e9..41216180e 100644 --- a/phi/assistant/thread.py +++ b/phi/assistant/thread.py @@ -1,12 +1,13 @@ -from typing import Any, Optional, Dict, List +from typing import Any, Optional, Dict, List, Union, Callable from pydantic import BaseModel, ConfigDict from phi.assistant.run import Run from phi.assistant.message import Message from phi.assistant.assistant import Assistant -from phi.assistant.exceptions import AssistantIdNotSet, ThreadIdNotSet +from phi.assistant.exceptions import ThreadIdNotSet from phi.utils.log import logger +from phi.utils.timer import Timer try: from openai import OpenAI @@ -26,7 +27,7 @@ class Thread(BaseModel): object: Optional[str] = None # A list of messages in this thread. - messages: List[Message | Dict] = [] + messages: List[Union[Message, Dict]] = [] # Assistant used for this thread assistant: Optional[Assistant] = None @@ -60,8 +61,9 @@ def load_from_openai(self, openai_thread: OpenAIThread): self.id = openai_thread.id self.object = openai_thread.object self.created_at = openai_thread.created_at + self.openai_thread = openai_thread - def create(self, messages: Optional[List[Message | Dict]] = None) -> "Thread": + def create(self, messages: Optional[List[Union[Message, Dict]]] = None) -> "Thread": request_body: Dict[str, Any] = {} if messages is not None: _messages = [] @@ -77,7 +79,7 @@ def create(self, messages: Optional[List[Message | Dict]] = None) -> "Thread": self.openai_thread = self.client.beta.threads.create(**request_body) self.load_from_openai(self.openai_thread) logger.debug(f"Thread created: {self.id}") - return self.openai_thread + return self def get_id(self) -> Optional[str]: _id = self.id or self.openai_thread.id if self.openai_thread else None @@ -86,10 +88,7 @@ def get_id(self) -> Optional[str]: _id = self.id return _id - def get(self, use_cache: bool = True) -> "Thread": - if self.openai_thread is not None and use_cache: - return self - + def get_from_openai(self) -> OpenAIThread: _thread_id = self.get_id() if _thread_id is None: raise ThreadIdNotSet("Thread.id not set") @@ -98,9 +97,16 @@ def get(self, use_cache: bool = True) -> "Thread": thread_id=_thread_id, ) self.load_from_openai(self.openai_thread) + return self.openai_thread + + def get(self, use_cache: bool = True) -> "Thread": + if self.openai_thread is not None and use_cache: + return self + + self.get_from_openai() return self - def get_or_create(self, use_cache: bool = True, messages: Optional[List[Message | Dict]] = None) -> "Thread": + def get_or_create(self, use_cache: bool = True, messages: Optional[List[Union[Message, Dict]]] = None) -> "Thread": try: return self.get(use_cache=use_cache) except ThreadIdNotSet: @@ -108,7 +114,7 @@ def get_or_create(self, use_cache: bool = True, messages: Optional[List[Message def update(self) -> "Thread": try: - thread_to_update = self.get() + thread_to_update = self.get_from_openai() if thread_to_update is not None: request_body: Dict[str, Any] = {} if self.metadata is not None: @@ -127,7 +133,7 @@ def update(self) -> "Thread": def delete(self) -> OpenAIThreadDeleted: try: - thread_to_delete = self.get() + thread_to_delete = self.get_from_openai() if thread_to_delete is not None: deletion_status = self.client.beta.threads.delete( thread_id=thread_to_delete.id, @@ -138,65 +144,60 @@ def delete(self) -> OpenAIThreadDeleted: logger.warning("Thread not available") raise - def add_message(self, message: Message | Dict) -> Message: + def add_message(self, message: Union[Message, Dict]) -> None: try: message = message if isinstance(message, Message) else Message(**message) except Exception as e: logger.error(f"Error creating Message: {e}") raise - message.thread_id = self.id message.create() - return message - def add(self, messages: List[Message | Dict]) -> List[Message]: + def add(self, messages: List[Union[Message, Dict]]) -> None: existing_thread = self.get_id() is not None if existing_thread: for message in messages: self.add_message(message=message) else: self.create(messages=messages) - return self.messages - def create_run( - self, assistant_id: Optional[str] = None, run: Optional[Run] = None, use_cache: bool = True, **kwargs + def run( + self, + run: Optional[Run] = None, + assistant: Optional[Assistant] = None, + assistant_id: Optional[str] = None, + wait: bool = True, + callback: Optional[Callable] = None, ) -> Run: try: - thread_to_run = self.get(use_cache=use_cache) + _thread_id = self.get_id() + if _thread_id is None: + _thread_id = self.get_from_openai().id except ThreadIdNotSet: logger.warning("Thread not available") raise - if assistant_id is None: - assistant_id = self.assistant_id - if assistant_id is None: - raise AssistantIdNotSet("Assistant.id not set") - - try: - if run is None: - run = Run(**kwargs) - except Exception as e: - logger.error(f"Error creating run: {e}") - raise - - run.thread_id = thread_to_run.id - run.assistant_id = assistant_id + _assistant = assistant or self.assistant + _assistant_id = assistant_id or self.assistant_id - logger.debug(f"Creating run: {run}") - run.create() - return run + _run = run or Run() + return _run.run( + thread_id=_thread_id, assistant=_assistant, assistant_id=_assistant_id, wait=wait, callback=callback + ) - def get_messages(self, use_cache: bool = True) -> List[Message]: + def get_messages(self) -> List[Message]: try: - thread_to_read = self.get(use_cache=use_cache) + _thread_id = self.get_id() + if _thread_id is None: + _thread_id = self.get_from_openai().id except ThreadIdNotSet: logger.warning("Thread not available") raise thread_messages = self.client.beta.threads.messages.list( - thread_id=thread_to_read.id, + thread_id=_thread_id, ) - return [Message(**message.model_dump()) for message in thread_messages] + return [Message.from_openai(message=message) for message in thread_messages] def to_dict(self) -> Dict[str, Any]: return self.model_dump(exclude_none=True, include={"id", "object", "messages", "metadata"}) @@ -207,6 +208,47 @@ def pprint(self): pprint(self.to_dict()) + def print_response(self, message: str, assistant: Assistant) -> None: + from phi.cli.console import console + from rich.table import Table + from rich.box import ROUNDED + from rich.markdown import Markdown + + # Start the response timer + response_timer = Timer() + response_timer.start() + + # Add the message to the thread + self.add(messages=[Message(role="user", content=message)]) + + # Run the assistant + self.run(assistant=assistant) + + # Stop the response timer + response_timer.stop() + + # Get the messages from the thread + messages = self.get_messages() + + # Get the assistant response + assistant_response: str = "" + for m in messages: + oai_message = m.openai_message + if oai_message and oai_message.role == "assistant": + for content in oai_message.content: + if content.type == "text": + text = content.text + assistant_response += text.value + break + + # Convert to markdown + md_response = Markdown(assistant_response) + table = Table(box=ROUNDED, border_style="blue") + table.add_column("Message") + table.add_column(message) + table.add_row(f"Response\n({response_timer.elapsed:.1f}s)", md_response) + console.print(table) + def __str__(self) -> str: import json diff --git a/phi/assistant/tool/__init__.py b/phi/assistant/tool/__init__.py new file mode 100644 index 000000000..01bde82b3 --- /dev/null +++ b/phi/assistant/tool/__init__.py @@ -0,0 +1 @@ +from phi.assistant.tool.tool import Tool diff --git a/phi/assistant/tool/registry.py b/phi/assistant/tool/registry.py new file mode 100644 index 000000000..9c562fc01 --- /dev/null +++ b/phi/assistant/tool/registry.py @@ -0,0 +1,27 @@ +from collections import OrderedDict +from typing import Callable, Dict + +from phi.assistant.function import Function +from phi.utils.log import logger + + +class ToolRegistry: + def __init__(self, name: str = "default_registry"): + self.name: str = name + self.functions: Dict[str, Function] = OrderedDict() + + def register(self, function: Callable): + try: + f = Function.from_callable(function) + self.functions[f.name] = f + logger.debug(f"Function: {f.name} registered with {self.name}") + logger.debug(f"Json Schema: {f.to_dict()}") + except Exception as e: + logger.warning(f"Failed to create Function for: {function.__name__}") + raise e + + def __repr__(self): + return f"<{self.__class__.__name__} name={self.name} functions={list(self.functions.keys())}>" + + def __str__(self): + return self.__repr__() diff --git a/phi/assistant/tool/shell.py b/phi/assistant/tool/shell.py new file mode 100644 index 000000000..8f2047319 --- /dev/null +++ b/phi/assistant/tool/shell.py @@ -0,0 +1,34 @@ +from typing import List + +from phi.assistant.tool.registry import ToolRegistry +from phi.utils.log import logger + + +class ShellTools(ToolRegistry): + def __init__(self): + super().__init__(name="shell_tools") + self.register(self.run_shell_command) + + def run_shell_command(self, args: List[str], tail: int = 100) -> str: + """Runs a shell command and returns the output or error. + + :param args: The command to run as a list of strings. + :param tail: The number of lines to return from the output. + :return: The output of the command. + """ + logger.info(f"Running shell command: {args}") + + import subprocess + + try: + result = subprocess.run(args, capture_output=True, text=True) + logger.debug(f"Result: {result}") + logger.debug(f"Return code: {result.returncode}") + if result.returncode != 0: + return f"Error: {result.stderr}" + + # return only the last n lines of the output + return "\n".join(result.stdout.split("\n")[-tail:]) + except Exception as e: + logger.warning(f"Failed to run shell command: {e}") + return f"Error: {e}" diff --git a/phi/assistant/tool.py b/phi/assistant/tool/tool.py similarity index 100% rename from phi/assistant/tool.py rename to phi/assistant/tool/tool.py diff --git a/phi/conversation/conversation.py b/phi/conversation/conversation.py index d12232eca..234bbec0f 100644 --- a/phi/conversation/conversation.py +++ b/phi/conversation/conversation.py @@ -110,19 +110,19 @@ class Conversation(BaseModel): # -*- User prompt: provide the user prompt as a string or using a function # Note: this will ignore the message provided to the chat function - user_prompt: Optional[List[Dict] | str] = None + user_prompt: Optional[Union[List[Dict], str]] = None # Function to build the user prompt. # This function is provided the conversation and the user message as arguments - # and should return the user_prompt as a List[Dict] | str. + # and should return the user_prompt as a Union[List[Dict], str]. # If add_references_to_prompt is True, then references are also provided as an argument. # If add_chat_history_to_prompt is True, then chat_history is also provided as an argument. # Signature: # def custom_user_prompt_function( # conversation: Conversation, - # message: List[Dict] | str, + # message: Union[List[Dict], str], # references: Optional[str] = None, # chat_history: Optional[str] = None, - # ) -> List[Dict] | str: + # ) -> Union[List[Dict], str]: # ... user_prompt_function: Optional[Callable[..., str]] = None # If True, the conversation provides a default user prompt @@ -424,8 +424,8 @@ def get_formatted_chat_history(self) -> Optional[str]: return remove_indent(formatted_history) def get_user_prompt( - self, message: List[Dict] | str, references: Optional[str] = None, chat_history: Optional[str] = None - ) -> List[Dict] | str: + self, message: Union[List[Dict], str], references: Optional[str] = None, chat_history: Optional[str] = None + ) -> Union[List[Dict], str]: """Build the user prompt given a message, references and chat_history""" # If the user_prompt is set, return it @@ -487,7 +487,7 @@ def get_user_prompt( _user_prompt = cast(str, remove_indent(_user_prompt)) return _user_prompt - def get_text_from_message(self, message: List[Dict] | str) -> str: + def get_text_from_message(self, message: Union[List[Dict], str]) -> str: """Return the user texts from the message""" if isinstance(message, str): return message @@ -508,7 +508,7 @@ def get_text_from_message(self, message: List[Dict] | str) -> str: return "\n".join(text_messages) return "" - def _chat(self, message: List[Dict] | str, stream: bool = True) -> Iterator[str]: + def _chat(self, message: Union[List[Dict], str], stream: bool = True) -> Iterator[str]: logger.debug("*********** Conversation Chat Start ***********") # Load the conversation from the database if available self.read_from_storage() @@ -537,7 +537,7 @@ def _chat(self, message: List[Dict] | str, stream: bool = True) -> Iterator[str] user_prompt_chat_history = self.get_formatted_chat_history() # -*- Build the user prompt - user_prompt: List[Dict] | str = self.get_user_prompt( + user_prompt: Union[List[Dict], str] = self.get_user_prompt( message=message, references=user_prompt_references, chat_history=user_prompt_chat_history ) @@ -602,7 +602,7 @@ def _chat(self, message: List[Dict] | str, stream: bool = True) -> Iterator[str] yield llm_response logger.debug("*********** Conversation Chat End ***********") - def _chat_tasks(self, message: List[Dict] | str, stream: bool = True) -> Iterator[str]: + def _chat_tasks(self, message: Union[List[Dict], str], stream: bool = True) -> Iterator[str]: if self.tasks is None or len(self.tasks) == 0: return "" @@ -673,7 +673,7 @@ def _chat_tasks(self, message: List[Dict] | str, stream: bool = True) -> Iterato yield full_response logger.debug("*********** Conversation Tasks End ***********") - def chat(self, message: List[Dict] | str, stream: bool = True) -> Union[Iterator[str], str]: + def chat(self, message: Union[List[Dict], str], stream: bool = True) -> Union[Iterator[str], str]: if self.tasks and len(self.tasks) > 0: resp = self._chat_tasks(message=message, stream=stream) else: @@ -683,7 +683,7 @@ def chat(self, message: List[Dict] | str, stream: bool = True) -> Union[Iterator else: return next(resp) - def run(self, message: List[Dict] | str, stream: bool = True) -> Union[Iterator[str], str]: + def run(self, message: Union[List[Dict], str], stream: bool = True) -> Union[Iterator[str], str]: return self.chat(message=message, stream=stream) def _chat_raw( @@ -876,7 +876,7 @@ def search_knowledge_base(self, query: str) -> Optional[str]: # Print Response ########################################################################### - def print_response(self, message: List[Dict] | str, stream: bool = True) -> None: + def print_response(self, message: Union[List[Dict], str], stream: bool = True) -> None: from phi.cli.console import console from rich.live import Live from rich.table import Table diff --git a/phi/llm/schemas.py b/phi/llm/schemas.py index 3ff42972a..122b1ea9e 100644 --- a/phi/llm/schemas.py +++ b/phi/llm/schemas.py @@ -1,4 +1,4 @@ -from typing import Optional, Any, Dict, Callable, List, get_type_hints +from typing import Optional, Any, Dict, Callable, List, get_type_hints, Union from pydantic import BaseModel, validate_call from phi.utils.log import logger @@ -12,7 +12,7 @@ class Message(BaseModel): role: str # The contents of the message. content is required for all messages, # and may be null for assistant messages with function calls. - content: Optional[List[Dict] | str] = None + content: Optional[Union[List[Dict], str]] = None # The name of the author of this message. name is required if role is function, # and it should be the name of the function whose response is in the content. # May contain a-z, A-Z, 0-9, and underscores, with a maximum length of 64 characters. diff --git a/phi/llm/task/llm_task.py b/phi/llm/task/llm_task.py index 507da577b..ad1998b55 100644 --- a/phi/llm/task/llm_task.py +++ b/phi/llm/task/llm_task.py @@ -75,15 +75,15 @@ class LLMTask(BaseModel): # -*- User prompt: provide the user prompt as a string or using a function # Note: this will ignore the message provided to the run function - user_prompt: Optional[List[Dict] | str] = None + user_prompt: Optional[Union[List[Dict], str]] = None # Function to build the user prompt. # This function is provided the task and the input message as arguments - # and should return the user_prompt as a List[Dict] | str. + # and should return the user_prompt as a Union[List[Dict], str]. # If add_references_to_prompt is True, then references are also provided as an argument. # Signature: # def custom_user_prompt_function( # task: Task, - # message: Optional[List[Dict] | str] = None, + # message: Optional[Union[List[Dict], str]] = None, # references: Optional[str] = None, # ) -> str: # ... @@ -201,8 +201,8 @@ def get_references_from_knowledge_base(self, query: str, num_documents: Optional return json.dumps([doc.to_dict() for doc in relevant_docs]) def get_user_prompt( - self, message: Optional[List[Dict] | str] = None, references: Optional[str] = None - ) -> List[Dict] | str: + self, message: Optional[Union[List[Dict], str]] = None, references: Optional[str] = None + ) -> Union[List[Dict], str]: """Build the user prompt given a message and references""" # If the user_prompt is set, return it @@ -259,7 +259,7 @@ def get_user_prompt( _user_prompt = cast(str, remove_indent(_user_prompt)) return _user_prompt - def get_text_from_message(self, message: List[Dict] | str) -> str: + def get_text_from_message(self, message: Union[List[Dict], str]) -> str: """Return the user texts from the message""" if isinstance(message, str): return message @@ -280,7 +280,7 @@ def get_text_from_message(self, message: List[Dict] | str) -> str: return "\n".join(text_messages) return "" - def _run(self, message: Optional[List[Dict] | str] = None, stream: bool = True) -> Iterator[str]: + def _run(self, message: Optional[Union[List[Dict], str]] = None, stream: bool = True) -> Iterator[str]: # -*- Set default LLM if self.llm is None: self.llm = OpenAIChat() @@ -309,7 +309,7 @@ def _run(self, message: Optional[List[Dict] | str] = None, stream: bool = True) logger.debug(f"Time to get references: {reference_timer.elapsed:.4f}s") # -*- Build the user prompt - user_prompt: List[Dict] | str = self.get_user_prompt(message=message, references=user_prompt_references) + user_prompt: Union[List[Dict], str] = self.get_user_prompt(message=message, references=user_prompt_references) # -*- Build the messages to send to the LLM # Create system message @@ -358,7 +358,7 @@ def _run(self, message: Optional[List[Dict] | str] = None, stream: bool = True) if not stream: yield llm_response - def run(self, message: Optional[List[Dict] | str] = None, stream: bool = True) -> Union[Iterator[str], str]: + def run(self, message: Optional[Union[List[Dict], str]] = None, stream: bool = True) -> Union[Iterator[str], str]: resp = self._run(message=message, stream=stream) if stream: return resp @@ -422,7 +422,7 @@ def search_knowledge_base(self, query: str) -> Optional[str]: # Print Response ########################################################################### - def print_response(self, message: List[Dict] | str, stream: bool = True) -> None: + def print_response(self, message: Union[List[Dict], str], stream: bool = True) -> None: from phi.cli.console import console from rich.live import Live from rich.table import Table