Skip to content
This repository has been archived by the owner on Aug 13, 2024. It is now read-only.

Commit

Permalink
expose assistant function selection
Browse files Browse the repository at this point in the history
* improve json detection for weaker models
* fix weird escapes for mixtral
  • Loading branch information
jordan-wu-97 committed Jan 24, 2024
1 parent 8eaeef6 commit ad1e6ff
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 13 deletions.
2 changes: 1 addition & 1 deletion packages/openassistants-fastapi/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "definitive-openassistants-fastapi"
version = "0.0.12"
version = "0.0.13"
description = ""
authors = ["Rick Lamers <[email protected]>"]
readme = "README.md"
Expand Down
28 changes: 25 additions & 3 deletions packages/openassistants/openassistants/core/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@
generate_argument_decisions,
generate_arguments,
)
from openassistants.llm_function_calling.selection import select_function
from openassistants.llm_function_calling.selection import (
SelectFunctionResult,
select_function,
)
from openassistants.utils.async_utils import AsyncStreamVersion
from openassistants.utils.langchain_util import LangChainCachedEmbeddings
from openassistants.utils.vision import image_url_to_text
Expand Down Expand Up @@ -194,6 +197,25 @@ async def do_infilling(

return complete, arguments

async def run_function_selection(
self,
chat_history: List[OpasMessage],
) -> SelectFunctionResult:
all_functions = await self.get_all_functions()

last_message = chat_history[-1]

assert isinstance(last_message, OpasUserMessage)
assert isinstance(last_message.content, str)

select_function_result = await select_function(
self.function_identification,
all_functions,
last_message.content,
)

return select_function_result

async def handle_user_plaintext(
self,
message: OpasUserMessage,
Expand All @@ -214,8 +236,8 @@ async def handle_user_plaintext(
selected_function = filtered[0]

if selected_function is None:
function_selection = await select_function(
self.function_identification, all_functions, message.content
function_selection = await self.run_function_selection(
chat_history=chat_history,
)

if function_selection.function:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ async def select_function(

# Ensure the selected function names are in the loaded signatures
selected_functions = [f for f in functions if f.get_id() in function_names]
if not selected_functions:

if len(selected_functions) == 0 and len(fallbacks) == 0:
return SelectFunctionResult()

# Include the signatures of all the selected functions in the final evaluation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,17 +77,29 @@ def chunk_list_by_max_size(lst, max_size):
return [chunk.tolist() for chunk in np.array_split(lst, n)]


def find_json_substring(s):
def find_indexes_of_char(string: str, char: str):
return [index for index, c in enumerate(string) if c == char]


def find_json_substring(s: str):
start = s.find("{")
end = s.rfind("}") + 1
if start != -1 and end != -1:
return s[start:end]
return None

all_closing_brackets = find_indexes_of_char(s, "}")

for i in all_closing_brackets:
if i > start:
try:
json.loads(s[start : i + 1])
return s[start : i + 1]
except json.JSONDecodeError:
pass

raise ValueError(f"Could not find JSON substring in response content: {s}")


async def generate_to_json(
chat: BaseChatModel,
messages,
messages: list[BaseMessage],
output_json_schema: Optional[dict],
task_name: str,
tags: Optional[list[str]] = None,
Expand All @@ -104,7 +116,7 @@ async def generate_to_json(


async def generate_to_json_generic(
chat,
chat: BaseChatModel,
messages: list[BaseMessage],
output_json_schema: Optional[dict],
tags: list[str],
Expand All @@ -122,8 +134,12 @@ async def generate_to_json_generic(

messages = [system_message] + messages
response = await chat.ainvoke(ensure_alternating(messages), {"tags": tags})
assert isinstance(response.content, str)
content = response.content

# replace weird \_ escaping of underscores when using mixtral
content = content.replace(r"\_", "_")

json_substring = find_json_substring(content)
if json_substring is not None:
try:
Expand Down
2 changes: 1 addition & 1 deletion packages/openassistants/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "definitive-openassistants"
version = "0.0.12"
version = "0.0.13"
description = ""
authors = ["Rick Lamers <[email protected]>"]
readme = "README.md"
Expand Down

0 comments on commit ad1e6ff

Please sign in to comment.