Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v2.4.25 #1038

Merged
merged 13 commits into from
Jul 3, 2024
6 changes: 4 additions & 2 deletions phi/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,15 @@ def run_function_calls(self, function_calls: List[FunctionCall], role: str = "to
# -*- Run function call
_function_call_timer = Timer()
_function_call_timer.start()
function_call.execute()
function_call_success = function_call.execute()
_function_call_timer.stop()

_function_call_result = Message(
role=role,
content=function_call.result,
content=function_call.result if function_call_success else function_call.error,
tool_call_id=function_call.call_id,
tool_call_name=function_call.function.name,
tool_call_error=not function_call_success,
metrics={"time": _function_call_timer.elapsed},
)
if "tool_call_times" not in self.metrics:
Expand Down
6 changes: 5 additions & 1 deletion phi/llm/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ class Message(BaseModel):
tool_call_id: Optional[str] = None
# The name of the tool call
tool_call_name: Optional[str] = None
# The error of the tool call
tool_call_error: bool = False
# The tool calls generated by the model, such as function calls.
tool_calls: Optional[List[Dict[str, Any]]] = None
# Metrics for the message, tokes + the time it took to generate the response.
Expand All @@ -44,7 +46,9 @@ def get_content_string(self) -> str:
return ""

def to_dict(self) -> Dict[str, Any]:
_dict = self.model_dump(exclude_none=True, exclude={"metrics", "tool_call_name", "internal_id"})
_dict = self.model_dump(
exclude_none=True, exclude={"metrics", "tool_call_name", "internal_id", "tool_call_error"}
)
# Manually add the content field if it is None
if self.content is None:
_dict["content"] = None
Expand Down
12 changes: 10 additions & 2 deletions phi/llm/mistral/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,11 @@ def response(self, messages: List[Message]) -> str:
)
continue
if _function_call.error is not None:
messages.append(Message(role="tool", tool_call_id=_tool_call_id, content=_function_call.error))
messages.append(
Message(
role="tool", tool_call_id=_tool_call_id, tool_call_error=True, content=_function_call.error
)
)
continue
function_calls_to_run.append(_function_call)

Expand Down Expand Up @@ -259,7 +263,11 @@ def response_stream(self, messages: List[Message]) -> Iterator[str]:
)
continue
if _function_call.error is not None:
messages.append(Message(role="tool", tool_call_id=_tool_call_id, content=_function_call.error))
messages.append(
Message(
role="tool", tool_call_id=_tool_call_id, tool_call_error=True, content=_function_call.error
)
)
continue
function_calls_to_run.append(_function_call)

Expand Down
223 changes: 161 additions & 62 deletions phi/llm/ollama/chat.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion phi/llm/ollama/hermes.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def response(self, messages: List[Message]) -> str:
messages.append(Message(role="user", content="Could not find function to call."))
continue
if _function_call.error is not None:
messages.append(Message(role="user", content=_function_call.error))
messages.append(Message(role="user", tool_call_error=True, content=_function_call.error))
continue
function_calls_to_run.append(_function_call)

Expand Down
2 changes: 1 addition & 1 deletion phi/llm/ollama/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def response(self, messages: List[Message]) -> str:
messages.append(Message(role="user", content="Could not find function to call."))
continue
if _function_call.error is not None:
messages.append(Message(role="user", content=_function_call.error))
messages.append(Message(role="user", tool_call_error=True, content=_function_call.error))
continue
function_calls_to_run.append(_function_call)

Expand Down
85 changes: 85 additions & 0 deletions phi/llm/ollama/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import json
from typing import Optional, Dict, Literal, Union

from pydantic import BaseModel


class _MessageToolCallExtractionResult(BaseModel):
tool_calls: Optional[list] = None
invalid_json_format: bool = False


def _extract_json(s: str) -> Union[Optional[Dict], Literal[False]]:
"""
Extracts all valid JSON from a string then combines them and returns it as a dictionary.

Args:
s: The string to extract JSON from.

Returns:
A dictionary containing the extracted JSON, or None if no JSON was found or False if an invalid JSON was found.
"""
json_objects = []
start_idx = 0

