Skip to content

Commit

Permalink
Resolve merge conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
bonk1t committed Dec 9, 2024
1 parent e3d772b commit cc04c70
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 192 deletions.
16 changes: 14 additions & 2 deletions agency_swarm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
from .agency import Agency
from .agents import Agent
from .tools import BaseTool
from .util import get_openai_client, llm_validator, set_openai_client, set_openai_key
from .util.streaming import AgencyEventHandler, AgencyEventHandlerWithTracking
from .util import (
get_openai_client,
get_tracker,
llm_validator,
set_openai_client,
set_openai_key,
set_tracker,
)
from .util.streaming import (
AgencyEventHandler,
AgencyEventHandlerWithTracking,
)

__all__ = [
"Agency",
Expand All @@ -14,4 +24,6 @@
"set_openai_client",
"set_openai_key",
"llm_validator",
"set_tracker",
"get_tracker",
]
150 changes: 8 additions & 142 deletions agency_swarm/agency/agency.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,10 @@
)

from openai.lib._parsing._completions import type_to_response_format_param
from openai.types.beta.threads import Message
from openai.types.beta.threads.runs import RunStep
from openai.types.beta.threads.runs.tool_call import (
CodeInterpreterToolCall,
FileSearchToolCall,
FunctionToolCall,
)
from pydantic import BaseModel, Field, field_validator
from rich.console import Console

from agency_swarm.agents import Agent
from agency_swarm.messages.message_output import MessageOutputLive
from agency_swarm.threads import Thread
from agency_swarm.threads.thread_async import ThreadAsync
from agency_swarm.tools import BaseTool, CodeInterpreter, FileSearch
Expand All @@ -40,8 +32,11 @@
from agency_swarm.util.files import get_file_purpose, get_tools
from agency_swarm.util.oai import get_tracker
from agency_swarm.util.shared_state import SharedState
from agency_swarm.util.streaming import AgencyEventHandler
from agency_swarm.util.usage_tracking.tracker_factory import get_tracker
from agency_swarm.util.streaming import (
AgencyEventHandler,
create_gradio_handler,
create_term_handler,
)

