From 9ba40af49be9008f38fd809a3a110d1cafaad35b Mon Sep 17 00:00:00 2001 From: Jordan Wu <101218661+jordan-definitive@users.noreply.github.com> Date: Thu, 11 Jan 2024 15:37:48 -0800 Subject: [PATCH] fallback function support --- .../openassistants-fastapi/pyproject.toml | 2 +- .../openassistants/functions/base.py | 18 ++++++-- .../llm_function_calling/selection.py | 44 +++++++++++++++---- packages/openassistants/pyproject.toml | 2 +- 4 files changed, 52 insertions(+), 14 deletions(-) diff --git a/packages/openassistants-fastapi/pyproject.toml b/packages/openassistants-fastapi/pyproject.toml index 8dc366b..9db0ebd 100644 --- a/packages/openassistants-fastapi/pyproject.toml +++ b/packages/openassistants-fastapi/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "definitive-openassistants-fastapi" -version = "0.0.10" +version = "0.0.11" description = "" authors = ["Rick Lamers "] readme = "README.md" diff --git a/packages/openassistants/openassistants/functions/base.py b/packages/openassistants/openassistants/functions/base.py index 8d53bc8..ff058a5 100644 --- a/packages/openassistants/openassistants/functions/base.py +++ b/packages/openassistants/openassistants/functions/base.py @@ -71,14 +71,13 @@ def get_signature(self) -> str: # convert JSON Schema types to Python types signature params_repr = PyRepr.repr_json_schema(self.get_parameters_json_schema()) - sample_question_text = "\n".join(f"* {q}" for q in self.get_sample_questions()) - documentation = f"""\ {self.get_description()} -Example Questions: -{sample_question_text} """ + if len(self.get_sample_questions()) > 0: + documentation += "\n".join(f"* {q}" for q in self.get_sample_questions()) + # Construct the function signature signature = f"""\ def {self.get_id()}({params_repr}) -> pd.DataFrame: @@ -92,6 +91,13 @@ def {self.get_id()}({params_repr}) -> pd.DataFrame: async def get_entity_configs(self) -> Mapping[str, IEntityConfig]: pass + @abc.abstractmethod + def get_is_fallback(self) -> bool: + """ + Whether this function is a fallback function. + Fallback functions are functions that are always available to be chosen if no other non-fallback functions are relevant. + """ # noqa: E501 + @abc.abstractmethod async def execute( self, @@ -133,6 +139,7 @@ class BaseFunction(IFunction, BaseModel, abc.ABC): sample_questions: List[str] = [] confirm: bool = False parameters: BaseFunctionParameters = BaseFunctionParameters() + is_fallback: bool = False def get_id(self) -> str: return self.id @@ -158,6 +165,9 @@ def get_parameters_json_schema(self) -> JSONSchema: async def get_entity_configs(self) -> Mapping[str, IEntityConfig]: return {} + def get_is_fallback(self) -> bool: + return self.is_fallback + class IFunctionLibrary(abc.ABC): @abc.abstractmethod diff --git a/packages/openassistants/openassistants/llm_function_calling/selection.py b/packages/openassistants/openassistants/llm_function_calling/selection.py index 32e0ccb..4349adb 100644 --- a/packages/openassistants/openassistants/llm_function_calling/selection.py +++ b/packages/openassistants/openassistants/llm_function_calling/selection.py @@ -17,8 +17,13 @@ async def filter_functions( functions_text = "\n".join([f.get_signature() for f in functions]) json_schema = { "type": "object", - "properties": {"function_name": {"type": "string"}}, + "properties": {"function_name": {"$ref": "#/definitions/functions"}}, "required": ["function_name"], + "definitions": { + "functions": { + "enum": [f.get_id() for f in functions], + } + }, } messages = [ HumanMessage( @@ -54,7 +59,10 @@ async def select_function( user_query: str, chunk_size: int = 4, ) -> SelectFunctionResult: - subsets = chunk_list_by_max_size(functions, chunk_size) + non_fallbacks = [f for f in functions if not f.get_is_fallback()] + fallbacks = [f for f in functions if f.get_is_fallback()] + + subsets = chunk_list_by_max_size(non_fallbacks, chunk_size) # Make LLM calls in parallel tasks = [ @@ -77,21 +85,39 @@ async def select_function( json_schema = { "type": "object", "properties": { - "function_name": {"type": "string"}, - "suggested_function_names": {"type": "array", "items": {"type": "string"}}, + "reason": {"type": "string"}, + "function_name": {"$ref": "#/definitions/functions"}, + "related_function_names": { + "type": "array", + "items": {"$ref": "#/definitions/functions"}, + }, + }, + "definitions": { + "functions": { + "enum": [f.get_id() for f in selected_functions + fallbacks], + } }, } + fallbacks_signatures = "\n".join([f.get_signature() for f in fallbacks]) + selection_messages = [ HumanMessage( - content=f"""Prior selection reduced the candidates to these functions: + content=f"""\ +Prior selection reduced the candidates to these functions: + {selected_functions_signatures} +These fallback functions can be used when none of the above functions are a good match: + +{fallbacks_signatures} + Scenario 1: There is a function in the list of candidates that is a match to the user query. Action: provide the name of the function as the 'function_name' argument. Scenario 2: None of the functions in the list of candidates match the user query. -Action: select related functions from the list of candidates as the 'suggested_function_names' argument. You are also allowed to return an empty list of suggested functions if you think none of the functions are a good match. +Action: select related functions from the list of candidates as the 'related_function_names' argument. +You are also allowed to return an empty list of related functions if you think none of the functions are a good match. First decide which of the two scenarios is the case. Then take the appropriate action. @@ -114,10 +140,12 @@ async def select_function( suggested_function_names = json_result.get("related_function_names", []) selected_function = next( - (f for f in selected_functions if f.get_id() == function_name), None + (f for f in selected_functions + fallbacks if f.get_id() == function_name), None ) suggested_functions = [ - f for f in selected_functions if f.get_id() in suggested_function_names + f + for f in selected_functions + fallbacks + if f.get_id() in suggested_function_names ] or None return SelectFunctionResult( diff --git a/packages/openassistants/pyproject.toml b/packages/openassistants/pyproject.toml index 406dcc5..64593c2 100644 --- a/packages/openassistants/pyproject.toml +++ b/packages/openassistants/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "definitive-openassistants" -version = "0.0.10" +version = "0.0.11" description = "" authors = ["Rick Lamers "] readme = "README.md"