Skip to content

Commit b153d3e

Browse files
authored
Added is_correct & reward flow through tool env (#277)
* Added is_correct & reward flow through tool env * Display rewards for all trajectories in episode, not just the first
1 parent e2b9240 commit b153d3e

File tree

3 files changed

+9
-2
lines changed

3 files changed

+9
-2
lines changed

rllm/agents/tool_agent.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,11 @@ def update_from_env(self, observation: Any, reward: float, done: bool, info: dic
9494
self.messages.extend(obs_messages)
9595
self.current_observation = observation
9696

97+
if self._trajectory.steps:
98+
self._trajectory.steps[-1].reward = reward
99+
self._trajectory.steps[-1].done = done
100+
self._trajectory.steps[-1].info = info
101+
97102
def update_from_model(self, response: str, **kwargs) -> Action:
98103
"""
99104
Updates the agent's state based on the model's response.

rllm/engine/agent_workflow_engine.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,9 @@ async def process_task_with_retry(self, task: dict, task_id: str, rollout_idx: i
8585
uid = f"{task_id}:{rollout_idx}"
8686
episode = await workflow.run_with_termination_handling(task=task, uid=uid, **kwargs)
8787

88-
colorful_print(f"[{uid}] Rollout completed with termination reason: {episode.termination_reason}", fg="green" if episode.is_correct else "yellow")
88+
# Display rewards for all trajectories
89+
rewards_str = ", ".join([f"{traj.name}: {traj.reward:.1f}" for traj in episode.trajectories])
90+
colorful_print(f"[{uid}] Rollout completed. Rewards: {rewards_str}, Termination: {episode.termination_reason}", fg="green" if episode.is_correct else "yellow")
8991

9092
if episode.termination_reason != TerminationReason.ERROR:
9193
return task_id, rollout_idx, episode

rllm/environments/tools/tool_env.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def step(self, action: list[dict] | str | dict):
9898

9999
task_info = self.task if self.task is not None else {}
100100
reward_output = self.reward_fn(task_info=task_info, action=llm_response)
101-
return {}, reward_output.reward, done, {"response": action, "metadata": reward_output.metadata}
101+
return {}, reward_output.reward, done, {"response": action, "metadata": reward_output.metadata, "is_correct": reward_output.is_correct}
102102

103103
tool_calls = action
104104
assert isinstance(tool_calls, list)

0 commit comments

Comments
 (0)