Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 38 additions & 2 deletions rllm/engine/agent_workflow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@


class AgentWorkflowEngine:
def __init__(self, workflow_cls: type[Workflow], workflow_args: dict, rollout_engine: RolloutEngine, config=None, n_parallel_tasks: int = 128, retry_limit: int = 3, raise_on_error: bool = True, **kwargs):
def __init__(self, workflow_cls: type[Workflow], workflow_args: dict, rollout_engine: RolloutEngine, config=None, n_parallel_tasks: int = 128, retry_limit: int = 3, raise_on_error: bool = True, episode_logger=None, **kwargs):
"""Initialize the AgentWorkflowEngine.

Args:
Expand All @@ -33,6 +33,7 @@ def __init__(self, workflow_cls: type[Workflow], workflow_args: dict, rollout_en
n_parallel_tasks: Number of parallel workflow instances to maintain.
retry_limit: Maximum number of retry attempts for failed tasks.
raise_on_error: Whether to raise exceptions on permanent failures.
episode_logger: Optional logger for saving episode data to files.
**kwargs: Additional keyword arguments.
"""
self.workflow_cls = workflow_cls
Expand All @@ -49,6 +50,24 @@ def __init__(self, workflow_cls: type[Workflow], workflow_args: dict, rollout_en
self.executor = ThreadPoolExecutor(max_workers=self.n_parallel_tasks)
self.workflow_queue = None

# Episode logging support
self.episode_logger = episode_logger
self.current_step = 0
self.current_epoch = 0
self.current_mode = "train" # "train" or "val"

def set_training_step(self, step: int, mode: str = "train", epoch: int = 0):
"""Set current training step for episode logging.

Args:
step: Current training step number
mode: Mode identifier ('train' or 'val'), defaults to 'train'
epoch: Current epoch number, defaults to 0
"""
self.current_step = step
self.current_mode = mode
self.current_epoch = epoch

async def initialize_pool(self):
"""Initialize the workflow pool with parallel workflow instances.

Expand Down Expand Up @@ -154,6 +173,18 @@ async def execute_tasks(self, tasks: list[dict], task_ids: list[str] | None = No
sorted_tasks = sorted(task_states.keys(), key=lambda task_id: task_states[task_id]["idx"])
for task_id in sorted_tasks:
results.extend(task_states[task_id]["episodes"])

# Log episodes if logger is provided
if self.episode_logger is not None:
try:
logger.info(f"Logging {len(results)} episodes to step={self.current_step}, mode={self.current_mode}, epoch={self.current_epoch}")
self.episode_logger.log_episodes_batch(results, self.current_step, self.current_mode, self.current_epoch)
except Exception as e:
logger.error(f"Failed to log episodes: {e}")
import traceback

traceback.print_exc()

return results

async def execute_tasks_verl(self, batch: "DataProto", **kwargs) -> "DataProto":
Expand All @@ -167,12 +198,17 @@ async def execute_tasks_verl(self, batch: "DataProto", **kwargs) -> "DataProto":
DataProto: Transformed results compatible with Verl training.
"""
self.rollout_engine.wake_up()
if batch.meta_info.get("validate", False):
is_validation = batch.meta_info.get("validate", False)
if is_validation:
self.rollout_engine.validate = True
self.current_mode = "val"
else:
self.current_mode = "train"
tasks = batch.non_tensor_batch["extra_info"].tolist()
task_ids = batch.non_tensor_batch["task_ids"].tolist()
results = await self.execute_tasks(tasks, task_ids, **kwargs) # list of Episodes
self.rollout_engine.validate = False
self.current_mode = "train"
self.rollout_engine.sleep()
return self.transform_results_for_verl(results, task_ids)

Expand Down
2 changes: 2 additions & 0 deletions rllm/trainer/config/_generated_agent_ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,8 @@ trainer:
val_before_train: true
val_only: false
test_freq: -1
log_episodes: false
episode_log_dir: logs/${trainer.project_name}/${trainer.experiment_name}
critic_warmup: 0
default_hdfs_dir: null
del_local_ckpt_after_load: false
Expand Down
6 changes: 5 additions & 1 deletion rllm/trainer/config/agent_ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,8 @@ rllm:
fireworks:
deployment_id: null
model_id_prefix: test-model
concurrency: 32
concurrency: 32

trainer:
log_episodes: false
episode_log_dir: logs/${trainer.project_name}/${trainer.experiment_name}
6 changes: 5 additions & 1 deletion rllm/trainer/config/agent_ppo_trainer_megatron.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,8 @@ rllm:
mask_timeout: True
rejection_sample:
enable: False
multiplier: 1
multiplier: 1

trainer:
log_episodes: false
episode_log_dir: logs/${trainer.project_name}/${trainer.experiment_name}
6 changes: 5 additions & 1 deletion rllm/trainer/config/agent_sft_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,8 @@ defaults:

data:
rllm:
tokenize_and_mask_method: cumulative
tokenize_and_mask_method: cumulative

trainer:
log_episodes: false
episode_log_dir: logs/${trainer.project_name}/${trainer.experiment_name}
15 changes: 15 additions & 0 deletions rllm/trainer/verl/agent_workflow_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from rllm.engine.agent_workflow_engine import AgentWorkflowEngine
from rllm.engine.rollout.verl_engine import VerlEngine
from rllm.utils.episode_logger import EpisodeLogger
from rllm.workflows.workflow import TerminationReason
from verl import DataProto
from verl.protocol import pad_dataproto_to_divisor
Expand Down Expand Up @@ -76,13 +77,21 @@ def init_workers(self):
tokenizer=self.tokenizer,
)