console = Console()
T = TypeVar("T", bound=BaseModel)
Expand Down Expand Up @@ -159,8 +154,7 @@ def __init__(
self._create_special_tools()
self._init_agents()

self.usage_tracker = get_tracker(usage_tracker)

@get_tracker().get_observe_decorator()
def get_completion(
self,
message: str,
Expand Down Expand Up @@ -357,10 +351,6 @@ def demo_gradio(self, height=450, dark_mode=True, **kwargs):
gradio_handler = create_gradio_handler(chatbot_queue=chatbot_queue)

with gr.Blocks(js=js) as demo:
chatbot_queue = queue.Queue()
event_handler = GradioEventHandler(
chatbot_queue=chatbot_queue, usage_tracker=self.usage_tracker
)
chatbot = gr.Chatbot(height=height)
with gr.Row():
with gr.Column(scale=9):
Expand Down Expand Up @@ -541,7 +531,7 @@ def bot(original_message, history, dropdown):
target=self.get_completion_stream,
args=(
original_message,
event_handler,
gradio_handler,
[],
recipient_agent,
"",
Expand Down Expand Up @@ -664,131 +654,7 @@ def run_demo(self):
"""
Executes agency in the terminal with autocomplete for recipient agent names.
"""
outer_self = self
from agency_swarm import AgencyEventHandlerWithTracking

class TermEventHandler(AgencyEventHandlerWithTracking):
message_output = None

@override
def on_message_created(self, message: Message) -> None:
if message.role == "user":
self.message_output = MessageOutputLive(
"text", self.agent_name, self.recipient_agent_name, ""
)
self.message_output.cprint_update(message.content[0].text.value)
else:
self.message_output = MessageOutputLive(
"text", self.recipient_agent_name, self.agent_name, ""
)

@override
def on_message_done(self, message: Message) -> None:
self.message_output = None

@override
def on_text_delta(self, delta, snapshot):
self.message_output.cprint_update(snapshot.value)

@override
def on_tool_call_created(self, tool_call):
if isinstance(tool_call, dict):
if "type" not in tool_call:
tool_call["type"] = "function"

if tool_call["type"] == "function":
tool_call = FunctionToolCall(**tool_call)
elif tool_call["type"] == "code_interpreter":
tool_call = CodeInterpreterToolCall(**tool_call)
elif (
tool_call["type"] == "file_search"
or tool_call["type"] == "retrieval"
):
tool_call = FileSearchToolCall(**tool_call)
else:
raise ValueError("Invalid tool call type: " + tool_call["type"])

# TODO: add support for code interpreter and retirieval tools

if tool_call.type == "function":
self.message_output = MessageOutputLive(
"function",
self.recipient_agent_name,
self.agent_name,
str(tool_call.function),
)

@override
def on_tool_call_delta(self, delta, snapshot):
if isinstance(snapshot, dict):
if "type" not in snapshot:
snapshot["type"] = "function"

if snapshot["type"] == "function":
snapshot = FunctionToolCall(**snapshot)
elif snapshot["type"] == "code_interpreter":
snapshot = CodeInterpreterToolCall(**snapshot)
elif snapshot["type"] == "file_search":
snapshot = FileSearchToolCall(**snapshot)
else:
raise ValueError("Invalid tool call type: " + snapshot["type"])

self.message_output.cprint_update(str(snapshot.function))

@override
def on_tool_call_done(self, snapshot):
self.message_output = None

# TODO: add support for code interpreter and retrieval tools
if snapshot.type != "function":
return

if snapshot.function.name == "SendMessage" and not (
hasattr(
outer_self.send_message_tool_class.ToolConfig,
"output_as_result",
)
and outer_self.send_message_tool_class.ToolConfig.output_as_result
):
try:
args = eval(snapshot.function.arguments)
recipient = args["recipient"]
self.message_output = MessageOutputLive(
"text", self.recipient_agent_name, recipient, ""
)

self.message_output.cprint_update(args["message"])
except Exception as e:
pass

self.message_output = None

@override
def on_run_step_done(self, run_step: RunStep) -> None:
super().on_run_step_done(run_step)

if run_step.type == "tool_calls":
for tool_call in run_step.step_details.tool_calls:
if tool_call.type != "function":
continue

if tool_call.function.name == "SendMessage":
continue

self.message_output = None
self.message_output = MessageOutputLive(
"function_output",
tool_call.function.name,
self.recipient_agent_name,
tool_call.function.output,
)
self.message_output.cprint_update(tool_call.function.output)

self.message_output = None

@override
def on_end(self):
self.message_output = None
term_handler = create_term_handler(agency=self)

self.recipient_agents = [str(agent.name) for agent in self.main_recipients]

Expand Down
8 changes: 4 additions & 4 deletions agency_swarm/threads/thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,8 +469,8 @@ def _create_run(
"parallel_tool_calls": recipient_agent.parallel_tool_calls
},
metadata={
"sender_agent": self.agent.name,
"recipient_agent": recipient_agent.name,
"sender_agent_name": self.agent.name,
"recipient_agent_name": recipient_agent.name,
},
response_format=response_format,
) as stream:
Expand All @@ -489,8 +489,8 @@ def _create_run(
parallel_tool_calls=recipient_agent.parallel_tool_calls,
response_format=response_format,
metadata={
"sender_agent": self.agent.name,
"recipient_agent": recipient_agent.name,
"sender_agent_name": self.agent.name,
"recipient_agent_name": recipient_agent.name,
},
)
self._run = self.client.beta.threads.runs.poll(
Expand Down
10 changes: 8 additions & 2 deletions agency_swarm/util/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
from .cli.create_agent_template import create_agent_template
from .cli.import_agent import import_agent
from .files import get_file_purpose, get_tools
from .oai import get_openai_client, set_openai_client, set_openai_key
from .usage_tracking import AbstractTracker, LangfuseUsageTracker, SQLiteUsageTracker
from .oai import (
get_openai_client,
get_tracker,
set_openai_client,
set_openai_key,
set_tracker,
)
from .tracking import AbstractTracker, LangfuseTracker, SQLiteTracker
from .validators import llm_validator
42 changes: 0 additions & 42 deletions agency_swarm/util/streaming.py

This file was deleted.

0 comments on commit cc04c70

Please sign in to comment.