Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simulator Tracking #2

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions examples/simulation_mode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from vllm.entrypoints.simulator import vLLMSimulator, WorkloadRequest

# workloads = [
# WorkloadRequest(0.0, 10, 10),
# WorkloadRequest(0.0, 20, 20),
# WorkloadRequest(0.0, 10, 10),
# WorkloadRequest(0.0, 20, 20),
# WorkloadRequest(0.05, 30, 30),
# WorkloadRequest(0.1, 40, 40),
# WorkloadRequest(0.2, 50, 50),
# ]

# engine = vLLMSimulator(model="facebook/opt-125m")
# engine.profile(workloads)

# Profile workload batch with prefill 2->2048 tokens and decode 1->256 tokens, using power of 2 ranges, but with one mid point value in between.
# For decode workload, sample the input tokens to 64.

import numpy as np

prefill_sizes = np.logspace(np.log2(2), np.log2(8192), num=64,
base=2).round().astype(int)
decode_sizes = np.logspace(np.log2(2), np.log2(512), num=64,
base=2).round().astype(int)

workload_batch = []
# first measure prefill
for i in prefill_sizes:
workload_batch.append([WorkloadRequest(0, i, 0)])
# then measure decode
for i in decode_sizes:
# TODO: check ctx size > 1, and how it varies (it should not)
# workload_batch.append([WorkloadRequest(0, 1, 2) for _ in range(i)])
workload_batch.append([WorkloadRequest(0, 1, 2) for _ in range(i)])

print(prefill_sizes)
print(decode_sizes)

# workload_batch = [
# [WorkloadRequest(0, 100, 0), WorkloadRequest(0, 100, 0)], # 200 prefil tokens
# [WorkloadRequest(0, 200, 4), WorkloadRequest(0, 200, 4)], # 400 prefill tokens, 2 decode token per iter
# ]

# engine = vLLMSimulator(model="facebook/opt-125m")
# engine = vLLMSimulator(model="meta-llama/Meta-Llama-3-8B-Instruct")
engine = vLLMSimulator(model="meta-llama/Meta-Llama-3-70B-Instruct",
tensor_parallel_size=4)
profile = engine.profile_tokens_curve(workload_batch, n_trials=3)

# print(profile.prefill_timing) # this is a defaultdict(list)
# print(profile.decode_timing) # this is a defaultdict(list)

# turn this into a pandas dataframe

import pandas as pd

data = []
for k, v in profile.prefill_timing.items():
for i in v:
data.append({"size": k, "time_ms": i, "op": "prefill"})

for k, v in profile.decode_timing.items():
for i in v:
data.append({"size": k, "time_ms": i, "op": "decode"})

df = pd.DataFrame(data)

import sys

print("----")
df.to_csv(sys.stdout, index=False)
df.to_csv("profile-70b.csv", index=False)
1 change: 1 addition & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1613,6 +1613,7 @@ class EngineConfig:
decoding_config: Optional[DecodingConfig]
observability_config: Optional[ObservabilityConfig]
prompt_adapter_config: Optional[PromptAdapterConfig]
simulation_mode: bool = False