while start_idx < len(s):
# Find the next '{' which indicates the start of a JSON block
json_start = s.find("{", start_idx)
if json_start == -1:
break # No more JSON objects found

# Find the matching '}' for the found '{'
stack = []
i = json_start
while i < len(s):
if s[i] == "{":
stack.append("{")
elif s[i] == "}":
if stack:
stack.pop()
if not stack:
json_end = i
break
i += 1
else:
return False

json_str = s[json_start : json_end + 1]
try:
json_obj = json.loads(json_str)
json_objects.append(json_obj)
except ValueError:
return False

start_idx = json_end + 1

if not json_objects:
return None

# Combine all JSON objects into one
combined_json = {}
for obj in json_objects:
for key, value in obj.items():
if key not in combined_json:
combined_json[key] = value
elif isinstance(value, list) and isinstance(combined_json[key], list):
combined_json[key].extend(value)

return combined_json


def _extract_tool_calls(assistant_msg_content: str) -> _MessageToolCallExtractionResult:
ashpreetbedi marked this conversation as resolved.
Show resolved Hide resolved
json_obj = _extract_json(assistant_msg_content)
if json_obj is None:
return _MessageToolCallExtractionResult()

if json_obj is False or not isinstance(json_obj, dict):
return _MessageToolCallExtractionResult(invalid_json_format=True)

tool_calls: Optional[list] = json_obj.get("tool_calls")

# Not tool call json object
if not isinstance(tool_calls, list):
return _MessageToolCallExtractionResult(invalid_json_format=True)

return _MessageToolCallExtractionResult(tool_calls=tool_calls)
7 changes: 4 additions & 3 deletions phi/llm/openai/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def run_function(self, function_call: Dict[str, Any]) -> Tuple[Message, Optional
if _function_call is None:
return Message(role="function", content="Could not find function to call."), None
if _function_call.error is not None:
return Message(role="function", content=_function_call.error), _function_call
return Message(role="function", tool_call_error=True, content=_function_call.error), _function_call

if self.function_call_stack is None:
self.function_call_stack = []
Expand All @@ -263,12 +263,13 @@ def run_function(self, function_call: Dict[str, Any]) -> Tuple[Message, Optional
self.function_call_stack.append(_function_call)
_function_call_timer = Timer()
_function_call_timer.start()
_function_call.execute()
function_call_success = _function_call.execute()
_function_call_timer.stop()
_function_call_message = Message(
role="function",
name=_function_call.function.name,
content=_function_call.result,
content=_function_call.result if function_call_success else _function_call.error,
tool_call_error=not function_call_success,
metrics={"time": _function_call_timer.elapsed},
)
if "function_call_times" not in self.metrics:
Expand Down
6 changes: 5 additions & 1 deletion phi/llm/together/together.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,11 @@ def response_stream(self, messages: List[Message]) -> Iterator[str]:
)
continue
if _function_call.error is not None:
messages.append(Message(role="tool", tool_call_id=_tool_call_id, content=_function_call.error))
messages.append(
Message(
role="tool", tool_call_id=_tool_call_id, tool_call_error=True, content=_function_call.error
)
)
continue
function_calls_to_run.append(_function_call)

Expand Down
4 changes: 2 additions & 2 deletions phi/tools/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def execute(self) -> bool:
except Exception as e:
logger.warning(f"Could not run function {self.get_call_str()}")
logger.exception(e)
self.result = str(e)
self.error = str(e)
return False

try:
Expand All @@ -152,5 +152,5 @@ def execute(self) -> bool:
except Exception as e:
logger.warning(f"Could not run function {self.get_call_str()}")
logger.exception(e)
self.result = str(e)
self.error = str(e)
return False
5 changes: 4 additions & 1 deletion phi/utils/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ def get_function_call(
_arguments = json.loads(arguments)
except Exception as e:
logger.error(f"Unable to decode function arguments:\n{arguments}\nError: {e}")
function_call.error = f"Error while decoding function arguments: {e}\n\n Please make sure we can json.loads() the arguments and retry."
function_call.error = (
f"Error while decoding function arguments: {e}\n\n"
f"Please make sure we can json.loads() the arguments and retry."
)
return function_call

if not isinstance(_arguments, dict):
Expand Down
Loading