Skip to content

Commit a3274be

Browse files
committed
[Feature] Tool services
ghstack-source-id: bbad726 Pull-Request: #3220
1 parent 0b7bddd commit a3274be

File tree

9 files changed

+2349
-248
lines changed

9 files changed

+2349
-248
lines changed

.github/unittest/llm/scripts_llm/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,4 @@ dependencies:
2222
- transformers
2323
- datasets
2424
- vllm
25+
- mcp

.github/unittest/llm/scripts_llm/install.sh

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,17 @@ python -m pip install -e . --no-build-isolation
6161

6262
# smoke test
6363
python -c "import torchrl"
64+
65+
# Install MCP dependencies for tool execution tests
66+
printf "* Installing MCP dependencies (uvx, Deno)\n"
67+
68+
# Install uvx (universal package runner)
69+
pip install uvx
70+
71+
# Install Deno (required by mcp-run-python)
72+
curl -fsSL https://deno.land/install.sh | sh
73+
export PATH="$HOME/.deno/bin:$PATH"
74+
75+
# Verify installations
76+
uvx --version || echo "Warning: uvx not installed"
77+
deno --version || echo "Warning: Deno not installed"

examples/llm/python_mcp_tool.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
"""Execute Python code using MCP server with mcp-run-python."""
7+
8+
import json
9+
import os
10+
11+
from tensordict import set_list_to_stack, TensorDict
12+
13+
from torchrl.data.llm import History
14+
from torchrl.envs.llm import ChatEnv
15+
from torchrl.envs.llm.transforms import MCPToolTransform
16+
17+
set_list_to_stack(True).set()
18+
19+
deno_path = os.path.expanduser("~/.deno/bin")
20+
if deno_path not in os.environ.get("PATH", ""):
21+
os.environ["PATH"] = f"{deno_path}:{os.environ['PATH']}"
22+
23+
servers = {
24+
"python": {
25+
"command": "uvx",
26+
"args": ["mcp-run-python", "stdio"],
27+
"env": os.environ.copy(),
28+
}
29+
}
30+
31+
env = ChatEnv(batch_size=(1,))
32+
env = env.append_transform(MCPToolTransform(servers=servers))
33+
34+
reset_data = TensorDict(query="You are a helpful assistant", batch_size=(1,))
35+
td = env.reset(reset_data)
36+
37+
history = td.get("history")
38+
39+
code = """
40+
import math
41+
result = math.sqrt(144) + math.pi
42+
print(f"Result: {result}")
43+
result
44+
"""
45+
46+
response = (
47+
History(
48+
role="assistant",
49+
content=f'Let me calculate that.\n<tool>python.run_python_code\n{json.dumps({"python_code": code})}</tool>',
50+
)
51+
.unsqueeze(0)
52+
.unsqueeze(0)
53+
)
54+
55+
history.full = history.prompt.extend(response, inplace=True, dim=-1)
56+
history.response = response
57+
58+
result = env.step(td.set("history", history))
59+
60+
print("Python code executed via MCP!")
61+
print("\nTool response:")
62+
tool_response = result["next", "history"].prompt[0, -1]
63+
print(f"Role: {tool_response.role}")
64+
print(f"Content: {tool_response.content}")
65+
66+
fibonacci_code = """
67+
def fibonacci(n):
68+
if n <= 1:
69+
return n
70+
return fibonacci(n-1) + fibonacci(n-2)
71+
72+
result = [fibonacci(i) for i in range(10)]
73+
print(f"Fibonacci sequence: {result}")
74+
result
75+
"""
76+
77+
history = result["next", "history"]
78+
response2 = (
79+
History(
80+
role="assistant",
81+
content=f'Now calculating Fibonacci.\n<tool>python.run_python_code\n{json.dumps({"python_code": fibonacci_code})}</tool>',
82+
)
83+
.unsqueeze(0)
84+
.unsqueeze(0)
85+
)
86+
87+
history.full = history.prompt.extend(response2, inplace=True, dim=-1)
88+
history.response = response2
89+
90+
result2 = env.step(result["next"].set("history", history))
91+
92+
print("\n\nSecond execution:")
93+
print("\nTool response:")
94+
tool_response2 = result2["next", "history"].prompt[0, -1]
95+
print(f"Role: {tool_response2.role}")
96+
print(f"Content: {tool_response2.content[:500]}...")

0 commit comments

Comments
 (0)