# Create episode logger if enabled in config
episode_logger = None
if self.config.trainer.get("log_episodes", False):
# Get episode log directory from config, default to "logs/my_project/my_experiment"
episode_log_dir = self.config.trainer.get("episode_log_dir", f"logs/{self.config.trainer.project_name}/{self.config.trainer.experiment_name}")
episode_logger = EpisodeLogger(base_dir=episode_log_dir, subdirectory="episodes")

self.agent_execution_engine = AgentWorkflowEngine(
workflow_cls=self.workflow_class,
workflow_args=self.workflow_args,
rollout_engine=rollout_engine,
config=self.config,
n_parallel_tasks=self.config.rllm.workflow.n_parallel_tasks,
retry_limit=self.config.rllm.workflow.retry_limit,
episode_logger=episode_logger,
)

# init workflow workers
Expand Down Expand Up @@ -111,6 +120,7 @@ def fit_agent(self):

start_time = time.time()
if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
self.agent_execution_engine.set_training_step(self.global_steps, mode="val", epoch=0)
val_metrics = self._validate_agent()
pprint(f"Initial validation metrics: {val_metrics}")
logger.log(data=val_metrics, step=self.global_steps)
Expand Down Expand Up @@ -145,6 +155,9 @@ def fit_agent(self):

new_batch.pop(batch_keys=["input_ids", "attention_mask", "position_ids"], non_tensor_batch_keys=["raw_prompt_ids"])

# Update training step in engine for episode logging
self.agent_execution_engine.set_training_step(self.global_steps, mode="train", epoch=epoch)

with marked_timer("step", timing_raw):
# generate trajectories
final_gen_batch_output = self.generate_trajectories(batch=new_batch, timing_raw=timing_raw)
Expand Down Expand Up @@ -391,6 +404,7 @@ def fit_agent(self):
# validate
if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and self.global_steps % self.config.trainer.test_freq == 0:
with marked_timer("testing", timing_raw, color="green"):
self.agent_execution_engine.set_training_step(self.global_steps, mode="val", epoch=epoch)
val_metrics: dict = self._validate_agent()
metrics.update(val_metrics)

