diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index d9b1d18da156c..52d1806018f51 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -179,6 +179,32 @@ def __init__( self.cache_engine: List[CPUCacheEngine] self.cpu_cache: List[List[torch.Tensor]] + # Torch profiler. Enabled and configured through env vars: + # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace + if envs.VLLM_TORCH_PROFILER_DIR: + torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR + logger.info("Profiling enabled. Traces will be saved to: %s", + torch_profiler_trace_dir) + self.profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + ], + with_stack=True, + on_trace_ready=torch.profiler.tensorboard_trace_handler( + torch_profiler_trace_dir, use_gzip=True)) + else: + self.profiler = None + + def start_profile(self): + if self.profiler is None: + raise RuntimeError("Profiler is not enabled.") + self.profiler.start() + + def stop_profile(self): + if self.profiler is None: + raise RuntimeError("Profiler is not enabled.") + self.profiler.stop() + def init_device(self) -> None: if self.local_omp_cpuid != "all": torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)