From 9391b412af3e55e2a78efc43a3ecca4bb0962f77 Mon Sep 17 00:00:00 2001 From: simon-mo Date: Sat, 25 May 2024 04:28:29 +0000 Subject: [PATCH 1/6] simple simulation without interpolation --- examples/simulation_mode.py | 157 ++++++++++++++++++++++++++++ vllm/engine/arg_utils.py | 7 ++ vllm/engine/async_llm_engine.py | 5 +- vllm/engine/llm_engine.py | 5 +- vllm/executor/gpu_executor.py | 17 +++ vllm/executor/simulated_executor.py | 105 +++++++++++++++++++ 6 files changed, 294 insertions(+), 2 deletions(-) create mode 100644 examples/simulation_mode.py create mode 100644 vllm/executor/simulated_executor.py diff --git a/examples/simulation_mode.py b/examples/simulation_mode.py new file mode 100644 index 0000000000000..3f192b05b5e4e --- /dev/null +++ b/examples/simulation_mode.py @@ -0,0 +1,157 @@ +from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams +import simpy +from dataclasses import dataclass +import time +import json + + +class SimulationProfile: + """Profiling data structure capturing the timing of a single forward pass.""" + + def __init__(self): + self.prefill_timing = {} + self.decode_timing = {} + + def record(self, prefill_tokens, decode_tokens, duration_ms): + # TODO: use histogram, and sampling + + assert not all([prefill_tokens, decode_tokens]), "chunked prefill todo" + + if prefill_tokens: + self.prefill_timing[str(prefill_tokens)] = duration_ms + else: + self.decode_timing[str(decode_tokens)] = duration_ms + + def get_estimate(self, prefill_tokens, decode_tokens): + # TODO: sample + + assert not all([prefill_tokens, decode_tokens]), "chunked prefill todo" + + if prefill_tokens: + return self.prefill_timing[str(prefill_tokens)] + else: + return self.decode_timing[str(decode_tokens)] + + def save(self, path): + with open(path, "w") as f: + json.dump( + { + "prefill_timing": self.prefill_timing, + "decode_timing": self.decode_timing + }, f) + + @classmethod + def load(cls, path): + o = cls() + with open(path, "r") as f: + data = json.load(f) + o.prefill_timing = data["prefill_timing"] + o.decode_timing = data["decode_timing"] + return o + + +# workload characteristics +@dataclass +class WorkloadRequest: + arrival_time: float + num_input_tokens: int + num_output_tokens: int + + # TODO: add fields for prefix sharing + + +workloads = [ + WorkloadRequest(0.0, 10, 10), + WorkloadRequest(0.0, 20, 20), + WorkloadRequest(0.0, 10, 10), + WorkloadRequest(0.0, 20, 20), + WorkloadRequest(0.05, 30, 30), + WorkloadRequest(0.1, 40, 40), + WorkloadRequest(0.2, 50, 50), +] + +# SIMULATION_MODE = True +SIMULATION_MODE = False + +engine = LLMEngine.from_engine_args( + EngineArgs(model="facebook/opt-125m", simulation_mode=SIMULATION_MODE)) +# env = simpy.Environment() +env = simpy.rt.RealtimeEnvironment(factor=1, strict=True) + +if not SIMULATION_MODE: + # enable profiling + profile = SimulationProfile() + engine.model_executor.simulation_profile = profile +else: + profile = SimulationProfile.load("profile.json") + + engine.model_executor.simulation_profile = profile + engine.model_executor.env = env + + time.time = lambda: env.now + +generator_finished = False +enqueued = simpy.Store(env, capacity=1) + + +def request_generator(env): + curr_time = env.now + # assume workloads are sorted by arrival time + for i, workload in enumerate(workloads): + if env.now != workload.arrival_time: + yield env.timeout(workload.arrival_time - curr_time) + + engine.add_request( + request_id=str(i), + prompt=None, + prompt_token_ids=[0] * workload.num_input_tokens, + params=SamplingParams(max_tokens=workload.num_output_tokens, + ignore_eos=True), + ) + + if len(enqueued.items) == 0: + # notify the engine that there is a new request + enqueued.put(i) + + global generator_finished + generator_finished = True + if len(enqueued.items) == 0: + # notify the engine that there is a new request + enqueued.put(i) + + +def engine_runner(env): + start_time = time.time() + while not generator_finished or engine.has_unfinished_requests(): + start = time.time() + outputs = engine.step() + print("---") + for output in outputs: + output_metrics = { + "arrival_time": output.metrics.arrival_time - start_time, + "last_token_time": output.metrics.last_token_time - start_time, + "first_scheduled_time": + output.metrics.first_scheduled_time - start_time, + "first_token_time": + output.metrics.first_token_time - start_time + } + print(output.request_id, output_metrics) + print("---") + end = time.time() + # TODO: use proper synchronization + passed = end - start + print(passed) + if not SIMULATION_MODE: + yield env.timeout(passed) + else: + if not engine.has_unfinished_requests(): + yield enqueued.get() + yield env.timeout(0.0006) # fixed scheduler overhead + + +env.process(request_generator(env)) +env.process(engine_runner(env)) +env.run() + +if not SIMULATION_MODE: + profile.save("profile.json") diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 538e3427e37fb..4153a9514c9d2 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -91,6 +91,9 @@ class EngineArgs: ngram_prompt_lookup_max: Optional[int] = None ngram_prompt_lookup_min: Optional[int] = None + # Simulation mode + simulation_mode: bool = False + def __post_init__(self): if self.tokenizer is None: self.tokenizer = self.model @@ -539,6 +542,10 @@ def add_cli_args( "prometheus metrics, if multiple names provided, metrics" "tag will take the first one.") + parser.add_argument('--simulation-mode', + action='store_true', + help='Enable simulation mode for the engine.') + return parser @classmethod diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 5a15ed67e3327..4b3f60587800b 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -359,7 +359,10 @@ def from_engine_args( distributed_executor_backend = ( engine_config.parallel_config.distributed_executor_backend) - if engine_config.device_config.device_type == "neuron": + if engine_args.simulation_mode: + from vllm.executor.simulated_executor import SimulatedExecutorAsync + executor_class = SimulatedExecutorAsync + elif engine_config.device_config.device_type == "neuron": from vllm.executor.neuron_executor import NeuronExecutorAsync executor_class = NeuronExecutorAsync elif engine_config.device_config.device_type == "cpu": diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 0631c0de76822..3418e5880262c 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -280,7 +280,10 @@ def from_engine_args( engine_config.parallel_config.distributed_executor_backend) # Initialize the cluster and specify the executor class. - if engine_config.device_config.device_type == "neuron": + if engine_args.simulation_mode: + from vllm.executor.simulated_executor import SimulatedExecutor + executor_class = SimulatedExecutor + elif engine_config.device_config.device_type == "neuron": from vllm.executor.neuron_executor import NeuronExecutor executor_class = NeuronExecutor elif engine_config.device_config.device_type == "cpu": diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 3ad201f4757ec..03b3acc9017f2 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -1,3 +1,4 @@ +import time from typing import Any, Dict, List, Optional, Set, Tuple, Union from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase @@ -88,7 +89,23 @@ def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None: def execute_model( self, execute_model_req: ExecuteModelRequest ) -> List[Union[SamplerOutput, PoolerOutput]]: + prefill_tokens = sum( + seq_group.token_chunk_size + for seq_group in execute_model_req.seq_group_metadata_list + if seq_group.is_prompt) + decode_tokens = sum( + seq_group.token_chunk_size + for seq_group in execute_model_req.seq_group_metadata_list + if not seq_group.is_prompt) + + start = time.perf_counter_ns() output = self.driver_worker.execute_model(execute_model_req) + duration = (time.perf_counter_ns() - start) / 1e6 + + print( + f"prefill_tokens: {prefill_tokens}, decode_tokens: {decode_tokens}, duration_ms: {duration}" + ) + self.simulation_profile.record(prefill_tokens, decode_tokens, duration) return output def add_lora(self, lora_request: LoRARequest) -> bool: diff --git a/vllm/executor/simulated_executor.py b/vllm/executor/simulated_executor.py new file mode 100644 index 0000000000000..81b9c6e87db28 --- /dev/null +++ b/vllm/executor/simulated_executor.py @@ -0,0 +1,105 @@ +from typing import List, Set, Tuple, Optional + +import simpy + +from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase +from vllm.sequence import CompletionSequenceGroupOutput, SequenceOutput, Logprob +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ParallelConfig, SchedulerConfig, + SpeculativeConfig, VisionLanguageConfig) +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.sequence import ExecuteModelRequest, SamplerOutput + +logger = init_logger(__name__) + + +class SimulatedExecutor(ExecutorBase): + model_config: ModelConfig + cache_config: CacheConfig + parallel_config: ParallelConfig + scheduler_config: SchedulerConfig + device_config: DeviceConfig + load_config: LoadConfig + lora_config: Optional[LoRAConfig] + vision_language_config: Optional[VisionLanguageConfig] + speculative_config: Optional[SpeculativeConfig] + + env = simpy.Environment + + def _init_executor(self) -> None: + pass + + def _init_worker(self): + pass + + def determine_num_available_blocks(self) -> Tuple[int, int]: + # TODO: make it realistic + return [int(1e5), int(1e5)] + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache by invoking the underlying worker. + """ + pass + + def execute_model( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + prefill_tokens = sum( + seq_group.token_chunk_size + for seq_group in execute_model_req.seq_group_metadata_list + if seq_group.is_prompt) + decode_tokens = sum( + seq_group.token_chunk_size + for seq_group in execute_model_req.seq_group_metadata_list + if not seq_group.is_prompt) + + out = [] + ids = [] + for seq_group in execute_model_req.seq_group_metadata_list: + ids.append(seq_group.request_id) + + for seq_id, _data in seq_group.seq_data.items(): + out.append( + CompletionSequenceGroupOutput( + samples=[ + SequenceOutput( + parent_seq_id=seq_id, + output_token=0, + logprobs={ + 0: Logprob(logprob=1), + }, + ), + ], + prompt_logprobs=None, + )) + # print("processed requests: ", ids) + duration_ms = self.simulation_profile.get_estimate( + prefill_tokens, decode_tokens) + self.env.run(self.env.now + duration_ms / 1e3) + return [SamplerOutput(outputs=out)] + + def add_lora(self, lora_request: LoRARequest) -> bool: + raise NotImplementedError() + + def remove_lora(self, lora_id: int) -> bool: + raise NotImplementedError() + + def list_loras(self) -> Set[int]: + raise NotImplementedError() + + def check_health(self) -> None: + return + + +class SimulatedExecutorAsync(SimulatedExecutor, ExecutorAsyncBase): + + async def execute_model_async( + self, + execute_model_req: ExecuteModelRequest, + ) -> List[SamplerOutput]: + pass + + async def check_health_async(self) -> None: + return From 7175a28cc2537aa3e20791191fc1315889c5a85f Mon Sep 17 00:00:00 2001 From: simon-mo Date: Tue, 25 Jun 2024 18:53:29 +0000 Subject: [PATCH 2/6] wip refactoring --- examples/simulation_mode.py | 149 +----------------------- vllm/entrypoints/simulator/__init__.py | 4 + vllm/entrypoints/simulator/loadgen.py | 10 ++ vllm/entrypoints/simulator/profile.py | 46 ++++++++ vllm/entrypoints/simulator/simulator.py | 96 +++++++++++++++ 5 files changed, 159 insertions(+), 146 deletions(-) create mode 100644 vllm/entrypoints/simulator/__init__.py create mode 100644 vllm/entrypoints/simulator/loadgen.py create mode 100644 vllm/entrypoints/simulator/profile.py create mode 100644 vllm/entrypoints/simulator/simulator.py diff --git a/examples/simulation_mode.py b/examples/simulation_mode.py index 3f192b05b5e4e..e1ffe58ab7359 100644 --- a/examples/simulation_mode.py +++ b/examples/simulation_mode.py @@ -1,64 +1,4 @@ -from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams -import simpy -from dataclasses import dataclass -import time -import json - - -class SimulationProfile: - """Profiling data structure capturing the timing of a single forward pass.""" - - def __init__(self): - self.prefill_timing = {} - self.decode_timing = {} - - def record(self, prefill_tokens, decode_tokens, duration_ms): - # TODO: use histogram, and sampling - - assert not all([prefill_tokens, decode_tokens]), "chunked prefill todo" - - if prefill_tokens: - self.prefill_timing[str(prefill_tokens)] = duration_ms - else: - self.decode_timing[str(decode_tokens)] = duration_ms - - def get_estimate(self, prefill_tokens, decode_tokens): - # TODO: sample - - assert not all([prefill_tokens, decode_tokens]), "chunked prefill todo" - - if prefill_tokens: - return self.prefill_timing[str(prefill_tokens)] - else: - return self.decode_timing[str(decode_tokens)] - - def save(self, path): - with open(path, "w") as f: - json.dump( - { - "prefill_timing": self.prefill_timing, - "decode_timing": self.decode_timing - }, f) - - @classmethod - def load(cls, path): - o = cls() - with open(path, "r") as f: - data = json.load(f) - o.prefill_timing = data["prefill_timing"] - o.decode_timing = data["decode_timing"] - return o - - -# workload characteristics -@dataclass -class WorkloadRequest: - arrival_time: float - num_input_tokens: int - num_output_tokens: int - - # TODO: add fields for prefix sharing - +from vllm.entrypoints.simulator import vLLMSimulator, WorkloadRequest workloads = [ WorkloadRequest(0.0, 10, 10), @@ -70,88 +10,5 @@ class WorkloadRequest: WorkloadRequest(0.2, 50, 50), ] -# SIMULATION_MODE = True -SIMULATION_MODE = False - -engine = LLMEngine.from_engine_args( - EngineArgs(model="facebook/opt-125m", simulation_mode=SIMULATION_MODE)) -# env = simpy.Environment() -env = simpy.rt.RealtimeEnvironment(factor=1, strict=True) - -if not SIMULATION_MODE: - # enable profiling - profile = SimulationProfile() - engine.model_executor.simulation_profile = profile -else: - profile = SimulationProfile.load("profile.json") - - engine.model_executor.simulation_profile = profile - engine.model_executor.env = env - - time.time = lambda: env.now - -generator_finished = False -enqueued = simpy.Store(env, capacity=1) - - -def request_generator(env): - curr_time = env.now - # assume workloads are sorted by arrival time - for i, workload in enumerate(workloads): - if env.now != workload.arrival_time: - yield env.timeout(workload.arrival_time - curr_time) - - engine.add_request( - request_id=str(i), - prompt=None, - prompt_token_ids=[0] * workload.num_input_tokens, - params=SamplingParams(max_tokens=workload.num_output_tokens, - ignore_eos=True), - ) - - if len(enqueued.items) == 0: - # notify the engine that there is a new request - enqueued.put(i) - - global generator_finished - generator_finished = True - if len(enqueued.items) == 0: - # notify the engine that there is a new request - enqueued.put(i) - - -def engine_runner(env): - start_time = time.time() - while not generator_finished or engine.has_unfinished_requests(): - start = time.time() - outputs = engine.step() - print("---") - for output in outputs: - output_metrics = { - "arrival_time": output.metrics.arrival_time - start_time, - "last_token_time": output.metrics.last_token_time - start_time, - "first_scheduled_time": - output.metrics.first_scheduled_time - start_time, - "first_token_time": - output.metrics.first_token_time - start_time - } - print(output.request_id, output_metrics) - print("---") - end = time.time() - # TODO: use proper synchronization - passed = end - start - print(passed) - if not SIMULATION_MODE: - yield env.timeout(passed) - else: - if not engine.has_unfinished_requests(): - yield enqueued.get() - yield env.timeout(0.0006) # fixed scheduler overhead - - -env.process(request_generator(env)) -env.process(engine_runner(env)) -env.run() - -if not SIMULATION_MODE: - profile.save("profile.json") +engine = vLLMSimulator(model="facebook/opt-125m") +engine.profile(workloads) diff --git a/vllm/entrypoints/simulator/__init__.py b/vllm/entrypoints/simulator/__init__.py new file mode 100644 index 0000000000000..2d4439780a070 --- /dev/null +++ b/vllm/entrypoints/simulator/__init__.py @@ -0,0 +1,4 @@ +from vllm.entrypoints.simulator.simulator import vLLMSimulator +from vllm.entrypoints.simulator.loadgen import WorkloadRequest + +__all__ = ["vLLMSimulator", "WorkloadRequest"] diff --git a/vllm/entrypoints/simulator/loadgen.py b/vllm/entrypoints/simulator/loadgen.py new file mode 100644 index 0000000000000..2aa361cea27ba --- /dev/null +++ b/vllm/entrypoints/simulator/loadgen.py @@ -0,0 +1,10 @@ +from dataclasses import dataclass + + +@dataclass +class WorkloadRequest: + arrival_time: float + num_input_tokens: int + num_output_tokens: int + + # TODO: add fields for prefix sharing diff --git a/vllm/entrypoints/simulator/profile.py b/vllm/entrypoints/simulator/profile.py new file mode 100644 index 0000000000000..6ba72accf606f --- /dev/null +++ b/vllm/entrypoints/simulator/profile.py @@ -0,0 +1,46 @@ +import json + + +class SimulationProfile: + """Profiling data structure capturing the timing of a single forward pass.""" + + def __init__(self): + self.prefill_timing = {} + self.decode_timing = {} + + def record(self, prefill_tokens, decode_tokens, duration_ms): + # TODO: use histogram, and sampling + + assert not all([prefill_tokens, decode_tokens]), "chunked prefill todo" + + if prefill_tokens: + self.prefill_timing[str(prefill_tokens)] = duration_ms + else: + self.decode_timing[str(decode_tokens)] = duration_ms + + def get_estimate(self, prefill_tokens, decode_tokens): + # TODO: sample + + assert not all([prefill_tokens, decode_tokens]), "chunked prefill todo" + + if prefill_tokens: + return self.prefill_timing[str(prefill_tokens)] + else: + return self.decode_timing[str(decode_tokens)] + + def save(self, path): + with open(path, "w") as f: + json.dump( + { + "prefill_timing": self.prefill_timing, + "decode_timing": self.decode_timing + }, f) + + @classmethod + def load(cls, path): + o = cls() + with open(path, "r") as f: + data = json.load(f) + o.prefill_timing = data["prefill_timing"] + o.decode_timing = data["decode_timing"] + return o diff --git a/vllm/entrypoints/simulator/simulator.py b/vllm/entrypoints/simulator/simulator.py new file mode 100644 index 0000000000000..52f4fdea250cf --- /dev/null +++ b/vllm/entrypoints/simulator/simulator.py @@ -0,0 +1,96 @@ +from dataclasses import dataclass +import time +import json + +import simpy + +from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams +from vllm.entrypoints.simulator.profile import SimulationProfile + + +class vLLMSimulator: + + def __init__(self, model="facebook/opt-125m") -> None: + self.profile_engine = LLMEngine.from_engine_args( + EngineArgs(model=model)) + self.simulated_engine = LLMEngine.from_engine_args( + EngineArgs(model=model, simulation_mode=True)) + + self._reset() + + def _reset(self): + self.env = simpy.rt.RealtimeEnvironment(factor=1, strict=True) + self.generator_finished = False + self.enqueued = simpy.Store(self.env, capacity=1) + + def _process_run_engine(self, is_profile): + start_time = time.time() + while not self.generator_finished or self.engine.has_unfinished_requests( + ): + start = time.time() + outputs = self.engine.step() + print("---") + for output in outputs: + output_metrics = { + "arrival_time": + output.metrics.arrival_time - start_time, + "last_token_time": + output.metrics.last_token_time - start_time, + "first_scheduled_time": + output.metrics.first_scheduled_time - start_time, + "first_token_time": + output.metrics.first_token_time - start_time + } + print(output.request_id, output_metrics) + print("---") + end = time.time() + # TODO: use proper synchronization + passed = end - start + print(passed) + if is_profile: + yield self.env.timeout(passed) + else: + if not self.engine.has_unfinished_requests(): + yield self.enqueued.get() + yield self.env.timeout(0.0006) # fixed scheduler overhead + + def _process_request_generator(self, workloads): + curr_time = self.env.now + # assume workloads are sorted by arrival time + for i, workload in enumerate(workloads): + if self.env.now != workload.arrival_time: + assert self.env.now < workload.arrival_time + yield self.env.timeout(workload.arrival_time - curr_time) + + self.engine.add_request( + request_id=str(i), + prompt=None, + prompt_token_ids=[0] * workload.num_input_tokens, + params=SamplingParams(max_tokens=workload.num_output_tokens, + ignore_eos=True), + ) + + if len(self.enqueued.items) == 0: + # notify the engine that there is a new request + self.enqueued.put(i) + + global generator_finished + generator_finished = True + if len(self.enqueued.items) == 0: + # notify the engine that there is a new request + self.enqueued.put(i) + + def profile(self, workloads): + profile = SimulationProfile() + self.profile_engine.model_executor.simulation_profile = profile + + self.env.process(self._process_request_generator(workloads)) + self.env.process(self._process_run_engine(is_profile=True)) + self.env.run() + + profile.save("profile.json") + + +# if __name__ == "__main__": +# from vllm.utils import FlexibleArgumentParser +# parser = FlexibleArgumentParser() From b797102f597440b3a54c7dbb44c136d3e15d45df Mon Sep 17 00:00:00 2001 From: simon-mo Date: Mon, 1 Jul 2024 22:12:52 -0700 Subject: [PATCH 3/6] working sim --- vllm/entrypoints/simulator/simulator.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/vllm/entrypoints/simulator/simulator.py b/vllm/entrypoints/simulator/simulator.py index 52f4fdea250cf..d4f19e3d376a4 100644 --- a/vllm/entrypoints/simulator/simulator.py +++ b/vllm/entrypoints/simulator/simulator.py @@ -23,12 +23,12 @@ def _reset(self): self.generator_finished = False self.enqueued = simpy.Store(self.env, capacity=1) - def _process_run_engine(self, is_profile): + def _process_run_engine(self, engine, is_profile): start_time = time.time() while not self.generator_finished or self.engine.has_unfinished_requests( ): start = time.time() - outputs = self.engine.step() + outputs = engine.step() print("---") for output in outputs: output_metrics = { @@ -54,7 +54,7 @@ def _process_run_engine(self, is_profile): yield self.enqueued.get() yield self.env.timeout(0.0006) # fixed scheduler overhead - def _process_request_generator(self, workloads): + def _process_request_generator(self, engine, workloads): curr_time = self.env.now # assume workloads are sorted by arrival time for i, workload in enumerate(workloads): @@ -62,10 +62,9 @@ def _process_request_generator(self, workloads): assert self.env.now < workload.arrival_time yield self.env.timeout(workload.arrival_time - curr_time) - self.engine.add_request( + engine.add_request( request_id=str(i), - prompt=None, - prompt_token_ids=[0] * workload.num_input_tokens, + inputs=dict(prompt_token_ids=[0] * workload.num_input_tokens), params=SamplingParams(max_tokens=workload.num_output_tokens, ignore_eos=True), ) @@ -84,8 +83,8 @@ def profile(self, workloads): profile = SimulationProfile() self.profile_engine.model_executor.simulation_profile = profile - self.env.process(self._process_request_generator(workloads)) - self.env.process(self._process_run_engine(is_profile=True)) + self.env.process(self._process_request_generator(self.profile_engine, workloads)) + self.env.process(self._process_run_engine(self.profile_engine, is_profile=True)) self.env.run() profile.save("profile.json") From 19da9e16342bac524940b4db68579524397a17f7 Mon Sep 17 00:00:00 2001 From: simon-mo Date: Thu, 1 Aug 2024 17:11:36 -0700 Subject: [PATCH 4/6] fix merge --- vllm/config.py | 1 + vllm/engine/arg_utils.py | 1 + vllm/engine/llm_engine.py | 10 +++++----- vllm/entrypoints/simulator/simulator.py | 7 +++++-- vllm/executor/simulated_executor.py | 26 ++++++++++++++++++++----- 5 files changed, 33 insertions(+), 12 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index ef56e2b6395be..da78f723ea472 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1613,6 +1613,7 @@ class EngineConfig: decoding_config: Optional[DecodingConfig] observability_config: Optional[ObservabilityConfig] prompt_adapter_config: Optional[PromptAdapterConfig] + simulation_mode: bool = False def __post_init__(self): """Verify configs are valid & consistent with each other. diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index ac7860b856da4..5091c7ca0a222 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -881,6 +881,7 @@ def create_engine_config(self, ) -> EngineConfig: decoding_config=decoding_config, observability_config=observability_config, prompt_adapter_config=prompt_adapter_config, + simulation_mode=self.simulation_mode, ) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 1d3fccb53311f..1d1253afca617 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -171,6 +171,7 @@ def __init__( log_stats: bool, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, + simulation_mode: bool = False, ) -> None: logger.info( "Initializing an LLM engine (v%s) with config: " @@ -352,6 +353,8 @@ def __init__( self.get_tokenizer_for_seq, ), )) + + self.simulation_mode = simulation_mode def _initialize_kv_caches(self) -> None: """Initialize the KV cache in the worker(s). @@ -381,12 +384,10 @@ def _get_executor_cls(cls, distributed_executor_backend = ( engine_config.parallel_config.distributed_executor_backend) # Initialize the cluster and specify the executor class. -<<<<<<< HEAD - if engine_args.simulation_mode: + if engine_config.simulation_mode: from vllm.executor.simulated_executor import SimulatedExecutor executor_class = SimulatedExecutor -======= - if isinstance(distributed_executor_backend, type): + elif isinstance(distributed_executor_backend, type): if not issubclass(distributed_executor_backend, ExecutorBase): raise TypeError( "distributed_executor_backend must be a subclass of " @@ -394,7 +395,6 @@ def _get_executor_cls(cls, if distributed_executor_backend.uses_ray: # type: ignore initialize_ray_cluster(engine_config.parallel_config) executor_class = distributed_executor_backend ->>>>>>> 6a11fdfbb8d6701c7ad38648aead23d8cbe6aac5 elif engine_config.device_config.device_type == "neuron": from vllm.executor.neuron_executor import NeuronExecutor executor_class = NeuronExecutor diff --git a/vllm/entrypoints/simulator/simulator.py b/vllm/entrypoints/simulator/simulator.py index d4f19e3d376a4..87b1ade538e93 100644 --- a/vllm/entrypoints/simulator/simulator.py +++ b/vllm/entrypoints/simulator/simulator.py @@ -29,6 +29,8 @@ def _process_run_engine(self, engine, is_profile): ): start = time.time() outputs = engine.step() + if len(outputs) == 0: + continue print("---") for output in outputs: output_metrics = { @@ -73,8 +75,7 @@ def _process_request_generator(self, engine, workloads): # notify the engine that there is a new request self.enqueued.put(i) - global generator_finished - generator_finished = True + self.generator_finished = True if len(self.enqueued.items) == 0: # notify the engine that there is a new request self.enqueued.put(i) @@ -88,6 +89,8 @@ def profile(self, workloads): self.env.run() profile.save("profile.json") + + # if __name__ == "__main__": diff --git a/vllm/executor/simulated_executor.py b/vllm/executor/simulated_executor.py index 81b9c6e87db28..c4e76f452ef53 100644 --- a/vllm/executor/simulated_executor.py +++ b/vllm/executor/simulated_executor.py @@ -6,9 +6,10 @@ from vllm.sequence import CompletionSequenceGroupOutput, SequenceOutput, Logprob from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, - SpeculativeConfig, VisionLanguageConfig) + SpeculativeConfig) from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import ExecuteModelRequest, SamplerOutput logger = init_logger(__name__) @@ -22,7 +23,6 @@ class SimulatedExecutor(ExecutorBase): device_config: DeviceConfig load_config: LoadConfig lora_config: Optional[LoRAConfig] - vision_language_config: Optional[VisionLanguageConfig] speculative_config: Optional[SpeculativeConfig] env = simpy.Environment @@ -81,13 +81,29 @@ def execute_model( return [SamplerOutput(outputs=out)] def add_lora(self, lora_request: LoRARequest) -> bool: - raise NotImplementedError() + raise NotImplementedError def remove_lora(self, lora_id: int) -> bool: - raise NotImplementedError() + raise NotImplementedError + + def pin_lora(self, lora_id: int) -> bool: + raise NotImplementedError # type: ignore def list_loras(self) -> Set[int]: - raise NotImplementedError() + raise NotImplementedError + + def add_prompt_adapter( + self, prompt_adapter_request: PromptAdapterRequest) -> bool: + raise NotImplementedError + + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + raise NotImplementedError + + def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: + raise NotImplementedError # type: ignore + + def list_prompt_adapters(self) -> Set[int]: + raise NotImplementedError def check_health(self) -> None: return From 352163d46f7b444d84d76b45c4892fe30ff0c6c6 Mon Sep 17 00:00:00 2001 From: simon-mo Date: Fri, 2 Aug 2024 11:36:19 -0700 Subject: [PATCH 5/6] add 8b sim mode --- examples/simulation_mode.py | 77 +++++++++++++++++++++---- vllm/entrypoints/simulator/profile.py | 13 +++-- vllm/entrypoints/simulator/simulator.py | 42 +++++++++++++- 3 files changed, 113 insertions(+), 19 deletions(-) diff --git a/examples/simulation_mode.py b/examples/simulation_mode.py index e1ffe58ab7359..1b098cdb39184 100644 --- a/examples/simulation_mode.py +++ b/examples/simulation_mode.py @@ -1,14 +1,67 @@ from vllm.entrypoints.simulator import vLLMSimulator, WorkloadRequest -workloads = [ - WorkloadRequest(0.0, 10, 10), - WorkloadRequest(0.0, 20, 20), - WorkloadRequest(0.0, 10, 10), - WorkloadRequest(0.0, 20, 20), - WorkloadRequest(0.05, 30, 30), - WorkloadRequest(0.1, 40, 40), - WorkloadRequest(0.2, 50, 50), -] - -engine = vLLMSimulator(model="facebook/opt-125m") -engine.profile(workloads) +# workloads = [ +# WorkloadRequest(0.0, 10, 10), +# WorkloadRequest(0.0, 20, 20), +# WorkloadRequest(0.0, 10, 10), +# WorkloadRequest(0.0, 20, 20), +# WorkloadRequest(0.05, 30, 30), +# WorkloadRequest(0.1, 40, 40), +# WorkloadRequest(0.2, 50, 50), +# ] + +# engine = vLLMSimulator(model="facebook/opt-125m") +# engine.profile(workloads) + +# Profile workload batch with prefill 2->2048 tokens and decode 1->256 tokens, using power of 2 ranges, but with one mid point value in between. +# For decode workload, sample the input tokens to 64. + +import numpy as np +prefill_sizes = np.logspace(np.log2(2), np.log2(8192), num=64, base=2).round().astype(int) +decode_sizes = np.logspace(np.log2(2), np.log2(512), num=64, base=2).round().astype(int) + +workload_batch = [] +# first measure prefill +for i in prefill_sizes: + workload_batch.append([WorkloadRequest(0, i, 0)]) +# then measure decode +for i in decode_sizes: + # TODO: check ctx size > 1, and how it varies (it should not) + # workload_batch.append([WorkloadRequest(0, 1, 2) for _ in range(i)]) + workload_batch.append([WorkloadRequest(0, 1, 2) for _ in range(i)]) + +print(prefill_sizes) +print(decode_sizes) + +# workload_batch = [ +# [WorkloadRequest(0, 100, 0), WorkloadRequest(0, 100, 0)], # 200 prefil tokens +# [WorkloadRequest(0, 200, 4), WorkloadRequest(0, 200, 4)], # 400 prefill tokens, 2 decode token per iter +# ] + +# engine = vLLMSimulator(model="facebook/opt-125m") +engine = vLLMSimulator(model="meta-llama/Meta-Llama-3-8B-Instruct") +profile = engine.profile_tokens_curve(workload_batch, n_trials=5) + +# print(profile.prefill_timing) # this is a defaultdict(list) +# print(profile.decode_timing) # this is a defaultdict(list) + +# turn this into a pandas dataframe + +import pandas as pd + +data = [] +for k, v in profile.prefill_timing.items(): + for i in v: + data.append({"size": k, "time_ms": i, "op": "prefill"}) + +for k, v in profile.decode_timing.items(): + for i in v: + data.append({"size": k, "time_ms": i, "op": "decode"}) + + +df = pd.DataFrame(data) + +import sys +print("----") +df.to_csv(sys.stdout, index=False) +df.to_csv("profile.csv", index=False) \ No newline at end of file diff --git a/vllm/entrypoints/simulator/profile.py b/vllm/entrypoints/simulator/profile.py index 6ba72accf606f..ad0dbe42cc7d9 100644 --- a/vllm/entrypoints/simulator/profile.py +++ b/vllm/entrypoints/simulator/profile.py @@ -1,12 +1,13 @@ import json +from collections import defaultdict class SimulationProfile: """Profiling data structure capturing the timing of a single forward pass.""" def __init__(self): - self.prefill_timing = {} - self.decode_timing = {} + self.prefill_timing = defaultdict(list) + self.decode_timing = defaultdict(list) def record(self, prefill_tokens, decode_tokens, duration_ms): # TODO: use histogram, and sampling @@ -14,9 +15,9 @@ def record(self, prefill_tokens, decode_tokens, duration_ms): assert not all([prefill_tokens, decode_tokens]), "chunked prefill todo" if prefill_tokens: - self.prefill_timing[str(prefill_tokens)] = duration_ms + self.prefill_timing[str(prefill_tokens)].append(duration_ms) else: - self.decode_timing[str(decode_tokens)] = duration_ms + self.decode_timing[str(decode_tokens)].append(duration_ms) def get_estimate(self, prefill_tokens, decode_tokens): # TODO: sample @@ -24,9 +25,9 @@ def get_estimate(self, prefill_tokens, decode_tokens): assert not all([prefill_tokens, decode_tokens]), "chunked prefill todo" if prefill_tokens: - return self.prefill_timing[str(prefill_tokens)] + return self.prefill_timing[str(prefill_tokens)][0] else: - return self.decode_timing[str(decode_tokens)] + return self.decode_timing[str(decode_tokens)][0] def save(self, path): with open(path, "w") as f: diff --git a/vllm/entrypoints/simulator/simulator.py b/vllm/entrypoints/simulator/simulator.py index 87b1ade538e93..d09fdb25cbf43 100644 --- a/vllm/entrypoints/simulator/simulator.py +++ b/vllm/entrypoints/simulator/simulator.py @@ -12,7 +12,7 @@ class vLLMSimulator: def __init__(self, model="facebook/opt-125m") -> None: self.profile_engine = LLMEngine.from_engine_args( - EngineArgs(model=model)) + EngineArgs(model=model, enforce_eager=True)) self.simulated_engine = LLMEngine.from_engine_args( EngineArgs(model=model, simulation_mode=True)) @@ -91,6 +91,46 @@ def profile(self, workloads): profile.save("profile.json") + def profile_tokens_curve(self, workload_batch, n_trials=1): + profile = SimulationProfile() + self.profile_engine.model_executor.simulation_profile = profile + idx = -1 + + # warmup + for batch in workload_batch: + for workload in batch: + idx += 1 + self.profile_engine.add_request( + request_id=str(idx), + inputs=dict(prompt_token_ids=[0] * workload.num_input_tokens), + params=SamplingParams(max_tokens=max(workload.num_output_tokens, 1), + ignore_eos=True), + ) + while self.profile_engine.has_unfinished_requests(): + self.profile_engine.step() + + # real run + + profile = SimulationProfile() + self.profile_engine.model_executor.simulation_profile = profile + + + for batch in workload_batch: + for _ in range(n_trials): + for workload in batch: + idx += 1 + self.profile_engine.add_request( + request_id=str(idx), + inputs=dict(prompt_token_ids=[0] * workload.num_input_tokens), + params=SamplingParams(max_tokens=max(workload.num_output_tokens, 1), + ignore_eos=True), + ) + while self.profile_engine.has_unfinished_requests(): + self.profile_engine.step() + + return profile + + # if __name__ == "__main__": From a98b71cf26b65597e2ff9f47de4e10715390f2b8 Mon Sep 17 00:00:00 2001 From: simon-mo Date: Mon, 5 Aug 2024 10:49:21 -0700 Subject: [PATCH 6/6] profile 70b --- examples/simulation_mode.py | 17 ++++++++----- vllm/entrypoints/simulator/profile.py | 29 +++++++++++++++++++++++ vllm/entrypoints/simulator/simulator.py | 23 +++++++++++++----- vllm/executor/distributed_gpu_executor.py | 5 +++- vllm/executor/gpu_executor.py | 21 ++++------------ 5 files changed, 65 insertions(+), 30 deletions(-) diff --git a/examples/simulation_mode.py b/examples/simulation_mode.py index 1b098cdb39184..71372fe0a62f5 100644 --- a/examples/simulation_mode.py +++ b/examples/simulation_mode.py @@ -17,8 +17,11 @@ # For decode workload, sample the input tokens to 64. import numpy as np -prefill_sizes = np.logspace(np.log2(2), np.log2(8192), num=64, base=2).round().astype(int) -decode_sizes = np.logspace(np.log2(2), np.log2(512), num=64, base=2).round().astype(int) + +prefill_sizes = np.logspace(np.log2(2), np.log2(8192), num=64, + base=2).round().astype(int) +decode_sizes = np.logspace(np.log2(2), np.log2(512), num=64, + base=2).round().astype(int) workload_batch = [] # first measure prefill @@ -39,8 +42,10 @@ # ] # engine = vLLMSimulator(model="facebook/opt-125m") -engine = vLLMSimulator(model="meta-llama/Meta-Llama-3-8B-Instruct") -profile = engine.profile_tokens_curve(workload_batch, n_trials=5) +# engine = vLLMSimulator(model="meta-llama/Meta-Llama-3-8B-Instruct") +engine = vLLMSimulator(model="meta-llama/Meta-Llama-3-70B-Instruct", + tensor_parallel_size=4) +profile = engine.profile_tokens_curve(workload_batch, n_trials=3) # print(profile.prefill_timing) # this is a defaultdict(list) # print(profile.decode_timing) # this is a defaultdict(list) @@ -58,10 +63,10 @@ for i in v: data.append({"size": k, "time_ms": i, "op": "decode"}) - df = pd.DataFrame(data) import sys + print("----") df.to_csv(sys.stdout, index=False) -df.to_csv("profile.csv", index=False) \ No newline at end of file +df.to_csv("profile-70b.csv", index=False) diff --git a/vllm/entrypoints/simulator/profile.py b/vllm/entrypoints/simulator/profile.py index ad0dbe42cc7d9..9e75711267f37 100644 --- a/vllm/entrypoints/simulator/profile.py +++ b/vllm/entrypoints/simulator/profile.py @@ -1,5 +1,30 @@ import json from collections import defaultdict +from contextlib import contextmanager +import time + + +@contextmanager +def profile_hook(execute_model_req, profile): + prefill_tokens = sum( + seq_group.token_chunk_size + for seq_group in execute_model_req.seq_group_metadata_list + if seq_group.is_prompt) + decode_tokens = sum( + seq_group.token_chunk_size + for seq_group in execute_model_req.seq_group_metadata_list + if not seq_group.is_prompt) + + start = time.perf_counter_ns() + + yield + + duration = (time.perf_counter_ns() - start) / 1e6 + + print( + f"prefill_tokens: {prefill_tokens}, decode_tokens: {decode_tokens}, duration_ms: {duration}" + ) + profile.record(prefill_tokens, decode_tokens, duration) class SimulationProfile: @@ -9,6 +34,10 @@ def __init__(self): self.prefill_timing = defaultdict(list) self.decode_timing = defaultdict(list) + def clear(self): + self.prefill_timing.clear() + self.decode_timing.clear() + def record(self, prefill_tokens, decode_tokens, duration_ms): # TODO: use histogram, and sampling diff --git a/vllm/entrypoints/simulator/simulator.py b/vllm/entrypoints/simulator/simulator.py index d09fdb25cbf43..31674e7f33ffe 100644 --- a/vllm/entrypoints/simulator/simulator.py +++ b/vllm/entrypoints/simulator/simulator.py @@ -10,11 +10,11 @@ class vLLMSimulator: - def __init__(self, model="facebook/opt-125m") -> None: + def __init__(self, model="facebook/opt-125m", **engine_args) -> None: self.profile_engine = LLMEngine.from_engine_args( - EngineArgs(model=model, enforce_eager=True)) + EngineArgs(model=model, enforce_eager=True, **engine_args)) self.simulated_engine = LLMEngine.from_engine_args( - EngineArgs(model=model, simulation_mode=True)) + EngineArgs(model=model, simulation_mode=True, **engine_args)) self._reset() @@ -98,7 +98,11 @@ def profile_tokens_curve(self, workload_batch, n_trials=1): # warmup for batch in workload_batch: + num_prefill_toks = 0 + num_decode_toks = 0 for workload in batch: + num_prefill_toks += workload.num_input_tokens + num_decode_toks += workload.num_output_tokens idx += 1 self.profile_engine.add_request( request_id=str(idx), @@ -106,18 +110,22 @@ def profile_tokens_curve(self, workload_batch, n_trials=1): params=SamplingParams(max_tokens=max(workload.num_output_tokens, 1), ignore_eos=True), ) + print(f"Warmup {num_prefill_toks=} {num_decode_toks=}") while self.profile_engine.has_unfinished_requests(): self.profile_engine.step() # real run - profile = SimulationProfile() - self.profile_engine.model_executor.simulation_profile = profile + self.profile_engine.model_executor.simulation_profile.clear() for batch in workload_batch: - for _ in range(n_trials): + for trial_idx in range(n_trials): + num_prefill_toks = 0 + num_decode_toks = 0 for workload in batch: + num_prefill_toks += workload.num_input_tokens + num_decode_toks += workload.num_output_tokens idx += 1 self.profile_engine.add_request( request_id=str(idx), @@ -125,9 +133,12 @@ def profile_tokens_curve(self, workload_batch, n_trials=1): params=SamplingParams(max_tokens=max(workload.num_output_tokens, 1), ignore_eos=True), ) + print(f"Trial {trial_idx+1}/{n_trials} {num_prefill_toks=} {num_decode_toks=}") while self.profile_engine.has_unfinished_requests(): self.profile_engine.step() + print("Finished profiling") + return profile diff --git a/vllm/executor/distributed_gpu_executor.py b/vllm/executor/distributed_gpu_executor.py index 4df54a09e5e8c..2472ef89581b9 100644 --- a/vllm/executor/distributed_gpu_executor.py +++ b/vllm/executor/distributed_gpu_executor.py @@ -7,6 +7,7 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.entrypoints.simulator.profile import profile_hook logger = init_logger(__name__) @@ -73,8 +74,10 @@ def execute_model( **self.extra_execute_model_run_workers_kwargs) # Only the driver worker returns the sampling results. - driver_outputs = self._driver_execute_model(execute_model_req) + with profile_hook(execute_model_req, self.simulation_profile): + driver_outputs = self._driver_execute_model(execute_model_req) assert driver_outputs is not None + return driver_outputs def stop_remote_worker_execution_loop(self) -> None: diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 13e51100a38e6..f6fd44c8420fb 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -10,6 +10,8 @@ make_async) from vllm.worker.worker_base import WorkerWrapperBase +from vllm.entrypoints.simulator.profile import profile_hook + logger = init_logger(__name__) @@ -108,23 +110,8 @@ def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None: def execute_model( self, execute_model_req: ExecuteModelRequest ) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]: - prefill_tokens = sum( - seq_group.token_chunk_size - for seq_group in execute_model_req.seq_group_metadata_list - if seq_group.is_prompt) - decode_tokens = sum( - seq_group.token_chunk_size - for seq_group in execute_model_req.seq_group_metadata_list - if not seq_group.is_prompt) - - start = time.perf_counter_ns() - output = self.driver_worker.execute_model(execute_model_req) - duration = (time.perf_counter_ns() - start) / 1e6 - - print( - f"prefill_tokens: {prefill_tokens}, decode_tokens: {decode_tokens}, duration_ms: {duration}" - ) - self.simulation_profile.record(prefill_tokens, decode_tokens, duration) + with profile_hook(execute_model_req, self.simulation_profile): + output = self.driver_worker.execute_model(execute_model_req) return output def add_lora(self, lora_request: LoRARequest) -> bool: