Skip to content
This repository has been archived by the owner on Aug 13, 2024. It is now read-only.

Commit

Permalink
fallback function support
Browse files Browse the repository at this point in the history
  • Loading branch information
jordan-wu-97 committed Jan 12, 2024
1 parent dc0c957 commit 9ba40af
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 14 deletions.
2 changes: 1 addition & 1 deletion packages/openassistants-fastapi/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "definitive-openassistants-fastapi"
version = "0.0.10"
version = "0.0.11"
description = ""
authors = ["Rick Lamers <[email protected]>"]
readme = "README.md"
Expand Down
18 changes: 14 additions & 4 deletions packages/openassistants/openassistants/functions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,13 @@ def get_signature(self) -> str:
# convert JSON Schema types to Python types signature
params_repr = PyRepr.repr_json_schema(self.get_parameters_json_schema())

sample_question_text = "\n".join(f"* {q}" for q in self.get_sample_questions())

documentation = f"""\
{self.get_description()}
Example Questions:
{sample_question_text}
"""

if len(self.get_sample_questions()) > 0:
documentation += "\n".join(f"* {q}" for q in self.get_sample_questions())

# Construct the function signature
signature = f"""\
def {self.get_id()}({params_repr}) -> pd.DataFrame:
Expand All @@ -92,6 +91,13 @@ def {self.get_id()}({params_repr}) -> pd.DataFrame:
async def get_entity_configs(self) -> Mapping[str, IEntityConfig]:
pass

@abc.abstractmethod
def get_is_fallback(self) -> bool:
"""
Whether this function is a fallback function.
Fallback functions are functions that are always available to be chosen if no other non-fallback functions are relevant.
""" # noqa: E501

@abc.abstractmethod
async def execute(
self,
Expand Down Expand Up @@ -133,6 +139,7 @@ class BaseFunction(IFunction, BaseModel, abc.ABC):
sample_questions: List[str] = []
confirm: bool = False
parameters: BaseFunctionParameters = BaseFunctionParameters()
is_fallback: bool = False

def get_id(self) -> str:
return self.id
Expand All @@ -158,6 +165,9 @@ def get_parameters_json_schema(self) -> JSONSchema:
async def get_entity_configs(self) -> Mapping[str, IEntityConfig]:
return {}

def get_is_fallback(self) -> bool:
return self.is_fallback


class IFunctionLibrary(abc.ABC):
@abc.abstractmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,13 @@ async def filter_functions(
functions_text = "\n".join([f.get_signature() for f in functions])
json_schema = {
"type": "object",
"properties": {"function_name": {"type": "string"}},
"properties": {"function_name": {"$ref": "#/definitions/functions"}},
"required": ["function_name"],
"definitions": {
"functions": {
"enum": [f.get_id() for f in functions],
}
},
}
messages = [
HumanMessage(
Expand Down Expand Up @@ -54,7 +59,10 @@ async def select_function(
user_query: str,
chunk_size: int = 4,
) -> SelectFunctionResult:
subsets = chunk_list_by_max_size(functions, chunk_size)
non_fallbacks = [f for f in functions if not f.get_is_fallback()]
fallbacks = [f for f in functions if f.get_is_fallback()]

subsets = chunk_list_by_max_size(non_fallbacks, chunk_size)

# Make LLM calls in parallel
tasks = [
Expand All @@ -77,21 +85,39 @@ async def select_function(
json_schema = {
"type": "object",
"properties": {
"function_name": {"type": "string"},
"suggested_function_names": {"type": "array", "items": {"type": "string"}},
"reason": {"type": "string"},
"function_name": {"$ref": "#/definitions/functions"},
"related_function_names": {
"type": "array",
"items": {"$ref": "#/definitions/functions"},
},
},
"definitions": {
"functions": {
"enum": [f.get_id() for f in selected_functions + fallbacks],
}
},
}

fallbacks_signatures = "\n".join([f.get_signature() for f in fallbacks])

selection_messages = [
HumanMessage(
content=f"""Prior selection reduced the candidates to these functions:
content=f"""\
Prior selection reduced the candidates to these functions:
{selected_functions_signatures}
These fallback functions can be used when none of the above functions are a good match:
{fallbacks_signatures}
Scenario 1: There is a function in the list of candidates that is a match to the user query.
Action: provide the name of the function as the 'function_name' argument.
Scenario 2: None of the functions in the list of candidates match the user query.
Action: select related functions from the list of candidates as the 'suggested_function_names' argument. You are also allowed to return an empty list of suggested functions if you think none of the functions are a good match.
Action: select related functions from the list of candidates as the 'related_function_names' argument.
You are also allowed to return an empty list of related functions if you think none of the functions are a good match.
First decide which of the two scenarios is the case. Then take the appropriate action.
Expand All @@ -114,10 +140,12 @@ async def select_function(
suggested_function_names = json_result.get("related_function_names", [])

selected_function = next(
(f for f in selected_functions if f.get_id() == function_name), None
(f for f in selected_functions + fallbacks if f.get_id() == function_name), None
)
suggested_functions = [
f for f in selected_functions if f.get_id() in suggested_function_names
f
for f in selected_functions + fallbacks
if f.get_id() in suggested_function_names
] or None

return SelectFunctionResult(
Expand Down
2 changes: 1 addition & 1 deletion packages/openassistants/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "definitive-openassistants"
version = "0.0.10"
version = "0.0.11"
description = ""
authors = ["Rick Lamers <[email protected]>"]
readme = "README.md"
Expand Down

0 comments on commit 9ba40af

Please sign in to comment.