Skip to content

Commit

Permalink
core: add multi-step scheduling support for the synchronous engine (#914
Browse files Browse the repository at this point in the history
)
  • Loading branch information
AlpinDale authored Dec 18, 2024
1 parent 7996677 commit b1492c1
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 80 deletions.
126 changes: 115 additions & 11 deletions aphrodite/engine/aphrodite_engine.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import time
from contextlib import contextmanager
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List, Optional
from typing import Sequence as GenericSequence
from typing import Tuple, Type, TypeVar, Union

import torch
from loguru import logger
from transformers import PreTrainedTokenizer
from typing_extensions import assert_never
Expand Down Expand Up @@ -74,6 +76,14 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
Optional[MultiModalDataDict]]


@dataclass
class SchedulerOutputState:
"""Caches the scheduler outputs for a virtual engine. Used for Multi-Step"""
last_output: Optional[SamplerOutput] = None
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
scheduler_outputs: Optional[SchedulerOutputs] = None


class AphroditeEngine:
"""An LLM engine that receives requests and generates texts.
Expand Down Expand Up @@ -197,7 +207,8 @@ def __init__(
"KV Cache DataType": cache_config.cache_dtype,
"Device": device_config.device,
"Rope Scaling": model_config.rope_scaling,
"Guided Decoding Backend": decoding_config
"Guided Decoding Backend": decoding_config,
"Scheduler Steps": scheduler_config.num_scheduler_steps,
}

logger.info("-" * 85)
Expand Down Expand Up @@ -326,6 +337,11 @@ def get_tokenizer_for_seq(sequence: Sequence) -> PreTrainedTokenizer:
),
))

self.cached_scheduler_outputs = [
SchedulerOutputState()
for _ in range(self.parallel_config.pipeline_parallel_size)
]

def _initialize_kv_caches(self) -> None:
"""Initialize the KV cache in the worker(s).
Expand Down Expand Up @@ -1189,16 +1205,36 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
raise NotImplementedError(
"Pipeline parallelism is only supported through AsyncAphrodite "
"as performance will be severely degraded otherwise.")
if self.scheduler_config.num_scheduler_steps > 1:
raise NotImplementedError(
"Multiple scheduler steps (multi-step) are only supported "
"through AsyncAphrodite.")
seq_group_metadata_list, scheduler_outputs = self.scheduler[
0].schedule()

# These are cached outputs from previous iterations. None if on first
# iteration
cached_outputs = self.cached_scheduler_outputs[0]
seq_group_metadata_list = cached_outputs.seq_group_metadata_list
scheduler_outputs = cached_outputs.scheduler_outputs
# Skip the scheduler if there are any remaining steps in the seq groups.
# This ensures that the scheduler is only called again when the current
# batch has completed.
if not self._has_remaining_steps(seq_group_metadata_list):
seq_group_metadata_list, scheduler_outputs = self.scheduler[
0].schedule()
if (self.scheduler_config.is_multi_step
and scheduler_outputs.num_lookahead_slots > 0):
# cache the scheduler outputs for the next iteration if we have
# lookahead slots
self._cache_scheduler_outputs_for_multi_step(
0, seq_group_metadata_list, scheduler_outputs)
assert seq_group_metadata_list is not None
assert scheduler_outputs is not None

if not scheduler_outputs.is_empty():
finished_requests_ids = self.scheduler[
0].get_and_reset_finished_requests_ids()
# Check if we have a cached last_output from the previous iteration.
# For supporting PP this is probably the best way to pass the
# sampled_token_ids, as a separate broadcast over all the PP stages
# will cause one virtual engine's microbatch to block the pipeline.
last_sampled_token_ids = \
self._get_last_sampled_token_ids(0)
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
Expand All @@ -1207,15 +1243,31 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
running_queue_size=scheduler_outputs.running_queue_size,
finished_requests_ids=finished_requests_ids,
)
# We use ExecuteModelRequest to pass the last sampled_token_ids
# to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids=last_sampled_token_ids)
output = self.model_executor.execute_model(
execute_model_req=execute_model_req)
# we need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP.
if self.scheduler_config.is_multi_step:
self._update_cached_scheduler_output(0, output)
else:
output = []