Expand Down Expand Up @@ -455,6 +469,7 @@ def fit_agent(self):
if self.global_steps >= self.total_training_steps:
# perform validation after training
if self.val_reward_fn is not None:
self.agent_execution_engine.set_training_step(self.global_steps, mode="val", epoch=epoch)
val_metrics = self._validate_agent()
pprint(f"Final validation metrics: {val_metrics}")
logger.log(data=val_metrics, step=self.global_steps)
Expand Down
5 changes: 5 additions & 0 deletions rllm/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Utilities for the rllm package."""

from rllm.utils.episode_logger import EpisodeLogger

__all__ = ["EpisodeLogger"]
180 changes: 180 additions & 0 deletions rllm/utils/episode_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
"""Episode JSON Logger for saving detailed episode information."""

import hashlib
import json
from pathlib import Path
from typing import Any

from rllm.agents.agent import Episode


class EpisodeLogger:
"""Logger to save episodes to individual JSON files with step and data hash."""

def __init__(self, base_dir: str, subdirectory: str = "episodes"):
"""Initialize the episode logger.

Args:
base_dir: Base directory for episode logs. Can be configured via
config.trainer.episode_log_dir
(default: "logs/${trainer.project_name}/${trainer.experiment_name}")
subdirectory: Subdirectory within base_dir for episodes (default: "episodes")
Final path will be: {base_dir}/{subdirectory}/
"""
self.log_dir = Path(base_dir) / subdirectory
self.log_dir.mkdir(parents=True, exist_ok=True)

@staticmethod
def compute_task_hash(task: Any, length: int = 8) -> str:
"""Compute a hash from the task data.

Args:
task: The task dictionary or data
length: Length of the hash to use (default 8 chars)

Returns:
Hash string
"""
# Convert task to a stable string representation
task_str = json.dumps(task, sort_keys=True, default=str)
# Compute SHA256 hash
hash_obj = hashlib.sha256(task_str.encode("utf-8"))
# Return first `length` characters of hex digest
return hash_obj.hexdigest()[:length]

def get_step_dir(self, step: int, mode: str = "train", epoch: int = 0) -> Path:
"""Get the directory path for a specific training or validation step.

Args:
step: Current training/validation step
mode: Mode identifier ('train' or 'val'), defaults to 'train'
epoch: Current epoch number, defaults to 0

Returns:
Path object for the step directory
"""
step_dir = self.log_dir / f"{mode}_step_{step}_epoch_{epoch}"
step_dir.mkdir(parents=True, exist_ok=True)
return step_dir

def get_episode_filename(self, episode: Episode, step: int) -> str:
"""Generate filename for an episode.

Format: episode_hash{task_hash}_id{episode_id}.json

Args:
episode: The episode to save
step: Current training step (not used in filename, but kept for compatibility)

Returns:
Filename string
"""
task_hash = self.compute_task_hash(episode.task)
# Clean episode_id to make it filesystem-safe
episode_id_safe = str(episode.id).replace(":", "_").replace("/", "_")

filename = f"episode_hash{task_hash}_id{episode_id_safe}.json"
return filename

def log_episode(self, episode: Episode, step: int, mode: str = "train", epoch: int = 0):
"""Log a single episode to its own JSON file in a step-specific directory.

