From 5f0b9933e63839e816b9736a65a3c55005df2cfe Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 17 Jul 2024 12:40:10 -0700 Subject: [PATCH 1/2] [Bugfix] Fix Ray Metrics API usage (#6354) --- tests/metrics/test_metrics.py | 54 +++++++++++ vllm/engine/async_llm_engine.py | 19 ++++ vllm/engine/llm_engine.py | 2 + vllm/engine/metrics.py | 160 ++++++++++++++++++++++++-------- 4 files changed, 195 insertions(+), 40 deletions(-) diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py index 5694061e17e02..42b15cd6c458e 100644 --- a/tests/metrics/test_metrics.py +++ b/tests/metrics/test_metrics.py @@ -1,11 +1,13 @@ from typing import List import pytest +import ray from prometheus_client import REGISTRY from vllm import EngineArgs, LLMEngine from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.engine.metrics import RayPrometheusStatLogger from vllm.sampling_params import SamplingParams MODELS = [ @@ -241,3 +243,55 @@ def assert_metrics(engine: LLMEngine, disable_log_stats: bool, labels) assert ( metric_value == num_requests), "Metrics should be collected" + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [16]) +def test_engine_log_metrics_ray( + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + # This test is quite weak - it only checks that we can use + # RayPrometheusStatLogger without exceptions. + # Checking whether the metrics are actually emitted is unfortunately + # non-trivial. + + # We have to run in a Ray task for Ray metrics to be emitted correctly + @ray.remote(num_gpus=1) + def _inner(): + + class _RayPrometheusStatLogger(RayPrometheusStatLogger): + + def __init__(self, *args, **kwargs): + self._i = 0 + super().__init__(*args, **kwargs) + + def log(self, *args, **kwargs): + self._i += 1 + return super().log(*args, **kwargs) + + engine_args = EngineArgs( + model=model, + dtype=dtype, + disable_log_stats=False, + ) + engine = LLMEngine.from_engine_args(engine_args) + logger = _RayPrometheusStatLogger( + local_interval=0.5, + labels=dict(model_name=engine.model_config.served_model_name), + max_model_len=engine.model_config.max_model_len) + engine.add_logger("ray", logger) + for i, prompt in enumerate(example_prompts): + engine.add_request( + f"request-id-{i}", + prompt, + SamplingParams(max_tokens=max_tokens), + ) + while engine.has_unfinished_requests(): + engine.step() + assert logger._i > 0, ".log must be called at least once" + + ray.get(_inner.remote()) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 93bf8793dae33..0e63506e7c367 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -12,6 +12,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_timeout import asyncio_timeout from vllm.engine.llm_engine import LLMEngine +from vllm.engine.metrics import StatLoggerBase from vllm.executor.ray_utils import initialize_ray_cluster, ray from vllm.inputs import LLMInputs, PromptInputs from vllm.logger import init_logger @@ -389,6 +390,7 @@ def from_engine_args( engine_args: AsyncEngineArgs, start_engine_loop: bool = True, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, ) -> "AsyncLLMEngine": """Creates an async LLM engine from the engine arguments.""" # Create the engine configs. @@ -451,6 +453,7 @@ def from_engine_args( max_log_len=engine_args.max_log_len, start_engine_loop=start_engine_loop, usage_context=usage_context, + stat_loggers=stat_loggers, ) return engine @@ -957,3 +960,19 @@ async def is_tracing_enabled(self) -> bool: ) else: return self.engine.is_tracing_enabled() + + def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None: + if self.engine_use_ray: + ray.get( + self.engine.add_logger.remote( # type: ignore + logger_name=logger_name, logger=logger)) + else: + self.engine.add_logger(logger_name=logger_name, logger=logger) + + def remove_logger(self, logger_name: str) -> None: + if self.engine_use_ray: + ray.get( + self.engine.remove_logger.remote( # type: ignore + logger_name=logger_name)) + else: + self.engine.remove_logger(logger_name=logger_name) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 622221d2dd13e..68ca9a97a3c61 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -379,6 +379,7 @@ def from_engine_args( cls, engine_args: EngineArgs, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, ) -> "LLMEngine": """Creates an LLM engine from the engine arguments.""" # Create the engine configs. @@ -423,6 +424,7 @@ def from_engine_args( executor_class=executor_class, log_stats=not engine_args.disable_log_stats, usage_context=usage_context, + stat_loggers=stat_loggers, ) return engine diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index 9e187b2fbc33b..4ed7da2377111 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -30,55 +30,55 @@ # begin-metrics-definitions class Metrics: labelname_finish_reason = "finished_reason" - _base_library = prometheus_client + _gauge_cls = prometheus_client.Gauge + _counter_cls = prometheus_client.Counter + _histogram_cls = prometheus_client.Histogram def __init__(self, labelnames: List[str], max_model_len: int): # Unregister any existing vLLM collectors self._unregister_vllm_metrics() # Config Information - self.info_cache_config = prometheus_client.Info( - name='vllm:cache_config', - documentation='information of cache_config') + self._create_info_cache_config() # System stats # Scheduler State - self.gauge_scheduler_running = self._base_library.Gauge( + self.gauge_scheduler_running = self._gauge_cls( name="vllm:num_requests_running", documentation="Number of requests currently running on GPU.", labelnames=labelnames) - self.gauge_scheduler_waiting = self._base_library.Gauge( + self.gauge_scheduler_waiting = self._gauge_cls( name="vllm:num_requests_waiting", documentation="Number of requests waiting to be processed.", labelnames=labelnames) - self.gauge_scheduler_swapped = self._base_library.Gauge( + self.gauge_scheduler_swapped = self._gauge_cls( name="vllm:num_requests_swapped", documentation="Number of requests swapped to CPU.", labelnames=labelnames) # KV Cache Usage in % - self.gauge_gpu_cache_usage = self._base_library.Gauge( + self.gauge_gpu_cache_usage = self._gauge_cls( name="vllm:gpu_cache_usage_perc", documentation="GPU KV-cache usage. 1 means 100 percent usage.", labelnames=labelnames) - self.gauge_cpu_cache_usage = self._base_library.Gauge( + self.gauge_cpu_cache_usage = self._gauge_cls( name="vllm:cpu_cache_usage_perc", documentation="CPU KV-cache usage. 1 means 100 percent usage.", labelnames=labelnames) # Iteration stats - self.counter_num_preemption = self._base_library.Counter( + self.counter_num_preemption = self._counter_cls( name="vllm:num_preemptions_total", documentation="Cumulative number of preemption from the engine.", labelnames=labelnames) - self.counter_prompt_tokens = self._base_library.Counter( + self.counter_prompt_tokens = self._counter_cls( name="vllm:prompt_tokens_total", documentation="Number of prefill tokens processed.", labelnames=labelnames) - self.counter_generation_tokens = self._base_library.Counter( + self.counter_generation_tokens = self._counter_cls( name="vllm:generation_tokens_total", documentation="Number of generation tokens processed.", labelnames=labelnames) - self.histogram_time_to_first_token = self._base_library.Histogram( + self.histogram_time_to_first_token = self._histogram_cls( name="vllm:time_to_first_token_seconds", documentation="Histogram of time to first token in seconds.", labelnames=labelnames, @@ -86,7 +86,7 @@ def __init__(self, labelnames: List[str], max_model_len: int): 0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5, 0.75, 1.0, 2.5, 5.0, 7.5, 10.0 ]) - self.histogram_time_per_output_token = self._base_library.Histogram( + self.histogram_time_per_output_token = self._histogram_cls( name="vllm:time_per_output_token_seconds", documentation="Histogram of time per output token in seconds.", labelnames=labelnames, @@ -97,83 +97,157 @@ def __init__(self, labelnames: List[str], max_model_len: int): # Request stats # Latency - self.histogram_e2e_time_request = self._base_library.Histogram( + self.histogram_e2e_time_request = self._histogram_cls( name="vllm:e2e_request_latency_seconds", documentation="Histogram of end to end request latency in seconds.", labelnames=labelnames, buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0]) # Metadata - self.histogram_num_prompt_tokens_request = self._base_library.Histogram( + self.histogram_num_prompt_tokens_request = self._histogram_cls( name="vllm:request_prompt_tokens", documentation="Number of prefill tokens processed.", labelnames=labelnames, buckets=build_1_2_5_buckets(max_model_len), ) self.histogram_num_generation_tokens_request = \ - self._base_library.Histogram( + self._histogram_cls( name="vllm:request_generation_tokens", documentation="Number of generation tokens processed.", labelnames=labelnames, buckets=build_1_2_5_buckets(max_model_len), ) - self.histogram_best_of_request = self._base_library.Histogram( + self.histogram_best_of_request = self._histogram_cls( name="vllm:request_params_best_of", documentation="Histogram of the best_of request parameter.", labelnames=labelnames, buckets=[1, 2, 5, 10, 20], ) - self.histogram_n_request = self._base_library.Histogram( + self.histogram_n_request = self._histogram_cls( name="vllm:request_params_n", documentation="Histogram of the n request parameter.", labelnames=labelnames, buckets=[1, 2, 5, 10, 20], ) - self.counter_request_success = self._base_library.Counter( + self.counter_request_success = self._counter_cls( name="vllm:request_success_total", documentation="Count of successfully processed requests.", labelnames=labelnames + [Metrics.labelname_finish_reason]) # Speculatie decoding stats - self.gauge_spec_decode_draft_acceptance_rate = self._base_library.Gauge( + self.gauge_spec_decode_draft_acceptance_rate = self._gauge_cls( name="vllm:spec_decode_draft_acceptance_rate", documentation="Speulative token acceptance rate.", labelnames=labelnames) - self.gauge_spec_decode_efficiency = self._base_library.Gauge( + self.gauge_spec_decode_efficiency = self._gauge_cls( name="vllm:spec_decode_efficiency", documentation="Speculative decoding system efficiency.", labelnames=labelnames) - self.counter_spec_decode_num_accepted_tokens = ( - self._base_library.Counter( - name="vllm:spec_decode_num_accepted_tokens_total", - documentation="Number of accepted tokens.", - labelnames=labelnames)) - self.counter_spec_decode_num_draft_tokens = self._base_library.Counter( + self.counter_spec_decode_num_accepted_tokens = (self._counter_cls( + name="vllm:spec_decode_num_accepted_tokens_total", + documentation="Number of accepted tokens.", + labelnames=labelnames)) + self.counter_spec_decode_num_draft_tokens = self._counter_cls( name="vllm:spec_decode_num_draft_tokens_total", documentation="Number of draft tokens.", labelnames=labelnames) - self.counter_spec_decode_num_emitted_tokens = ( - self._base_library.Counter( - name="vllm:spec_decode_num_emitted_tokens_total", - documentation="Number of emitted tokens.", - labelnames=labelnames)) + self.counter_spec_decode_num_emitted_tokens = (self._counter_cls( + name="vllm:spec_decode_num_emitted_tokens_total", + documentation="Number of emitted tokens.", + labelnames=labelnames)) # Deprecated in favor of vllm:prompt_tokens_total - self.gauge_avg_prompt_throughput = self._base_library.Gauge( + self.gauge_avg_prompt_throughput = self._gauge_cls( name="vllm:avg_prompt_throughput_toks_per_s", documentation="Average prefill throughput in tokens/s.", labelnames=labelnames, ) # Deprecated in favor of vllm:generation_tokens_total - self.gauge_avg_generation_throughput = self._base_library.Gauge( + self.gauge_avg_generation_throughput = self._gauge_cls( name="vllm:avg_generation_throughput_toks_per_s", documentation="Average generation throughput in tokens/s.", labelnames=labelnames, ) + def _create_info_cache_config(self) -> None: + # Config Information + self.info_cache_config = prometheus_client.Info( + name='vllm:cache_config', + documentation='information of cache_config') + def _unregister_vllm_metrics(self) -> None: - for collector in list(self._base_library.REGISTRY._collector_to_names): + for collector in list(prometheus_client.REGISTRY._collector_to_names): if hasattr(collector, "_name") and "vllm" in collector._name: - self._base_library.REGISTRY.unregister(collector) + prometheus_client.REGISTRY.unregister(collector) + + +# end-metrics-definitions + + +class _RayGaugeWrapper: + """Wraps around ray.util.metrics.Gauge to provide same API as + prometheus_client.Gauge""" + + def __init__(self, + name: str, + documentation: str = "", + labelnames: Optional[List[str]] = None): + labelnames_tuple = tuple(labelnames) if labelnames else None + self._gauge = ray_metrics.Gauge(name=name, + description=documentation, + tag_keys=labelnames_tuple) + + def labels(self, **labels): + self._gauge.set_default_tags(labels) + return self + + def set(self, value: Union[int, float]): + return self._gauge.set(value) + + +class _RayCounterWrapper: + """Wraps around ray.util.metrics.Counter to provide same API as + prometheus_client.Counter""" + + def __init__(self, + name: str, + documentation: str = "", + labelnames: Optional[List[str]] = None): + labelnames_tuple = tuple(labelnames) if labelnames else None + self._counter = ray_metrics.Counter(name=name, + description=documentation, + tag_keys=labelnames_tuple) + + def labels(self, **labels): + self._counter.set_default_tags(labels) + return self + + def inc(self, value: Union[int, float] = 1.0): + if value == 0: + return + return self._counter.inc(value) + + +class _RayHistogramWrapper: + """Wraps around ray.util.metrics.Histogram to provide same API as + prometheus_client.Histogram""" + + def __init__(self, + name: str, + documentation: str = "", + labelnames: Optional[List[str]] = None, + buckets: Optional[List[float]] = None): + labelnames_tuple = tuple(labelnames) if labelnames else None + self._histogram = ray_metrics.Histogram(name=name, + description=documentation, + tag_keys=labelnames_tuple, + boundaries=buckets) + + def labels(self, **labels): + self._histogram.set_default_tags(labels) + return self + + def observe(self, value: Union[int, float]): + return self._histogram.observe(value) class RayMetrics(Metrics): @@ -181,7 +255,9 @@ class RayMetrics(Metrics): RayMetrics is used by RayPrometheusStatLogger to log to Ray metrics. Provides the same metrics as Metrics but uses Ray's util.metrics library. """ - _base_library = ray_metrics + _gauge_cls = _RayGaugeWrapper + _counter_cls = _RayCounterWrapper + _histogram_cls = _RayHistogramWrapper def __init__(self, labelnames: List[str], max_model_len: int): if ray_metrics is None: @@ -192,8 +268,9 @@ def _unregister_vllm_metrics(self) -> None: # No-op on purpose pass - -# end-metrics-definitions + def _create_info_cache_config(self) -> None: + # No-op on purpose + pass def build_1_2_5_buckets(max_value: int) -> List[int]: @@ -498,3 +575,6 @@ def log(self, stats: Stats): class RayPrometheusStatLogger(PrometheusStatLogger): """RayPrometheusStatLogger uses Ray metrics instead.""" _metrics_cls = RayMetrics + + def info(self, type: str, obj: SupportsMetricsInfo) -> None: + return None From e76466dde2bc9525d55165ceaa600d298c7bf773 Mon Sep 17 00:00:00 2001 From: Alexander Matveev <59768536+alexm-neuralmagic@users.noreply.github.com> Date: Wed, 17 Jul 2024 17:30:28 -0400 Subject: [PATCH 2/2] [Core] draft_model_runner: Implement prepare_inputs on GPU for advance_step (#6338) --- CMakeLists.txt | 1 + csrc/ops.h | 5 + csrc/prepare_inputs/advance_step.cu | 131 +++++++++ csrc/prepare_inputs/advance_step.cuh | 19 ++ csrc/torch_bindings.cpp | 4 + tests/spec_decode/e2e/conftest.py | 1 + tests/spec_decode/test_multi_step_worker.py | 48 +++ vllm/_custom_ops.py | 12 + vllm/model_executor/layers/sampler.py | 147 +++++++--- vllm/model_executor/sampling_metadata.py | 10 + vllm/spec_decode/draft_model_runner.py | 305 +++++++++++++++----- vllm/spec_decode/multi_step_worker.py | 15 +- 12 files changed, 568 insertions(+), 130 deletions(-) create mode 100644 csrc/prepare_inputs/advance_step.cu create mode 100644 csrc/prepare_inputs/advance_step.cuh diff --git a/CMakeLists.txt b/CMakeLists.txt index ced73ca03bfbc..335623bd2677d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -151,6 +151,7 @@ set(VLLM_EXT_SRC "csrc/quantization/fp8/common.cu" "csrc/cuda_utils_kernels.cu" "csrc/moe_align_block_size_kernels.cu" + "csrc/prepare_inputs/advance_step.cu" "csrc/torch_bindings.cpp") if(VLLM_GPU_LANG STREQUAL "CUDA") diff --git a/csrc/ops.h b/csrc/ops.h index f9feb3deff5e4..1e94a9f45ef08 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -52,6 +52,11 @@ void gelu_fast(torch::Tensor& out, torch::Tensor& input); void gelu_quick(torch::Tensor& out, torch::Tensor& input); +void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size, + torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids, + torch::Tensor& input_positions, torch::Tensor& seq_lens, + torch::Tensor& slot_mapping, torch::Tensor& block_tables); + #ifndef USE_ROCM torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes, const torch::Tensor& codebooks, diff --git a/csrc/prepare_inputs/advance_step.cu b/csrc/prepare_inputs/advance_step.cu new file mode 100644 index 0000000000000..0e537ddd6c4cd --- /dev/null +++ b/csrc/prepare_inputs/advance_step.cu @@ -0,0 +1,131 @@ +/* + * The goal of this GPU kernel is to advance input tensors on the GPU directly + * PR: https://github.com/vllm-project/vllm/pull/6338 + * Current restrictions: + * 1. Specialized for DraftModelRunner + * 2. Supports flash_attn only + */ + +#include "advance_step.cuh" + +namespace prepare_inputs { + +// +template +__global__ void advance_step_kernel(int num_seqs, int num_queries, + int block_size, long* input_tokens_ptr, + long const* sampled_token_ids_ptr, + long* input_positions_ptr, + int* seq_lens_ptr, long* slot_mapping_ptr, + int const* block_tables_ptr, + int64_t const block_tables_stride) { + int num_query_blocks = div_ceil(num_queries, num_threads); + + if (blockIdx.x >= num_query_blocks) { + return; + } + + int cur_query_id = blockIdx.x * num_threads + threadIdx.x; + + if (cur_query_id >= num_queries) { + return; + } + + // Update input_tokens + input_tokens_ptr[cur_query_id] = sampled_token_ids_ptr[cur_query_id]; + + int seq_len = seq_lens_ptr[cur_query_id]; + int next_seq_len = seq_len + 1; + int next_input_pos = next_seq_len - 1; + + // Update seq_lens + seq_lens_ptr[cur_query_id] = next_seq_len; + // Update input_positions + input_positions_ptr[cur_query_id] = next_input_pos; + + int const* seq_block_tables_ptr = + block_tables_ptr + block_tables_stride * cur_query_id; + + int block_index = next_input_pos / block_size; + int block_offset = next_input_pos % block_size; + + int slot_num = seq_block_tables_ptr[block_index] * block_size + block_offset; + // Update slot_mapping + slot_mapping_ptr[cur_query_id] = slot_num; +} + +inline void verify_tensor(std::string const& name, torch::Tensor& t, + int64_t const size_0, int64_t const size_1, + c10::ScalarType const type) { + bool size_0_cond = true; + if (size_0 != -1) { + size_0_cond = t.size(0) == size_0; + } + + bool size_1_cond = true; + if (size_1 != -1) { + size_1_cond = t.size(1) == size_1; + } + + bool is_contiguous = t.is_contiguous(); + bool same_type = t.dtype() == type; + + bool pass = size_0_cond && size_1_cond && is_contiguous && same_type; + if (!pass) { + TORCH_CHECK(false, "tensor: name = ", name, ", shape = ", t.sizes(), + " is_cont = ", t.is_contiguous(), ", type = ", t.dtype(), + " is not as expected: shape = [", size_0, ", ", size_1, + "], type = ", type); + } +} + +void advance_step(int num_seqs, int num_queries, int block_size, + torch::Tensor& input_tokens, // type: long + torch::Tensor& sampled_token_ids, // type: long + torch::Tensor& input_positions, // type: long + torch::Tensor& seq_lens, // type: int + torch::Tensor& slot_mapping, // type: long + torch::Tensor& block_tables) { // type: int + + if (logging) { + printf("advance_step:\n"); + printf(" num_seqs = %d\n", num_seqs); + printf(" num_queries = %d\n", num_queries); + printf(" block_size = %d\n", block_size); + } + // Verify all tensors + verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong); + verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1, + at::kLong); + verify_tensor("input_positions", input_positions, num_seqs, -1, at::kLong); + verify_tensor("seq_lens", seq_lens, num_seqs, -1, at::kInt); + verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, at::kLong); + verify_tensor("block_tables", block_tables, num_seqs, -1, at::kInt); + + int dev = sampled_token_ids.get_device(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev); + + int blocks; + cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev); + + advance_step_kernel<<>>( + num_seqs, num_queries, block_size, + reinterpret_cast(input_tokens.data_ptr()), + reinterpret_cast(sampled_token_ids.data_ptr()), + reinterpret_cast(input_positions.data_ptr()), + reinterpret_cast(seq_lens.data_ptr()), + reinterpret_cast(slot_mapping.data_ptr()), + reinterpret_cast(block_tables.data_ptr()), + block_tables.stride(0)); +} + +} // namespace prepare_inputs + +void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size, + torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids, + torch::Tensor& input_positions, torch::Tensor& seq_lens, + torch::Tensor& slot_mapping, torch::Tensor& block_tables) { + prepare_inputs::advance_step(num_seqs, num_queries, block_size, input_tokens, + sampled_token_ids, input_positions, seq_lens, + slot_mapping, block_tables); +} \ No newline at end of file diff --git a/csrc/prepare_inputs/advance_step.cuh b/csrc/prepare_inputs/advance_step.cuh new file mode 100644 index 0000000000000..f21574681b1ab --- /dev/null +++ b/csrc/prepare_inputs/advance_step.cuh @@ -0,0 +1,19 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include +#include + +namespace prepare_inputs { + +static constexpr int max_threads = 256; +static constexpr bool logging = false; + +constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } + +} // namespace prepare_inputs diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 9dc7cefc404ca..ff9875e0e17a3 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -72,6 +72,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("gelu_quick(Tensor! out, Tensor input) -> ()"); ops.impl("gelu_quick", torch::kCUDA, &gelu_quick); + // prepare_inputs advance_step + ops.def("advance_step", &advance_step); + ops.impl("advance_step", torch::kCUDA, &advance_step); + // Layernorm // Apply Root Mean Square (RMS) Normalization to the input tensor. ops.def( diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index 34a6c9a393a58..da72f6d503c11 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -227,6 +227,7 @@ def get_output_from_llm_generator( maybe_assert_ngram_worker(llm) outputs = llm.generate(prompts, sampling_params, use_tqdm=True) + token_ids = [output.outputs[0].token_ids for output in outputs] tokens = [output.outputs[0].text for output in outputs] diff --git a/tests/spec_decode/test_multi_step_worker.py b/tests/spec_decode/test_multi_step_worker.py index 9832d4f267e8a..442e40f07f0bb 100644 --- a/tests/spec_decode/test_multi_step_worker.py +++ b/tests/spec_decode/test_multi_step_worker.py @@ -642,3 +642,51 @@ def test_draft_proposals_mixed_k(): assert proposals.proposal_lens.tolist() == [ k for _ in range(expected_num_proposal_seqs - 1) ] + [0 for _ in range(expected_num_no_proposal_seqs)] + [k] + + +@torch.inference_mode() +def test_use_draft_model_runner_advance_step(): + """Verify that draft model runner triggers advance step + when applicable. + """ + seed = 100 + model_name = 'JackFram/llama-68m' + + k = 5 + batch_size = 32 + block_size = 32 + num_gpu_blocks = 2048 // block_size + worker = create_worker( + MultiStepWorker, + model_name, + block_size, + num_gpu_blocks, + seed, + model_runner_cls=TP1DraftModelRunner, + ) + + # Mock "_gpu_advance_step" to raise an exception when called. + exception_secret = "artificial stop" + worker.model_runner._gpu_advance_step = MagicMock() + worker.model_runner._gpu_advance_step.side_effect = ValueError( + exception_secret) + + seq_group_metadata_list, _, _ = create_batch(batch_size, k) + + # Fallback (should not call) when num_steps=1. + execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k, + num_steps=1) + worker.execute_model(execute_model_req=execute_model_req) + + # Expect exception if _gpu_advance_step is called. + execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k, + num_steps=k) + + with pytest.raises(ValueError, match=exception_secret): + worker.execute_model(execute_model_req=execute_model_req) + call_args_list = worker.model_runner._gpu_advance_step.call_args_list + assert len(call_args_list) == 1 diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 4ca67224a91b8..143957f7b65f0 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -166,6 +166,18 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon) +def advance_step(num_seqs: int, num_queries: int, block_size: int, + input_tokens: torch.Tensor, sampled_token_ids: torch.Tensor, + input_positions: torch.Tensor, seq_lens: torch.Tensor, + slot_mapping: torch.Tensor, + block_tables: torch.Tensor) -> None: + """Advance a step on GPU for existing inputs for a multi-step runner""" + return torch.ops._C.advance_step(num_seqs, num_queries, block_size, + input_tokens, sampled_token_ids, + input_positions, seq_lens, slot_mapping, + block_tables) + + # quantization ops # awq def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor, diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 6d00ea64f7cb8..5c376797a054f 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -47,6 +47,32 @@ def __init__(self): # speculative decoding. self.include_gpu_probs_tensor = False + def _init_sampling_tensors( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ): + """The goal here is to reuse sampling tensors between similar decode + runs. This is possible because sampling logic does not change between + decodes of the same sequences. + """ + _, vocab_size = logits.shape + + # First free any existing stored sampling tensors. + # This is necessary because some sampling tensors may + # have pinned memory. + self._sampling_tensors = None + + # Initialize new sampling tensors + (sampling_tensors, do_penalties, do_top_p_top_k, + do_min_p) = SamplingTensors.from_sampling_metadata( + sampling_metadata, vocab_size, logits.device, logits.dtype) + + self._sampling_tensors = sampling_tensors + self._do_penalties = do_penalties + self._do_top_p_top_k = do_top_p_top_k + self._do_min_p = do_min_p + def forward( self, logits: torch.Tensor, @@ -60,12 +86,23 @@ def forward( assert logits is not None _, vocab_size = logits.shape - logits = _apply_min_tokens_penalty(logits, sampling_metadata) - # Prepare sampling tensors with pinned memory to avoid blocking. - (sampling_tensors, do_penalties, do_top_p_top_k, - do_min_p) = SamplingTensors.from_sampling_metadata( - sampling_metadata, vocab_size, logits.device, logits.dtype) + if not sampling_metadata.reuse_sampling_tensors: + self._init_sampling_tensors(logits, sampling_metadata) + elif self._do_penalties: + # In this case, the sampling tensors logic depends on + # "output_tokens" of a sequence. As a result, we cannot + # reuse sampling tensors, since "output_tokens" changes + # between decode runs. + self._init_sampling_tensors(logits, sampling_metadata) + + assert self._sampling_tensors is not None + sampling_tensors = self._sampling_tensors + do_penalties = self._do_penalties + do_top_p_top_k = self._do_top_p_top_k + do_min_p = self._do_min_p + + logits = _apply_min_tokens_penalty(logits, sampling_metadata) # Apply presence and frequency penalties. if do_penalties: @@ -77,7 +114,7 @@ def forward( # Apply temperature scaling. # Use in-place division to avoid creating a new tensor. - logits.div_(sampling_tensors.temperatures.unsqueeze_(dim=1)) + logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1)) if do_top_p_top_k: logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps, @@ -109,13 +146,19 @@ def forward( on_device_tensors = None # Get the logprobs query results. - prompt_logprobs, sample_logprobs = _get_logprobs( - logprobs, sampling_metadata, sample_results) - return _build_sampler_output(sample_results, - sampling_metadata, - prompt_logprobs, - sample_logprobs, - on_device_tensors=on_device_tensors) + prompt_logprobs = None + sample_logprobs = None + if not sampling_metadata.skip_sampler_cpu_output: + prompt_logprobs, sample_logprobs = _get_logprobs( + logprobs, sampling_metadata, sample_results) + + return _build_sampler_output( + sample_results, + sampling_metadata, + prompt_logprobs, + sample_logprobs, + on_device_tensors=on_device_tensors, + skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output) @property def _should_modify_greedy_probs_inplace(self) -> bool: @@ -535,24 +578,29 @@ def _sample_with_torch( # GPU<->CPU sync happens in the loop below. # This also converts the sample output to Python objects. - for sampling_type in SamplingType: - if sampling_type not in sample_metadata: - continue - (seq_group_id, seq_groups) = sample_metadata[sampling_type] - if sampling_type == SamplingType.GREEDY: - sample_results = _greedy_sample(seq_groups, greedy_samples) - elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): - sample_results = _random_sample(seq_groups, - multinomial_samples[sampling_type]) - elif sampling_type == SamplingType.BEAM: - sample_results = _beam_search_sample(seq_groups, - beam_search_logprobs) - sample_results_dict.update(zip(seq_group_id, sample_results)) + if not sampling_metadata.skip_sampler_cpu_output: + for sampling_type in SamplingType: + if sampling_type not in sample_metadata: + continue + (seq_group_id, seq_groups) = sample_metadata[sampling_type] + if sampling_type == SamplingType.GREEDY: + sample_results = _greedy_sample(seq_groups, greedy_samples) + elif sampling_type in (SamplingType.RANDOM, + SamplingType.RANDOM_SEED): + sample_results = _random_sample( + seq_groups, multinomial_samples[sampling_type]) + elif sampling_type == SamplingType.BEAM: + sample_results = _beam_search_sample(seq_groups, + beam_search_logprobs) + sample_results_dict.update(zip(seq_group_id, sample_results)) + + sample_results = [ + sample_results_dict.get(i, ([], [])) + for i in range(len(sampling_metadata.seq_groups)) + ] + else: + sample_results = [] - sample_results = [ - sample_results_dict.get(i, ([], [])) - for i in range(len(sampling_metadata.seq_groups)) - ] return sample_results, sampled_token_ids_tensor @@ -997,10 +1045,11 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor, def _build_sampler_output( sample_results: SampleResultType, sampling_metadata: SamplingMetadata, - prompt_logprobs: List[Optional[PromptLogprobs]], - sample_logprobs: List[SampleLogprobs], + prompt_logprobs: Optional[List[Optional[PromptLogprobs]]], + sample_logprobs: Optional[List[SampleLogprobs]], on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]], + skip_sampler_cpu_output: bool = False, ) -> SamplerOutput: """Construct Python objects with the output of sampling. @@ -1010,22 +1059,26 @@ def _build_sampler_output( allows post-processing without copies to CPU/serialization, e.g. in speculative decoding rejection sampling. """ - sampler_output: List[CompletionSequenceGroupOutput] = [] - for (seq_group, sample_result, group_prompt_logprobs, - group_sample_logprobs) in zip(sampling_metadata.seq_groups, - sample_results, prompt_logprobs, - sample_logprobs): - seq_ids = seq_group.seq_ids - next_token_ids, parent_ids = sample_result - seq_outputs: List[SequenceOutput] = [] - for parent_id, next_token_id, logprobs in zip(parent_ids, - next_token_ids, - group_sample_logprobs): - seq_outputs.append( - SequenceOutput(seq_ids[parent_id], next_token_id, logprobs)) - sampler_output.append( - CompletionSequenceGroupOutput(seq_outputs, group_prompt_logprobs)) + if not skip_sampler_cpu_output: + assert prompt_logprobs is not None + assert sample_logprobs is not None + + for (seq_group, sample_result, group_prompt_logprobs, + group_sample_logprobs) in zip(sampling_metadata.seq_groups, + sample_results, prompt_logprobs, + sample_logprobs): + seq_ids = seq_group.seq_ids + next_token_ids, parent_ids = sample_result + seq_outputs: List[SequenceOutput] = [] + for parent_id, next_token_id, logprobs in zip( + parent_ids, next_token_ids, group_sample_logprobs): + seq_outputs.append( + SequenceOutput(seq_ids[parent_id], next_token_id, + logprobs)) + sampler_output.append( + CompletionSequenceGroupOutput(seq_outputs, + group_prompt_logprobs)) # If not specified, store None values in SamplerOutput. if on_device_tensors is not None: diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index c346cd0562867..29b077cf6d912 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -87,6 +87,12 @@ def sample(logits): The first tuple is [1, 2] (sampled index within original logit), and the second tuple is [0, 1] (sampled index within pruned logit). num_prompts: Number of prompt sequence groups in seq_groups. + skip_sampler_cpu_output: Indicates if we want to skip the GPU=>CPU + serialization of token outputs. + reuse_sampling_tensors: Indicates if we want to reuse sampling + tensors that are part of the sampler forward pass. Currently, + it is mainly used for multi-step decode. + """ def __init__( @@ -95,11 +101,15 @@ def __init__( selected_token_indices: torch.Tensor, categorized_sample_indices: Dict[SamplingType, torch.Tensor], num_prompts: int, + skip_sampler_cpu_output: bool = False, + reuse_sampling_tensors: bool = False, ) -> None: self.seq_groups = seq_groups self.selected_token_indices = selected_token_indices self.categorized_sample_indices = categorized_sample_indices self.num_prompts = num_prompts + self.skip_sampler_cpu_output = skip_sampler_cpu_output + self.reuse_sampling_tensors = reuse_sampling_tensors @staticmethod def prepare( diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 90bba96ee8acb..3cb7ec58da4c1 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -2,17 +2,22 @@ import torch +from vllm import _custom_ops as ops +from vllm.attention.backends.flash_attn import FlashAttentionMetadata from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) from vllm.logger import init_logger -from vllm.sequence import (IntermediateTensors, SamplerOutput, - SequenceGroupMetadata) +from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, + SamplerOutput) from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata, ModelRunner) logger = init_logger(__name__) +debug_advance_input = False +enable_gpu_advance_step = True + class TP1DraftModelRunner(ModelRunner): """Specialized model runner for speculative decoding draft model. @@ -21,18 +26,9 @@ class TP1DraftModelRunner(ModelRunner): we could get rid of most CPU-GPU synchronization and data transfer overheads by keeping model input and output tensors on GPU all the time. - This runner is still under development so there's no performance gain - at this moment. Currently we adopt a temporary solution that caches the - seq_group_metadata_list for multi-step execution, so that we can - leverage existing prepare_model_input to be compatible with the current - execution flow, but we plan to remove this cache and avoid calling - prepare_model_input in execute_model at all. - - The detail development plan includes: - 1. Use "update_model_input" to update existing model_input without - creating a new one. - 2. Improve the performance of "update_model_input" with a GPU kernel. - 3. Support TP > 1 (this requires some designs because we do not expect + TODOs: + 1. Currently supports only flash-attn, add support for other attn_backends. + 2. Support TP > 1 (this requires some designs because we do not expect any broadcasting inside execute_model). """ @@ -71,51 +67,156 @@ def __init__( return_hidden_states=return_hidden_states, ) - # TODO: Remove this cache when we are able to update model_input - # directly in advance_step. - self.cached_seq_group_metadata_list: Optional[ - List[SequenceGroupMetadata]] = None + def _update_flash_attn_metadata(self, attn_metadata, num_seqs, + num_queries): + assert isinstance(attn_metadata, FlashAttentionMetadata) - def prepare_model_input( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - virtual_engine: int = 0, - finished_requests_ids: Optional[List[str]] = None - ) -> ModelInputForGPUWithSamplingMetadata: - """A temporary solution that caches the seq_group_metadata_list - for multi-step execution. - TODO: In-place update model_input and remove this function. - """ - self.cached_seq_group_metadata_list = seq_group_metadata_list - return super().prepare_model_input( - seq_group_metadata_list, - finished_requests_ids=finished_requests_ids) + if num_seqs != num_queries: + assert num_seqs > num_queries + assert attn_metadata.use_cuda_graph + + assert attn_metadata.num_prefills == 0 + assert attn_metadata.num_prefill_tokens == 0 + assert attn_metadata.num_decode_tokens == num_seqs + assert attn_metadata.slot_mapping.shape == (num_seqs, ) + + assert len(attn_metadata.seq_lens) == num_seqs + assert attn_metadata.seq_lens_tensor.shape == (num_seqs, ) + assert attn_metadata.max_query_len == 1 + assert attn_metadata.max_prefill_seq_len == 0 + assert attn_metadata.max_decode_seq_len == max(attn_metadata.seq_lens) + + assert attn_metadata.query_start_loc.shape == (num_queries + 1, ) + assert attn_metadata.seq_start_loc.shape == (num_seqs + 1, ) + + assert attn_metadata.context_lens_tensor.shape == (num_queries, ) + + assert attn_metadata.block_tables.shape[0] == num_seqs + + # Update query lengths. Note that we update only queries and not seqs, + # since tensors may be padded due to captured cuda graph batch size + for i in range(num_queries): + attn_metadata.seq_lens[i] += 1 + attn_metadata.max_decode_seq_len = max(attn_metadata.seq_lens) - def update_model_input( + def _update_sampling_metadata(self, sampling_metadata, num_seqs, + num_queries): + + assert sampling_metadata.num_prompts == 0 + assert len(sampling_metadata.seq_groups) == num_queries + assert sampling_metadata.selected_token_indices.shape == ( + num_queries, ) + # assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501 + + # Verify that all sequences are decodes + for i in range(num_queries): + seq_group = sampling_metadata.seq_groups[i] + + assert seq_group.is_prompt is False # No prompt + assert seq_group.prompt_logprob_indices == [] # No prompt + assert seq_group.sample_indices == [i] # Simple + assert seq_group.seq_len is None # Decode + assert seq_group.query_len is None # Decode + + def _gpu_advance_step( self, model_input: ModelInputForGPUWithSamplingMetadata, last_output: SamplerOutput ) -> ModelInputForGPUWithSamplingMetadata: - """Prepare the model inputs for the next step. - TODO: In-place update model_input instead of calling - prepare_model_input. + # Currently, we expect "decode mode" only + assert not model_input.is_prompt + + # Get num_seqs + num_seqs = len(model_input.seq_lens) + num_queries = len(model_input.query_lens) + + # Get output tokens GPU tensor + sampled_token_ids = last_output.sampled_token_ids + assert sampled_token_ids is not None + + # Update attn_metadata + attn_metadata = model_input.attn_metadata + assert isinstance(attn_metadata, FlashAttentionMetadata) + self._update_flash_attn_metadata(attn_metadata, num_seqs, num_queries) + + # Update GPU tensors + ops.advance_step(num_seqs=num_seqs, + num_queries=num_queries, + block_size=self.block_size, + input_tokens=model_input.input_tokens, + sampled_token_ids=sampled_token_ids, + input_positions=model_input.input_positions, + seq_lens=attn_metadata.seq_lens_tensor, + slot_mapping=attn_metadata.slot_mapping, + block_tables=attn_metadata.block_tables) + + # Update sampling_metadata + sampling_metadata = model_input.sampling_metadata + self._update_sampling_metadata(sampling_metadata, num_seqs, + num_queries) + + # Create new input + new_model_input = self._model_input_cls( + input_tokens=model_input.input_tokens, + input_positions=model_input.input_positions, + attn_metadata=attn_metadata, + seq_lens=attn_metadata.seq_lens, + query_lens=model_input.query_lens, + lora_mapping=model_input.lora_mapping, + lora_requests=model_input.lora_requests, + multi_modal_kwargs=model_input.multi_modal_kwargs, + sampling_metadata=model_input.sampling_metadata, + is_prompt=False, + ) + + # Ensure we skip CPU samples + assert new_model_input.sampling_metadata.skip_sampler_cpu_output is True + # We can reuse sampling tensors since every decode iteration is the same + new_model_input.sampling_metadata.reuse_sampling_tensors = True + + if debug_advance_input: + logger.debug("NEW INPUT: ") + logger.debug(" input_tokens = %s", new_model_input.input_tokens) + logger.debug(" input_positions = %s", + new_model_input.input_positions) + logger.debug(" seq_lens = %d", new_model_input.seq_lens) + logger.debug(" query_lens = %d", new_model_input.query_lens) + logger.debug(" attn_metadata:") + logger.debug(" seq_lens_tensor: %s", + attn_metadata.seq_lens_tensor) + logger.debug(" slot_mapping: %s", attn_metadata.slot_mapping) + logger.debug(" block_tables: %s", attn_metadata.block_tables) + + return new_model_input + + def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest): + """Determines if draft_model_runner GPU multi-step can be used. + Currently required conditions are: + 1. Only decodes + 2. Only flash-attn + 3. No LORA + 4. No prompt_adapter_config """ + if not enable_gpu_advance_step: + return False - # Append the output token to the sequence data. - assert self.cached_seq_group_metadata_list is not None - for seq_group_metadata, sequence_group_outputs in zip( - self.cached_seq_group_metadata_list, last_output.outputs): - seq_group_metadata.is_prompt = False + # We allow multi-step GPU only in decode mode + for seq_group in execute_model_req.seq_group_metadata_list: + if seq_group.is_prompt: + return False - for seq_output in sequence_group_outputs.samples: - seq = seq_group_metadata.seq_data[seq_output.parent_seq_id] + # TODO: Add support for other attn backends + if self.attn_backend.get_name() != "flash-attn": + return False - token_id = seq_output.output_token - token_logprob = seq_output.logprobs[token_id] + # TODO: Add support for LORA + if self.lora_config: + return False - seq.append_token_id(token_id, token_logprob.logprob) - seq.update_num_computed_tokens(1) + # TODO: Add soft-tuning prompt adapter support + if self.prompt_adapter_config: + return False - return self.prepare_model_input(self.cached_seq_group_metadata_list) + return True @torch.inference_mode() def execute_model( @@ -125,42 +226,86 @@ def execute_model( intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, ) -> Optional[List[SamplerOutput]]: - # Since we do not broadcast data inside execute_model anymore, - # we need to figure out the best way to support TP > 1 in this - # case, because we will at least need to broadcast the sampled - # tokens to all workers. - if not self.is_driver_worker: - raise ValueError("TP1DraftModelRunner only supports TP=1.") + """Executes num_steps forward passes with advacement of input tensors + on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions. - if self.lora_config: - assert model_input.lora_requests is not None - assert model_input.lora_mapping is not None - self.set_active_loras(model_input.lora_requests, - model_input.lora_mapping) + Optimizations used: + 1. Input tensors are updated on the GPU directly + 2. Skips GPU=>CPU serialization of sampler outputs (we don't need + them since we do batch expansion later that uses GPU outputs) + 3. Reuses sampling tensors (since we run only decodes and they have + a repeating sampling logic) + """ - if self.prompt_adapter_config: - assert model_input.prompt_adapter_requests is not None - assert model_input.prompt_adapter_mapping is not None - self.set_active_prompt_adapters( - model_input.prompt_adapter_requests, - model_input.prompt_adapter_mapping) + # When num_steps == 1, we execute the fallback here for the GPU + # advance_step, which runs prepare_inputs on CPU and for each spec + # iteration invokes this function only once + # (Look at multi-step-worker code) + is_fallback = num_steps == 1 + if not is_fallback: + # Since we do not broadcast data inside execute_model anymore, + # we need to figure out the best way to support TP > 1 in this + # case, because we will at least need to broadcast the sampled + # tokens to all workers. + if not self.is_driver_worker: + raise ValueError("TP1DraftModelRunner only supports TP=1.") + + # Sanity + if self.lora_config is not None: + raise ValueError("TP1DraftModelRunner has no support for LORA") + if self.prompt_adapter_config is not None: + raise ValueError("TP1DraftModelRunner has no support for " + "prompt_adapter_config") + if model_input.multi_modal_kwargs: + raise ValueError( + "TP1DraftModelRunner has no support for multi_modal_kwargs" + ) + else: + if self.lora_config: + assert model_input.lora_requests is not None + assert model_input.lora_mapping is not None + self.set_active_loras(model_input.lora_requests, + model_input.lora_mapping) + + if self.prompt_adapter_config: + assert model_input.prompt_adapter_requests is not None + assert model_input.prompt_adapter_mapping is not None + self.set_active_prompt_adapters( + model_input.prompt_adapter_requests, + model_input.prompt_adapter_mapping) + + # Detect exec mode + assert model_input.attn_metadata is not None + use_cuda_graph = False + if model_input.attn_metadata.num_prefills > 0: + # In this case, execute_model(..) was called directly + if num_steps > 1: + raise ValueError( + "execute_model(..) of draft_model_runner can be called " + "directly only with a single-step prefill") + else: + # We can skip CPU samples for spec token generation. + # (We do allow CPU samples for num_steps == 1 to support the + # fallback case, where supports_gpu_multi_step(..) does not pass) + model_input.sampling_metadata.skip_sampler_cpu_output = ( + not is_fallback) + + # Attn attr defines if we use cuda graphs + use_cuda_graph = model_input.attn_metadata.use_cuda_graph + + # Get model + if use_cuda_graph: + graph_batch_size = model_input.input_tokens.shape[0] + model_executable = (self.graph_runners[model_input.virtual_engine] + [graph_batch_size]) + else: + model_executable = self.model - virtual_engine = model_input.virtual_engine outputs: List[SamplerOutput] = [] for step in range(num_steps): - # Currently cuda graph is only supported by the decode phase. - assert model_input.attn_metadata is not None - prefill_meta = model_input.attn_metadata.prefill_metadata - decode_meta = model_input.attn_metadata.decode_metadata - if prefill_meta is None and decode_meta.use_cuda_graph: - assert model_input.input_tokens is not None - graph_batch_size = model_input.input_tokens.shape[0] - model_executable = ( - self.graph_runners[virtual_engine][graph_batch_size]) - else: - model_executable = self.model - multi_modal_kwargs = model_input.multi_modal_kwargs or {} + + # Run model hidden_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, @@ -181,8 +326,8 @@ def execute_model( sampling_metadata=model_input.sampling_metadata, )) - # Prepare the inputs for the next step. + # Prepare inputs for the next step if step != num_steps - 1: - model_input = self.update_model_input(model_input, outputs[-1]) + model_input = self._gpu_advance_step(model_input, outputs[-1]) return outputs diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index 11e99882e3f0b..91689324557b5 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -67,14 +67,23 @@ def sampler_output( expanded_request, indices_of_seq_with_bonus_tokens =\ self._expand_execute_model_request( execute_model_req, seq_ids_with_bonus_token_in_last_step) + # Run model sample_len times. model_outputs: List[SamplerOutput] = [] - if isinstance(self.model_runner, TP1DraftModelRunner): + if isinstance( + self.model_runner, TP1DraftModelRunner + ) and self.model_runner.supports_gpu_multi_step(expanded_request): + # Here we run the draft_model_runner with multi-step prepare + # on the GPU directly expanded_request.num_steps = sample_len model_outputs = self.execute_model( execute_model_req=expanded_request) else: - # TODO: Remove this branch once DraftModelRunner supports TP>1. + # Here we run multi-step directly, with every step prepared + # on the CPU. + # TODO: Remove this branch once DraftModelRunner supports TP>1 + # and other restrictions that are part of DraftModelRunner's + # supports_gpu_multi_step(..) for _ in range(sample_len): model_output: List[SamplerOutput] = super().execute_model( execute_model_req=expanded_request) @@ -171,7 +180,7 @@ def _filter_model_output( outputs=[ expanded_batch_output.outputs[i] for i in output_indices_to_retain - ], + ] if len(expanded_batch_output.outputs) > 0 else [], sampled_token_probs=( expanded_batch_output. sampled_token_probs[output_indices_to_retain]