def __post_init__(self):
"""Verify configs are valid & consistent with each other.
Expand Down
11 changes: 11 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,19 @@ class EngineArgs:
speculative_disable_by_batch_size: Optional[int] = None
ngram_prompt_lookup_max: Optional[int] = None
ngram_prompt_lookup_min: Optional[int] = None

spec_decoding_acceptance_method: str = 'rejection_sampler'
typical_acceptance_sampler_posterior_threshold: Optional[float] = None
typical_acceptance_sampler_posterior_alpha: Optional[float] = None
qlora_adapter_name_or_path: Optional[str] = None
disable_logprobs_during_spec_decoding: Optional[bool] = None

qlora_adapter_name_or_path: Optional[str] = None
otlp_traces_endpoint: Optional[str] = None

# Simulation mode
simulation_mode: bool = False


def __post_init__(self):
if self.tokenizer is None:
Expand Down Expand Up @@ -661,6 +667,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default=None,
help='Target URL to which OpenTelemetry traces will be sent.')

parser.add_argument('--simulation-mode',
action='store_true',
help='Enable simulation mode for the engine.')

return parser

@classmethod
Expand Down Expand Up @@ -871,6 +881,7 @@ def create_engine_config(self, ) -> EngineConfig:
decoding_config=decoding_config,
observability_config=observability_config,
prompt_adapter_config=prompt_adapter_config,
simulation_mode=self.simulation_mode,
)


Expand Down
6 changes: 5 additions & 1 deletion vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,11 @@ def _get_executor_cls(
cls, engine_config: EngineConfig) -> Type[ExecutorAsyncBase]:
distributed_executor_backend = (
engine_config.parallel_config.distributed_executor_backend)
if isinstance(distributed_executor_backend, type):

if engine_config.simulation_mode:
from vllm.executor.simulated_executor import SimulatedExecutorAsync
executor_class = SimulatedExecutorAsync
elif isinstance(distributed_executor_backend, type):
if not issubclass(distributed_executor_backend, ExecutorAsyncBase):
raise TypeError(
"distributed_executor_backend must be a subclass of "
Expand Down
8 changes: 7 additions & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def __init__(
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
simulation_mode: bool = False,
) -> None:
logger.info(
"Initializing an LLM engine (v%s) with config: "
Expand Down Expand Up @@ -352,6 +353,8 @@ def __init__(
self.get_tokenizer_for_seq,
),
))

self.simulation_mode = simulation_mode

def _initialize_kv_caches(self) -> None:
"""Initialize the KV cache in the worker(s).
Expand Down Expand Up @@ -381,7 +384,10 @@ def _get_executor_cls(cls,
distributed_executor_backend = (
engine_config.parallel_config.distributed_executor_backend)
# Initialize the cluster and specify the executor class.
if isinstance(distributed_executor_backend, type):
if engine_config.simulation_mode:
from vllm.executor.simulated_executor import SimulatedExecutor
executor_class = SimulatedExecutor
elif isinstance(distributed_executor_backend, type):
if not issubclass(distributed_executor_backend, ExecutorBase):
raise TypeError(
"distributed_executor_backend must be a subclass of "
Expand Down
4 changes: 4 additions & 0 deletions vllm/entrypoints/simulator/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from vllm.entrypoints.simulator.simulator import vLLMSimulator
from vllm.entrypoints.simulator.loadgen import WorkloadRequest

__all__ = ["vLLMSimulator", "WorkloadRequest"]
10 changes: 10 additions & 0 deletions vllm/entrypoints/simulator/loadgen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from dataclasses import dataclass


@dataclass
class WorkloadRequest:
arrival_time: float
num_input_tokens: int
num_output_tokens: int

# TODO: add fields for prefix sharing
76 changes: 76 additions & 0 deletions vllm/entrypoints/simulator/profile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import json
from collections import defaultdict
from contextlib import contextmanager
import time


@contextmanager
def profile_hook(execute_model_req, profile):
prefill_tokens = sum(
seq_group.token_chunk_size
for seq_group in execute_model_req.seq_group_metadata_list
if seq_group.is_prompt)
decode_tokens = sum(
seq_group.token_chunk_size
for seq_group in execute_model_req.seq_group_metadata_list
if not seq_group.is_prompt)

start = time.perf_counter_ns()

yield

duration = (time.perf_counter_ns() - start) / 1e6

print(
f"prefill_tokens: {prefill_tokens}, decode_tokens: {decode_tokens}, duration_ms: {duration}"
)
profile.record(prefill_tokens, decode_tokens, duration)


class SimulationProfile:
"""Profiling data structure capturing the timing of a single forward pass."""

def __init__(self):
self.prefill_timing = defaultdict(list)
self.decode_timing = defaultdict(list)

def clear(self):
self.prefill_timing.clear()
self.decode_timing.clear()

def record(self, prefill_tokens, decode_tokens, duration_ms):
# TODO: use histogram, and sampling

assert not all([prefill_tokens, decode_tokens]), "chunked prefill todo"

if prefill_tokens:
self.prefill_timing[str(prefill_tokens)].append(duration_ms)
else:
self.decode_timing[str(decode_tokens)].append(duration_ms)

def get_estimate(self, prefill_tokens, decode_tokens):
# TODO: sample

assert not all([prefill_tokens, decode_tokens]), "chunked prefill todo"

if prefill_tokens:
return self.prefill_timing[str(prefill_tokens)][0]
else:
return self.decode_timing[str(decode_tokens)][0]

def save(self, path):
with open(path, "w") as f:
json.dump(
{
"prefill_timing": self.prefill_timing,
"decode_timing": self.decode_timing
}, f)

@classmethod
def load(cls, path):
o = cls()
with open(path, "r") as f:
data = json.load(f)
o.prefill_timing = data["prefill_timing"]
o.decode_timing = data["decode_timing"]
return o
Loading
Loading