Args:
episode: The episode to log
step: Current training/validation step
mode: Mode identifier ('train' or 'val'), defaults to 'train'
epoch: Current epoch number, defaults to 0
"""
episode_data = {"training_step": step, "epoch": epoch, "episode_id": episode.id, "task": episode.task, "task_hash": self.compute_task_hash(episode.task), "is_correct": episode.is_correct, "termination_reason": episode.termination_reason.value if episode.termination_reason else None, "metrics": episode.metrics, "timing": episode.info.get("timing", {}), "trajectories": []}

for traj in episode.trajectories:
traj_data = {
"name": traj.name,
"uid": traj.uid,
"reward": traj.reward,
"num_steps": len(traj.steps),
"timing": traj.info.get("timing", {}),
"steps": [
{
"observation": step.observation,
"thought": step.thought,
"action": step.action,
"reward": step.reward,
"done": step.done,
"model_response": step.model_response,
"chat_completions": step.chat_completions,
"timing": step.info.get("timing", {}), # Add step-level timing
}
for step in traj.steps
],
}
episode_data["trajectories"].append(traj_data)

# Write to individual file in step-specific directory
step_dir = self.get_step_dir(step, mode, epoch)
filename = self.get_episode_filename(episode, step)
filepath = step_dir / filename

try:
with open(filepath, "w") as f:
json_str = json.dumps(episode_data, indent=2, default=str)
f.write(json_str + "\n")
f.flush() # Ensure data is written to disk
except Exception as e:
print(f"Error writing episode to {filepath}: {e}")
import traceback

traceback.print_exc()
raise

def log_episodes(self, episodes: list[Episode], step: int, mode: str = "train", epoch: int = 0):
"""Log multiple episodes, each to its own file.

Args:
episodes: List of episodes to log
step: Current training/validation step
mode: Mode identifier ('train' or 'val'), defaults to 'train'
epoch: Current epoch number, defaults to 0
"""
print(f"[EpisodeLogger] Logging {len(episodes)} episodes for step={step}, mode={mode}, epoch={epoch}")
for i, episode in enumerate(episodes):
try:
self.log_episode(episode, step, mode, epoch)
print(f"[EpisodeLogger] Successfully logged episode {i + 1}/{len(episodes)}: {episode.id}")
except Exception as e:
print(f"[EpisodeLogger] Failed to log episode {i + 1}/{len(episodes)}: {e}")
raise

def log_episodes_batch(self, episodes: list[Episode], step: int, mode: str = "train", epoch: int = 0, batch_summary: bool = True):
"""Log multiple episodes and optionally create a batch summary in step-specific directory.

Args:
episodes: List of episodes to log
step: Current training/validation step
mode: Mode identifier ('train' or 'val'), defaults to 'train'
epoch: Current epoch number, defaults to 0
batch_summary: Whether to create a summary file for the batch
"""
# Log individual episodes
self.log_episodes(episodes, step, mode, epoch)

# Optionally create batch summary in step-specific directory
if batch_summary and episodes:
summary_data = {
"training_step": step,
"epoch": epoch,
"mode": mode,
"num_episodes": len(episodes),
"episode_files": [self.get_episode_filename(ep, step) for ep in episodes],
"summary_stats": {
"total_correct": sum(1 for ep in episodes if ep.is_correct),
"total_incorrect": sum(1 for ep in episodes if not ep.is_correct),
"accuracy": sum(1 for ep in episodes if ep.is_correct) / len(episodes) if episodes else 0,
"avg_trajectories_per_episode": sum(len(ep.trajectories) for ep in episodes) / len(episodes) if episodes else 0,
},
}

step_dir = self.get_step_dir(step, mode, epoch)
summary_file = step_dir / "batch_summary.json"
with open(summary_file, "w") as f:
json.dump(summary_data, f, indent=2)
10 changes: 10 additions & 0 deletions rllm/workflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
"TerminationEvent",
"SingleTurnWorkflow",
"MultiTurnWorkflow",
"CumulativeWorkflow",
"TimingTrackingMixin",
]


Expand All @@ -23,4 +25,12 @@ def __getattr__(name):
from .multi_turn_workflow import MultiTurnWorkflow as _Multi

return _Multi
if name == "CumulativeWorkflow":
from .cumulative_workflow import CumulativeWorkflow as _Cumulative

return _Cumulative
if name == "TimingTrackingMixin":
from .timing_mixin import TimingTrackingMixin as _Mixin

return _Mixin
raise AttributeError(name)
Loading