From ad1e6fffef7055a90645a5722d508b3aaf78adf1 Mon Sep 17 00:00:00 2001 From: Jordan Wu <101218661+jordan-definitive@users.noreply.github.com> Date: Wed, 24 Jan 2024 15:00:44 -0800 Subject: [PATCH] expose assistant function selection * improve json detection for weaker models * fix weird escapes for mixtral --- .../openassistants-fastapi/pyproject.toml | 2 +- .../openassistants/core/assistant.py | 28 +++++++++++++++-- .../llm_function_calling/selection.py | 3 +- .../llm_function_calling/utils.py | 30 ++++++++++++++----- packages/openassistants/pyproject.toml | 2 +- 5 files changed, 52 insertions(+), 13 deletions(-) diff --git a/packages/openassistants-fastapi/pyproject.toml b/packages/openassistants-fastapi/pyproject.toml index 965d828..eed4286 100644 --- a/packages/openassistants-fastapi/pyproject.toml +++ b/packages/openassistants-fastapi/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "definitive-openassistants-fastapi" -version = "0.0.12" +version = "0.0.13" description = "" authors = ["Rick Lamers "] readme = "README.md" diff --git a/packages/openassistants/openassistants/core/assistant.py b/packages/openassistants/openassistants/core/assistant.py index faa2b42..7c5932b 100644 --- a/packages/openassistants/openassistants/core/assistant.py +++ b/packages/openassistants/openassistants/core/assistant.py @@ -26,7 +26,10 @@ generate_argument_decisions, generate_arguments, ) -from openassistants.llm_function_calling.selection import select_function +from openassistants.llm_function_calling.selection import ( + SelectFunctionResult, + select_function, +) from openassistants.utils.async_utils import AsyncStreamVersion from openassistants.utils.langchain_util import LangChainCachedEmbeddings from openassistants.utils.vision import image_url_to_text @@ -194,6 +197,25 @@ async def do_infilling( return complete, arguments + async def run_function_selection( + self, + chat_history: List[OpasMessage], + ) -> SelectFunctionResult: + all_functions = await self.get_all_functions() + + last_message = chat_history[-1] + + assert isinstance(last_message, OpasUserMessage) + assert isinstance(last_message.content, str) + + select_function_result = await select_function( + self.function_identification, + all_functions, + last_message.content, + ) + + return select_function_result + async def handle_user_plaintext( self, message: OpasUserMessage, @@ -214,8 +236,8 @@ async def handle_user_plaintext( selected_function = filtered[0] if selected_function is None: - function_selection = await select_function( - self.function_identification, all_functions, message.content + function_selection = await self.run_function_selection( + chat_history=chat_history, ) if function_selection.function: diff --git a/packages/openassistants/openassistants/llm_function_calling/selection.py b/packages/openassistants/openassistants/llm_function_calling/selection.py index 4349adb..b2180cc 100644 --- a/packages/openassistants/openassistants/llm_function_calling/selection.py +++ b/packages/openassistants/openassistants/llm_function_calling/selection.py @@ -74,7 +74,8 @@ async def select_function( # Ensure the selected function names are in the loaded signatures selected_functions = [f for f in functions if f.get_id() in function_names] - if not selected_functions: + + if len(selected_functions) == 0 and len(fallbacks) == 0: return SelectFunctionResult() # Include the signatures of all the selected functions in the final evaluation diff --git a/packages/openassistants/openassistants/llm_function_calling/utils.py b/packages/openassistants/openassistants/llm_function_calling/utils.py index f37d549..1660c61 100644 --- a/packages/openassistants/openassistants/llm_function_calling/utils.py +++ b/packages/openassistants/openassistants/llm_function_calling/utils.py @@ -77,17 +77,29 @@ def chunk_list_by_max_size(lst, max_size): return [chunk.tolist() for chunk in np.array_split(lst, n)] -def find_json_substring(s): +def find_indexes_of_char(string: str, char: str): + return [index for index, c in enumerate(string) if c == char] + + +def find_json_substring(s: str): start = s.find("{") - end = s.rfind("}") + 1 - if start != -1 and end != -1: - return s[start:end] - return None + + all_closing_brackets = find_indexes_of_char(s, "}") + + for i in all_closing_brackets: + if i > start: + try: + json.loads(s[start : i + 1]) + return s[start : i + 1] + except json.JSONDecodeError: + pass + + raise ValueError(f"Could not find JSON substring in response content: {s}") async def generate_to_json( chat: BaseChatModel, - messages, + messages: list[BaseMessage], output_json_schema: Optional[dict], task_name: str, tags: Optional[list[str]] = None, @@ -104,7 +116,7 @@ async def generate_to_json( async def generate_to_json_generic( - chat, + chat: BaseChatModel, messages: list[BaseMessage], output_json_schema: Optional[dict], tags: list[str], @@ -122,8 +134,12 @@ async def generate_to_json_generic( messages = [system_message] + messages response = await chat.ainvoke(ensure_alternating(messages), {"tags": tags}) + assert isinstance(response.content, str) content = response.content + # replace weird \_ escaping of underscores when using mixtral + content = content.replace(r"\_", "_") + json_substring = find_json_substring(content) if json_substring is not None: try: diff --git a/packages/openassistants/pyproject.toml b/packages/openassistants/pyproject.toml index a0fd660..e7ccab6 100644 --- a/packages/openassistants/pyproject.toml +++ b/packages/openassistants/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "definitive-openassistants" -version = "0.0.12" +version = "0.0.13" description = "" authors = ["Rick Lamers "] readme = "README.md"