Skip to content

Commit

Permalink
[V1] EngineCore supports profiling (#10564)
Browse files Browse the repository at this point in the history
Signed-off-by: Abatom <[email protected]>
  • Loading branch information
Abatom authored Nov 23, 2024
1 parent 28598f3 commit d345f40
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 9 deletions.
6 changes: 6 additions & 0 deletions vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,16 @@ class EngineCoreOutputs(msgspec.Struct,
outputs: List[EngineCoreOutput]


@dataclass
class EngineCoreProfile:
is_start: bool


class EngineCoreRequestType(enum.Enum):
"""
Request types defined as hex byte strings, so it can be sent over sockets
without separate encoding step.
"""
ADD = b'\x00'
ABORT = b'\x01'
PROFILE = b'\x02'
4 changes: 2 additions & 2 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,10 +346,10 @@ async def check_health(self) -> None:
logger.debug("Called check_health.")

async def start_profile(self) -> None:
raise ValueError("Not supported on V1 yet.")
await self.engine_core.profile(True)

async def stop_profile(self) -> None:
raise ValueError("Not supported on V1 yet.")
await self.engine_core.profile(False)

@property
def is_running(self) -> bool:
Expand Down
14 changes: 12 additions & 2 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import multiprocessing
import pickle
import queue
import threading
import time
Expand All @@ -16,7 +17,8 @@
from vllm.usage.usage_lib import UsageContext
from vllm.v1.core.scheduler import Scheduler
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
EngineCoreRequest, EngineCoreRequestType)
EngineCoreProfile, EngineCoreRequest,
EngineCoreRequestType)
from vllm.v1.engine.mm_input_mapper import MMInputMapper
from vllm.v1.executor.gpu_executor import GPUExecutor
from vllm.v1.request import Request, RequestStatus
Expand Down Expand Up @@ -126,6 +128,9 @@ def step(self) -> List[EngineCoreOutput]:
scheduler_output, output)
return engine_core_outputs

def profile(self, is_start=True):
self.model_executor.worker.profile(is_start)


class EngineCoreProc(EngineCore):
"""ZMQ-wrapper for running EngineCore in background process."""
Expand Down Expand Up @@ -312,11 +317,14 @@ def _log_stats(self):
self._last_logging_time = now

def _handle_client_request(
self, request: Union[EngineCoreRequest, List[str]]) -> None:
self, request: Union[EngineCoreRequest, EngineCoreProfile,
List[str]]) -> None:
"""Handle EngineCoreRequest or EngineCoreABORT from Client."""

if isinstance(request, EngineCoreRequest):
self.add_request(request)
elif isinstance(request, EngineCoreProfile):
self.model_executor.worker.profile(request.is_start)
else:
# TODO: make an EngineCoreAbort wrapper
assert isinstance(request, list)
Expand All @@ -341,6 +349,8 @@ def process_input_socket(self, input_path: str):
request = decoder_add_req.decode(request_data)
elif request_type == EngineCoreRequestType.ABORT.value:
request = decoder_abort_req.decode(request_data)
elif request_type == EngineCoreRequestType.PROFILE.value:
request = pickle.loads(request_data)
else:
raise ValueError(f"Unknown RequestType: {request_type}")

Expand Down
28 changes: 23 additions & 5 deletions vllm/v1/engine/core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from vllm.logger import init_logger
from vllm.utils import get_open_zmq_ipc_path
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
EngineCoreRequest, EngineCoreRequestType)
EngineCoreProfile, EngineCoreRequest,
EngineCoreRequestType)
from vllm.v1.engine.core import EngineCore, EngineCoreProc
from vllm.v1.serial_utils import PickleEncoder

Expand Down Expand Up @@ -58,6 +59,9 @@ def get_output(self) -> List[EngineCoreOutput]:
def add_request(self, request: EngineCoreRequest) -> None:
raise NotImplementedError

async def profile(self, is_start=True) -> None:
raise NotImplementedError

def abort_requests(self, request_ids: List[str]) -> None:
raise NotImplementedError

Expand Down Expand Up @@ -95,6 +99,9 @@ def add_request(self, request: EngineCoreRequest) -> None:
def abort_requests(self, request_ids: List[str]) -> None:
self.engine_core.abort_requests(request_ids)

async def profile(self, is_start=True) -> None:
self.engine_core.profile(is_start)


class MPClient(EngineCoreClient):
"""
Expand Down Expand Up @@ -177,8 +184,10 @@ def get_output(self) -> List[EngineCoreOutput]:
engine_core_outputs = self.decoder.decode(frame.buffer).outputs
return engine_core_outputs

def _send_input(self, request_type: EngineCoreRequestType,
request: Union[EngineCoreRequest, List[str]]) -> None:
def _send_input(
self, request_type: EngineCoreRequestType,
request: Union[EngineCoreRequest, EngineCoreProfile,
List[str]]) -> None:

# (RequestType, SerializedRequest)
msg = (request_type.value, self.encoder.encode(request))
Expand All @@ -190,6 +199,10 @@ def add_request(self, request: EngineCoreRequest) -> None:
def abort_requests(self, request_ids: List[str]) -> None:
self._send_input(EngineCoreRequestType.ABORT, request_ids)

async def profile(self, is_start=True) -> None:
self._send_input(EngineCoreRequestType.PROFILE,
EngineCoreProfile(is_start))


class AsyncMPClient(MPClient):
"""Asyncio-compatible client for multi-proc EngineCore."""
Expand All @@ -205,8 +218,9 @@ async def get_output_async(self) -> List[EngineCoreOutput]:
return engine_core_outputs

async def _send_input(
self, request_type: EngineCoreRequestType,
request: Union[EngineCoreRequest, List[str]]) -> None:
self, request_type: EngineCoreRequestType,
request: Union[EngineCoreRequest, EngineCoreProfile,
List[str]]) -> None:

msg = (request_type.value, self.encoder.encode(request))
await self.input_socket.send_multipart(msg, copy=False)
Expand All @@ -217,3 +231,7 @@ async def add_request_async(self, request: EngineCoreRequest) -> None:
async def abort_requests_async(self, request_ids: List[str]) -> None:
if len(request_ids) > 0:
await self._send_input(EngineCoreRequestType.ABORT, request_ids)

async def profile(self, is_start=True) -> None:
await self._send_input(EngineCoreRequestType.PROFILE,
EngineCoreProfile(is_start))
25 changes: 25 additions & 0 deletions vllm/v1/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
import torch.distributed

import vllm.envs as envs
from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment,
Expand Down Expand Up @@ -56,6 +57,22 @@ def __init__(
init_cached_hf_modules()

self.model_runner = GPUModelRunner(vllm_config)
# Torch profiler. Enabled and configured through env vars:
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
if envs.VLLM_TORCH_PROFILER_DIR:
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
logger.info("Profiling enabled. Traces will be saved to: %s",
torch_profiler_trace_dir)
self.profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
with_stack=True,
on_trace_ready=torch.profiler.tensorboard_trace_handler(
torch_profiler_trace_dir, use_gzip=True))
else:
self.profiler = None

def initialize(self):
if self.device_config.device.type == "cuda":
Expand Down Expand Up @@ -184,6 +201,14 @@ def execute_model(
# TODO(woosuk): Send the output to the engine process.
return output

def profile(self, is_start=True):
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
if is_start:
self.profiler.start()
else:
self.profiler.stop()


def init_worker_distributed_environment(
parallel_config: ParallelConfig,
Expand Down

0 comments on commit d345f40

Please sign in to comment.