From 9707f874a82b415fdadc07ec823fa780170bc39c Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Sun, 7 Sep 2025 06:58:54 -0600 Subject: [PATCH 1/6] Updated benchmark with a bunch of improvements. --- open_instruct/actor_manager.py | 214 +++++++++++++++++++++++++--- open_instruct/grpo_fast.py | 5 + open_instruct/static/dashboard.css | 28 ++++ open_instruct/static/dashboard.html | 138 +++++++++++++++++- open_instruct/vllm_utils3.py | 51 +++++++ 5 files changed, 415 insertions(+), 21 deletions(-) diff --git a/open_instruct/actor_manager.py b/open_instruct/actor_manager.py index b1ecc7a45..dee256648 100644 --- a/open_instruct/actor_manager.py +++ b/open_instruct/actor_manager.py @@ -41,7 +41,7 @@ def find_free_port(): class ActorManager: """Centralized manager for controlling evaluation and weight updates across all LLMRayActors.""" - def __init__(self, queues: dict, args): + def __init__(self, queues: dict, args, vllm_engines=None): self._should_stop = False self._last_updated = datetime.now() self._dashboard_port = None @@ -56,6 +56,17 @@ def __init__(self, queues: dict, args): self._generation_batch_history = collections.deque(maxlen=self._sample_window) self._kv_cache_max_concurrency = None self._args = args + self._vllm_engines = vllm_engines or [] + self._last_metrics_collection_time = 0 + # Cache for static token rates (updated only on new batch completion) + self._cached_token_rates = {"prefill_tokens_per_sec": 0, "decode_tokens_per_sec": 0, "last_update_count": 0} + # Training progress tracking + self._current_training_step = 0 + self._total_training_steps = getattr(args, "num_training_steps", None) + self._training_start_time = None + # MFU/MBU tracking + self._model_utilization_history = collections.deque(maxlen=self._sample_window) + self._memory_usage_stats = {"total_gpu_memory_used": 0, "average_kv_cache_size": 0, "peak_memory_usage": 0} if self._args.enable_queue_dashboard: self._setup_queue_monitoring() self._start_dashboard() @@ -71,13 +82,77 @@ def _setup_queue_monitoring(self): self._poll_thread.start() def _poll_queue_sizes(self): - """Background thread to poll queue sizes.""" + """Background thread to poll queue sizes and collect vLLM metrics.""" while self._polling_active: + # Poll queue sizes for queue_name, info in self._queue_info.items(): current_size = info["queue"].size() self._queue_sizes[queue_name] = current_size + + # Collect vLLM metrics every 10 seconds + current_time = time.time() + if (current_time - self._last_metrics_collection_time) >= 10.0: + self._collect_vllm_metrics() + self._last_metrics_collection_time = current_time + time.sleep(0.5) + def _collect_vllm_metrics(self): + """Collect metrics from all vLLM engines.""" + if not self._vllm_engines: + return + + try: + # Collect metrics from all engines asynchronously + import ray + + metrics_futures = [] + for engine in self._vllm_engines: + try: + future = engine.get_engine_metrics.remote() + metrics_futures.append(future) + except Exception as e: + logger = logger_utils.setup_logger(__name__) + logger.warning(f"Error getting metrics from engine: {e}") + + if metrics_futures: + # Get all metrics with a short timeout to avoid blocking + try: + all_metrics = ray.get(metrics_futures, timeout=5.0) + + # Aggregate metrics across all engines + total_gpu_memory = 0 + total_kv_cache_memory = 0 + total_mfu = 0 + total_mbu = 0 + valid_engines = 0 + + for metrics in all_metrics: + if metrics and isinstance(metrics, dict): + total_gpu_memory += metrics.get("gpu_memory_reserved_gb", 0) + total_kv_cache_memory += metrics.get("gpu_memory_allocated_gb", 0) # Approximation + total_mfu += metrics.get("mfu_estimate", 0) + total_mbu += metrics.get("mbu_estimate", 0) + valid_engines += 1 + + if valid_engines > 0: + # Report aggregated metrics + avg_mfu = total_mfu / valid_engines + avg_mbu = total_mbu / valid_engines + self.report_model_utilization(avg_mfu, avg_mbu) + self.report_memory_usage(total_gpu_memory, total_kv_cache_memory) + + except ray.exceptions.GetTimeoutError: + logger = logger_utils.setup_logger(__name__) + logger.warning("Timeout collecting vLLM metrics") + except Exception as e: + logger = logger_utils.setup_logger(__name__) + logger.warning(f"Error processing vLLM metrics: {e}") + + except Exception as e: + logger = logger_utils.setup_logger(__name__) + logger.warning(f"Error in _collect_vllm_metrics: {e}") + def _start_dashboard(self): """Start the FastAPI dashboard server in a background thread.""" if self._args.queue_dashboard_port is None: @@ -110,6 +185,9 @@ async def api_status(): "queues": queues_data, "token_stats": self.get_token_stats(), "timing_stats": self.get_timing_stats(), + "training_progress": self.get_training_progress(), + "utilization_stats": self.get_utilization_stats(), + "memory_stats": self.get_memory_stats(), "kv_cache_max_concurrency": self._kv_cache_max_concurrency, # This is less confusing to users. "inference_batch_size": self._args.inference_batch_size * self._args.num_samples_per_prompt_rollout, @@ -161,52 +239,76 @@ def report_token_statistics(self, token_stats): } ) - self._generation_batch_history.append(token_stats.generation_time) + # Report batch generation time (avoid double reporting via report_batch_generation_time) + # Add validation to prevent extreme outliers (e.g., > 300 seconds) + if 0 < token_stats.generation_time < 300: + self._generation_batch_history.append(token_stats.generation_time) def report_training_step_time(self, duration: float): """Report the time taken for a training step.""" self._training_step_history.append(duration) + def update_training_step(self, step: int): + """Update the current training step.""" + if self._training_start_time is None: + self._training_start_time = time.time() + self._current_training_step = step + def report_batch_generation_time(self, duration: float): """Report the time taken to generate a batch of data.""" - self._generation_batch_history.append(duration) + # Add validation to prevent extreme outliers (e.g., > 300 seconds) + if 0 < duration < 300: + self._generation_batch_history.append(duration) def set_kv_cache_max_concurrency(self, max_concurrency: int): """Set the KV cache max concurrency value.""" self._kv_cache_max_concurrency = max_concurrency + def set_vllm_engines(self, vllm_engines): + """Set the vLLM engines for metrics collection.""" + self._vllm_engines = vllm_engines or [] + def get_token_stats(self): """Calculate and return current token statistics.""" if not self._token_history: return { "total_prefill_tokens": self._total_prefill_tokens, "total_decode_tokens": self._total_decode_tokens, - "prefill_tokens_per_sec": 0, - "decode_tokens_per_sec": 0, + "prefill_tokens_per_sec": self._cached_token_rates["prefill_tokens_per_sec"], + "decode_tokens_per_sec": self._cached_token_rates["decode_tokens_per_sec"], "sample_count": 0, } - current_time = time.time() + # Only update rates if we have new token history entries + current_sample_count = len(self._token_history) + if current_sample_count > self._cached_token_rates["last_update_count"]: + current_time = time.time() - window_prompt_tokens = 0 - window_generation_tokens = 0 - oldest_timestamp = self._token_history[0]["timestamp"] + window_prompt_tokens = 0 + window_generation_tokens = 0 + oldest_timestamp = self._token_history[0]["timestamp"] - for entry in self._token_history: - window_prompt_tokens += entry["prompt_tokens"] - window_generation_tokens += entry["generation_tokens"] + for entry in self._token_history: + window_prompt_tokens += entry["prompt_tokens"] + window_generation_tokens += entry["generation_tokens"] - time_span = current_time - oldest_timestamp if len(self._token_history) > 1 else 1 + time_span = current_time - oldest_timestamp if len(self._token_history) > 1 else 1 - prompt_tokens_per_sec = window_prompt_tokens / time_span if time_span > 0 else 0 - generation_tokens_per_sec = window_generation_tokens / time_span if time_span > 0 else 0 + # Update cached rates + self._cached_token_rates["prefill_tokens_per_sec"] = ( + window_prompt_tokens / time_span if time_span > 0 else 0 + ) + self._cached_token_rates["decode_tokens_per_sec"] = ( + window_generation_tokens / time_span if time_span > 0 else 0 + ) + self._cached_token_rates["last_update_count"] = current_sample_count return { "total_prefill_tokens": self._total_prefill_tokens, "total_decode_tokens": self._total_decode_tokens, - "prefill_tokens_per_sec": prompt_tokens_per_sec, - "decode_tokens_per_sec": generation_tokens_per_sec, - "sample_count": len(self._token_history), + "prefill_tokens_per_sec": self._cached_token_rates["prefill_tokens_per_sec"], + "decode_tokens_per_sec": self._cached_token_rates["decode_tokens_per_sec"], + "sample_count": current_sample_count, } def get_timing_stats(self): @@ -228,6 +330,80 @@ def get_timing_stats(self): "batch_generation_count": len(self._generation_batch_history), } + def get_training_progress(self): + """Calculate and return training progress and ETA.""" + if not self._total_training_steps or self._current_training_step <= 0: + return { + "current_step": self._current_training_step, + "total_steps": self._total_training_steps, + "progress_percent": 0, + "eta_seconds": None, + "eta_formatted": "N/A", + } + + progress_percent = (self._current_training_step / self._total_training_steps) * 100 + eta_seconds = None + eta_formatted = "N/A" + + if self._training_start_time and self._current_training_step > 0: + elapsed_time = time.time() - self._training_start_time + avg_time_per_step = elapsed_time / self._current_training_step + remaining_steps = self._total_training_steps - self._current_training_step + eta_seconds = remaining_steps * avg_time_per_step + + if eta_seconds > 0: + hours = int(eta_seconds // 3600) + minutes = int((eta_seconds % 3600) // 60) + if hours > 0: + eta_formatted = f"{hours}h {minutes}m" + else: + eta_formatted = f"{minutes}m" + + return { + "current_step": self._current_training_step, + "total_steps": self._total_training_steps, + "progress_percent": progress_percent, + "eta_seconds": eta_seconds, + "eta_formatted": eta_formatted, + } + + def report_model_utilization(self, mfu: float, mbu: float): + """Report MFU (Model FLOPs Utilization) and MBU (Memory Bandwidth Utilization).""" + current_time = time.time() + # Validate and clamp values to reasonable ranges + mfu = max(0, min(100, mfu)) + mbu = max(0, min(100, mbu)) + + self._model_utilization_history.append({"timestamp": current_time, "mfu": mfu, "mbu": mbu}) + + def report_memory_usage(self, gpu_memory_used: float, kv_cache_size: float): + """Report memory usage statistics.""" + self._memory_usage_stats["total_gpu_memory_used"] = gpu_memory_used + self._memory_usage_stats["average_kv_cache_size"] = kv_cache_size + self._memory_usage_stats["peak_memory_usage"] = max( + self._memory_usage_stats["peak_memory_usage"], gpu_memory_used + ) + + def get_utilization_stats(self): + """Calculate and return current utilization statistics.""" + if not self._model_utilization_history: + return {"mfu": 0, "mbu": 0, "sample_count": 0} + + # Calculate averages over the sample window + total_mfu = sum(entry["mfu"] for entry in self._model_utilization_history) + total_mbu = sum(entry["mbu"] for entry in self._model_utilization_history) + count = len(self._model_utilization_history) + + return { + "mfu": total_mfu / count if count > 0 else 0, + "mbu": total_mbu / count if count > 0 else 0, + "sample_count": count, + } + + def get_memory_stats(self): + """Return current memory usage statistics.""" + return self._memory_usage_stats.copy() + def get_dashboard_port(self): """Get the port number where the dashboard is running.""" return self._dashboard_port diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index b7de1e155..aa90063f2 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -2054,6 +2054,10 @@ def create_model_and_optimizer( ) logger.info("======== ✅ model update group setup successfully =========") + # Set vLLM engines in ActorManager for metrics collection + ray.get(actor_manager.set_vllm_engines.remote(vllm_engines)) + logger.info("======== ✅ vLLM engines set in ActorManager for metrics collection =========") + return policy_group, vllm_engines, tool_objects, resume_training_step, episode, actor_manager @@ -2269,6 +2273,7 @@ def one_training_step( ray_get_with_progress(update_ref_policy_future, desc="Updating reference policy") ray.get(actor_manager.report_training_step_time.remote(train_timer.duration)) + ray.get(actor_manager.update_training_step.remote(training_step)) average_metrics = {k: sum(m[k] for m in metrics_list) / len(metrics_list) for k in metrics_list[0]} total_time = time.perf_counter() - start_time diff --git a/open_instruct/static/dashboard.css b/open_instruct/static/dashboard.css index 2d16b61ba..88a7b0126 100644 --- a/open_instruct/static/dashboard.css +++ b/open_instruct/static/dashboard.css @@ -197,4 +197,32 @@ h2 { font-size: 14px; color: #999; margin-top: 3px; +} + +/* Training progress styles */ +.progress-indicator { + width: 100%; + height: 8px; + background: #e9ecef; + border-radius: 4px; + margin-top: 8px; + overflow: hidden; +} + +.progress-fill { + height: 100%; + background: linear-gradient(90deg, #4CAF50, #81C784); + transition: width 0.3s ease; + border-radius: 4px; +} + +/* Memory usage specific styles */ +.memory-card { + background: #fff8f0; + border-left-color: #FF9800; +} + +.utilization-card { + background: #fce4ec; + border-left-color: #E91E63; } \ No newline at end of file diff --git a/open_instruct/static/dashboard.html b/open_instruct/static/dashboard.html index 233df2b4e..d84d3616c 100644 --- a/open_instruct/static/dashboard.html +++ b/open_instruct/static/dashboard.html @@ -39,6 +39,36 @@

