Skip to content

Commit

Permalink
wip refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
simon-mo committed Jun 25, 2024
1 parent 535b555 commit 7175a28
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 146 deletions.
149 changes: 3 additions & 146 deletions examples/simulation_mode.py
Original file line number Diff line number Diff line change
@@ -1,64 +1,4 @@
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
import simpy
from dataclasses import dataclass
import time
import json


class SimulationProfile:
"""Profiling data structure capturing the timing of a single forward pass."""

def __init__(self):
self.prefill_timing = {}
self.decode_timing = {}

def record(self, prefill_tokens, decode_tokens, duration_ms):
# TODO: use histogram, and sampling

assert not all([prefill_tokens, decode_tokens]), "chunked prefill todo"

if prefill_tokens:
self.prefill_timing[str(prefill_tokens)] = duration_ms
else:
self.decode_timing[str(decode_tokens)] = duration_ms

def get_estimate(self, prefill_tokens, decode_tokens):
# TODO: sample

assert not all([prefill_tokens, decode_tokens]), "chunked prefill todo"

if prefill_tokens:
return self.prefill_timing[str(prefill_tokens)]
else:
return self.decode_timing[str(decode_tokens)]

def save(self, path):
with open(path, "w") as f:
json.dump(
{
"prefill_timing": self.prefill_timing,
"decode_timing": self.decode_timing
}, f)

@classmethod
def load(cls, path):
o = cls()
with open(path, "r") as f:
data = json.load(f)
o.prefill_timing = data["prefill_timing"]
o.decode_timing = data["decode_timing"]
return o


# workload characteristics
@dataclass
class WorkloadRequest:
arrival_time: float
num_input_tokens: int
num_output_tokens: int

# TODO: add fields for prefix sharing

from vllm.entrypoints.simulator import vLLMSimulator, WorkloadRequest

workloads = [
WorkloadRequest(0.0, 10, 10),
Expand All @@ -70,88 +10,5 @@ class WorkloadRequest:
WorkloadRequest(0.2, 50, 50),
]

# SIMULATION_MODE = True
SIMULATION_MODE = False

engine = LLMEngine.from_engine_args(
EngineArgs(model="facebook/opt-125m", simulation_mode=SIMULATION_MODE))
# env = simpy.Environment()
env = simpy.rt.RealtimeEnvironment(factor=1, strict=True)

if not SIMULATION_MODE:
# enable profiling
profile = SimulationProfile()
engine.model_executor.simulation_profile = profile
else:
profile = SimulationProfile.load("profile.json")

engine.model_executor.simulation_profile = profile
engine.model_executor.env = env

time.time = lambda: env.now

generator_finished = False
enqueued = simpy.Store(env, capacity=1)


def request_generator(env):
curr_time = env.now
# assume workloads are sorted by arrival time
for i, workload in enumerate(workloads):
if env.now != workload.arrival_time:
yield env.timeout(workload.arrival_time - curr_time)

engine.add_request(
request_id=str(i),
prompt=None,
prompt_token_ids=[0] * workload.num_input_tokens,
params=SamplingParams(max_tokens=workload.num_output_tokens,
ignore_eos=True),
)

if len(enqueued.items) == 0:
# notify the engine that there is a new request
enqueued.put(i)

global generator_finished
generator_finished = True
if len(enqueued.items) == 0:
# notify the engine that there is a new request
enqueued.put(i)


def engine_runner(env):
start_time = time.time()
while not generator_finished or engine.has_unfinished_requests():
start = time.time()
outputs = engine.step()
print("---")
for output in outputs:
output_metrics = {
"arrival_time": output.metrics.arrival_time - start_time,
"last_token_time": output.metrics.last_token_time - start_time,
"first_scheduled_time":
output.metrics.first_scheduled_time - start_time,
"first_token_time":
output.metrics.first_token_time - start_time
}
print(output.request_id, output_metrics)
print("---")
end = time.time()
# TODO: use proper synchronization
passed = end - start
print(passed)
if not SIMULATION_MODE:
yield env.timeout(passed)
else:
if not engine.has_unfinished_requests():
yield enqueued.get()
yield env.timeout(0.0006) # fixed scheduler overhead


env.process(request_generator(env))
env.process(engine_runner(env))
env.run()

if not SIMULATION_MODE:
profile.save("profile.json")
engine = vLLMSimulator(model="facebook/opt-125m")
engine.profile(workloads)
4 changes: 4 additions & 0 deletions vllm/entrypoints/simulator/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from vllm.entrypoints.simulator.simulator import vLLMSimulator
from vllm.entrypoints.simulator.loadgen import WorkloadRequest

