|
| 1 | +""" |
| 2 | +MathAgent with few-shot prompting support to match tinker-cookbook math_rl. |
| 3 | +
|
| 4 | +This agent variant includes: |
| 5 | +1. Few-shot prefix with a standard example (strawberry) |
| 6 | +2. Instruction text matching math_rl: " Write your answer in \\boxed{} format." |
| 7 | +""" |
| 8 | + |
| 9 | +import copy |
| 10 | +from typing import Any |
| 11 | + |
| 12 | +from rllm.agents.agent import Action, BaseAgent, Step, Trajectory |
| 13 | + |
| 14 | + |
| 15 | +class MathAgentWithFewshot(BaseAgent): |
| 16 | + """ |
| 17 | + A math agent with few-shot prompting that matches tinker-cookbook math_rl behavior. |
| 18 | + """ |
| 19 | + |
| 20 | + # Standard few-shot example from tinker-cookbook math_rl |
| 21 | + STANDARD_FEWSHOT_PREFIX = [ |
| 22 | + { |
| 23 | + "role": "user", |
| 24 | + "content": "How many r's are in strawberry? Provide a numerical answer without units, written inside \\boxed{}.", |
| 25 | + }, |
| 26 | + { |
| 27 | + "role": "assistant", |
| 28 | + "content": "Let's spell the word out and number all the letters: 1) s 2) t 3) r 4) a 5) w 6) b 7) e 8) r 9) r 10) y. We have r's at positions 3, 8, and 9. \\boxed{3}", |
| 29 | + }, |
| 30 | + ] |
| 31 | + |
| 32 | + def __init__(self, accumulate_thinking=True, use_fewshot=True): |
| 33 | + """ |
| 34 | + Initialize the MathAgent with few-shot support. |
| 35 | +
|
| 36 | + Args: |
| 37 | + accumulate_thinking: Whether to accumulate thinking in conversation history |
| 38 | + use_fewshot: Whether to use few-shot prompting |
| 39 | + """ |
| 40 | + self._trajectory = Trajectory() |
| 41 | + self.messages = [] |
| 42 | + self.accumulate_thinking = accumulate_thinking |
| 43 | + self.use_fewshot = use_fewshot |
| 44 | + |
| 45 | + # Add few-shot prefix if enabled |
| 46 | + if self.use_fewshot: |
| 47 | + self.messages.extend(copy.deepcopy(self.STANDARD_FEWSHOT_PREFIX)) |
| 48 | + |
| 49 | + def update_from_env(self, observation: Any, reward: float, done: bool, info: dict, **kwargs): |
| 50 | + """Process environment feedback and update internal state.""" |
| 51 | + |
| 52 | + # Reward update for existing step (None OR empty dict) |
| 53 | + if observation is None or (isinstance(observation, dict) and observation == {}): |
| 54 | + if self.trajectory.steps: |
| 55 | + cur_step = self.get_current_state() |
| 56 | + cur_step.reward = reward |
| 57 | + cur_step.done = done |
| 58 | + cur_step.info = info |
| 59 | + return |
| 60 | + |
| 61 | + # Update reward/done/info on existing step if we have steps already |
| 62 | + if self.trajectory.steps: |
| 63 | + cur_step = self.get_current_state() |
| 64 | + cur_step.reward = reward |
| 65 | + cur_step.done = done |
| 66 | + cur_step.info.update(info) |
| 67 | + |
| 68 | + if done: |
| 69 | + return |
| 70 | + |
| 71 | + # This is a new observation, create a new step |
| 72 | + if isinstance(observation, dict): |
| 73 | + if "question" not in observation: |
| 74 | + raise ValueError(f"Observation dict missing required 'question' field: {observation}") |
| 75 | + # Match math_rl instruction text exactly |
| 76 | + formatted_observation = observation["question"] + " Provide a numerical answer without units, written inside \\boxed{}." |
| 77 | + elif isinstance(observation, str): |
| 78 | + formatted_observation = observation + " Provide a numerical answer without units, written inside \\boxed{}." |
| 79 | + else: |
| 80 | + raise ValueError(f"Invalid observation type: {type(observation)}") |
| 81 | + |
| 82 | + self.messages.append({"role": "user", "content": formatted_observation}) |
| 83 | + |
| 84 | + new_step = Step(observation=formatted_observation) |
| 85 | + self._trajectory.steps.append(new_step) |
| 86 | + |
| 87 | + def update_from_model(self, response: str, **kwargs) -> Action: |
| 88 | + """ |
| 89 | + Updates the agent's internal state based on the model's response. |
| 90 | + """ |
| 91 | + |
| 92 | + # Update the latest step |
| 93 | + self.messages.append({"role": "assistant", "content": response}) |
| 94 | + |
| 95 | + cur_step = self.get_current_state() |
| 96 | + cur_step.chat_completions = self.chat_completions |
| 97 | + cur_step.model_response = response |
| 98 | + |
| 99 | + if response.count("</think>") == 1: |
| 100 | + thought, sep, action = response.partition("</think>") |
| 101 | + thought = thought + sep |
| 102 | + action = Action(action.strip()) |
| 103 | + else: |
| 104 | + thought = None |
| 105 | + action = Action(response.strip()) |
| 106 | + |
| 107 | + cur_step.thought = thought |
| 108 | + cur_step.action = action |
| 109 | + |
| 110 | + # TODO: remove this temporary fix |
| 111 | + return Action(response.strip()) |
| 112 | + |
| 113 | + def reset(self) -> None: |
| 114 | + """Reset agent state for new episode (wipes trajectory but keeps few-shot prefix).""" |
| 115 | + self._trajectory = Trajectory() |
| 116 | + self.messages = [] |
| 117 | + |
| 118 | + # Re-add few-shot prefix after reset |
| 119 | + if self.use_fewshot: |
| 120 | + self.messages.extend(copy.deepcopy(self.STANDARD_FEWSHOT_PREFIX)) |
| 121 | + |
| 122 | + @property |
| 123 | + def chat_completions(self) -> list[dict[str, str]]: |
| 124 | + """Return conversation history for model interaction.""" |
| 125 | + # remove thinking from assistant messages if not accumulate_thinking except the last one |
| 126 | + messages = copy.deepcopy(self.messages) |
| 127 | + if not self.accumulate_thinking: |
| 128 | + for msg in messages[:-1]: |
| 129 | + if msg["role"] == "assistant": |
| 130 | + _, sep, after = msg["content"].partition("</think>") |
| 131 | + if sep: |
| 132 | + msg["content"] = after |
| 133 | + return messages |
| 134 | + |
| 135 | + @property |
| 136 | + def trajectory(self) -> Trajectory: |
| 137 | + """Return complete interaction trajectory.""" |
| 138 | + return self._trajectory |
| 139 | + |
| 140 | + def get_current_state(self) -> Step: |
| 141 | + """Returns the current step/state of the agent.""" |
| 142 | + assert self._trajectory.steps, "Trajectory should not be empty when get_current_state is called." |
| 143 | + return self._trajectory.steps[-1] |
0 commit comments