Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up the mlflow trace for tool calling #47

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -116,62 +116,72 @@ def predict(
)

# Call the LLM to recursively calls tools and eventually deliver a generation to send back to the user
(
model_response,
messages_log_with_tool_calls,
) = self.recursively_call_and_run_tools(messages=messages)

# If your front end keeps of converastion history and automatically appends the bot's response to the messages history, remove this line.
messages_log_with_tool_calls.append(
model_response.choices[0].message.to_dict()
) # OpenAI client
messages_log_with_tool_calls = self.recursively_call_and_run_tools(
messages=messages
)

# remove the system prompt - this should not be exposed to the Agent caller
messages_log_with_tool_calls = messages_log_with_tool_calls[1:]

return {
"content": model_response.choices[0].message.content,
# "content": model_response.choices[0].message.content,
"content": messages_log_with_tool_calls[-1]["content"],
# messages should be returned back to the Review App (or any other front end app) and stored there so it can be passed back to this stateless agent with the next turns of converastion.
"messages": messages_log_with_tool_calls,
}

@mlflow.trace(span_type="AGENT")
@mlflow.trace(span_type="CHAIN")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ, why change from agent to chain?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly asking since this does still seem like the core agentic loop logic

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(same thinking as below)

def recursively_call_and_run_tools(self, max_iter=10, **kwargs):
messages = kwargs["messages"]
del kwargs["messages"]
i = 0
while i < max_iter:
response = self.chat_completion(messages=messages, tools=True)
assistant_message = response.choices[0].message # openai client
tool_calls = assistant_message.tool_calls # openai
if tool_calls is None:
# the tool execution finished, and we have a generation
return (response, messages)
tool_messages = []
for tool_call in tool_calls: # TODO: should run in parallel
function = tool_call.function # openai
args = json.loads(function.arguments) # openai
result = execute_function(self.tool_functions[function.name], args)
tool_message = {
"role": "tool",
"tool_call_id": tool_call.id,
"content": result,
} # openai

tool_messages.append(tool_message)
assistant_message_dict = assistant_message.dict().copy() # openai
del assistant_message_dict["content"]
del assistant_message_dict["function_call"] # openai only
if "audio" in assistant_message_dict:
del assistant_message_dict["audio"] # llama70b hack
messages = (
messages
+ [
assistant_message_dict,
]
+ tool_messages
)
i += 1
with mlflow.start_span(name=f"iteration_{i}", span_type="CHAIN") as span:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question here RE "CHAIN" vs "AGENT"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mirrored what LangGraph traces looked like here - they use a CHAIN for this type of logic. My thinking: this span represents logic that is actually a deterministic chain of steps.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, thanks! Do you have a pointer to a LangGraph "CHAIN" tracing example? i'm wondering if LangGraph uses that span name mostly writing a fixed chain, since the logic in this function calling agent thing looks similar to the Python equivalent of a LangGraph agent (e.g. if you translated something like LangGraph create_react_agent into Python code)

response = self.chat_completion(messages=messages, tools=True)
assistant_message = response.choices[0].message # openai client
tool_calls = assistant_message.tool_calls # openai
if tool_calls is None:
# the tool execution finished, and we have a generation
messages.append(assistant_message.to_dict())
return messages
tool_messages = []
for tool_call in tool_calls: # TODO: should run in parallel
with mlflow.start_span(
name="execute_tool", span_type="TOOL"
) as span:
function = tool_call.function # openai
args = json.loads(function.arguments) # openai
span.set_inputs(
{
"function_name": function.name,
"function_args_raw": function.arguments,
"function_args_loaded": args,
}
)
result = execute_function(
self.tool_functions[function.name], args
)
tool_message = {
"role": "tool",
"tool_call_id": tool_call.id,
"content": result,
} # openai

tool_messages.append(tool_message)
span.set_outputs({"new_message": tool_message})
assistant_message_dict = assistant_message.dict().copy() # openai
del assistant_message_dict["content"]
del assistant_message_dict["function_call"] # openai only
if "audio" in assistant_message_dict:
del assistant_message_dict["audio"] # llama70b hack
messages = (
messages
+ [
assistant_message_dict,
]
+ tool_messages
)
i += 1
# TODO: Handle more gracefully
raise "ERROR: max iter reached"

Expand All @@ -198,8 +208,6 @@ def chat_completion(self, messages: List[Dict[str, str]], tools: bool = False):
return traced_create(model=endpoint_name, messages=messages, **llm_options)


logging.basicConfig(level=logging.INFO)

# tell MLflow logging where to find the agent's code
set_model(FunctionCallingAgent())

Expand All @@ -212,20 +220,24 @@ def chat_completion(self, messages: List[Dict[str, str]], tools: bool = False):
# print(find_config_folder_location())
# print(os.path.abspath(os.getcwd()))
# mlflow.tracing.disable()
logging.basicConfig(level=logging.DEBUG)
agent = FunctionCallingAgent()

vibe_check_query = {
"messages": [
# {"role": "user", "content": f"what is agent evaluation?"},
# {"role": "user", "content": f"How does the blender work?"},
{
"role": "user",
"content": f"How does the BlendMaster Elite 4000 blender work?",
},
# {
# "role": "user",
# "content": f"find all docs from the section header 'Databricks documentation archive' or 'Work with files on Databricks'",
# },
{
"role": "user",
"content": "Translate the sku `OLD-abs-1234` to the new format",
}
# {
# "role": "user",
# "content": "Translate the sku `OLD-abs-1234` to the new format",
# }
# {
# "role": "user",
# "content": f"convert sku 'OLD-XXX-1234' to the new format",
Expand All @@ -239,3 +251,12 @@ def chat_completion(self, messages: List[Dict[str, str]], tools: bool = False):

output = agent.predict(model_input=vibe_check_query)
print(output["content"])

second_turn = {
"messages": output["messages"]
+ [{"role": "user", "content": "How do I turn it on?"}]
}

# Run the Agent again with the same input to continue the conversation
second_turn_output = agent.predict(model_input=second_turn)
print(second_turn_output["content"])
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import json


@mlflow.trace(span_type="FUNCTION")
def execute_function(tool, args):
result = tool(**args)
return json.dumps(result)
Loading