request_outputs = self._process_model_outputs(
output, scheduler_outputs.scheduled_seq_groups,
scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
# Finish the current step for all the sequence groups.
if self.scheduler_config.is_multi_step:
for seq_group in seq_group_metadata_list:
seq_group.finish_step()
if not self._has_remaining_steps(seq_group_metadata_list):
# clear the cache if we have finished all the steps
if self.scheduler_config.is_multi_step:
self.cached_scheduler_outputs[0] = SchedulerOutputState()
request_outputs = self._process_model_outputs(
output, scheduler_outputs.scheduled_seq_groups,
scheduler_outputs.ignored_seq_groups, seq_group_metadata_list)
else:
request_outputs = []

# Log stats.
self.do_log_stats(scheduler_outputs, output)
Expand All @@ -1230,6 +1282,58 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:

return request_outputs

def _has_remaining_steps(
self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
) -> bool:
if (not self.scheduler_config.is_multi_step
or not seq_group_metadata_list):
return False
# TODO(will) this is a sanity check for nowto make sure that all the
# seqs are on the same steps. Eventually we will want to do some sort of
# dynamic scheduling when doing multi-step decoding.
ref_remaining_steps = seq_group_metadata_list[0].state.remaining_steps
if any([
seq_group.state.remaining_steps != ref_remaining_steps
for seq_group in seq_group_metadata_list[1:]
]):
raise AssertionError(("All running sequence groups should "
"have the same remaining steps."))
return ref_remaining_steps > 0

def _cache_scheduler_outputs_for_multi_step(
self, virtual_engine: int,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
scheduler_outputs: SchedulerOutputs) -> None:
self.cached_scheduler_outputs[
virtual_engine].seq_group_metadata_list = seq_group_metadata_list
self.cached_scheduler_outputs[virtual_engine].scheduler_outputs = \
scheduler_outputs
self.cached_scheduler_outputs[virtual_engine].last_output = None

def _update_cached_scheduler_output(
self, virtual_engine: int,
output: List[Optional[SamplerOutput]]) -> None:
if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0
and output[0] is not None):
last_output = output[-1]
assert last_output is not None
assert last_output.sampled_token_ids_cpu is not None
assert last_output.sampled_token_ids is None
assert last_output.sampled_token_probs is None
self.cached_scheduler_outputs[
virtual_engine].last_output = last_output

def _get_last_sampled_token_ids(
self, virtual_engine: int) -> Optional[torch.Tensor]:
cached_last_output = self.cached_scheduler_outputs[
virtual_engine].last_output
if (self.scheduler_config.is_multi_step
and self.parallel_config.pipeline_parallel_size > 1
and cached_last_output is not None
and cached_last_output.sampled_token_ids_cpu is not None):
return cached_last_output.sampled_token_ids_cpu
return None

def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
if logger_name in self.stat_loggers:
raise KeyError(f"Logger with name {logger_name} already exists.")
Expand Down
72 changes: 3 additions & 69 deletions aphrodite/engine/async_aphrodite.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import asyncio
import time
from dataclasses import dataclass
from functools import partial
from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
Optional, Set, Tuple, Type, Union)

import torch
from loguru import logger
from transformers import PreTrainedTokenizer
from typing_extensions import assert_never
Expand All @@ -17,11 +15,11 @@
from aphrodite.common.outputs import EmbeddingRequestOutput, RequestOutput
from aphrodite.common.pooling_params import PoolingParams
from aphrodite.common.sampling_params import SamplingParams
from aphrodite.common.sequence import (ExecuteModelRequest, SamplerOutput,
SequenceGroupMetadata)
from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
from aphrodite.engine.aphrodite_engine import (AphroditeEngine,
DecoderPromptComponents,
PromptComponents)
PromptComponents,
SchedulerOutputState)
from aphrodite.engine.args_tools import AsyncEngineArgs
from aphrodite.engine.async_timeout import asyncio_timeout
from aphrodite.engine.metrics_types import StatLoggerBase
Expand Down Expand Up @@ -255,24 +253,11 @@ def has_new_requests(self):
return not self._new_requests.empty()


@dataclass
class SchedulerOutputState:
"""Caches the scheduler outputs for a virtual engine. Used for Multi-Step"""
last_output: Optional[SamplerOutput] = None
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
scheduler_outputs: Optional[SchedulerOutputs] = None


