Skip to content

Commit 8e51f12

Browse files
Merge pull request #283 from thwu1/nightly
Tinker Implementation
2 parents 0360a25 + 75a20e2 commit 8e51f12

18 files changed

+2874
-3
lines changed
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
"""
2+
MathAgent with few-shot prompting support to match tinker-cookbook math_rl.
3+
4+
This agent variant includes:
5+
1. Few-shot prefix with a standard example (strawberry)
6+
2. Instruction text matching math_rl: " Write your answer in \\boxed{} format."
7+
"""
8+
9+
import copy
10+
from typing import Any
11+
12+
from rllm.agents.agent import Action, BaseAgent, Step, Trajectory
13+
14+
15+
class MathAgentWithFewshot(BaseAgent):
16+
"""
17+
A math agent with few-shot prompting that matches tinker-cookbook math_rl behavior.
18+
"""
19+
20+
# Standard few-shot example from tinker-cookbook math_rl
21+
STANDARD_FEWSHOT_PREFIX = [
22+
{
23+
"role": "user",
24+
"content": "How many r's are in strawberry? Provide a numerical answer without units, written inside \\boxed{}.",
25+
},
26+
{
27+
"role": "assistant",
28+
"content": "Let's spell the word out and number all the letters: 1) s 2) t 3) r 4) a 5) w 6) b 7) e 8) r 9) r 10) y. We have r's at positions 3, 8, and 9. \\boxed{3}",
29+
},
30+
]
31+
32+
def __init__(self, accumulate_thinking=True, use_fewshot=True):
33+
"""
34+
Initialize the MathAgent with few-shot support.
35+
36+
Args:
37+
accumulate_thinking: Whether to accumulate thinking in conversation history
38+
use_fewshot: Whether to use few-shot prompting
39+
"""
40+
self._trajectory = Trajectory()
41+
self.messages = []
42+
self.accumulate_thinking = accumulate_thinking
43+
self.use_fewshot = use_fewshot
44+
45+
# Add few-shot prefix if enabled
46+
if self.use_fewshot:
47+
self.messages.extend(copy.deepcopy(self.STANDARD_FEWSHOT_PREFIX))
48+
49+
def update_from_env(self, observation: Any, reward: float, done: bool, info: dict, **kwargs):
50+
"""Process environment feedback and update internal state."""
51+
52+
# Reward update for existing step (None OR empty dict)
53+
if observation is None or (isinstance(observation, dict) and observation == {}):
54+
if self.trajectory.steps:
55+
cur_step = self.get_current_state()
56+
cur_step.reward = reward
57+
cur_step.done = done
58+
cur_step.info = info
59+
return
60+
61+
# Update reward/done/info on existing step if we have steps already
62+
if self.trajectory.steps:
63+
cur_step = self.get_current_state()
64+
cur_step.reward = reward
65+
cur_step.done = done
66+
cur_step.info.update(info)
67+
68+
if done:
69+
return
70+
71+
# This is a new observation, create a new step
72+
if isinstance(observation, dict):
73+
if "question" not in observation:
74+
raise ValueError(f"Observation dict missing required 'question' field: {observation}")
75+
# Match math_rl instruction text exactly
76+
formatted_observation = observation["question"] + " Provide a numerical answer without units, written inside \\boxed{}."
77+
elif isinstance(observation, str):
78+
formatted_observation = observation + " Provide a numerical answer without units, written inside \\boxed{}."
79+
else:
80+
raise ValueError(f"Invalid observation type: {type(observation)}")
81+
82+
self.messages.append({"role": "user", "content": formatted_observation})
83+
84+
new_step = Step(observation=formatted_observation)
85+
self._trajectory.steps.append(new_step)
86+
87+
def update_from_model(self, response: str, **kwargs) -> Action:
88+
"""
89+
Updates the agent's internal state based on the model's response.
90+
"""
91+
92+
# Update the latest step
93+
self.messages.append({"role": "assistant", "content": response})
94+
95+
cur_step = self.get_current_state()
96+
cur_step.chat_completions = self.chat_completions
97+
cur_step.model_response = response
98+
99+
if response.count("</think>") == 1:
100+
thought, sep, action = response.partition("</think>")
101+
thought = thought + sep
102+
action = Action(action.strip())
103+
else:
104+
thought = None
105+
action = Action(response.strip())
106+
107+
cur_step.thought = thought
108+
cur_step.action = action
109+
110+
# TODO: remove this temporary fix
111+
return Action(response.strip())
112+
113+
def reset(self) -> None:
114+
"""Reset agent state for new episode (wipes trajectory but keeps few-shot prefix)."""
115+
self._trajectory = Trajectory()
116+
self.messages = []
117+
118+
# Re-add few-shot prefix after reset
119+
if self.use_fewshot:
120+
self.messages.extend(copy.deepcopy(self.STANDARD_FEWSHOT_PREFIX))
121+
122+
@property
123+
def chat_completions(self) -> list[dict[str, str]]:
124+
"""Return conversation history for model interaction."""
125+
# remove thinking from assistant messages if not accumulate_thinking except the last one
126+
messages = copy.deepcopy(self.messages)
127+
if not self.accumulate_thinking:
128+
for msg in messages[:-1]:
129+
if msg["role"] == "assistant":
130+
_, sep, after = msg["content"].partition("</think>")
131+
if sep:
132+
msg["content"] = after
133+
return messages
134+
135+
@property
136+
def trajectory(self) -> Trajectory:
137+
"""Return complete interaction trajectory."""
138+
return self._trajectory
139+
140+
def get_current_state(self) -> Step:
141+
"""Returns the current step/state of the agent."""
142+
assert self._trajectory.steps, "Trajectory should not be empty when get_current_state is called."
143+
return self._trajectory.steps[-1]

0 commit comments

Comments
 (0)