From 0c5c7b051ba3048eb0793aea9f3ba74a51f2c350 Mon Sep 17 00:00:00 2001 From: Ashpreet Bedi Date: Thu, 28 Sep 2023 18:15:28 +0100 Subject: [PATCH] v2.0.8 --- phi/ai/operator.py | 1 - phi/ai/phi_ai.py | 224 +++++++++++++++++++++--- phi/api/ai.py | 96 ++++++++++- phi/api/routes.py | 1 + phi/conversation/conversation.py | 47 ++--- phi/llm/base.py | 12 +- phi/llm/function/shell.py | 19 ++- phi/llm/openai.py | 284 ++++++++++++++++++++++++++----- phi/llm/schemas.py | 3 +- phi/utils/functions.py | 84 +++++++++ pyproject.toml | 3 +- requirements.txt | 21 ++- 12 files changed, 672 insertions(+), 123 deletions(-) create mode 100644 phi/utils/functions.py diff --git a/phi/ai/operator.py b/phi/ai/operator.py index 5d232cc28..3859f372a 100644 --- a/phi/ai/operator.py +++ b/phi/ai/operator.py @@ -10,7 +10,6 @@ def phi_ai_conversation( stream: bool = False, ) -> None: """Start a conversation with Phi AI.""" - from phi.ai.phi_ai import PhiAI conversation_type = ConversationType.AUTO if autonomous_conversation else ConversationType.RAG diff --git a/phi/ai/phi_ai.py b/phi/ai/phi_ai.py index 15f3104ac..393229de8 100644 --- a/phi/ai/phi_ai.py +++ b/phi/ai/phi_ai.py @@ -1,14 +1,24 @@ +import json from typing import Optional, Dict, List, Any, Iterator +from rich import box +from rich.prompt import Prompt +from rich.live import Live +from rich.table import Table +from rich.markdown import Markdown + +from phi.api.ai import conversation_chat from phi.api.schemas.user import UserSchema from phi.api.schemas.ai import ConversationType from phi.cli.config import PhiCliConfig from phi.cli.console import console from phi.cli.settings import phi_cli_settings -from phi.llm.schemas import Function +from phi.llm.schemas import Function, Message, FunctionCall from phi.llm.function.shell import ShellScriptsRegistry from phi.workspace.config import WorkspaceConfig from phi.utils.log import logger +from phi.utils.functions import get_function_call +from phi.utils.timer import Timer from phi.utils.json_io import write_json_file, read_json_file @@ -30,9 +40,7 @@ def __init__( _active_workspace = _phi_config.get_active_ws_config() self.conversation_db: Optional[List[Dict[str, Any]]] = None - self.functions: Dict[str, Function] = { - "run_shell_command": Function.from_callable(ShellScriptsRegistry.run_shell_command) - } + self.functions: Dict[str, Function] = ShellScriptsRegistry().functions logger.debug(f"Functions: {self.functions.keys()}") _conversation_id = None @@ -79,30 +87,34 @@ def __init__( self.save_conversation() logger.debug(f"--**-- Conversation: {self.conversation_id} --**--") - async def start_conversation(self, stream: bool = False): - from rich import box - from rich.prompt import Prompt - from rich.live import Live - from rich.table import Table - from rich.markdown import Markdown - from phi.api.ai import conversation_chat - + def start_conversation(self, stream: bool = False): conversation_active = True while conversation_active: username = self.user.username or "You" console.rule() - user_message = Prompt.ask(f"[bold] :sunglasses: {username} [/bold]", console=console) - self.conversation_history.append({"role": "user", "content": user_message}) + user_message_str_valid = False + while not user_message_str_valid: + user_message_str = Prompt.ask(f"[bold] :sunglasses: {username} [/bold]", console=console) + if ( + user_message_str is None + or user_message_str == "" + or user_message_str == "{}" + or len(user_message_str) < 2 + ): + console.print("Please enter a valid message") + continue + user_message_str_valid = True + self.conversation_history.append({"role": "user", "content": user_message_str}) # -*- Quit conversation - if user_message in ("exit", "quit", "bye"): + if user_message_str in ("exit", "quit", "bye"): conversation_active = False # -*- Send message to Phi AI api_response: Optional[Iterator[str]] = conversation_chat( user=self.user, conversation_id=self.conversation_id, - message=user_message, + message=Message(role="user", content=user_message_str), conversation_type=self.conversation_type, functions=self.functions, stream=stream, @@ -112,22 +124,126 @@ async def start_conversation(self, stream: bool = False): conversation_active = False else: with Live(console=console) as live: + response_content = "" if stream: - chat_response = "" for _response in api_response: - chat_response += _response - table = Table(show_header=False, box=box.ROUNDED) - table.add_row(Markdown(chat_response)) - live.update(table) - self.conversation_history.append({"role": "assistant", "content": chat_response}) + if _response is None or _response == "" or _response == "{}": + continue + response_dict = json.loads(_response) + if "content" in response_dict and response_dict.get("content") is not None: + response_content += response_dict.get("content") + table = Table(show_header=False, box=box.ROUNDED) + table.add_row(Markdown(response_content)) + live.update(table) + elif "function_call" in response_dict: + for function_response in self.run_function_stream(response_dict.get("function_call")): + response_content += function_response + table = Table(show_header=False, box=box.ROUNDED) + table.add_row(Markdown(response_content)) + live.update(table) else: - chat_response = next(api_response) + _response = next(api_response) + if _response is None or _response == "" or _response == "{}": + response_content = "Something went wrong, please try again." + else: + response_dict = json.loads(_response) + if "content" in response_dict and response_dict.get("content") is not None: + response_content = response_dict.get("content") + elif "function_call" in response_dict: + response_content = self.run_function(response_dict.get("function_call")) table = Table(show_header=False, box=box.ROUNDED) - table.add_row(Markdown(chat_response)) + table.add_row(Markdown(response_content)) console.print(table) - self.conversation_history.append({"role": "assistant", "content": chat_response}) + self.conversation_history.append({"role": "assistant", "content": response_content}) self.save_conversation() + def run_function_stream(self, function_call: Dict[str, Any]) -> Iterator[str]: + _function_name = function_call.get("name") + _function_arguments_str = function_call.get("arguments") + if _function_name is not None: + function_call_obj: Optional[FunctionCall] = get_function_call( + name=_function_name, arguments=_function_arguments_str, functions=self.functions + ) + if function_call_obj is None: + return "Something went wrong, please try again." + + # -*- Run function call + yield f"Running: {function_call_obj.get_call_str()}\n\n" + function_call_timer = Timer() + function_call_timer.start() + function_call_obj.run() + function_call_timer.stop() + function_call_message = Message( + role="function", + name=function_call_obj.function.name, + content=function_call_obj.result, + metrics={"time": function_call_timer.elapsed}, + ) + # -*- Send message to Phi AI + api_response: Optional[Iterator[str]] = conversation_chat( + user=self.user, + conversation_id=self.conversation_id, + message=function_call_message, + conversation_type=self.conversation_type, + functions=self.functions, + stream=True, + ) + if api_response is not None: + for _response in api_response: + if _response is None or _response == "" or _response == "{}": + continue + response_dict = json.loads(_response) + if "content" in response_dict and response_dict.get("content") is not None: + yield response_dict.get("content") + elif "function_call" in response_dict: + yield from self.run_function_stream(response_dict.get("function_call")) + else: + yield "Could not run function, please try again." + + def run_function(self, function_call: Dict[str, Any]) -> str: + _function_name = function_call.get("name") + _function_arguments_str = function_call.get("arguments") + if _function_name is not None: + function_call_obj: Optional[FunctionCall] = get_function_call( + name=_function_name, arguments=_function_arguments_str, functions=self.functions + ) + if function_call_obj is None: + return "Something went wrong, please try again." + + # -*- Run function call + function_run_response = f"Running: {function_call_obj.get_call_str()}\n\n" + function_call_timer = Timer() + function_call_timer.start() + function_call_obj.run() + function_call_timer.stop() + function_call_message = Message( + role="function", + name=function_call_obj.function.name, + content=function_call_obj.result, + metrics={"time": function_call_timer.elapsed}, + ) + # -*- Send message to Phi AI + api_response: Optional[Iterator[str]] = conversation_chat( + user=self.user, + conversation_id=self.conversation_id, + message=function_call_message, + conversation_type=self.conversation_type, + functions=self.functions, + stream=False, + ) + if api_response is not None: + _response = next(api_response) + if _response is None or _response == "" or _response == "{}": + function_run_response += "Something went wrong, please try again." + else: + response_dict = json.loads(_response) + if "content" in response_dict and response_dict.get("content") is not None: + function_run_response += response_dict.get("content") + elif "function_call" in response_dict: + function_run_response += self.run_function(response_dict.get("function_call")) + return function_run_response + return "Something went wrong, please try again." + def print_conversation_history(self): from rich import box from rich.table import Table @@ -191,3 +307,61 @@ def get_latest_conversation(self) -> Optional[Dict[str, Any]]: if len(conversations) == 0: return None return conversations[0] + + # async def conversation(self, stream: bool = False): + # from rich import box + # from rich.prompt import Prompt + # from rich.live import Live + # from rich.table import Table + # from rich.markdown import Markdown + # from phi.api.ai import ai_ws_connect + # + # logger.info("Starting conversation with Phi AI") + # + # conversation_active = True + # username = self.user.username or "You" + # async with ai_ws_connect( + # user=self.user, + # conversation_id=self.conversation_id, + # conversation_type=self.conversation_type, + # stream=stream, + # ) as ai_ws: + # while conversation_active: + # console.rule() + # user_message = Prompt.ask(f"[bold] :sunglasses: {username} [/bold]", console=console) + # self.conversation_history.append({"role": "user", "content": user_message}) + # + # # -*- Quit conversation + # if user_message in ("exit", "quit", "bye"): + # conversation_active = False + # + # # -*- Send message to Phi AI + # await ai_ws.send(user_message) + # with Live(console=console) as live: + # if stream: + # chat_response = "" + # ai_response_chunk = await ai_ws.recv() + # while ai_response_chunk is not None and ai_response_chunk != "AI_RESPONSE_STOP_STREAM": + # chat_response += ai_response_chunk + # table = Table(show_header=False, box=box.ROUNDED) + # table.add_row(Markdown(chat_response)) + # live.update(table) + # ai_response_chunk = await ai_ws.recv() + # if ai_response_chunk is None or ai_response_chunk == "AI_RESPONSE_STOP_STREAM": + # break + # if ai_response_chunk.startswith("{"): + # await ai_ws.send("function_result:one") + # self.conversation_history.append({"role": "assistant", "content": chat_response}) + # else: + # ai_response = await ai_ws.recv() + # logger.info(f"ai_response: {type(ai_response)} | {ai_response}") + # + # if ai_response is None: + # logger.error("Could not reach Phi AI") + # conversation_active = False + # chat_response = ai_response + # table = Table(show_header=False, box=box.ROUNDED) + # table.add_row(Markdown(chat_response)) + # console.print(table) + # self.conversation_history.append({"role": "assistant", "content": chat_response}) + # self.save_conversation() diff --git a/phi/api/ai.py b/phi/api/ai.py index 6e2809032..b1b69c0a7 100644 --- a/phi/api/ai.py +++ b/phi/api/ai.py @@ -10,7 +10,7 @@ ConversationCreateResponse, ) from phi.api.schemas.user import UserSchema -from phi.llm.schemas import Function +from phi.llm.schemas import Function, Message from phi.utils.log import logger @@ -55,7 +55,7 @@ def conversation_create( def conversation_chat( user: UserSchema, conversation_id: int, - message: str, + message: Message, conversation_type: ConversationType = ConversationType.RAG, functions: Optional[Dict[str, Function]] = None, stream: bool = True, @@ -71,7 +71,7 @@ def conversation_chat( "user": user.model_dump(include={"id_user", "email"}), "conversation": { "id": conversation_id, - "message": message, + "message": message.model_dump(exclude_none=True), "type": conversation_type, "client": ConversationClient.CLI, "functions": { @@ -97,9 +97,15 @@ def conversation_chat( "user": user.model_dump(include={"id_user", "email"}), "conversation": { "id": conversation_id, - "message": message, - "type": ConversationType.RAG, + "message": message.model_dump(exclude_none=True), + "type": conversation_type, "client": ConversationClient.CLI, + "functions": { + k: v.model_dump(include={"name", "description", "parameters"}) + for k, v in functions.items() + } + if functions is not None + else None, "stream": stream, }, }, @@ -115,3 +121,83 @@ def conversation_chat( except Exception as e: logger.debug(f"Failed conversation chat: {e}") return None + + +# import os +# import base64 +# import contextlib +# import wsproto +# +# class ConnectionClosed(Exception): +# logger.debug("Connection closed") +# pass +# +# +# class WebsocketConnection: +# def __init__(self, network_steam): +# self._ws_connection = wsproto.Connection(wsproto.ConnectionType.CLIENT) +# self._network_stream = network_steam +# self._events = [] +# +# async def send(self, text): +# """ +# Send a text frame over the websocket connection. +# """ +# try: +# event = wsproto.events.TextMessage(text) +# data = self._ws_connection.send(event) +# await self._network_stream.write(data) +# except Exception as e: +# logger.debug(f"Failed to send: {e}") +# logger.info("Connection closed, please start chat again.") +# exit(1) +# +# async def recv(self): +# """ +# Receive the next text frame from the websocket connection. +# """ +# try: +# while not self._events: +# data = await self._network_stream.read(max_bytes=4096) +# self._ws_connection.receive_data(data) +# self._events = list(self._ws_connection.events()) +# +# event = self._events.pop(0) +# if isinstance(event, wsproto.events.TextMessage): +# return event.data +# elif isinstance(event, wsproto.events.CloseConnection): +# raise ConnectionClosed() +# except Exception as e: +# logger.debug(f"Failed to receive: {e}") +# logger.info("Connection closed, please start chat again.") +# exit(1) +# +# +# @contextlib.asynccontextmanager +# async def ai_ws_connect( +# user: UserSchema, +# conversation_id: int, +# conversation_type: ConversationType = ConversationType.RAG, +# stream: bool = True, +# ): +# async with api.AuthenticatedAsyncClient() as api_client: +# headers = { +# "connection": "upgrade", +# "upgrade": "websocket", +# "sec-websocket-key": base64.b64encode(os.urandom(16)), +# "sec-websocket-version": "13", +# "X-PHIDATA-USER-ID": f"{user.id_user}", +# "X-PHIDATA-CONVERSATION-ID": f"{conversation_id}", +# "X-PHIDATA-CONVERSATION-TYPE": conversation_type.value, +# "X-PHIDATA-CONVERSATION-STREAM": "true" if stream else "false", +# } +# headers.update(api.authenticated_headers) +# logger.debug(f"Connecting to {ApiRoutes.AI_CONVERSATION_CHAT_WS}. Headers: {headers}") +# async with api_client.stream( +# "GET", +# ApiRoutes.AI_CONVERSATION_CHAT_WS, +# headers=headers, # type: ignore +# ) as response: +# network_steam = response.extensions["network_stream"] +# yield WebsocketConnection(network_steam) +# diff --git a/phi/api/routes.py b/phi/api/routes.py index 5028b5568..6f1b36f8f 100644 --- a/phi/api/routes.py +++ b/phi/api/routes.py @@ -33,3 +33,4 @@ class ApiRoutes: # ai paths AI_CONVERSATION_CREATE: str = "/v1/ai/conversation/create" AI_CONVERSATION_CHAT: str = "/v1/ai/conversation/chat" + AI_CONVERSATION_CHAT_WS: str = "/v1/ai/conversation/chat_ws" diff --git a/phi/conversation/conversation.py b/phi/conversation/conversation.py index bf2478575..1fe1ba101 100644 --- a/phi/conversation/conversation.py +++ b/phi/conversation/conversation.py @@ -1,3 +1,4 @@ +import json from datetime import datetime from typing import List, Any, Optional, Dict, Iterator, Callable, cast, Union @@ -310,8 +311,6 @@ def get_references(self, query: str) -> Optional[str]: if self.knowledge_base is None: return None - import json - relevant_docs: List[Document] = self.knowledge_base.search(query=query) return json.dumps([doc.to_dict() for doc in relevant_docs]) @@ -425,14 +424,14 @@ def _chat(self, message: str, stream: bool = True) -> Iterator[str]: messages += self.history.get_last_n_messages(last_n=self.num_history_messages) messages += [user_prompt_message] - # -*- Generate response + # -*- Generate response: including running function calls llm_response = "" if stream: - for response_chunk in self.llm.response_stream(messages=messages): + for response_chunk in self.llm.parsed_response_stream(messages=messages): llm_response += response_chunk yield response_chunk else: - llm_response = self.llm.response(messages=messages) + llm_response = self.llm.parsed_response(messages=messages) # -*- Add messages to the history # Add the system prompt to the history - added only if this is the first message to the LLM @@ -476,10 +475,10 @@ def chat(self, message: str, stream: bool = True) -> Union[Iterator[str], str]: else: return next(resp) - def _prompt( + def _chat_raw( self, messages: List[Message], user_message: Optional[str] = None, stream: bool = True - ) -> Iterator[str]: - logger.debug("*********** Conversation Prompt Start ***********") + ) -> Iterator[Dict]: + logger.debug("*********** Conversation Raw Chat Start ***********") # Load the conversation from the database if available self.read_from_storage() @@ -492,16 +491,20 @@ def _prompt( self.history.add_user_prompt(message=message) # -*- Generate response - llm_response = "" + batch_llm_response = {} if stream: - for response_chunk in self.llm.response_stream(messages=messages): - llm_response += response_chunk - yield response_chunk + for response_delta in self.llm.response_delta(messages=messages): + yield response_delta else: - llm_response = self.llm.response(messages=messages) + batch_llm_response = self.llm.response_message(messages=messages) # -*- Add response to the history - this is added to the chat and llm history - self.history.add_llm_response(Message(role="assistant", content=llm_response)) + # Last message is the llm response + llm_response_message = messages[-1] + try: + self.history.add_llm_response(llm_response_message) + except Exception as e: + logger.warning(f"Failed to add llm response to history: {e}") # -*- Save conversation to storage self.write_to_storage() @@ -509,7 +512,7 @@ def _prompt( # -*- Send conversation event event_data = { "user_message": user_message, - "llm_response": llm_response, + "llm_response": llm_response_message, "messages": [m.model_dump(exclude_none=True) for m in messages], "metrics": self.llm.metrics, } @@ -517,13 +520,13 @@ def _prompt( # -*- Return final response if not streaming if not stream: - yield llm_response - logger.debug("*********** Conversation Prompt End ***********") + yield batch_llm_response + logger.debug("*********** Conversation Raw Chat End ***********") - def prompt( + def chat_raw( self, messages: List[Message], user_message: Optional[str] = None, stream: bool = True - ) -> Union[Iterator[str], str]: - resp = self._prompt(messages=messages, user_message=user_message, stream=stream) + ) -> Union[Iterator[Dict], Dict]: + resp = self._chat_raw(messages=messages, user_message=user_message, stream=stream) if stream: return resp else: @@ -557,7 +560,7 @@ def generate_name(self) -> str: ) user_message = Message(role="user", content=_conv) generate_name_message = [system_message, user_message] - generated_name = self.llm.response(messages=generate_name_message) + generated_name = self.llm.parsed_response(messages=generate_name_message) if len(generated_name.split()) > 15: logger.error("Generated name is too long. Trying again.") return self.generate_name() @@ -635,8 +638,6 @@ def get_last_n_chats(self, num_chats: Optional[int] = None) -> str: Each chat contains 2 messages. One from the user and one from the assistant. :return: A list of dictionaries representing the chat history. """ - import json - history: List[Dict[str, Any]] = [] all_chats = self.history.get_chats() if len(all_chats) == 0: diff --git a/phi/llm/base.py b/phi/llm/base.py index 12b75e27b..c487f9ef1 100644 --- a/phi/llm/base.py +++ b/phi/llm/base.py @@ -18,13 +18,21 @@ class LLM(BaseModel): function_call_limit: int = 50 function_call_stack: Optional[List[FunctionCall]] = None show_function_calls: Optional[bool] = None + # If True, runs function calls before sending back the response content. + run_function_calls: bool = True model_config = ConfigDict(arbitrary_types_allowed=True) - def response(self, messages: List[Message]) -> str: + def parsed_response(self, messages: List[Message]) -> str: raise NotImplementedError - def response_stream(self, messages: List[Message]) -> Iterator[str]: + def response_message(self, messages: List[Message]) -> Dict: + raise NotImplementedError + + def parsed_response_stream(self, messages: List[Message]) -> Iterator[str]: + raise NotImplementedError + + def response_delta(self, messages: List[Message]) -> Iterator[Dict]: raise NotImplementedError def to_dict(self) -> Dict[str, Any]: diff --git a/phi/llm/function/shell.py b/phi/llm/function/shell.py index 2dc8f6ea8..f0a7be4b4 100644 --- a/phi/llm/function/shell.py +++ b/phi/llm/function/shell.py @@ -19,12 +19,13 @@ def run_shell_command(self, args: List[str]) -> str: import subprocess - result = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - - logger.debug("Return code:", result.returncode) - logger.debug("Have {} bytes in stdout:\n{}".format(len(result.stdout), result.stdout.decode())) - logger.debug("Have {} bytes in stderr:\n{}".format(len(result.stderr), result.stderr.decode())) - - if result.returncode != 0: - return f"error: {result.stderr.decode()}" - return result.stdout.decode() + try: + result = subprocess.run(args, capture_output=True, text=True) + logger.debug(f"Result: {result}") + logger.debug(f"Return code: {result.returncode}") + if result.returncode != 0: + return f"Error: {result.stderr}" + return result.stdout + except Exception as e: + logger.warning(f"Failed to run shell command: {e}") + return f"Error: {e}" diff --git a/phi/llm/openai.py b/phi/llm/openai.py index b8004de41..b5ecacdba 100644 --- a/phi/llm/openai.py +++ b/phi/llm/openai.py @@ -31,7 +31,7 @@ def api_kwargs(self) -> dict: kwargs["function_call"] = self.function_call return kwargs - def response(self, messages: List[Message]) -> str: + def parsed_response(self, messages: List[Message]) -> str: logger.debug("---------- OpenAI Response Start ----------") # -*- Log messages for debugging for m in messages: @@ -103,7 +103,7 @@ def response(self, messages: List[Message]) -> str: return assistant_message.content # -*- Parse and run function call - if assistant_message.function_call is not None: + if assistant_message.function_call is not None and self.run_function_calls: _function_name = assistant_message.function_call.get("name") _function_arguments_str = assistant_message.function_call.get("arguments") if _function_name is not None: @@ -142,24 +142,95 @@ def response(self, messages: List[Message]) -> str: final_response = "" if self.show_function_calls: final_response += f"Running: {function_call.get_call_str()}\n\n" - final_response += self.response(messages=messages) + final_response += self.parsed_response(messages=messages) return final_response logger.debug("---------- OpenAI Response End ----------") return "Something went wrong, please try again." - def response_stream(self, messages: List[Message]) -> Iterator[str]: + def response_message(self, messages: List[Message]) -> Dict: logger.debug("---------- OpenAI Response Start ----------") # -*- Log messages for debugging for m in messages: m.log() + response_timer = Timer() + response_timer.start() + response: OpenAIObject = ChatCompletion.create( + model=self.model, + messages=[m.to_dict() for m in messages], + **self.api_kwargs(), + ) + response_timer.stop() + logger.debug(f"Time to generate response: {response_timer.elapsed:.4f}s") + # logger.debug(f"OpenAI response type: {type(response)}") + # logger.debug(f"OpenAI response: {response}") + + # -*- Parse response + response_message = response.choices[0].message + response_role = response_message.get("role") + response_content = response_message.get("content") + response_function_call = response_message.get("function_call") + # -*- Create assistant message - assistant_message = Message(role="assistant", content="") + assistant_message = Message( + role=response_role or "assistant", + content=response_content, + ) + if response_function_call is not None and isinstance(response_function_call, OpenAIObject): + assistant_message.function_call = response_function_call.to_dict() + + # -*- Update usage metrics + # Add response time to metrics + assistant_message.metrics["time"] = response_timer.elapsed + if "response_times" not in self.metrics: + self.metrics["response_times"] = [] + self.metrics["response_times"].append(response_timer.elapsed) + + # Add token usage to metrics + response_usage = response.usage + prompt_tokens = response_usage.get("prompt_tokens") + if prompt_tokens is not None: + assistant_message.metrics["prompt_tokens"] = prompt_tokens + if "prompt_tokens" not in self.metrics: + self.metrics["prompt_tokens"] = prompt_tokens + else: + self.metrics["prompt_tokens"] += prompt_tokens + completion_tokens = response_usage.get("completion_tokens") + if completion_tokens is not None: + assistant_message.metrics["completion_tokens"] = completion_tokens + if "completion_tokens" not in self.metrics: + self.metrics["completion_tokens"] = completion_tokens + else: + self.metrics["completion_tokens"] += completion_tokens + total_tokens = response_usage.get("total_tokens") + if total_tokens is not None: + assistant_message.metrics["total_tokens"] = total_tokens + if "total_tokens" not in self.metrics: + self.metrics["total_tokens"] = total_tokens + else: + self.metrics["total_tokens"] += total_tokens + # -*- Add assistant message to messages messages.append(assistant_message) + assistant_message.log() + + # -*- Return response + response_message_dict = response_message.to_dict_recursive() + logger.debug("---------- OpenAI Response End ----------") + return response_message_dict + + def parsed_response_stream(self, messages: List[Message]) -> Iterator[str]: + logger.debug("---------- OpenAI Response Start ----------") + # -*- Log messages for debugging + for m in messages: + m.log() - _function_name = "" - _function_arguments_str = "" + # -*- Create assistant message + assistant_message = Message(role="assistant", content="") + + assistant_message_content = "" + assistant_message_function_name = "" + assistant_message_function_arguments_str = "" completion_tokens = 0 response_timer = Timer() response_timer.start() @@ -180,20 +251,31 @@ def response_stream(self, messages: List[Message]) -> Iterator[str]: # -*- Return content if present, otherwise get function call if response_content is not None: - assistant_message.content += response_content + assistant_message_content += response_content yield response_content # -*- Parse function call if response_function_call is not None: _function_name_stream = response_function_call.get("name") if _function_name_stream is not None: - _function_name += _function_name_stream + assistant_message_function_name += _function_name_stream _function_args_stream = response_function_call.get("arguments") if _function_args_stream is not None: - _function_arguments_str += _function_args_stream + assistant_message_function_arguments_str += _function_args_stream response_timer.stop() logger.debug(f"Time to generate response: {response_timer.elapsed:.4f}s") + # -*- Add content to assistant message + if assistant_message_content != "": + assistant_message.content = assistant_message_content + + # -*- Add function call to assistant message + if assistant_message_function_name != "": + assistant_message.function_call = { + "name": assistant_message_function_name, + "arguments": assistant_message_function_arguments_str, + } + # -*- Update usage metrics # Add response time to metrics assistant_message.metrics["time"] = response_timer.elapsed @@ -222,52 +304,166 @@ def response_stream(self, messages: List[Message]) -> Iterator[str]: else: self.metrics["total_tokens"] += total_tokens + # -*- Add assistant message to messages + messages.append(assistant_message) + assistant_message.log() + # -*- Parse and run function call - if _function_name is not None and _function_name != "": - # Update assistant message to reflect function call - if assistant_message.content == "": - assistant_message.content = None + if assistant_message.function_call is not None and self.run_function_calls: + _function_name = assistant_message.function_call.get("name") + _function_arguments_str = assistant_message.function_call.get("arguments") + if _function_name is not None: + # Get function call + function_call: Optional[FunctionCall] = self.get_function_call( + name=_function_name, arguments=_function_arguments_str + ) + if function_call is None: + return "Something went wrong, please try again." + + if self.function_call_stack is None: + self.function_call_stack = [] + + # -*- Check function call limit + if len(self.function_call_stack) > self.function_call_limit: + return f"Function call limit ({self.function_call_limit}) exceeded." + + # -*- Run function call + self.function_call_stack.append(function_call) + if self.show_function_calls: + yield f"Running: {function_call.get_call_str()}\n\n" + function_call_timer = Timer() + function_call_timer.start() + function_call.run() + function_call_timer.stop() + function_call_message = Message( + role="function", + name=function_call.function.name, + content=function_call.result, + metrics={"time": function_call_timer.elapsed}, + ) + messages.append(function_call_message) + if "function_call_times" not in self.metrics: + self.metrics["function_call_times"] = {} + if function_call.function.name not in self.metrics["function_call_times"]: + self.metrics["function_call_times"][function_call.function.name] = [] + self.metrics["function_call_times"][function_call.function.name].append(function_call_timer.elapsed) + + # -*- Yield new response using result of function call + yield from self.parsed_response_stream(messages=messages) + logger.debug("---------- OpenAI Response End ----------") + + def response_delta(self, messages: List[Message]) -> Iterator[Dict]: + logger.debug("---------- OpenAI Response Start ----------") + # -*- Log messages for debugging + for m in messages: + m.log() + + # -*- Create assistant message + assistant_message = Message(role="assistant") + + assistant_message_content = "" + assistant_message_function_name = "" + assistant_message_function_arguments_str = "" + completion_tokens = 0 + response_timer = Timer() + response_timer.start() + for response in ChatCompletion.create( + model=self.model, + messages=[m.to_dict() for m in messages], + stream=True, + **self.api_kwargs(), + ): + # logger.debug(f"OpenAI response type: {type(response)}") + # logger.debug(f"OpenAI response: {response}") + completion_tokens += 1 + + # -*- Parse response + response_delta = response.choices[0].delta + + # -*- Read content + response_content = response_delta.get("content") + if response_content is not None: + assistant_message_content += response_content + + # -*- Read function call + response_function_call = response_delta.get("function_call") + if response_function_call is not None: + _function_name_stream = response_function_call.get("name") + if _function_name_stream is not None: + assistant_message_function_name += _function_name_stream + _function_args_stream = response_function_call.get("arguments") + if _function_args_stream is not None: + assistant_message_function_arguments_str += _function_args_stream + + yield response_delta.to_dict_recursive() + response_timer.stop() + logger.debug(f"Time to generate response: {response_timer.elapsed:.4f}s") + + # -*- Add content to assistant message + if assistant_message_content != "": + assistant_message.content = assistant_message_content + + # -*- Add function call to assistant message + if assistant_message_function_name != "": assistant_message.function_call = { - "name": _function_name, - "arguments": _function_arguments_str, + "name": assistant_message_function_name, + "arguments": assistant_message_function_arguments_str, } - assistant_message.log() - # Get function call - function_call: Optional[FunctionCall] = self.get_function_call( - name=_function_name, arguments=_function_arguments_str - ) - if function_call is None: - return "Something went wrong, please try again." + # -*- Update usage metrics + # Add response time to metrics + assistant_message.metrics["time"] = response_timer.elapsed + if "response_times" not in self.metrics: + self.metrics["response_times"] = [] + self.metrics["response_times"].append(response_timer.elapsed) + + # Add token usage to metrics + # TODO: compute prompt tokens + prompt_tokens = 0 + assistant_message.metrics["prompt_tokens"] = prompt_tokens + if "prompt_tokens" not in self.metrics: + self.metrics["prompt_tokens"] = prompt_tokens + else: + self.metrics["prompt_tokens"] += prompt_tokens + logger.debug(f"Estimated completion tokens: {completion_tokens}") + assistant_message.metrics["completion_tokens"] = completion_tokens + if "completion_tokens" not in self.metrics: + self.metrics["completion_tokens"] = completion_tokens + else: + self.metrics["completion_tokens"] += completion_tokens - if self.function_call_stack is None: - self.function_call_stack = [] + total_tokens = prompt_tokens + completion_tokens + assistant_message.metrics["total_tokens"] = total_tokens + if "total_tokens" not in self.metrics: + self.metrics["total_tokens"] = total_tokens + else: + self.metrics["total_tokens"] += total_tokens - # -*- Check function call limit - if len(self.function_call_stack) > self.function_call_limit: - return f"Function call limit ({self.function_call_limit}) exceeded." + # -*- Add assistant message to messages + messages.append(assistant_message) + assistant_message.log() + logger.debug("---------- OpenAI Response End ----------") + + def run_function_call(self, function_call: Dict[str, Any]) -> Optional[Message]: + _function_name = function_call.get("name") + _function_arguments_str = function_call.get("arguments") + if _function_name is not None: + function_call_obj: Optional[FunctionCall] = self.get_function_call( + name=_function_name, arguments=_function_arguments_str + ) + if function_call_obj is None: + return None # -*- Run function call - self.function_call_stack.append(function_call) - if self.show_function_calls: - yield f"Running: {function_call.get_call_str()}\n\n" function_call_timer = Timer() function_call_timer.start() - function_call.run() + function_call_obj.run() function_call_timer.stop() function_call_message = Message( role="function", - name=function_call.function.name, - content=function_call.result, + name=function_call_obj.function.name, + content=function_call_obj.result, metrics={"time": function_call_timer.elapsed}, ) - messages.append(function_call_message) - if "function_call_times" not in self.metrics: - self.metrics["function_call_times"] = {} - if function_call.function.name not in self.metrics["function_call_times"]: - self.metrics["function_call_times"][function_call.function.name] = [] - self.metrics["function_call_times"][function_call.function.name].append(function_call_timer.elapsed) - - # -*- Yield new response using result of function call - yield from self.response_stream(messages=messages) - logger.debug("---------- OpenAI Response End ----------") + return function_call_message + return None diff --git a/phi/llm/schemas.py b/phi/llm/schemas.py index c4d765a20..f6d0e336c 100644 --- a/phi/llm/schemas.py +++ b/phi/llm/schemas.py @@ -132,7 +132,8 @@ def run(self) -> bool: self.result = self.function.entrypoint() return True except Exception as e: - logger.warning(f"Could not run function {self.get_call_str()}: {e}") + logger.warning(f"Could not run function {self.get_call_str()}") + logger.error(e) return False # Validate the arguments if provided. diff --git a/phi/utils/functions.py b/phi/utils/functions.py new file mode 100644 index 000000000..0b783c217 --- /dev/null +++ b/phi/utils/functions.py @@ -0,0 +1,84 @@ +import json +from typing import Optional, Dict, Any + +from phi.llm.schemas import Function, FunctionCall +from phi.utils.log import logger + + +def get_function_call( + name: str, arguments: Optional[str] = None, functions: Optional[Dict[str, Function]] = None +) -> Optional[FunctionCall]: + logger.debug(f"Getting function {name}. Args: {arguments}") + if functions is None: + return None + + function_to_call: Optional[Function] = None + if name in functions: + function_to_call = functions[name] + if function_to_call is None: + logger.error(f"Function {name} not found") + return None + + function_call = FunctionCall(function=function_to_call) + if arguments is not None and arguments != "": + try: + if "None" in arguments: + arguments = arguments.replace("None", "null") + if "True" in arguments: + arguments = arguments.replace("True", "true") + if "False" in arguments: + arguments = arguments.replace("False", "false") + _arguments = json.loads(arguments) + except Exception as e: + logger.error(f"Unable to decode function arguments {arguments}: {e}") + return None + + if not isinstance(_arguments, dict): + logger.error(f"Function arguments {arguments} is not a valid JSON object") + return None + + try: + clean_arguments: Dict[str, Any] = {} + for k, v in _arguments.items(): + if isinstance(v, str): + _v = v.strip().lower() + if _v in ("none", "null"): + clean_arguments[k] = None + elif _v == "true": + clean_arguments[k] = True + elif _v == "false": + clean_arguments[k] = False + else: + clean_arguments[k] = v.strip() + else: + clean_arguments[k] = v + + function_call.arguments = clean_arguments + except Exception as e: + logger.error(f"Unable to parse function arguments {arguments}: {e}") + return None + return function_call + + +# def run_function(func, *args, **kwargs): +# if asyncio.iscoroutinefunction(func): +# logger.debug("Running asynchronous function") +# try: +# loop = asyncio.get_running_loop() +# except RuntimeError as e: # No running event loop +# logger.debug(f"Could not get running event loop: {e}") +# logger.debug("Running with a new event loop") +# loop = asyncio.new_event_loop() +# asyncio.set_event_loop(loop) +# result = loop.run_until_complete(func(*args, **kwargs)) +# loop.close() +# logger.debug("Done running with a new event loop") +# return result +# else: # There is a running event loop +# logger.debug("Running in existing event loop") +# result = loop.run_until_complete(func(*args, **kwargs)) +# logger.debug("Done running in existing event loop") +# return result +# else: # The function is a synchronous function +# logger.debug("Running synchronous function") +# return func(*args, **kwargs) diff --git a/pyproject.toml b/pyproject.toml index 1ffdf8c1a..8605328be 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "phidata" -version = "2.0.7" +version = "2.0.8" description = "AI Toolkit for Engineers" requires-python = ">=3.7" readme = "README.md" @@ -21,7 +21,6 @@ dependencies = [ "tomli", "typer", "typing-extensions", - "wsproto", ] [project.optional-dependencies] diff --git a/requirements.txt b/requirements.txt index d55d09fa4..a78f76e75 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,20 +1,19 @@ # -# This file is autogenerated by pip-compile with Python 3.9 +# This file is autogenerated by pip-compile with Python 3.11 # by the following command: # # ./scripts/upgrade.sh # annotated-types==0.5.0 anyio==4.0.0 -boto3==1.28.45 -botocore==1.31.45 +boto3==1.28.56 +botocore==1.31.56 certifi==2023.7.22 charset-normalizer==3.2.0 click==8.1.7 docker==6.1.3 -exceptiongroup==1.1.3 gitdb==4.0.10 -gitpython==3.1.35 +gitpython==3.1.37 h11==0.14.0 httpcore==0.18.0 httpx==0.25.0 @@ -23,21 +22,21 @@ jmespath==1.0.1 markdown-it-py==3.0.0 mdurl==0.1.2 packaging==23.1 -pydantic==2.3.0 -pydantic-core==2.6.3 +pydantic==2.4.2 +pydantic-core==2.10.1 pydantic-settings==2.0.3 pygments==2.16.1 python-dateutil==2.8.2 python-dotenv==1.0.0 pyyaml==6.0.1 requests==2.31.0 -rich==13.5.2 -s3transfer==0.6.2 +rich==13.5.3 +s3transfer==0.7.0 six==1.16.0 -smmap==5.0.0 +smmap==5.0.1 sniffio==1.3.0 tomli==2.0.1 typer==0.9.0 -typing-extensions==4.7.1 +typing-extensions==4.8.0 urllib3==1.26.16 websocket-client==1.6.3