class _AsyncAphrodite(AphroditeEngine):
"""Extension of AphroditeEngine to add async methods."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
pipeline_parallel_size = \
self.parallel_config.pipeline_parallel_size
self.cached_scheduler_outputs = [
SchedulerOutputState() for _ in range(pipeline_parallel_size)
]

async def step_async(
self, virtual_engine: int
Expand Down Expand Up @@ -356,57 +341,6 @@ async def step_async(
self.do_log_stats(scheduler_outputs, output)

return request_outputs

def _has_remaining_steps(
self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
) -> bool:
if (not self.scheduler_config.is_multi_step
or not seq_group_metadata_list):
return False
# TODO: this is a sanity check for now to make sure that all the
# seqs are on the same steps. Eventually we will want to do some sort of
# dynamic scheduling when doing multi-step decoding.
ref_remaining_steps = seq_group_metadata_list[0].state.remaining_steps
if any([
seq_group.state.remaining_steps != ref_remaining_steps
for seq_group in seq_group_metadata_list[1:]
]):
raise AssertionError(("All running sequence groups should "
"have the same remaining steps."))
return ref_remaining_steps > 0

def _cache_scheduler_outputs_for_multi_step(
self, virtual_engine: int,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
scheduler_outputs: SchedulerOutputs) -> None:
self.cached_scheduler_outputs[
virtual_engine].seq_group_metadata_list = seq_group_metadata_list
self.cached_scheduler_outputs[virtual_engine].scheduler_outputs = \
scheduler_outputs
self.cached_scheduler_outputs[virtual_engine].last_output = None
def _get_last_sampled_token_ids(
self, virtual_engine: int) -> Optional[torch.Tensor]:
cached_last_output = self.cached_scheduler_outputs[
virtual_engine].last_output
if (self.scheduler_config.is_multi_step
and self.parallel_config.pipeline_parallel_size > 1
and cached_last_output is not None
and cached_last_output.sampled_token_ids_cpu is not None):
return cached_last_output.sampled_token_ids_cpu
return None

def _update_cached_scheduler_output(
self, virtual_engine: int,
output: List[Optional[SamplerOutput]]) -> None:
if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0
and output[0] is not None):
last_output = output[-1]
assert last_output is not None
assert last_output.sampled_token_ids_cpu is not None
assert last_output.sampled_token_ids is None
assert last_output.sampled_token_probs is None
self.cached_scheduler_outputs[
virtual_engine].last_output = last_output

async def stop_remote_worker_execution_loop_async(self) -> None:
"""Stop the remote worker execution loop."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List

import pytest
import torch

from ..utils import RemoteOpenAIServer

Expand Down Expand Up @@ -59,6 +60,9 @@ async def test_multi_step(
num_scheduler_steps: int,
num_prompts: int,
):
if (tp_size > 1 or pp_size > 1) and torch.cuda.device_count() < 2:
pytest.skip("Skipping multi-GPU tests on single GPU system")

prompts = example_prompts
if len(prompts) < num_prompts:
prompts = prompts * ((num_prompts // len(prompts)) + 1)
Expand Down
48 changes: 48 additions & 0 deletions tests/multi_step/test_correctness_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Test the LLMEngine with multi-step-decoding
import pytest

from ..models.utils import check_outputs_equal

MODELS = [
"JackFram/llama-160m",
]
NUM_SCHEDULER_STEPS = [8] # Multi-step decoding steps
NUM_PROMPTS = [10]


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("tp_size", [1])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("enforce_eager", [True])
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
def test_multi_step_llm(hf_runner, aphrodite_runner, example_prompts,
model: str, dtype: str, tp_size: int, max_tokens: int,
enforce_eager: int, num_scheduler_steps: int,
num_prompts: int) -> None:

prompts = example_prompts
if len(prompts) < num_prompts:
prompts = prompts * ((num_prompts // len(prompts)) + 1)
prompts = prompts[:num_prompts]
assert len(prompts) == num_prompts

with aphrodite_runner(model,
dtype=dtype,
enforce_eager=enforce_eager,
gpu_memory_utilization=0.7,
tensor_parallel_size=tp_size,
use_v2_block_manager=True,
num_scheduler_steps=num_scheduler_steps
) as aphrodite_model:
aphrodite_outputs = aphrodite_model.generate_greedy(prompts, max_tokens)

with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(prompts, max_tokens)
check_outputs_equal(
outputs_0_lst=hf_outputs,
outputs_1_lst=aphrodite_outputs,
name_0="hf",
name_1="aphrodite",
)

0 comments on commit b1492c1

Please sign in to comment.