Skip to content

Commit

Permalink
formatter
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobweiss2305 committed Sep 30, 2024
1 parent d8d23af commit 545c21b
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 17 deletions.
7 changes: 6 additions & 1 deletion cookbook/providers/cohere/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@

agent = Agent(
model=CohereChat(id="command-r-08-2024"),
tools=[YFinanceTools(company_info=True, stock_fundamentals=True,)],
tools=[
YFinanceTools(
company_info=True,
stock_fundamentals=True,
)
],
show_tool_calls=True,
debug_mode=True,
markdown=True,
Expand Down
28 changes: 12 additions & 16 deletions phi/model/cohere/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
logger.error("`cohere` not installed")
raise


@dataclass
class StreamData:
response_content: str = ""
Expand All @@ -43,6 +44,7 @@ class StreamData:
time_to_first_token: Optional[float] = None
response_timer: Timer = field(default_factory=Timer)


class CohereChat(Model):
id: str = "command-r-plus"
name: str = "cohere"
Expand Down Expand Up @@ -276,11 +278,11 @@ def _prepare_function_calls(self, agent_message: Message) -> Tuple[List[Function
return function_calls_to_run, error_messages

def _handle_tool_calls(
self,
assistant_message: Message,
messages: List[Message],
self,
assistant_message: Message,
messages: List[Message],
response_tool_calls: List[Any],
model_response: ModelResponse
model_response: ModelResponse,
) -> Optional[Any]:
"""
Handle tool calls in the assistant message.
Expand Down Expand Up @@ -320,8 +322,7 @@ def _handle_tool_calls(
)
)
continue
function_calls_to_run.append(_function_call)

function_calls_to_run.append(_function_call)

if self.show_tool_calls:
model_response.content += "\nRunning:"
Expand All @@ -334,9 +335,7 @@ def _handle_tool_calls(
function_calls_to_run, error_messages = self._prepare_function_calls(assistant_message)

for _ in self.run_function_calls(
function_calls=function_calls_to_run,
function_call_results=function_call_results,
tool_role=tool_role
function_calls=function_calls_to_run, function_call_results=function_call_results, tool_role=tool_role
):
pass

Expand All @@ -354,7 +353,7 @@ def _handle_tool_calls(
]
else:
tool_results = None

return tool_results

def response(self, messages: List[Message], tool_results: Optional[List[ToolResult]] = None) -> ModelResponse:
Expand Down Expand Up @@ -404,7 +403,7 @@ def response(self, messages: List[Message], tool_results: Optional[List[ToolResu
assistant_message=assistant_message,
messages=messages,
response_tool_calls=response_tool_calls,
model_response=model_response
model_response=model_response,
)

# Make a recursive call with tool results if available
Expand Down Expand Up @@ -493,7 +492,7 @@ def response_stream(
stream_data.completion_tokens += 1
if stream_data.completion_tokens == 1:
stream_data.time_to_first_token = stream_data.response_timer.elapsed
logger.debug(f"Time to first token: {stream_data.time_to_first_token:.4f}s")
logger.debug(f"Time to first token: {stream_data.time_to_first_token:.4f}s")
yield ModelResponse(content=response.text)

if isinstance(response, ToolCallsChunkStreamedChatResponse):
Expand Down Expand Up @@ -611,10 +610,7 @@ def response_stream(
# Constructs a list named tool_results, where each element is a dictionary that contains details of tool calls and their outputs.
# It pairs each tool call in response_tool_calls with its corresponding result in function_call_results.
tool_results = [
ToolResult(
call=tool_call,
outputs=[tool_call.parameters, {"result": fn_result.content}]
)
ToolResult(call=tool_call, outputs=[tool_call.parameters, {"result": fn_result.content}])
for tool_call, fn_result in zip(stream_data.response_tool_calls, function_call_results)
]
messages.append(Message(role="user", content=""))
Expand Down

0 comments on commit 545c21b

Please sign in to comment.