__all__ = ["vLLMSimulator", "WorkloadRequest"]
10 changes: 10 additions & 0 deletions vllm/entrypoints/simulator/loadgen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from dataclasses import dataclass


@dataclass
class WorkloadRequest:
arrival_time: float
num_input_tokens: int
num_output_tokens: int

# TODO: add fields for prefix sharing
46 changes: 46 additions & 0 deletions vllm/entrypoints/simulator/profile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import json


class SimulationProfile:
"""Profiling data structure capturing the timing of a single forward pass."""

def __init__(self):
self.prefill_timing = {}
self.decode_timing = {}

def record(self, prefill_tokens, decode_tokens, duration_ms):
# TODO: use histogram, and sampling

assert not all([prefill_tokens, decode_tokens]), "chunked prefill todo"

if prefill_tokens:
self.prefill_timing[str(prefill_tokens)] = duration_ms
else:
self.decode_timing[str(decode_tokens)] = duration_ms

def get_estimate(self, prefill_tokens, decode_tokens):
# TODO: sample

assert not all([prefill_tokens, decode_tokens]), "chunked prefill todo"

if prefill_tokens:
return self.prefill_timing[str(prefill_tokens)]
else:
return self.decode_timing[str(decode_tokens)]

def save(self, path):
with open(path, "w") as f:
json.dump(
{
"prefill_timing": self.prefill_timing,
"decode_timing": self.decode_timing
}, f)

@classmethod
def load(cls, path):
o = cls()
with open(path, "r") as f:
data = json.load(f)
o.prefill_timing = data["prefill_timing"]
o.decode_timing = data["decode_timing"]
return o
96 changes: 96 additions & 0 deletions vllm/entrypoints/simulator/simulator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from dataclasses import dataclass
import time
import json

import simpy

from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
from vllm.entrypoints.simulator.profile import SimulationProfile


class vLLMSimulator:

def __init__(self, model="facebook/opt-125m") -> None:
self.profile_engine = LLMEngine.from_engine_args(
EngineArgs(model=model))
self.simulated_engine = LLMEngine.from_engine_args(
EngineArgs(model=model, simulation_mode=True))

self._reset()

def _reset(self):
self.env = simpy.rt.RealtimeEnvironment(factor=1, strict=True)
self.generator_finished = False
self.enqueued = simpy.Store(self.env, capacity=1)

def _process_run_engine(self, is_profile):
start_time = time.time()
while not self.generator_finished or self.engine.has_unfinished_requests(
):
start = time.time()
outputs = self.engine.step()
print("---")
for output in outputs:
output_metrics = {
"arrival_time":
output.metrics.arrival_time - start_time,
"last_token_time":
output.metrics.last_token_time - start_time,
"first_scheduled_time":
output.metrics.first_scheduled_time - start_time,
"first_token_time":
output.metrics.first_token_time - start_time
}
print(output.request_id, output_metrics)
print("---")
end = time.time()
# TODO: use proper synchronization
passed = end - start
print(passed)
if is_profile:
yield self.env.timeout(passed)
else:
if not self.engine.has_unfinished_requests():
yield self.enqueued.get()
yield self.env.timeout(0.0006) # fixed scheduler overhead

def _process_request_generator(self, workloads):
curr_time = self.env.now
# assume workloads are sorted by arrival time
for i, workload in enumerate(workloads):
if self.env.now != workload.arrival_time:
assert self.env.now < workload.arrival_time
yield self.env.timeout(workload.arrival_time - curr_time)

self.engine.add_request(
request_id=str(i),
prompt=None,
prompt_token_ids=[0] * workload.num_input_tokens,
params=SamplingParams(max_tokens=workload.num_output_tokens,
ignore_eos=True),
)

if len(self.enqueued.items) == 0:
# notify the engine that there is a new request
self.enqueued.put(i)

global generator_finished
generator_finished = True
if len(self.enqueued.items) == 0:
# notify the engine that there is a new request
self.enqueued.put(i)

def profile(self, workloads):
profile = SimulationProfile()
self.profile_engine.model_executor.simulation_profile = profile

self.env.process(self._process_request_generator(workloads))
self.env.process(self._process_run_engine(is_profile=True))
self.env.run()

profile.save("profile.json")


# if __name__ == "__main__":
# from vllm.utils import FlexibleArgumentParser
# parser = FlexibleArgumentParser()

0 comments on commit 7175a28

Please sign in to comment.