diff --git a/rllm/engine/agent_workflow_engine.py b/rllm/engine/agent_workflow_engine.py index 516297420..61b53229e 100644 --- a/rllm/engine/agent_workflow_engine.py +++ b/rllm/engine/agent_workflow_engine.py @@ -22,7 +22,7 @@ class AgentWorkflowEngine: - def __init__(self, workflow_cls: type[Workflow], workflow_args: dict, rollout_engine: RolloutEngine, config=None, n_parallel_tasks: int = 128, retry_limit: int = 3, raise_on_error: bool = True, **kwargs): + def __init__(self, workflow_cls: type[Workflow], workflow_args: dict, rollout_engine: RolloutEngine, config=None, n_parallel_tasks: int = 128, retry_limit: int = 3, raise_on_error: bool = True, episode_logger=None, **kwargs): """Initialize the AgentWorkflowEngine. Args: @@ -33,6 +33,7 @@ def __init__(self, workflow_cls: type[Workflow], workflow_args: dict, rollout_en n_parallel_tasks: Number of parallel workflow instances to maintain. retry_limit: Maximum number of retry attempts for failed tasks. raise_on_error: Whether to raise exceptions on permanent failures. + episode_logger: Optional logger for saving episode data to files. **kwargs: Additional keyword arguments. """ self.workflow_cls = workflow_cls @@ -49,6 +50,24 @@ def __init__(self, workflow_cls: type[Workflow], workflow_args: dict, rollout_en self.executor = ThreadPoolExecutor(max_workers=self.n_parallel_tasks) self.workflow_queue = None + # Episode logging support + self.episode_logger = episode_logger + self.current_step = 0 + self.current_epoch = 0 + self.current_mode = "train" # "train" or "val" + + def set_training_step(self, step: int, mode: str = "train", epoch: int = 0): + """Set current training step for episode logging. + + Args: + step: Current training step number + mode: Mode identifier ('train' or 'val'), defaults to 'train' + epoch: Current epoch number, defaults to 0 + """ + self.current_step = step + self.current_mode = mode + self.current_epoch = epoch + async def initialize_pool(self): """Initialize the workflow pool with parallel workflow instances. @@ -154,6 +173,18 @@ async def execute_tasks(self, tasks: list[dict], task_ids: list[str] | None = No sorted_tasks = sorted(task_states.keys(), key=lambda task_id: task_states[task_id]["idx"]) for task_id in sorted_tasks: results.extend(task_states[task_id]["episodes"]) + + # Log episodes if logger is provided + if self.episode_logger is not None: + try: + logger.info(f"Logging {len(results)} episodes to step={self.current_step}, mode={self.current_mode}, epoch={self.current_epoch}") + self.episode_logger.log_episodes_batch(results, self.current_step, self.current_mode, self.current_epoch) + except Exception as e: + logger.error(f"Failed to log episodes: {e}") + import traceback + + traceback.print_exc() + return results async def execute_tasks_verl(self, batch: "DataProto", **kwargs) -> "DataProto": @@ -167,12 +198,17 @@ async def execute_tasks_verl(self, batch: "DataProto", **kwargs) -> "DataProto": DataProto: Transformed results compatible with Verl training. """ self.rollout_engine.wake_up() - if batch.meta_info.get("validate", False): + is_validation = batch.meta_info.get("validate", False) + if is_validation: self.rollout_engine.validate = True + self.current_mode = "val" + else: + self.current_mode = "train" tasks = batch.non_tensor_batch["extra_info"].tolist() task_ids = batch.non_tensor_batch["task_ids"].tolist() results = await self.execute_tasks(tasks, task_ids, **kwargs) # list of Episodes self.rollout_engine.validate = False + self.current_mode = "train" self.rollout_engine.sleep() return self.transform_results_for_verl(results, task_ids) diff --git a/rllm/trainer/config/_generated_agent_ppo_trainer.yaml b/rllm/trainer/config/_generated_agent_ppo_trainer.yaml index 3853fa371..b5ae5f035 100644 --- a/rllm/trainer/config/_generated_agent_ppo_trainer.yaml +++ b/rllm/trainer/config/_generated_agent_ppo_trainer.yaml @@ -201,6 +201,8 @@ trainer: val_before_train: true val_only: false test_freq: -1 + log_episodes: false + episode_log_dir: logs/${trainer.project_name}/${trainer.experiment_name} critic_warmup: 0 default_hdfs_dir: null del_local_ckpt_after_load: false diff --git a/rllm/trainer/config/agent_ppo_trainer.yaml b/rllm/trainer/config/agent_ppo_trainer.yaml index a73a767b6..47f40aee6 100644 --- a/rllm/trainer/config/agent_ppo_trainer.yaml +++ b/rllm/trainer/config/agent_ppo_trainer.yaml @@ -64,4 +64,8 @@ rllm: fireworks: deployment_id: null model_id_prefix: test-model - concurrency: 32 \ No newline at end of file + concurrency: 32 + +trainer: + log_episodes: false + episode_log_dir: logs/${trainer.project_name}/${trainer.experiment_name} \ No newline at end of file diff --git a/rllm/trainer/config/agent_ppo_trainer_megatron.yaml b/rllm/trainer/config/agent_ppo_trainer_megatron.yaml index 714ff4afd..0a6878515 100644 --- a/rllm/trainer/config/agent_ppo_trainer_megatron.yaml +++ b/rllm/trainer/config/agent_ppo_trainer_megatron.yaml @@ -58,4 +58,8 @@ rllm: mask_timeout: True rejection_sample: enable: False - multiplier: 1 \ No newline at end of file + multiplier: 1 + +trainer: + log_episodes: false + episode_log_dir: logs/${trainer.project_name}/${trainer.experiment_name} \ No newline at end of file diff --git a/rllm/trainer/config/agent_sft_trainer.yaml b/rllm/trainer/config/agent_sft_trainer.yaml index 1310f5182..b96b1525e 100644 --- a/rllm/trainer/config/agent_sft_trainer.yaml +++ b/rllm/trainer/config/agent_sft_trainer.yaml @@ -9,4 +9,8 @@ defaults: data: rllm: - tokenize_and_mask_method: cumulative \ No newline at end of file + tokenize_and_mask_method: cumulative + +trainer: + log_episodes: false + episode_log_dir: logs/${trainer.project_name}/${trainer.experiment_name} \ No newline at end of file diff --git a/rllm/trainer/verl/agent_workflow_trainer.py b/rllm/trainer/verl/agent_workflow_trainer.py index bcd6eecd5..cb711818a 100644 --- a/rllm/trainer/verl/agent_workflow_trainer.py +++ b/rllm/trainer/verl/agent_workflow_trainer.py @@ -12,6 +12,7 @@ from rllm.engine.agent_workflow_engine import AgentWorkflowEngine from rllm.engine.rollout.verl_engine import VerlEngine +from rllm.utils.episode_logger import EpisodeLogger from rllm.workflows.workflow import TerminationReason from verl import DataProto from verl.protocol import pad_dataproto_to_divisor @@ -76,6 +77,13 @@ def init_workers(self): tokenizer=self.tokenizer, ) + # Create episode logger if enabled in config + episode_logger = None + if self.config.trainer.get("log_episodes", False): + # Get episode log directory from config, default to "logs/my_project/my_experiment" + episode_log_dir = self.config.trainer.get("episode_log_dir", f"logs/{self.config.trainer.project_name}/{self.config.trainer.experiment_name}") + episode_logger = EpisodeLogger(base_dir=episode_log_dir, subdirectory="episodes") + self.agent_execution_engine = AgentWorkflowEngine( workflow_cls=self.workflow_class, workflow_args=self.workflow_args, @@ -83,6 +91,7 @@ def init_workers(self): config=self.config, n_parallel_tasks=self.config.rllm.workflow.n_parallel_tasks, retry_limit=self.config.rllm.workflow.retry_limit, + episode_logger=episode_logger, ) # init workflow workers @@ -111,6 +120,7 @@ def fit_agent(self): start_time = time.time() if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): + self.agent_execution_engine.set_training_step(self.global_steps, mode="val", epoch=0) val_metrics = self._validate_agent() pprint(f"Initial validation metrics: {val_metrics}") logger.log(data=val_metrics, step=self.global_steps) @@ -145,6 +155,9 @@ def fit_agent(self): new_batch.pop(batch_keys=["input_ids", "attention_mask", "position_ids"], non_tensor_batch_keys=["raw_prompt_ids"]) + # Update training step in engine for episode logging + self.agent_execution_engine.set_training_step(self.global_steps, mode="train", epoch=epoch) + with marked_timer("step", timing_raw): # generate trajectories final_gen_batch_output = self.generate_trajectories(batch=new_batch, timing_raw=timing_raw) @@ -391,6 +404,7 @@ def fit_agent(self): # validate if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and self.global_steps % self.config.trainer.test_freq == 0: with marked_timer("testing", timing_raw, color="green"): + self.agent_execution_engine.set_training_step(self.global_steps, mode="val", epoch=epoch) val_metrics: dict = self._validate_agent() metrics.update(val_metrics) @@ -455,6 +469,7 @@ def fit_agent(self): if self.global_steps >= self.total_training_steps: # perform validation after training if self.val_reward_fn is not None: + self.agent_execution_engine.set_training_step(self.global_steps, mode="val", epoch=epoch) val_metrics = self._validate_agent() pprint(f"Final validation metrics: {val_metrics}") logger.log(data=val_metrics, step=self.global_steps) diff --git a/rllm/utils/__init__.py b/rllm/utils/__init__.py new file mode 100644 index 000000000..7951a3944 --- /dev/null +++ b/rllm/utils/__init__.py @@ -0,0 +1,5 @@ +"""Utilities for the rllm package.""" + +from rllm.utils.episode_logger import EpisodeLogger + +__all__ = ["EpisodeLogger"] diff --git a/rllm/utils/episode_logger.py b/rllm/utils/episode_logger.py new file mode 100644 index 000000000..bdd665bac --- /dev/null +++ b/rllm/utils/episode_logger.py @@ -0,0 +1,180 @@ +"""Episode JSON Logger for saving detailed episode information.""" + +import hashlib +import json +from pathlib import Path +from typing import Any + +from rllm.agents.agent import Episode + + +class EpisodeLogger: + """Logger to save episodes to individual JSON files with step and data hash.""" + + def __init__(self, base_dir: str, subdirectory: str = "episodes"): + """Initialize the episode logger. + + Args: + base_dir: Base directory for episode logs. Can be configured via + config.trainer.episode_log_dir + (default: "logs/${trainer.project_name}/${trainer.experiment_name}") + subdirectory: Subdirectory within base_dir for episodes (default: "episodes") + Final path will be: {base_dir}/{subdirectory}/ + """ + self.log_dir = Path(base_dir) / subdirectory + self.log_dir.mkdir(parents=True, exist_ok=True) + + @staticmethod + def compute_task_hash(task: Any, length: int = 8) -> str: + """Compute a hash from the task data. + + Args: + task: The task dictionary or data + length: Length of the hash to use (default 8 chars) + + Returns: + Hash string + """ + # Convert task to a stable string representation + task_str = json.dumps(task, sort_keys=True, default=str) + # Compute SHA256 hash + hash_obj = hashlib.sha256(task_str.encode("utf-8")) + # Return first `length` characters of hex digest + return hash_obj.hexdigest()[:length] + + def get_step_dir(self, step: int, mode: str = "train", epoch: int = 0) -> Path: + """Get the directory path for a specific training or validation step. + + Args: + step: Current training/validation step + mode: Mode identifier ('train' or 'val'), defaults to 'train' + epoch: Current epoch number, defaults to 0 + + Returns: + Path object for the step directory + """ + step_dir = self.log_dir / f"{mode}_step_{step}_epoch_{epoch}" + step_dir.mkdir(parents=True, exist_ok=True) + return step_dir + + def get_episode_filename(self, episode: Episode, step: int) -> str: + """Generate filename for an episode. + + Format: episode_hash{task_hash}_id{episode_id}.json + + Args: + episode: The episode to save + step: Current training step (not used in filename, but kept for compatibility) + + Returns: + Filename string + """ + task_hash = self.compute_task_hash(episode.task) + # Clean episode_id to make it filesystem-safe + episode_id_safe = str(episode.id).replace(":", "_").replace("/", "_") + + filename = f"episode_hash{task_hash}_id{episode_id_safe}.json" + return filename + + def log_episode(self, episode: Episode, step: int, mode: str = "train", epoch: int = 0): + """Log a single episode to its own JSON file in a step-specific directory. + + Args: + episode: The episode to log + step: Current training/validation step + mode: Mode identifier ('train' or 'val'), defaults to 'train' + epoch: Current epoch number, defaults to 0 + """ + episode_data = {"training_step": step, "epoch": epoch, "episode_id": episode.id, "task": episode.task, "task_hash": self.compute_task_hash(episode.task), "is_correct": episode.is_correct, "termination_reason": episode.termination_reason.value if episode.termination_reason else None, "metrics": episode.metrics, "timing": episode.info.get("timing", {}), "trajectories": []} + + for traj in episode.trajectories: + traj_data = { + "name": traj.name, + "uid": traj.uid, + "reward": traj.reward, + "num_steps": len(traj.steps), + "timing": traj.info.get("timing", {}), + "steps": [ + { + "observation": step.observation, + "thought": step.thought, + "action": step.action, + "reward": step.reward, + "done": step.done, + "model_response": step.model_response, + "chat_completions": step.chat_completions, + "timing": step.info.get("timing", {}), # Add step-level timing + } + for step in traj.steps + ], + } + episode_data["trajectories"].append(traj_data) + + # Write to individual file in step-specific directory + step_dir = self.get_step_dir(step, mode, epoch) + filename = self.get_episode_filename(episode, step) + filepath = step_dir / filename + + try: + with open(filepath, "w") as f: + json_str = json.dumps(episode_data, indent=2, default=str) + f.write(json_str + "\n") + f.flush() # Ensure data is written to disk + except Exception as e: + print(f"Error writing episode to {filepath}: {e}") + import traceback + + traceback.print_exc() + raise + + def log_episodes(self, episodes: list[Episode], step: int, mode: str = "train", epoch: int = 0): + """Log multiple episodes, each to its own file. + + Args: + episodes: List of episodes to log + step: Current training/validation step + mode: Mode identifier ('train' or 'val'), defaults to 'train' + epoch: Current epoch number, defaults to 0 + """ + print(f"[EpisodeLogger] Logging {len(episodes)} episodes for step={step}, mode={mode}, epoch={epoch}") + for i, episode in enumerate(episodes): + try: + self.log_episode(episode, step, mode, epoch) + print(f"[EpisodeLogger] Successfully logged episode {i + 1}/{len(episodes)}: {episode.id}") + except Exception as e: + print(f"[EpisodeLogger] Failed to log episode {i + 1}/{len(episodes)}: {e}") + raise + + def log_episodes_batch(self, episodes: list[Episode], step: int, mode: str = "train", epoch: int = 0, batch_summary: bool = True): + """Log multiple episodes and optionally create a batch summary in step-specific directory. + + Args: + episodes: List of episodes to log + step: Current training/validation step + mode: Mode identifier ('train' or 'val'), defaults to 'train' + epoch: Current epoch number, defaults to 0 + batch_summary: Whether to create a summary file for the batch + """ + # Log individual episodes + self.log_episodes(episodes, step, mode, epoch) + + # Optionally create batch summary in step-specific directory + if batch_summary and episodes: + summary_data = { + "training_step": step, + "epoch": epoch, + "mode": mode, + "num_episodes": len(episodes), + "episode_files": [self.get_episode_filename(ep, step) for ep in episodes], + "summary_stats": { + "total_correct": sum(1 for ep in episodes if ep.is_correct), + "total_incorrect": sum(1 for ep in episodes if not ep.is_correct), + "accuracy": sum(1 for ep in episodes if ep.is_correct) / len(episodes) if episodes else 0, + "avg_trajectories_per_episode": sum(len(ep.trajectories) for ep in episodes) / len(episodes) if episodes else 0, + }, + } + + step_dir = self.get_step_dir(step, mode, epoch) + summary_file = step_dir / "batch_summary.json" + with open(summary_file, "w") as f: + json.dump(summary_data, f, indent=2) diff --git a/rllm/workflows/__init__.py b/rllm/workflows/__init__.py index 03e5fbc2d..d33e04f1d 100644 --- a/rllm/workflows/__init__.py +++ b/rllm/workflows/__init__.py @@ -11,6 +11,8 @@ "TerminationEvent", "SingleTurnWorkflow", "MultiTurnWorkflow", + "CumulativeWorkflow", + "TimingTrackingMixin", ] @@ -23,4 +25,12 @@ def __getattr__(name): from .multi_turn_workflow import MultiTurnWorkflow as _Multi return _Multi + if name == "CumulativeWorkflow": + from .cumulative_workflow import CumulativeWorkflow as _Cumulative + + return _Cumulative + if name == "TimingTrackingMixin": + from .timing_mixin import TimingTrackingMixin as _Mixin + + return _Mixin raise AttributeError(name) diff --git a/rllm/workflows/cumulative_workflow.py b/rllm/workflows/cumulative_workflow.py index c244a9933..45fa5e30b 100644 --- a/rllm/workflows/cumulative_workflow.py +++ b/rllm/workflows/cumulative_workflow.py @@ -2,10 +2,11 @@ from rllm.agents.agent import Episode from rllm.engine.rollout.rollout_engine import ModelOutput +from rllm.workflows.timing_mixin import TimingTrackingMixin from rllm.workflows.workflow import TerminationEvent, TerminationReason, Workflow -class CumulativeWorkflow(Workflow): +class CumulativeWorkflow(TimingTrackingMixin, Workflow): def __init__( self, agent_cls, @@ -35,7 +36,7 @@ def __init__( async def run(self, task: dict, uid: str, **kwargs) -> Episode | None: """Execute a multi-step workflow""" - observation, info = await self.run_in_executor(self.reset, task=task, uid=uid) + observation, info = await self.timed_env_call(self.reset, task=task, uid=uid) self.agent.update_from_env(observation, 0, False, info) @@ -48,12 +49,12 @@ async def run(self, task: dict, uid: str, **kwargs) -> Episode | None: if max_tokens <= 0: raise TerminationEvent(TerminationReason.MAX_RESPONSE_LENGTH_EXCEEDED) - output: ModelOutput = await self.rollout_engine.get_model_response(self.agent.chat_completions, application_id=uid, accumulate_reasoning=True, enforce_max_prompt_length=False, max_tokens=max_tokens, **kwargs) + output: ModelOutput = await self.timed_llm_call(self.agent.chat_completions, application_id=uid, accumulate_reasoning=True, enforce_max_prompt_length=False, max_tokens=max_tokens, **kwargs) response = output.text action = self.agent.update_from_model(response) - next_obs, reward, done, info = await self.run_in_executor(self.env.step, action) + next_obs, reward, done, info = await self.timed_env_call(self.env.step, action) self.agent.update_from_env(next_obs, reward, done, info) if output.finish_reason == "length": diff --git a/rllm/workflows/multi_turn_workflow.py b/rllm/workflows/multi_turn_workflow.py index 268722a3c..249fbda23 100644 --- a/rllm/workflows/multi_turn_workflow.py +++ b/rllm/workflows/multi_turn_workflow.py @@ -2,10 +2,11 @@ from rllm.agents.agent import Episode from rllm.engine.rollout.rollout_engine import ModelOutput +from rllm.workflows.timing_mixin import TimingTrackingMixin from rllm.workflows.workflow import TerminationEvent, TerminationReason, Workflow -class MultiTurnWorkflow(Workflow): +class MultiTurnWorkflow(TimingTrackingMixin, Workflow): def __init__( self, agent_cls, @@ -33,17 +34,17 @@ def __init__( async def run(self, task: dict, uid: str, **kwargs) -> Episode | None: """Execute a multi-step workflow""" - observation, info = await self.run_in_executor(self.reset, task=task, uid=uid) + observation, info = await self.timed_env_call(self.reset, task=task, uid=uid) self.agent.update_from_env(observation, 0, False, info) for _ in range(1, self.max_steps + 1): - output: ModelOutput = await self.rollout_engine.get_model_response(self.agent.chat_completions, application_id=uid, **kwargs) + output: ModelOutput = await self.timed_llm_call(self.agent.chat_completions, application_id=uid, **kwargs) response = output.text action = self.agent.update_from_model(response) - next_obs, reward, done, info = await self.run_in_executor(self.env.step, action) + next_obs, reward, done, info = await self.timed_env_call(self.env.step, action) self.agent.update_from_env(next_obs, reward, done, info) if output.finish_reason == "length": diff --git a/rllm/workflows/single_turn_workflow.py b/rllm/workflows/single_turn_workflow.py index 0792b0309..a951122af 100644 --- a/rllm/workflows/single_turn_workflow.py +++ b/rllm/workflows/single_turn_workflow.py @@ -2,10 +2,11 @@ from rllm.agents.agent import Episode from rllm.engine.rollout.rollout_engine import ModelOutput +from rllm.workflows.timing_mixin import TimingTrackingMixin from rllm.workflows.workflow import TerminationEvent, TerminationReason, Workflow -class SingleTurnWorkflow(Workflow): +class SingleTurnWorkflow(TimingTrackingMixin, Workflow): def __init__( self, agent_cls, @@ -31,16 +32,16 @@ def __init__( async def run(self, task: dict, uid: str, **kwargs) -> Episode | None: """Execute a single-step workflow""" - observation, info = await self.run_in_executor(self.reset, task=task, uid=uid) + observation, info = await self.timed_env_call(self.reset, task=task, uid=uid) self.agent.update_from_env(observation, 0, False, info) - output: ModelOutput = await self.rollout_engine.get_model_response(self.agent.chat_completions, application_id=uid, skip_special_tokens=True, **kwargs) + output: ModelOutput = await self.timed_llm_call(self.agent.chat_completions, application_id=uid, skip_special_tokens=True, **kwargs) response = output.text action = self.agent.update_from_model(response) - _, reward, done, info = await self.run_in_executor(self.env.step, action) + _, reward, done, info = await self.timed_env_call(self.env.step, action) self.agent.update_from_env({}, reward, done, info) if output.finish_reason == "length": diff --git a/rllm/workflows/timing_mixin.py b/rllm/workflows/timing_mixin.py new file mode 100644 index 000000000..e5d90e7f5 --- /dev/null +++ b/rllm/workflows/timing_mixin.py @@ -0,0 +1,208 @@ +"""Timing tracking mixin for workflows to measure LLM, environment, and total execution time.""" + +import time +from datetime import datetime, timezone +from typing import Any + + +class TimingTrackingMixin: + """Mixin to add timing tracking to workflows.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._timing_data = { + "llm_time": 0.0, + "env_time": 0.0, + "reward_time": 0.0, + "total_time": 0.0, + "start_time": None, + "end_time": None, + } + # Track per-step timing + self._step_timings = [] + self._current_step_timing = None + + def start_timing(self): + """Start timing the episode.""" + self._timing_data["start_time"] = time.time() # Keep for calculation + self._timing_data["start_timestamp"] = datetime.now(timezone.utc).isoformat() + self._timing_data["end_time"] = None + self._timing_data["end_timestamp"] = None + self._timing_data["llm_time"] = 0.0 + self._timing_data["env_time"] = 0.0 + self._timing_data["reward_time"] = 0.0 + self._step_timings = [] + self._current_step_timing = None + + def _start_new_step_timing(self): + """Start timing for a new step.""" + if self._current_step_timing is not None: + # Finish the previous step timing BEFORE appending + self._finish_current_step_timing() + self._step_timings.append(self._current_step_timing) + + self._current_step_timing = { + "llm_time": 0.0, + "env_time": 0.0, + "step_start_time": time.time(), # Keep for calculation + "step_start_timestamp": datetime.now(timezone.utc).isoformat(), + "step_end_time": None, + "step_end_timestamp": None, + } + + def _finish_current_step_timing(self): + """Finish timing for the current step.""" + if self._current_step_timing is not None: + self._current_step_timing["step_end_time"] = time.time() + self._current_step_timing["step_end_timestamp"] = datetime.now(timezone.utc).isoformat() + + async def timed_llm_call(self, *args, **kwargs): + """Wrapper for LLM calls with timing. + + Returns: + ModelOutput from the rollout engine + """ + # Start a new step timing when we make an LLM call + self._start_new_step_timing() + + start = time.time() + result = await self.rollout_engine.get_model_response(*args, **kwargs) + duration = time.time() - start + + self._timing_data["llm_time"] += duration + if self._current_step_timing is not None: + self._current_step_timing["llm_time"] += duration + + return result + + async def timed_env_call(self, func, *args, **kwargs): + """Wrapper for environment calls with timing. + + Args: + func: The function to call (typically env.reset or env.step) + *args: Positional arguments for the function + **kwargs: Keyword arguments for the function + + Returns: + Result from the function call + """ + start = time.time() + result = await self.run_in_executor(func, *args, **kwargs) + duration = time.time() - start + + self._timing_data["env_time"] += duration + if self._current_step_timing is not None: + self._current_step_timing["env_time"] += duration + + return result + + def add_reward_time(self, duration: float): + """Add reward computation time. + + Args: + duration: Time spent computing rewards + """ + self._timing_data["reward_time"] += duration + + def finalize_timing(self): + """Finalize timing calculations. + + Returns: + Dictionary with timing information including start_timestamp and end_timestamp as ISO 8601 strings + """ + # Finish the last step timing if it exists + if self._current_step_timing is not None: + self._finish_current_step_timing() + self._step_timings.append(self._current_step_timing) + self._current_step_timing = None + + # Set end time for the episode + if self._timing_data["start_time"] is not None: + self._timing_data["end_time"] = time.time() + self._timing_data["end_timestamp"] = datetime.now(timezone.utc).isoformat() + self._timing_data["total_time"] = self._timing_data["end_time"] - self._timing_data["start_time"] + + return { + "start_timestamp": self._timing_data["start_timestamp"], # ISO 8601 timestamp string + "end_timestamp": self._timing_data["end_timestamp"], # ISO 8601 timestamp string + "llm_time": self._timing_data["llm_time"], + "env_time": self._timing_data["env_time"], + "reward_time": self._timing_data["reward_time"], + "total_time": self._timing_data["total_time"], + } + + def postprocess_episode(self, episode, termination_reason=None, error=None): + """Override to add timing metrics to episode. + + This should be called from the subclass's postprocess_episode method. + """ + # Get timing data before calling parent + timing_metrics = self.finalize_timing() + + # Call parent's postprocess if it exists + if hasattr(super(), "postprocess_episode"): + episode = super().postprocess_episode(episode, termination_reason, error) + + # Add timing to episode info + episode.info["timing"] = timing_metrics + + # Add per-trajectory metrics + for trajectory in episode.trajectories: + trajectory.info["num_steps"] = len(trajectory.steps) + + # Calculate trajectory-level timing (start and end timestamps) + trajectory_start = None + trajectory_end = None + + if trajectory.steps: + # Get start time from first step + if self._step_timings: + trajectory_start = self._step_timings[0].get("step_start_timestamp") + trajectory_end = self._step_timings[-1].get("step_end_timestamp") + elif timing_metrics.get("start_timestamp") is not None: + # Fallback to episode timing if no step timings available + trajectory_start = timing_metrics["start_timestamp"] + trajectory_end = timing_metrics["end_timestamp"] + + # Add trajectory-level timing with start and end timestamps as ISO 8601 strings + trajectory.info["timing"] = { + "start_timestamp": trajectory_start, # ISO 8601 timestamp string + "end_timestamp": trajectory_end, # ISO 8601 timestamp string + "llm_time": timing_metrics["llm_time"], + "env_time": timing_metrics["env_time"], + "reward_time": timing_metrics["reward_time"], + "total_time": timing_metrics["total_time"], + } + + # Add per-step timing to each step (with ISO 8601 timestamps) + for i, step in enumerate(trajectory.steps): + if i < len(self._step_timings): + step_timing = self._step_timings[i] + step.info["timing"] = { + "start_timestamp": step_timing.get("step_start_timestamp"), # ISO 8601 timestamp string + "end_timestamp": step_timing.get("step_end_timestamp"), # ISO 8601 timestamp string + "llm_time": step_timing.get("llm_time", 0.0), + "env_time": step_timing.get("env_time", 0.0), + } + else: + # No timing data available for this step + step.info["timing"] = { + "start_timestamp": None, + "end_timestamp": None, + "llm_time": 0.0, + "env_time": 0.0, + } + + return episode + + def reset(self, task: Any = None, uid: str | None = None): + """Override reset to start timing. + + Subclasses should call this if they override reset. + """ + # Start timing when resetting + self.start_timing() + + # Call parent's reset if it exists + if hasattr(super(), "reset"): + return super().reset(task, uid)