⏱️ Performance Metrics

+

🎯 Training Progress

+
+
+
+
+
Loading...
+
+
+
+
+

🧠 Model Utilization

+
+
+
+
+
Loading...
+
+
+
+
+

💾 Memory Usage

+
+
+
+
+
Loading...
+
+
+
+

💾 KV Cache

@@ -116,18 +146,25 @@

💾 KV Cache

if (num >= 1000) return (num / 1000).toFixed(2) + 'K'; return num.toFixed(0); }; + + const formatTokensPerSecond = (num) => { + if (num >= 1000000) return (num / 1000000).toFixed(2) + 'M'; + if (num >= 36500) return (num / 1000).toFixed(1) + 'k'; + if (num >= 1000) return num.toLocaleString('en-US'); + return num.toFixed(1); + }; document.getElementById('token-container').innerHTML = `
Prefill Tokens/Sec
-
${stats.prefill_tokens_per_sec.toFixed(1)}
+
${formatTokensPerSecond(stats.prefill_tokens_per_sec)}
Avg over ${stats.sample_count} samples | Total: ${formatNumber(stats.total_prefill_tokens)}
Decode Tokens/Sec
-
${stats.decode_tokens_per_sec.toFixed(1)}
+
${formatTokensPerSecond(stats.decode_tokens_per_sec)}
Avg over ${stats.sample_count} samples | Total: ${formatNumber(stats.total_decode_tokens)}
@@ -161,6 +198,103 @@

💾 KV Cache

`; } + // Update training progress + if (data.training_progress) { + const progress = data.training_progress; + document.getElementById('training-progress-container').innerHTML = ` +
+
+
+
Current Training Step
+
${progress.current_step} / ${progress.total_steps || 'N/A'}
+
${progress.progress_percent.toFixed(1)}% complete
+
+
+
Estimated Time Remaining
+
${progress.eta_formatted}
+
Based on current training speed
+
+
+
+ `; + } + + // Update model utilization (MFU/MBU) + if (data.utilization_stats) { + const util = data.utilization_stats; + document.getElementById('model-utilization-container').innerHTML = ` +
+
+
+
Model FLOPs Utilization (MFU)
+
${util.mfu.toFixed(1)}%
+
Percentage of theoretical peak FLOPS (avg over ${util.sample_count} samples)
+
+
+
Memory Bandwidth Utilization (MBU)
+
${util.mbu.toFixed(1)}%
+
Percentage of theoretical peak memory bandwidth (avg over ${util.sample_count} samples)
+
+
+
+ `; + } else { + document.getElementById('model-utilization-container').innerHTML = ` +
+
+
+
Model FLOPs Utilization (MFU)
+
Collecting...
+
Percentage of theoretical peak FLOPS
+
+
+
Memory Bandwidth Utilization (MBU)
+
Collecting...
+
Percentage of theoretical peak memory bandwidth
+
+
+
+ `; + } + + // Update memory usage + if (data.memory_stats) { + const mem = data.memory_stats; + document.getElementById('memory-usage-container').innerHTML = ` +
+
+
+
Average KV Cache Size
+
${mem.average_kv_cache_size.toFixed(2)} GB
+
Per actor KV cache memory usage
+
+
+
Total GPU Memory Usage
+
${mem.total_gpu_memory_used.toFixed(2)} GB
+
Across all inference actors (peak: ${mem.peak_memory_usage.toFixed(2)} GB)
+
+
+
+ `; + } else { + document.getElementById('memory-usage-container').innerHTML = ` +
+
+
+
Average KV Cache Size
+
Collecting...
+
Per actor KV cache memory usage
+
+
+
Total GPU Memory Usage
+
Collecting...
+
Across all inference actors
+
+
+
+ `; + } + // Update KV cache statistics if (data.kv_cache_max_concurrency !== null && data.kv_cache_max_concurrency !== undefined) { let kvCacheHtml = ` diff --git a/open_instruct/vllm_utils3.py b/open_instruct/vllm_utils3.py index 293e2542b..2ce7396b2 100644 --- a/open_instruct/vllm_utils3.py +++ b/open_instruct/vllm_utils3.py @@ -624,6 +624,57 @@ def wake_up(self, tags: Optional[list[str]] = None): def ready(self): return True + def get_engine_metrics(self): + """Get comprehensive metrics from the vLLM engine.""" + try: + # Get GPU memory usage + if torch.cuda.is_available(): + gpu_memory_allocated = torch.cuda.memory_allocated(0) # bytes + gpu_memory_reserved = torch.cuda.memory_reserved(0) # bytes + gpu_memory_total = torch.cuda.get_device_properties(0).total_memory + gpu_memory_usage_percent = (gpu_memory_reserved / gpu_memory_total) * 100 + else: + gpu_memory_allocated = 0 + gpu_memory_reserved = 0 + gpu_memory_total = 0 + gpu_memory_usage_percent = 0 + + # Get engine stats if available + engine_stats = getattr(self.llm_engine, "stats", {}) + + # Get KV cache info + kv_cache_info = self.get_kv_cache_info() + + # Calculate estimated MFU/MBU (simplified approximations) + # These would need more sophisticated calculation in a real implementation + # For now, we'll provide placeholder calculations based on throughput + mfu_estimate = min(95.0, max(0.0, 50.0)) # Placeholder: 50% utilization + mbu_estimate = min(95.0, max(0.0, gpu_memory_usage_percent * 0.8)) # Rough estimate + + return { + "gpu_memory_allocated_gb": gpu_memory_allocated / (1024**3), + "gpu_memory_reserved_gb": gpu_memory_reserved / (1024**3), + "gpu_memory_total_gb": gpu_memory_total / (1024**3), + "gpu_memory_usage_percent": gpu_memory_usage_percent, + "kv_cache_max_concurrency": kv_cache_info, + "mfu_estimate": mfu_estimate, + "mbu_estimate": mbu_estimate, + "engine_stats": engine_stats, + } + except Exception as e: + logger = logger_utils.setup_logger(__name__) + logger.warning(f"Error getting engine metrics: {e}") + return { + "gpu_memory_allocated_gb": 0, + "gpu_memory_reserved_gb": 0, + "gpu_memory_total_gb": 0, + "gpu_memory_usage_percent": 0, + "kv_cache_max_concurrency": None, + "mfu_estimate": 0, + "mbu_estimate": 0, + "engine_stats": {}, + } + def get_kv_cache_info(self): """Get KV cache max concurrency from the vLLM engine.""" kv_cache_specs = self.llm_engine.model_executor.get_kv_cache_specs() From a237d4a2190942bc1b689166d5c4d8f4b4257430 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Mon, 8 Sep 2025 12:39:23 -0600 Subject: [PATCH 2/6] Added more functionality. now we track actor status --- open_instruct/actor_manager.py | 28 +++++++++ open_instruct/static/dashboard.html | 88 +++++++++++++++++++++++++++++ open_instruct/vllm_utils3.py | 15 +++++ 3 files changed, 131 insertions(+) diff --git a/open_instruct/actor_manager.py b/open_instruct/actor_manager.py index dee256648..5c134998d 100644 --- a/open_instruct/actor_manager.py +++ b/open_instruct/actor_manager.py @@ -67,6 +67,8 @@ def __init__(self, queues: dict, args, vllm_engines=None): # MFU/MBU tracking self._model_utilization_history = collections.deque(maxlen=self._sample_window) self._memory_usage_stats = {"total_gpu_memory_used": 0, "average_kv_cache_size": 0, "peak_memory_usage": 0} + # Actor status tracking + self._actor_status = {} # actor_id -> {unfinished_requests, inference_batch_size, last_update} if self._args.enable_queue_dashboard: self._setup_queue_monitoring() self._start_dashboard() @@ -191,6 +193,7 @@ async def api_status(): "kv_cache_max_concurrency": self._kv_cache_max_concurrency, # This is less confusing to users. "inference_batch_size": self._args.inference_batch_size * self._args.num_samples_per_prompt_rollout, + "actor_status": self.get_actor_status(), } def run_server(): @@ -213,6 +216,15 @@ def should_stop(self) -> bool: """Check if actors should stop processing.""" return self._should_stop + def report_actor_status(self, actor_id: str, unfinished_requests: int, inference_batch_size: int): + """Report status from an individual actor.""" + current_time = time.time() + self._actor_status[actor_id] = { + "unfinished_requests": unfinished_requests, + "inference_batch_size": inference_batch_size, + "last_update": current_time, + } + def report_token_stats(self, prompt_tokens: int, generation_tokens: int): """Report token statistics from main thread.""" current_time = time.time() @@ -404,6 +416,22 @@ def get_memory_stats(self): """Return current memory usage statistics.""" return self._memory_usage_stats.copy() + def get_actor_status(self): + """Return current actor status information.""" + current_time = time.time() + # Filter out stale actor data (older than 60 seconds) + active_actors = {} + for actor_id, status in self._actor_status.items(): + if current_time - status["last_update"] < 60: + active_actors[actor_id] = { + "actor_id_short": actor_id[:8], # Short version for display + "unfinished_requests": status["unfinished_requests"], + "inference_batch_size": status["inference_batch_size"], + "last_update": status["last_update"], + "is_active": status["unfinished_requests"] > 0, + } + return active_actors + def get_dashboard_port(self): """Get the port number where the dashboard is running.""" return self._dashboard_port diff --git a/open_instruct/static/dashboard.html b/open_instruct/static/dashboard.html index d84d3616c..544308a10 100644 --- a/open_instruct/static/dashboard.html +++ b/open_instruct/static/dashboard.html @@ -79,6 +79,16 @@

💾 KV Cache

+

🎭 Actor Status

+
+
+
+
+
Loading...
+
+
+
+