diff --git a/.buildkite/nightly-benchmarks/README.md b/.buildkite/nightly-benchmarks/README.md index c1aebaf5b3bbe..fbf41eb10a392 100644 --- a/.buildkite/nightly-benchmarks/README.md +++ b/.buildkite/nightly-benchmarks/README.md @@ -34,17 +34,18 @@ See [vLLM performance dashboard](https://perf.vllm.ai) for the latest performan Performance benchmark will be triggered when: - A PR being merged into vllm. -- Every commit for those PRs with `perf-benchmarks` label. +- Every commit for those PRs with `perf-benchmarks` label AND `ready` label. Nightly benchmark will be triggered when: -- Every commit for those PRs with `nightly-benchmarks` label. +- Every commit for those PRs with `perf-benchmarks` label and `nightly-benchmarks` label. ## Performance benchmark details -See [descriptions.md](tests/descriptions.md) for detailed descriptions, and use `tests/latency-tests.json`, `tests/throughput-tests.json`, `tests/serving-tests.json` to configure the test cases. + +See [performance-benchmarks-descriptions.md](performance-benchmarks-descriptions.md) for detailed descriptions, and use `tests/latency-tests.json`, `tests/throughput-tests.json`, `tests/serving-tests.json` to configure the test cases. #### Latency test @@ -68,7 +69,7 @@ Here is an example of one test inside `latency-tests.json`: In this example: - The `test_name` attributes is a unique identifier for the test. In `latency-tests.json`, it must start with `latency_`. -- The `parameters` attribute control the command line arguments to be used for `benchmark_latency.py`. Note that please use underline `_` instead of the dash `-` when specifying the command line arguments, and `run-benchmarks-suite.sh` will convert the underline to dash when feeding the arguments to `benchmark_latency.py`. For example, the corresponding command line arguments for `benchmark_latency.py` will be `--model meta-llama/Meta-Llama-3-8B --tensor-parallel-size 1 --load-format dummy --num-iters-warmup 5 --num-iters 15` +- The `parameters` attribute control the command line arguments to be used for `benchmark_latency.py`. Note that please use underline `_` instead of the dash `-` when specifying the command line arguments, and `run-performance-benchmarks.sh` will convert the underline to dash when feeding the arguments to `benchmark_latency.py`. For example, the corresponding command line arguments for `benchmark_latency.py` will be `--model meta-llama/Meta-Llama-3-8B --tensor-parallel-size 1 --load-format dummy --num-iters-warmup 5 --num-iters 15` Note that the performance numbers are highly sensitive to the value of the parameters. Please make sure the parameters are set correctly. diff --git a/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml b/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml index 8490c9f1da221..2b70e2da5d87c 100644 --- a/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml +++ b/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml @@ -21,7 +21,7 @@ steps: containers: - image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT command: - - bash .buildkite/nightly-benchmarks/run-benchmarks-suite.sh + - bash .buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh resources: limits: nvidia.com/gpu: 8 diff --git a/.buildkite/nightly-benchmarks/tests/descriptions.md b/.buildkite/nightly-benchmarks/performance-benchmarks-descriptions.md similarity index 100% rename from .buildkite/nightly-benchmarks/tests/descriptions.md rename to .buildkite/nightly-benchmarks/performance-benchmarks-descriptions.md diff --git a/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py b/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py index 534ecf17930e9..f90e464288cf1 100644 --- a/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py +++ b/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py @@ -174,8 +174,8 @@ def results_to_json(latency, throughput, serving): # document the result with open(results_folder / "benchmark_results.md", "w") as f: - results = read_markdown( - "../.buildkite/nightly-benchmarks/tests/descriptions.md") + results = read_markdown("../.buildkite/nightly-benchmarks/" + + "performance-benchmarks-descriptions.md") results = results.format( latency_tests_markdown_table=latency_md_table, throughput_tests_markdown_table=throughput_md_table, diff --git a/.buildkite/nightly-benchmarks/run-benchmarks-suite.sh b/.buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh similarity index 90% rename from .buildkite/nightly-benchmarks/run-benchmarks-suite.sh rename to .buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh index f6e41fcfdd0be..a0b9a409b758d 100644 --- a/.buildkite/nightly-benchmarks/run-benchmarks-suite.sh +++ b/.buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh @@ -37,9 +37,9 @@ check_hf_token() { ensure_sharegpt_downloaded() { local FILE=ShareGPT_V3_unfiltered_cleaned_split.json if [ ! -f "$FILE" ]; then - wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/$FILE + wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/$FILE else - echo "$FILE already exists." + echo "$FILE already exists." fi } @@ -68,11 +68,29 @@ wait_for_server() { done' && return 0 || return 1 } +kill_processes_launched_by_current_bash() { + # Kill all python processes launched from current bash script + current_shell_pid=$$ + processes=$(ps -eo pid,ppid,command | awk -v ppid="$current_shell_pid" -v proc="$1" '$2 == ppid && $3 ~ proc {print $1}') + if [ -n "$processes" ]; then + echo "Killing the following processes matching '$1':" + echo "$processes" + echo "$processes" | xargs kill -9 + else + echo "No processes found matching '$1'." + fi +} + kill_gpu_processes() { - # kill all processes on GPU. - ps aux | grep python | grep openai | awk '{print $2}' | xargs -r kill -9 - ps -e | grep pt_main_thread | awk '{print $1}' | xargs kill -9 + ps -aux + lsof -t -i:8000 | xargs -r kill -9 + pkill -f pt_main_thread + # this line doesn't work now + # ps aux | grep python | grep openai | awk '{print $2}' | xargs -r kill -9 + pkill -f python3 + pkill -f /usr/bin/python3 + # wait until GPU memory usage smaller than 1GB while [ $(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits | head -n 1) -ge 1000 ]; do @@ -82,11 +100,6 @@ kill_gpu_processes() { # remove vllm config file rm -rf ~/.config/vllm - # Print the GPU memory usage - # so that we know if all GPU processes are killed. - gpu_memory_usage=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits -i 0) - # The memory usage should be 0 MB. - echo "GPU 0 Memory Usage: $gpu_memory_usage MB" } upload_to_buildkite() { @@ -104,7 +117,7 @@ upload_to_buildkite() { fi # Use the determined command to annotate and upload artifacts - $BUILDKITE_AGENT_COMMAND annotate --style "info" --context "$BUILDKITE_LABEL-benchmark-results" < $RESULTS_FOLDER/benchmark_results.md + $BUILDKITE_AGENT_COMMAND annotate --style "info" --context "$BUILDKITE_LABEL-benchmark-results" <$RESULTS_FOLDER/benchmark_results.md $BUILDKITE_AGENT_COMMAND artifact upload "$RESULTS_FOLDER/*" } @@ -156,7 +169,7 @@ run_latency_tests() { latency_command: $latency, gpu_type: $gpu }') - echo "$jq_output" > "$RESULTS_FOLDER/$test_name.commands" + echo "$jq_output" >"$RESULTS_FOLDER/$test_name.commands" # run the benchmark eval "$latency_command" @@ -166,7 +179,6 @@ run_latency_tests() { done } - run_throughput_tests() { # run throughput tests using `benchmark_throughput.py` # $1: a json file specifying throughput test cases @@ -214,7 +226,7 @@ run_throughput_tests() { throughput_command: $command, gpu_type: $gpu }') - echo "$jq_output" > "$RESULTS_FOLDER/$test_name.commands" + echo "$jq_output" >"$RESULTS_FOLDER/$test_name.commands" # run the benchmark eval "$throughput_command" @@ -246,7 +258,6 @@ run_serving_tests() { continue fi - # get client and server arguments server_params=$(echo "$params" | jq -r '.server_parameters') client_params=$(echo "$params" | jq -r '.client_parameters') @@ -324,7 +335,7 @@ run_serving_tests() { client_command: $client, gpu_type: $gpu }') - echo "$jq_output" > "$RESULTS_FOLDER/${new_test_name}.commands" + echo "$jq_output" >"$RESULTS_FOLDER/${new_test_name}.commands" done @@ -341,6 +352,7 @@ main() { # dependencies (which wget && which curl) || (apt-get update && apt-get install -y wget curl) (which jq) || (apt-get update && apt-get -y install jq) + (which lsof) || (apt-get update && apt-get install -y lsof) # get the current IP address, required by benchmark_serving.py export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') @@ -359,7 +371,6 @@ main() { run_latency_tests $QUICK_BENCHMARK_ROOT/tests/latency-tests.json run_throughput_tests $QUICK_BENCHMARK_ROOT/tests/throughput-tests.json - # postprocess benchmarking results pip install tabulate pandas python3 $QUICK_BENCHMARK_ROOT/scripts/convert-results-json-to-markdown.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 7babffc62f431..d583610a78655 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -311,6 +311,15 @@ steps: - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py +- label: Multi-step Tests (4 GPUs) # 10min + working_dir: "/vllm-workspace/tests" + num_gpus: 4 + source_file_dependencies: + - vllm/ + - tests/multi_step/test_correctness.py + commands: + - pytest -v -s multi_step/test_correctness.py + - label: Pipeline Parallelism Test # 23min working_dir: "/vllm-workspace/tests" num_gpus: 4 diff --git a/collect_env.py b/collect_env.py index 76df97b099b1b..839d54172e775 100644 --- a/collect_env.py +++ b/collect_env.py @@ -66,6 +66,8 @@ "nccl", "transformers", "zmq", + "nvidia", + "pynvml", } DEFAULT_PIP_PATTERNS = { @@ -79,6 +81,8 @@ "nccl", "transformers", "zmq", + "nvidia", + "pynvml", } diff --git a/docs/requirements-docs.txt b/docs/requirements-docs.txt index 9a5964ec65b99..6a8d99635b8f0 100644 --- a/docs/requirements-docs.txt +++ b/docs/requirements-docs.txt @@ -3,6 +3,7 @@ sphinx-book-theme==1.0.1 sphinx-copybutton==0.5.2 myst-parser==2.0.0 sphinx-argparse==0.4.0 +msgspec # packages to install to build the documentation pydantic diff --git a/requirements-test.txt b/requirements-test.txt index 95909d37e2c94..cdbc3e50cc9ec 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -29,4 +29,5 @@ matplotlib # required for qwen-vl test aiohttp # quantization -bitsandbytes==0.42.0 \ No newline at end of file +bitsandbytes==0.42.0 +buildkite-test-collector==0.1.8 \ No newline at end of file diff --git a/setup.py b/setup.py index 9e34433eff0d8..ef599b613667b 100644 --- a/setup.py +++ b/setup.py @@ -279,7 +279,7 @@ def _build_custom_ops() -> bool: def _build_core_ext() -> bool: - return not _is_neuron() and not _is_tpu() and not _is_openvino() + return not (_is_neuron() or _is_tpu() or _is_openvino() or _is_xpu()) def get_hipcc_rocm_version(): diff --git a/tests/core/block/test_prefix_caching_block.py b/tests/core/block/test_prefix_caching_block.py index 5fb8ec06cfa03..c2226870c2e83 100644 --- a/tests/core/block/test_prefix_caching_block.py +++ b/tests/core/block/test_prefix_caching_block.py @@ -682,6 +682,32 @@ def test_eviction_order(num_blocks: int, block_size: int, seed: int): assert new_block[0].block_id == last_block_id + # Test case for cache mertics + @staticmethod + def test_metric(): + block_size = 16 + allocator = PrefixCachingBlockAllocator(num_blocks=4, + block_size=block_size) + # Test when no query (0/0) + assert allocator.get_prefix_cache_hit_rate() == 0.0 + + token_ids = list(range(block_size)) + allocator.allocate_immutable_block(prev_block=None, + token_ids=token_ids) + # Test 0/1 hit rate + assert allocator.get_prefix_cache_hit_rate() == 0.0 + + allocator.allocate_immutable_block(prev_block=None, + token_ids=token_ids) + # Test 1/2 hit rate + assert allocator.get_prefix_cache_hit_rate() == 0.5 + + # Test more than one block + for _ in range(2, 1005): + allocator.allocate_immutable_block(prev_block=None, + token_ids=token_ids) + assert allocator.get_prefix_cache_hit_rate() > 0.99 + @staticmethod def create_immutable_chain( block_size: int, diff --git a/tests/models/test_gguf.py b/tests/models/test_gguf.py index 5971179f01211..196cd88e039a1 100644 --- a/tests/models/test_gguf.py +++ b/tests/models/test_gguf.py @@ -7,6 +7,7 @@ import pytest from huggingface_hub import hf_hub_download +from transformers import AutoTokenizer from tests.quantization.utils import is_quant_method_supported @@ -20,7 +21,7 @@ MODELS = [ ("TinyLlama/TinyLlama-1.1B-Chat-v1.0", hf_hub_download("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", - filename="tinyllama-1.1b-chat-v1.0.Q4_0.gguf")), + filename="tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf")), ("TinyLlama/TinyLlama-1.1B-Chat-v1.0", hf_hub_download("duyntnet/TinyLlama-1.1B-Chat-v1.0-imatrix-GGUF", filename="TinyLlama-1.1B-Chat-v1.0-IQ4_XS.gguf")), @@ -39,22 +40,36 @@ @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("tp_size", [1, 2]) def test_models( + num_gpus_available, vllm_runner, example_prompts, model, dtype: str, max_tokens: int, num_logprobs: int, + tp_size: int, ) -> None: + if num_gpus_available < tp_size: + pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") + original_model, gguf_model = model + tokenizer = AutoTokenizer.from_pretrained(original_model) + messages = [[{ + 'role': 'user', + 'content': prompt + }] for prompt in example_prompts] + example_prompts = tokenizer.apply_chat_template(messages, + tokenize=False, + add_generation_prompt=True) + # Run unquantized model. with vllm_runner(model_name=original_model, dtype=dtype, max_model_len=MAX_MODEL_LEN, - enforce_eager=True, - tensor_parallel_size=1) as original_model: + tensor_parallel_size=tp_size) as original_model: original_outputs = original_model.generate_greedy_logprobs( example_prompts[:-1], max_tokens, num_logprobs) @@ -63,8 +78,7 @@ def test_models( with vllm_runner(model_name=gguf_model, dtype=dtype, max_model_len=MAX_MODEL_LEN, - enforce_eager=True, - tensor_parallel_size=1) as gguf_model: + tensor_parallel_size=tp_size) as gguf_model: gguf_outputs = gguf_model.generate_greedy_logprobs( example_prompts[:-1], max_tokens, num_logprobs) diff --git a/tests/multi_step/__init__.py b/tests/multi_step/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/multi_step/test_correctness.py b/tests/multi_step/test_correctness.py new file mode 100644 index 0000000000000..bc14311c66424 --- /dev/null +++ b/tests/multi_step/test_correctness.py @@ -0,0 +1,85 @@ +# Test the AsyncLLMEngine with multi-step-decoding + +from typing import List + +import pytest + +from ..utils import RemoteOpenAIServer + +MODELS = [ + "JackFram/llama-160m", +] +NUM_SCHEDULER_STEPS = [8] # Multi-step decoding steps +NUM_PROMPTS = [10] + +DEFAULT_SERVER_ARGS: List[str] = [ + "--disable-log-requests", + "--use-v2-block-manager", + "--worker-use-ray", + "--gpu-memory-utilization", + "0.85", + "--swap-space", + "16", +] + + +async def completions_with_server_args(prompts: List[str], model_name: str, + server_cli_args: List[str]): + + outputs = None + with RemoteOpenAIServer(model_name, server_cli_args) as server: + client = server.get_async_client() + outputs = await client.completions.create(model=model_name, + prompt=prompts, + temperature=0, + stream=False, + max_tokens=5) + assert outputs is not None + + return outputs + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize(("tp_size, pp_size"), [ + (1, 1), + (2, 2), +]) +@pytest.mark.parametrize("eager_mode", [False, True]) +@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) +@pytest.mark.parametrize("num_prompts", NUM_PROMPTS) +@pytest.mark.asyncio +async def test_multi_step(example_prompts, model: str, tp_size: int, + pp_size: int, eager_mode: int, + num_scheduler_steps: int, num_prompts: int): + + prompts = example_prompts + if len(prompts) < num_prompts: + prompts = prompts * ((num_prompts // len(prompts)) + 1) + prompts = prompts[:num_prompts] + assert len(prompts) == num_prompts + + server_args = DEFAULT_SERVER_ARGS + ["--enforce-eager"] + ms_server_args = DEFAULT_SERVER_ARGS + \ + ["--num-scheduler-steps", f"{num_scheduler_steps}"] + + if eager_mode: + ms_server_args.append("--enforce-eager") + + distributed_args = [ + "--tensor-parallel-size", + str(tp_size), + "--pipeline-parallel-size", + str(pp_size), + ] + + ref_completions = await completions_with_server_args( + prompts, model, server_args + distributed_args) + test_completions = await completions_with_server_args( + prompts, model, ms_server_args + distributed_args) + + def get_text_generations(completions): + return [x.text for x in completions.choices] + + ref_generations = get_text_generations(ref_completions) + test_generations = get_text_generations(test_completions) + assert ref_generations == test_generations diff --git a/tests/prefix_caching/test_prefix_caching.py b/tests/prefix_caching/test_prefix_caching.py index 9821dbd066a59..2dff84b812b89 100644 --- a/tests/prefix_caching/test_prefix_caching.py +++ b/tests/prefix_caching/test_prefix_caching.py @@ -34,6 +34,9 @@ def test_block_allocator( assert (first_block == second_block) assert (second_block.ref_count == 2) + # Check metric: 1 hit of 2 queries + assert block_allocator.get_prefix_cache_hit_rate() == 0.5 + # Free the first_block and confirm that the ref_count is correctly # decremented on the second block block_allocator.free(first_block) @@ -48,6 +51,10 @@ def test_block_allocator( assert (first_block == second_block) assert (first_block.block_hash == block_hash) + # Allocate one more time to get 3/4 hit rate for easy checking + block_allocator.allocate(block_hash, 0) + assert block_allocator.get_prefix_cache_hit_rate() == 0.75 + @pytest.mark.parametrize("num_blocks", [16]) def test_eviction(num_blocks: int, ): diff --git a/tests/worker/test_model_input.py b/tests/worker/test_model_input.py index 2126fafb2323b..a57fdac803e42 100644 --- a/tests/worker/test_model_input.py +++ b/tests/worker/test_model_input.py @@ -10,6 +10,7 @@ from vllm.worker.embedding_model_runner import ( ModelInputForGPUWithPoolingMetadata) from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata +from vllm.worker.multi_step_model_runner import StatefulModelInput class MockAttentionBackend(AttentionBackend): @@ -154,3 +155,79 @@ def test_embedding_model_runner_input(): None) == getattr(attn_metadata, field.name, None) # Pooling metadata is not broadcast. assert received_model_input.pooling_metadata is None + + +def test_multi_step_model_runner_input(): + sampling_metadata = SamplingMetadata( + ["seq_group"], + "selected_token_indices", + "categorized_sample_indices", + "num_prompts", + ) + attn_metadata = AttentionMetadata( + num_prefills=1, + num_prefill_tokens=2, + num_decode_tokens=3, + slot_mapping=torch.zeros(1), + ) + frozen_model_input = ModelInputForGPUWithSamplingMetadata( + input_tokens=torch.ones(10), + input_positions=torch.ones(10), + sampling_metadata=sampling_metadata, + attn_metadata=attn_metadata) + + model_input = StatefulModelInput( + frozen_model_input=frozen_model_input, + is_last_step=True, + is_first_multi_step=False, + current_step=4, + last_sampled_token_ids=torch.ones((10, 1)), + is_multi_step=True, + num_queries=8, + num_seqs=5, + cached_outputs=[], + ) + + assert isinstance(model_input, StatefulModelInput) + + # Test round trip serialization. + tensor_dict = model_input.as_broadcastable_tensor_dict() + attn_backend = MockAttentionBackend() + received_model_input = (StatefulModelInput.from_broadcasted_tensor_dict( + tensor_dict, attn_backend=attn_backend)) + + receieved_frozen_input = received_model_input.frozen_model_input + + # Check that received copy has correct values. + assert isinstance(received_model_input, StatefulModelInput) + assert receieved_frozen_input.input_tokens is not None + assert (receieved_frozen_input.input_tokens == + frozen_model_input.input_tokens).all() + assert receieved_frozen_input.input_positions is not None + assert (receieved_frozen_input.input_positions == + frozen_model_input.input_positions).all() + assert receieved_frozen_input.multi_modal_kwargs is None + assert (frozen_model_input.multi_modal_kwargs == + frozen_model_input.multi_modal_kwargs) + assert receieved_frozen_input.lora_requests is None + assert (receieved_frozen_input.lora_requests == + frozen_model_input.lora_requests) + assert receieved_frozen_input.lora_mapping is None + assert ( + receieved_frozen_input.lora_mapping == frozen_model_input.lora_mapping) + for field in dataclasses.fields(AttentionMetadata): + assert getattr(receieved_frozen_input.attn_metadata, field.name, + None) == getattr(attn_metadata, field.name, None) + # For sampling metadata, only selected_token_indices is copied. + assert (receieved_frozen_input.sampling_metadata.selected_token_indices == + sampling_metadata.selected_token_indices) + assert receieved_frozen_input.sampling_metadata.seq_groups is None + + # check non frozen fields + assert received_model_input.is_last_step == model_input.is_last_step + assert (received_model_input.is_first_multi_step == + model_input.is_first_multi_step) + assert received_model_input.current_step == model_input.current_step + assert (received_model_input.last_sampled_token_ids == + model_input.last_sampled_token_ids).all() + assert received_model_input.is_multi_step == model_input.is_multi_step diff --git a/vllm/core/block/common.py b/vllm/core/block/common.py index 1e808e21b72e5..eb190adfbe802 100644 --- a/vllm/core/block/common.py +++ b/vllm/core/block/common.py @@ -1,4 +1,5 @@ from collections import deque +from dataclasses import dataclass from typing import Deque, Dict, Iterable, List, Optional, Protocol, Tuple from vllm.core.block.interfaces import Block, BlockAllocator @@ -282,6 +283,58 @@ def ids(self) -> List[int]: return self._block_ids +@dataclass +class CacheMetricData: + """A utility dataclass to maintain cache metric. + To avoid overflow, we maintain the hit rate in block granularity, so that + we can maintain a single hit rate for n_completed_block x block_size, + and calculate the real time hit rate by the following: + BS = The number of queries per block. + nB = The number of completed blocks. + HR = hit rate of (nB x BS) queries. + Q = current number of queries (< BS). + H = current number of hits (< BS). + hit rate = ((HR x nB) + (H / Q) x (Q / BS)) / (nB + Q / BS) + """ + num_completed_blocks: int = 0 + completed_block_cache_hit_rate: float = 0.0 + num_incompleted_block_queries: int = 0 + num_incompleted_block_hit: int = 0 + block_size: int = 1000 + + def query(self, hit: bool): + self.num_incompleted_block_queries += 1 + self.num_incompleted_block_hit += 1 if hit else 0 + + # When a block is completed, update the cache hit rate + # and reset the incomplete numbers. + if self.num_incompleted_block_queries == self.block_size: + hit_rate = (self.num_incompleted_block_hit / + self.num_incompleted_block_queries) + self.completed_block_cache_hit_rate = ( + self.completed_block_cache_hit_rate * self.num_completed_blocks + + hit_rate) / (self.num_completed_blocks + 1) + self.num_incompleted_block_queries = 0 + self.num_incompleted_block_hit = 0 + self.num_completed_blocks += 1 + + def get_hit_rate(self): + incomplete_ratio = self.num_incompleted_block_queries / self.block_size + total_blocks = self.num_completed_blocks + incomplete_ratio + if total_blocks == 0: + return 0.0 + + completed_block_hit, incompleted_block_hit = 0.0, 0.0 + if self.num_completed_blocks > 0: + completed_block_hit = (self.completed_block_cache_hit_rate * + self.num_completed_blocks) + if self.num_incompleted_block_queries > 0: + incompleted_hit_rate = (self.num_incompleted_block_hit / + self.num_incompleted_block_queries) + incompleted_block_hit = (incompleted_hit_rate * incomplete_ratio) + return (completed_block_hit + incompleted_block_hit) / total_blocks + + def get_all_blocks_recursively(last_block: Block) -> List[Block]: """Retrieves all the blocks in a sequence starting from the last block. diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py index 5287cd9c1bfb3..c6330df2a485a 100644 --- a/vllm/core/block/cpu_gpu_block_allocator.py +++ b/vllm/core/block/cpu_gpu_block_allocator.py @@ -323,6 +323,11 @@ def get_common_computed_block_ids( def all_block_ids(self) -> FrozenSet[int]: return frozenset(self._block_ids_to_allocator.keys()) + def get_prefix_cache_hit_rate(self, device: Device) -> float: + """Prefix cache hit rate. -1 means not supported or disabled.""" + assert device in self._allocators + return self._allocators[device].get_prefix_cache_hit_rate() + def get_and_reset_swaps(self) -> List[Tuple[int, int]]: """Returns and clears the mapping of source to destination block IDs. Will be called after every swapping operations for now, and after every diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py index ab39832bc1f6e..f26bc761c9967 100644 --- a/vllm/core/block/interfaces.py +++ b/vllm/core/block/interfaces.py @@ -186,6 +186,11 @@ def get_num_blocks_touched(self, num_lookahead_slots: int = 0) -> int: pass + @abstractmethod + def get_prefix_cache_hit_rate(self) -> float: + """Prefix cache hit rate. -1 means not supported or disabled.""" + pass + class NoFreeBlocksError(ValueError): pass @@ -278,3 +283,8 @@ def allocate_or_get_null_block(self) -> Block: There is at most one null block per allocator. """ pass + + @abstractmethod + def get_prefix_cache_hit_rate(self, device: Device) -> float: + """Prefix cache hit rate. -1 means not supported or disabled.""" + pass diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py index 14a62c2e7190e..1643fd69c58ab 100644 --- a/vllm/core/block/naive_block.py +++ b/vllm/core/block/naive_block.py @@ -341,6 +341,9 @@ def swap_in(self, blocks: List[Block]) -> None: block.block_id = block_id # Assign block_id + def get_prefix_cache_hit_rate(self) -> float: + return -1 + class NaiveBlock(Block): """An implementation of the Block class that does not support prefix diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index e145eeba2d66e..432a6651ab07a 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -1,9 +1,8 @@ """Token blocks.""" - from os.path import commonprefix from typing import Dict, FrozenSet, Iterable, List, Optional, Tuple -from vllm.core.block.common import (CopyOnWriteTracker, +from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker, get_all_blocks_recursively) from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device from vllm.core.block.naive_block import (BlockPool, NaiveBlock, @@ -107,6 +106,8 @@ def __init__( self._cow_tracker = CopyOnWriteTracker( refcounter=self._refcounter.as_readonly()) + self.metric_data = CacheMetricData() + # Implements Block.Factory. def _create_block( self, @@ -155,9 +156,11 @@ def allocate_immutable_block(self, cached_block_id = self._cached_blocks.get(block.content_hash, None) if cached_block_id is not None: + self.metric_data.query(hit=True) block.block_id = cached_block_id self._incr_refcount_cached_block(block) return block + self.metric_data.query(hit=False) self._block_pool.free_block(block) # No cached block => Allocate a new block @@ -404,6 +407,9 @@ def get_physical_block_id(self, absolute_id: int) -> int: def all_block_ids(self) -> FrozenSet[int]: return self._hashless_allocator.all_block_ids + def get_prefix_cache_hit_rate(self) -> float: + return self.metric_data.get_hit_rate() + def is_block_cached(self, block: Block) -> bool: assert block.content_hash is not None if block.content_hash in self._cached_blocks: diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index ad26d3c516ff0..0af04399a4b31 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -8,6 +8,7 @@ from typing import Set, Tuple from vllm.block import BlockTable, PhysicalTokenBlock +from vllm.core.block.common import CacheMetricData from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec from vllm.core.evictor_v1 import EvictionPolicy, Evictor, make_evictor from vllm.core.interfaces import AllocStatus, BlockSpaceManager @@ -60,6 +61,11 @@ def contains_block(self, block_hash: int) -> bool: def update_hash(self, block_hash: int, block: PhysicalTokenBlock): pass + @abstractmethod + def get_prefix_cache_hit_rate(self) -> float: + """Prefix cache hit rate. -1 means not supported or disabled.""" + pass + class CachedBlockAllocator(BlockAllocatorBase): """Manages free physical token blocks for a device. @@ -85,6 +91,8 @@ def __init__(self, self.default_hash_ctr = count() + self.cache_metric_data = CacheMetricData() + def allocate_block(self, block_hash: int, num_hashed_tokens: int) -> PhysicalTokenBlock: if self.current_num_blocks == self.num_blocks: @@ -105,15 +113,17 @@ def allocate(self, num_hashed_tokens: int = 0) -> PhysicalTokenBlock: if block_hash is None: block_hash = next(self.default_hash_ctr) + if block_hash in self.evictor: assert block_hash not in self.cached_blocks block = self.evictor.remove(block_hash) assert block.ref_count == 0 self.cached_blocks[block_hash] = block - block.ref_count += 1 - assert block.block_hash == block_hash - return block - if block_hash not in self.cached_blocks: + + if block_hash in self.cached_blocks: + self.cache_metric_data.query(hit=True) + else: + self.cache_metric_data.query(hit=False) self.cached_blocks[block_hash] = self.allocate_block( block_hash, num_hashed_tokens) block = self.cached_blocks[block_hash] @@ -150,6 +160,9 @@ def update_hash(self, block_hash: int, block: PhysicalTokenBlock): del self.cached_blocks[old_hash] self.cached_blocks[block_hash] = block + def get_prefix_cache_hit_rate(self) -> float: + return self.cache_metric_data.get_hit_rate() + class UncachedBlockAllocator(BlockAllocatorBase): """Manages free physical token blocks for a device. @@ -209,6 +222,9 @@ def update_hash(self, block_hash: int, block: PhysicalTokenBlock): raise NotImplementedError( "Invalid codepath for uncached block allocator.") + def get_prefix_cache_hit_rate(self) -> float: + return -1 + class BlockSpaceManagerV1(BlockSpaceManager): """Manages the mapping between logical and physical token blocks.""" @@ -705,3 +721,10 @@ def mark_blocks_as_computed(self, seq_group: SequenceGroup): if self.enable_caching: for seq in seq_group.get_seqs(): self.compute_full_blocks_in_seq(seq) + + def get_prefix_cache_hit_rate(self, device: Device) -> float: + if device == Device.GPU: + return self.gpu_allocator.get_prefix_cache_hit_rate() + if device == Device.CPU: + return self.cpu_allocator.get_prefix_cache_hit_rate() + raise ValueError(f"Invalid device: {device}") diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index b48ea1b19b82a..b7d9451f18067 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -441,6 +441,9 @@ def get_num_free_gpu_blocks(self) -> int: def get_num_free_cpu_blocks(self) -> int: return self.block_allocator.get_num_free_blocks(Device.CPU) + def get_prefix_cache_hit_rate(self, device: Device) -> float: + return self.block_allocator.get_prefix_cache_hit_rate(device) + def _can_swap(self, seq_group: SequenceGroup, device: Device, diff --git a/vllm/core/embedding_model_block_manager.py b/vllm/core/embedding_model_block_manager.py index f2d67306d7ceb..3d864a73f91d0 100644 --- a/vllm/core/embedding_model_block_manager.py +++ b/vllm/core/embedding_model_block_manager.py @@ -2,6 +2,7 @@ from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.sequence import Sequence, SequenceGroup +from vllm.utils import Device class EmbeddingModelBlockSpaceManager(BlockSpaceManager): @@ -81,3 +82,6 @@ def get_common_computed_block_ids(self, def mark_blocks_as_computed(self, seq_group: SequenceGroup): pass + + def get_prefix_cache_hit_rate(self, device: Device) -> float: + return -1 diff --git a/vllm/core/evictor_v2.py b/vllm/core/evictor_v2.py index 5b1a208b7c866..0b943e6e65f1c 100644 --- a/vllm/core/evictor_v2.py +++ b/vllm/core/evictor_v2.py @@ -85,19 +85,21 @@ def evict(self) -> Tuple[int, int]: if len(self.free_table) == 0: raise ValueError("No usable cache memory left") - evicted_block = next(iter(self.free_table.values())) - evicted_block_id = next(iter(self.free_table.keys())) + evicted_block, evicted_block_id = None, None # The blocks with the lowest timestamps should be placed consecutively # at the start of OrderedDict. Loop through all these blocks to # find the one with maximum number of hashed tokens. for _id, block in self.free_table.items(): + if evicted_block is None: + evicted_block, evicted_block_id = block, _id + continue if evicted_block.last_accessed < block.last_accessed: break - if (evicted_block.last_accessed == block.last_accessed and - evicted_block.num_hashed_tokens < block.num_hashed_tokens): - evicted_block = block - evicted_block_id = _id + if evicted_block.num_hashed_tokens < block.num_hashed_tokens: + evicted_block, evicted_block_id = block, _id + assert evicted_block is not None + assert evicted_block_id is not None self.free_table.pop(evicted_block_id) return evicted_block_id, evicted_block.content_hash @@ -110,7 +112,6 @@ def add(self, block_id: int, content_hash: int, num_hashed_tokens: int, def update(self, block_id: int, last_accessed: float): self.free_table[block_id].last_accessed = last_accessed - self.free_table.move_to_end(block_id) def remove(self, block_id: int): if block_id not in self.free_table: diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index 8759ee06795b8..becd0d2e7f849 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -5,6 +5,7 @@ from typing import Tuple from vllm.sequence import Sequence, SequenceGroup +from vllm.utils import Device class AllocStatus(enum.Enum): @@ -116,3 +117,8 @@ def get_common_computed_block_ids( @abstractmethod def mark_blocks_as_computed(self, seq_group: SequenceGroup): pass + + @abstractmethod + def get_prefix_cache_hit_rate(self, device: Device) -> float: + """Prefix cache hit rate. -1 means not supported or disabled.""" + pass diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 802359d2283f7..3b716e32032c1 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -14,7 +14,7 @@ from vllm.sequence import (Sequence, SequenceData, SequenceGroup, SequenceGroupMetadata, SequenceGroupMetadataDelta, SequenceStatus) -from vllm.utils import PyObjectCache +from vllm.utils import Device, PyObjectCache logger = init_logger(__name__) @@ -447,6 +447,9 @@ def has_unfinished_seqs(self) -> bool: return len(self.waiting) != 0 or len(self.running) != 0 or len( self.swapped) != 0 + def get_prefix_cache_hit_rate(self, device: Device) -> float: + return self.block_manager.get_prefix_cache_hit_rate(device) + def get_num_unfinished_seq_groups(self) -> int: return len(self.waiting) + len(self.running) + len(self.swapped) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 8fca2cc049958..7f45c3d06375a 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -5,6 +5,8 @@ from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Type, Union) +import torch + import vllm.envs as envs from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, EngineConfig, LoadConfig, LoRAConfig, ModelConfig, @@ -113,7 +115,7 @@ class EngineArgs: fully_sharded_loras: bool = False lora_extra_vocab_size: int = 256 long_lora_scaling_factors: Optional[Tuple[float]] = None - lora_dtype: str = 'auto' + lora_dtype: Optional[Union[str, torch.dtype]] = 'auto' max_cpu_loras: Optional[int] = None device: str = 'auto' num_scheduler_steps: int = 1 @@ -662,8 +664,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument( '--disable-logprobs-during-spec-decoding', - type=bool, + action=StoreBoolean, default=EngineArgs.disable_logprobs_during_spec_decoding, + nargs="?", + const="True", help='If set to True, token log probabilities are not returned ' 'during speculative decoding. If set to False, log probabilities ' 'are returned according to the settings in SamplingParams. If ' @@ -851,6 +855,12 @@ def create_engine_config(self, ) -> EngineConfig: "in low performance due to small KV cache space. Consider " "setting --max-model-len to a smaller value.", max_model_len) + if self.num_scheduler_steps > 1 and not self.use_v2_block_manager: + self.use_v2_block_manager = True + logger.warning( + "Enabled BlockSpaceManagerV2 because it is " + "required for multi-step (--num-scheduler-steps > 1)") + speculative_config = SpeculativeConfig.maybe_create_spec_config( target_model_config=model_config, target_parallel_config=parallel_config, @@ -879,7 +889,6 @@ def create_engine_config(self, ) -> EngineConfig: ) if self.num_scheduler_steps > 1: - raise NotImplementedError("Multi-step is not yet supported.") if speculative_config is not None: raise ValueError("Speculative decoding is not supported with " "multi-step (--num-scheduler-steps > 1)") diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index dced804fccca9..6385d3ca2297e 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1,9 +1,11 @@ import asyncio import time +from dataclasses import dataclass from functools import partial from typing import (AsyncGenerator, Callable, Dict, Iterable, List, Mapping, Optional, Set, Tuple, Type, Union) +import torch from transformers import PreTrainedTokenizer from typing_extensions import assert_never @@ -27,7 +29,8 @@ from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.sequence import (ExecuteModelRequest, SamplerOutput, + SequenceGroupMetadata) from vllm.usage.usage_lib import UsageContext from vllm.utils import print_warning_once @@ -249,9 +252,25 @@ def has_new_requests(self): return not self._new_requests.empty() +@dataclass +class SchedulerOutputState: + """Caches the scheduler outputs for a virtual engine. Used for Multi-Step""" + last_output: Optional[SamplerOutput] = None + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None + scheduler_outputs: Optional[SchedulerOutputs] = None + + class _AsyncLLMEngine(LLMEngine): """Extension of LLMEngine to add async methods.""" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + pipeline_parallel_size = \ + self.parallel_config.pipeline_parallel_size + self.cached_scheduler_outputs = [ + SchedulerOutputState() for _ in range(pipeline_parallel_size) + ] + async def step_async( self, virtual_engine: int ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: @@ -264,13 +283,39 @@ async def step_async( and updates the scheduler with the model outputs. Finally, it decodes the sequences and returns the newly generated results. """ - seq_group_metadata_list, scheduler_outputs = self.scheduler[ - virtual_engine].schedule() + # these are cached outputs from previous iterations. None if on first + # iteration + cached_outputs = self.cached_scheduler_outputs[virtual_engine] + seq_group_metadata_list = cached_outputs.seq_group_metadata_list + scheduler_outputs = cached_outputs.scheduler_outputs + # skip the scheduler if there are any remaining steps in the seq groups. + # This ensures that the scheduler is only called again when the current + # batch has completed. + if not self._has_remaining_steps(seq_group_metadata_list): + seq_group_metadata_list, scheduler_outputs = self.scheduler[ + virtual_engine].schedule() + + if (self.scheduler_config.is_multi_step + and scheduler_outputs.num_lookahead_slots > 0): + # cache the scheduler outputs for the next iteration if we have + # lookahead slots + self._cache_scheduler_outputs_for_multi_step( + virtual_engine, seq_group_metadata_list, scheduler_outputs) + + assert seq_group_metadata_list is not None + assert scheduler_outputs is not None if not scheduler_outputs.is_empty(): - # Execute the model. finished_requests_ids = self.scheduler[ virtual_engine].get_and_reset_finished_requests_ids() + + # Check if we have a cached last_output from the previous iteration. + # For supporting PP this is probably the best way to pass the + # sampled_token_ids, as a separate broadcast over all the PP stages + # will cause one virtual engine's microbatch to block the pipeline. + last_sampled_token_ids = \ + self._get_last_sampled_token_ids(virtual_engine) + execute_model_req = ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, @@ -279,15 +324,35 @@ async def step_async( virtual_engine=virtual_engine, num_lookahead_slots=scheduler_outputs.num_lookahead_slots, running_queue_size=scheduler_outputs.running_queue_size, - finished_requests_ids=finished_requests_ids) + finished_requests_ids=finished_requests_ids, + # We use ExecuteModelRequest to pass the last sampled_token_ids + # to each of the non-last PP stages for in-place prepare_input. + last_sampled_token_ids=last_sampled_token_ids) + # Execute the model. output = await self.model_executor.execute_model_async( execute_model_req) + # we need to do this here so that last step's sampled_token_ids can + # be passed to the next iteration for PP. + if self.scheduler_config.is_multi_step: + self._update_cached_scheduler_output(virtual_engine, output) else: output = [] - request_outputs = self._process_model_outputs( - output, scheduler_outputs.scheduled_seq_groups, - scheduler_outputs.ignored_seq_groups, seq_group_metadata_list) + # Finish the current step for all the sequence groups. + if self.scheduler_config.is_multi_step: + for seq_group in seq_group_metadata_list: + seq_group.finish_step() + + if not self._has_remaining_steps(seq_group_metadata_list): + # clear the cache if we have finished all the steps + if self.scheduler_config.is_multi_step: + self.cached_scheduler_outputs[ + virtual_engine] = SchedulerOutputState() + request_outputs = self._process_model_outputs( + output, scheduler_outputs.scheduled_seq_groups, + scheduler_outputs.ignored_seq_groups, seq_group_metadata_list) + else: + request_outputs = [] # Log stats. self.do_log_stats(scheduler_outputs, output) @@ -297,6 +362,60 @@ async def step_async( return request_outputs + def _has_remaining_steps( + self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] + ) -> bool: + if (not self.scheduler_config.is_multi_step + or not seq_group_metadata_list): + return False + + # TODO(will) this is a sanity check for nowto make sure that all the + # seqs are on the same steps. Eventually we will want to do some sort of + # dynamic scheduling when doing multi-step decoding. + ref_remaining_steps = seq_group_metadata_list[0].state.remaining_steps + if any([ + seq_group.state.remaining_steps != ref_remaining_steps + for seq_group in seq_group_metadata_list[1:] + ]): + raise AssertionError(("All running sequence groups should " + "have the same remaining steps.")) + + return ref_remaining_steps > 0 + + def _cache_scheduler_outputs_for_multi_step( + self, virtual_engine: int, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + scheduler_outputs: SchedulerOutputs) -> None: + self.cached_scheduler_outputs[ + virtual_engine].seq_group_metadata_list = seq_group_metadata_list + self.cached_scheduler_outputs[virtual_engine].scheduler_outputs = \ + scheduler_outputs + self.cached_scheduler_outputs[virtual_engine].last_output = None + + def _get_last_sampled_token_ids( + self, virtual_engine: int) -> Optional[torch.Tensor]: + cached_last_output = self.cached_scheduler_outputs[ + virtual_engine].last_output + if (self.scheduler_config.is_multi_step + and self.parallel_config.pipeline_parallel_size > 1 + and cached_last_output is not None + and cached_last_output.sampled_token_ids_cpu is not None): + return cached_last_output.sampled_token_ids_cpu + return None + + def _update_cached_scheduler_output( + self, virtual_engine: int, + output: List[Optional[SamplerOutput]]) -> None: + if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0 + and output[0] is not None): + last_output = output[-1] + assert last_output is not None + assert last_output.sampled_token_ids_cpu is not None + assert last_output.sampled_token_ids is None + assert last_output.sampled_token_probs is None + self.cached_scheduler_outputs[ + virtual_engine].last_output = last_output + async def stop_remote_worker_execution_loop_async(self) -> None: """Stop the remote worker execution loop.""" await self.model_executor.stop_remote_worker_execution_loop_async() diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index fcf45a38b9425..36cb6ce795f3e 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -47,7 +47,7 @@ AnyTokenizer, BaseTokenizerGroup, init_tokenizer_from_configs) from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) -from vllm.utils import Counter +from vllm.utils import Counter, Device from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) @@ -1390,6 +1390,13 @@ def _get_stats( for scheduler in self.scheduler) cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu) + # Prefix Cache Hit Rate. Note that we always use + # the cache hit rate of the first virtual engine. + cpu_prefix_cache_hit_rate = self.scheduler[ + 0].get_prefix_cache_hit_rate(Device.CPU) + gpu_prefix_cache_hit_rate = self.scheduler[ + 0].get_prefix_cache_hit_rate(Device.GPU) + # Iteration stats num_prompt_tokens_iter = 0 num_generation_tokens_iter = 0 @@ -1498,6 +1505,9 @@ def _get_stats( # KV Cache Usage in % gpu_cache_usage_sys=gpu_cache_usage_sys, cpu_cache_usage_sys=cpu_cache_usage_sys, + # Prefix Cache Hit Rate + cpu_prefix_cache_hit_rate=cpu_prefix_cache_hit_rate, + gpu_prefix_cache_hit_rate=gpu_prefix_cache_hit_rate, # Iteration stats num_prompt_tokens_iter=num_prompt_tokens_iter, diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index 1071786c27cd6..74277cae7c8ef 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -71,6 +71,17 @@ def __init__(self, labelnames: List[str], max_model_len: int): documentation="CPU KV-cache usage. 1 means 100 percent usage.", labelnames=labelnames, multiprocess_mode="sum") + # Prefix caching block hit rate + self.gauge_cpu_prefix_cache_hit_rate = self._gauge_cls( + name="vllm:cpu_prefix_cache_hit_rate", + documentation="CPU prefix cache block hit rate.", + labelnames=labelnames, + multiprocess_mode="sum") + self.gauge_gpu_prefix_cache_hit_rate = self._gauge_cls( + name="vllm:gpu_prefix_cache_hit_rate", + documentation="GPU prefix cache block hit rate.", + labelnames=labelnames, + multiprocess_mode="sum") # Iteration stats self.counter_num_preemption = self._counter_cls( @@ -351,7 +362,13 @@ def log(self, stats: Stats) -> None: stats.gpu_cache_usage_sys * 100, stats.cpu_cache_usage_sys * 100, ) - + if (stats.cpu_prefix_cache_hit_rate >= 0 + or stats.gpu_prefix_cache_hit_rate >= 0): + logger.info( + "Prefix cache hit rate: GPU: %.2f%%, CPU: %.2f%%", + stats.gpu_prefix_cache_hit_rate * 100, + stats.cpu_prefix_cache_hit_rate * 100, + ) if self.spec_decode_metrics is not None: logger.info( self._format_spec_decode_metrics_str( @@ -423,6 +440,10 @@ def _log_prometheus(self, stats: Stats) -> None: stats.gpu_cache_usage_sys) self._log_gauge(self.metrics.gauge_cpu_cache_usage, stats.cpu_cache_usage_sys) + self._log_gauge(self.metrics.gauge_cpu_prefix_cache_hit_rate, + stats.cpu_prefix_cache_hit_rate) + self._log_gauge(self.metrics.gauge_gpu_prefix_cache_hit_rate, + stats.gpu_prefix_cache_hit_rate) # Iteration level data self._log_counter(self.metrics.counter_num_preemption, diff --git a/vllm/engine/metrics_types.py b/vllm/engine/metrics_types.py index 7449aafc5aecb..1eccb23593408 100644 --- a/vllm/engine/metrics_types.py +++ b/vllm/engine/metrics_types.py @@ -32,6 +32,9 @@ class Stats: # KV Cache Usage in % gpu_cache_usage_sys: float cpu_cache_usage_sys: float + # Prefix caching block hit rate + cpu_prefix_cache_hit_rate: float + gpu_prefix_cache_hit_rate: float # Iteration stats (should have _iter suffix) num_prompt_tokens_iter: int diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 55976f430254c..7d40607e81791 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -69,13 +69,19 @@ def _get_create_worker_kwargs( distributed_init_method: Optional[str] = None) -> Dict: worker_kwargs = self._get_worker_kwargs(local_rank, rank, distributed_init_method) - if self.speculative_config is None: - worker_kwargs.update(worker_module_name="vllm.worker.worker", - worker_class_name="Worker") - else: + + if self.scheduler_config.is_multi_step: + worker_kwargs.update( + worker_module_name="vllm.worker.multi_step_worker", + worker_class_name="MultiStepWorker") + elif self.speculative_config: worker_kwargs.update( worker_module_name="vllm.spec_decode.spec_decode_worker", worker_class_name="create_spec_worker") + else: + worker_kwargs.update(worker_module_name="vllm.worker.worker", + worker_class_name="Worker") + return worker_kwargs def _create_worker(self, diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 3a08ab4dbfd44..4c38cd1cbd546 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -94,6 +94,9 @@ def _get_worker_wrapper_args(self) -> Dict[str, Any]: if self.speculative_config is not None: worker_module_name = "vllm.spec_decode.spec_decode_worker" worker_class_name = "create_spec_worker" + elif self.scheduler_config.is_multi_step: + worker_module_name = "vllm.worker.multi_step_worker" + worker_class_name = "MultiStepWorker" else: worker_module_name = "vllm.worker.worker" worker_class_name = "Worker" diff --git a/vllm/logger.py b/vllm/logger.py index 3c6bf0803a624..77dddbfb60965 100644 --- a/vllm/logger.py +++ b/vllm/logger.py @@ -43,6 +43,7 @@ }, }, "version": 1, + "disable_existing_loggers": False } diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index b4cc6daa3c41e..3824ed3570aeb 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -507,11 +507,16 @@ def weight_loader(self, loaded_shard_id if is_gguf_weight: - shard_size = loaded_weight.shape[output_dim] - shard_offset = loaded_weight.shape[output_dim] * \ - loaded_shard_id + tp_size = get_tensor_model_parallel_world_size() + output_dim = getattr(param, "output_dim", None) + shard_shape = list(loaded_weight.shape) + shard_shape[output_dim] = shard_shape[output_dim] // tp_size param.shard_id.append(loaded_shard_id) - param.shard_size[loaded_shard_id] = loaded_weight.shape + param.shard_size[loaded_shard_id] = shard_shape + + input_dim = getattr(param, "input_dim", None) + input_size = loaded_weight.shape[input_dim] + param_data = param_data.narrow(input_dim, 0, input_size) param_data = param_data.narrow(output_dim, shard_offset, shard_size) @@ -863,8 +868,13 @@ def weight_loader(self, param, orig_qkv_offsets, loaded_shard_id) if is_gguf_weight: + tp_size = get_tensor_model_parallel_world_size() + output_dim = getattr(param, "output_dim", None) + shard_shape = list(loaded_weight.shape) + shard_shape[output_dim] = shard_shape[output_dim] // tp_size param.shard_id.append(loaded_shard_id) - param.shard_size[loaded_shard_id] = loaded_weight.shape + param.shard_size[loaded_shard_id] = shard_shape + input_dim = getattr(param, "input_dim", None) input_size = loaded_weight.shape[input_dim] param_data = param_data.narrow(input_dim, 0, input_size) @@ -976,6 +986,7 @@ def __init__(self, def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() input_dim = getattr(param, "input_dim", None) # Special case for GGUF @@ -986,7 +997,10 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): # Materialize GGUF UninitializedParameter if is_gguf_weight and isinstance(param, UninitializedParameter): - param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype) + weight_shape = list(loaded_weight.shape) + if input_dim: + weight_shape[input_dim] = weight_shape[input_dim] // tp_size + param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype) param_data = param.data if input_dim is not None: diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index a4e0a4d509608..a6a1ed5b0dee5 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -5,7 +5,6 @@ from torch.nn.parameter import Parameter, UninitializedParameter from vllm import _custom_ops as ops -from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) @@ -39,9 +38,6 @@ def get_config_filenames(cls) -> List[str]: @classmethod def from_config(cls, config: Dict[str, Any]) -> "GGUFConfig": - if get_tensor_model_parallel_world_size() > 1: - raise ValueError( - "GGUF quantization hasn't supported tensor parallelism yet.") return cls() def get_quant_method(self, layer: torch.nn.Module, diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index e6ee2b967c8da..0562b71aa7493 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -93,11 +93,6 @@ def __init__( def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: """Compute the inverse frequency.""" - # NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`. - # However, we use `torch.arange(..., dtype=torch.float)` instead to - # avoid numerical issues with large base values (e.g., 10000000). - # This may cause a slight numerical difference between the HF - # implementation and ours. # NOTE(woosuk): To exactly match the HF implementation, we need to # use CPU to compute the cache and then move it to GPU. However, we # create the cache on GPU for faster initialization. This may cause @@ -724,16 +719,6 @@ def forward( return query, key -class GemmaRotaryEmbedding(RotaryEmbedding): - - def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: - # https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/gemma/modeling_gemma.py#L107 - inv_freq = 1.0 / (base**( - torch.arange(0, self.rotary_dim, 2, dtype=torch.int64).float() / - self.rotary_dim)) - return inv_freq - - class Llama3RotaryEmbedding(RotaryEmbedding): def __init__( diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index 74e534aa76a9d..28f69cfbc46bd 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -414,6 +414,8 @@ def __init__(self, config.hidden_size, quant_config=quant_config, ) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight self.num_experts = config.num_local_experts self.num_experts_per_tok = config.num_experts_per_tok self.unpadded_vocab_size = config.vocab_size diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index a11c7663263c6..73711d8eb5185 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -331,6 +331,8 @@ def __init__( self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index ef988532ce126..f78400b0df7b3 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -821,6 +821,8 @@ def __init__(self, lora_config: Optional[LoRAConfig] = None): super().__init__() + # currently all existing BART models have `tie_word_embeddings` enabled + assert config.tie_word_embeddings self.config = config self.model = BartModel(config, cache_config, diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 8cfd3c2672568..20dda2a67820d 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -494,6 +494,9 @@ def __init__(self, super().__init__() + # currently all existing BLIP-2 models have `tie_word_embeddings` + # enabled + assert config.tie_word_embeddings self.config = config self.multimodal_config = multimodal_config diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 282a0f84eacb1..07ee0e3c531d0 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -36,7 +36,7 @@ QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) + ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput @@ -276,7 +276,12 @@ def __init__( self.config = config self.quant_config = quant_config self.transformer = BloomModel(config, cache_config, quant_config) - self.lm_head = self.transformer.word_embeddings + if self.config.tie_word_embeddings: + self.lm_head = self.transformer.word_embeddings + else: + self.lm_head = ParallelLMHead(self.config.vocab_size, + self.config.hidden_size) + self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index b29ebe2f59e7b..4949d0232fabb 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -356,6 +356,9 @@ def __init__( self.max_position_embeddings = getattr(config, "max_sequence_length", 8192) self.transformer = ChatGLMModel(config, cache_config, quant_config) + if self.config.tie_word_embeddings: + self.transformer.output_layer.weight = ( + self.transformer.embedding.weight) self.lm_head = self.transformer.output_layer self.logits_processor = LogitsProcessor(config.padded_vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 0894f750e5fbf..f63cf246e510a 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -321,6 +321,9 @@ def __init__( ) -> None: super().__init__() self.config = config + # currently all existing command R models have `tie_word_embeddings` + # enabled + assert config.tie_word_embeddings self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index 7ebeca1a359ef..dca959798e8b2 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -362,6 +362,9 @@ def __init__( ): super().__init__() self.config = config + if config.tie_word_embeddings: + raise ValueError( + "tie_word_embeddings is not supported for Dbrx models.") self.quant_config = quant_config self.unpadded_vocab_size = config.vocab_size self.transformer = DbrxModel(config, cache_config, quant_config) diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index f10977ed2c90d..7a27e1388e987 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -380,6 +380,8 @@ def __init__( self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 14d1578863e5e..e1041edf81b0a 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -33,7 +33,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding +from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -148,14 +148,12 @@ def __init__(self, quant_config=quant_config, ) - # TODO(woosuk): Use the `get_rope` interface. - self.rotary_emb = GemmaRotaryEmbedding( + self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.head_dim, - max_position_embeddings=max_position_embeddings, + max_position=max_position_embeddings, base=self.rope_theta, is_neox_style=True, - dtype=torch.get_default_dtype(), ) self.attn = Attention(self.num_heads, self.head_dim, @@ -333,6 +331,8 @@ def __init__( super().__init__() self.config = config + # currently all existing Gemma models have `tie_word_embeddings` enabled + assert config.tie_word_embeddings self.lora_config = lora_config self.quant_config = quant_config diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index aa9cff02283c0..5e0f8b70d4b80 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -32,7 +32,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding +from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -130,14 +130,12 @@ def __init__(self, bias=config.attention_bias, quant_config=quant_config, ) - # TODO(woosuk): Use the `get_rope` interface. - self.rotary_emb = GemmaRotaryEmbedding( + self.rotary_emb = get_rope( self.head_dim, - self.head_dim, - max_position_embeddings, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, base=self.rope_theta, is_neox_style=True, - dtype=torch.get_default_dtype(), ) # FIXME(woosuk): While Gemma 2 uses sliding window attention for every @@ -325,6 +323,8 @@ def __init__( del lora_config # Unused. super().__init__() self.config = config + # currently all existing Gemma models have `tie_word_embeddings` enabled + assert config.tie_word_embeddings self.quant_config = quant_config self.model = Gemma2Model(config, cache_config, quant_config) self.logits_processor = LogitsProcessor( diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 4f2fe0c42a3ff..bfc231282952a 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -36,7 +36,7 @@ QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) + ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput @@ -249,7 +249,11 @@ def __init__( cache_config, quant_config, prefix="transformer") - self.lm_head = self.transformer.wte + if self.config.tie_word_embeddings: + self.lm_head = self.transformer.wte + else: + self.lm_head = ParallelLMHead(self.config.vocab_size, + self.config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index b30af3599aa4d..b93fb8d69b2d7 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -36,7 +36,7 @@ QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) + ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput @@ -259,7 +259,13 @@ def __init__( self.quant_config = quant_config self.transformer = GPTBigCodeModel(config, cache_config, quant_config, lora_config) - self.lm_head = self.transformer.wte + if self.config.tie_word_embeddings: + self.lm_head = self.transformer.wte + else: + self.lm_head = ParallelLMHead( + self.transformer.vocab_size, + self.transformer.embed_dim, + org_num_embeddings=self.config.vocab_size) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index e61b4448981e8..2adecf7fa9ef8 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -230,7 +230,7 @@ class GPTNeoXForCausalLM(nn.Module): def __init__( self, - config, + config: GPTNeoXConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): @@ -243,6 +243,8 @@ def __init__( config.hidden_size, quant_config=quant_config, ) + if self.config.tie_word_embeddings: + self.embed_out.weight = self.gpt_neox.embed_in.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 216458465513a..887a353df972c 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -264,6 +264,8 @@ def __init__( self.output = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) + if self.config.tie_word_embeddings: + self.output.weight = self.model.tok_embeddings.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index ec6bea920cc3a..a550f7e6c97a1 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -37,7 +37,7 @@ QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) + ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput @@ -291,7 +291,11 @@ def __init__( self.config = config self.quant_config = quant_config self.transformer = JAISModel(config, cache_config, quant_config) - self.lm_head = self.transformer.wte + if self.config.tie_word_embeddings: + self.lm_head = self.transformer.wte + else: + self.lm_head = ParallelLMHead(self.config.vocab_size, + self.config.hidden_size) if hasattr(config, "width_scale"): self.output_logits_scale = config.width_scale else: diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 46db364895b13..6433ea380cbfe 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -313,7 +313,7 @@ def forward( 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`. To reserve space in KV cache, we have to insert placeholder tokens - before they are inputted to the model, so the input processor prepends + before they are inputted to the model, so the input processor prepends additional image tokens (denoted as `32000`), resulting in: `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618, 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, @@ -331,7 +331,7 @@ def forward( input_ids: Flattened (concatenated) input_ids corresponding to a batch. pixel_values: The pixels in each input image. - + See also: :class:`LlavaImageInputs` """ diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index c1277359182e4..c7cb243fa84da 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -545,7 +545,7 @@ def forward( 9047, 13566, 29901]`. To reserve space in KV cache, we have to insert placeholder tokens - before they are inputted to the model, so the input processor prepends + before they are inputted to the model, so the input processor prepends additional image tokens (denoted as `32000`), resulting in: `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255, 29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568, @@ -566,7 +566,7 @@ def forward( batch. pixel_values: The pixels in each grid patch for each input image. image_sizes: The original `(height, width)` for each input image. - + See also: :class:`LlavaNextImageInputs` """ diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 729bd27c334d5..99a3c5dab39e4 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -496,6 +496,10 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, ): super().__init__() + # All MiniCPM-V models disable `tie_word_embeddings` but + # `PretrainedConfig.tie_word_embeddings` defaults to True; we cannot + # check `tie_word_embeddings` until vLLM integrate MiniCPM-V model + # and config class self.config = config self.multimodal_config = multimodal_config diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 587d2f26a2d5e..34f581ac78582 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -359,6 +359,8 @@ def __init__( if not lora_config else lora_config.lora_vocab_padding_size, quant_config=quant_config, ) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 812dce5d04771..8bdd52b343175 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -347,6 +347,8 @@ def __init__( self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index b05f799e4dd2b..c0d2d537e731f 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -36,7 +36,7 @@ QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) + ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput @@ -307,7 +307,11 @@ def __init__( self.config = config self.quant_config = quant_config self.model = OPTModel(config, cache_config, quant_config) - self.lm_head = self.model.decoder.embed_tokens + if self.config.tie_word_embeddings: + self.lm_head = self.model.decoder.embed_tokens + else: + self.lm_head = ParallelLMHead(config.vocab_size, + config.word_embed_proj_dim) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index 6923e11e288be..fab35f0b882a7 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -262,6 +262,8 @@ def __init__( self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 54f4dd2fcde0a..f31b5162aac96 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -260,6 +260,8 @@ def __init__( super().__init__() self.config = config + # lm_head use bias, cannot share word embeddings + assert not config.tie_word_embeddings self.lora_config = lora_config self.quant_config = quant_config diff --git a/vllm/model_executor/models/phi3_small.py b/vllm/model_executor/models/phi3_small.py index 98e344d483e29..df01bfa3d8e6e 100644 --- a/vllm/model_executor/models/phi3_small.py +++ b/vllm/model_executor/models/phi3_small.py @@ -368,6 +368,8 @@ def __init__( padding_size=DEFAULT_VOCAB_PADDING_SIZE, quant_config=quant_config, ) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -449,4 +451,3 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - self.lm_head.weight.data.copy_(self.model.embed_tokens.weight.data) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 1c8bb8a837c86..328f4e6fa827c 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -477,6 +477,8 @@ def __init__(self, self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index a7485bcb489a0..b7d017d5f3ea6 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -252,6 +252,8 @@ def __init__( self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.transformer.wte.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index e160c9a320820..6f838947fbf27 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -385,6 +385,8 @@ def __init__( self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index c98226d61a8a0..decbf89d27c7c 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -243,6 +243,8 @@ def __init__( self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py index e9bf67d314d0a..c0bafa9367e43 100644 --- a/vllm/model_executor/models/xverse.py +++ b/vllm/model_executor/models/xverse.py @@ -313,6 +313,8 @@ def __init__( self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py index 99ba940e5d2ab..958f6c516a2f8 100644 --- a/vllm/platforms/__init__.py +++ b/vllm/platforms/__init__.py @@ -1,23 +1,54 @@ -import torch - from .interface import Platform, PlatformEnum, UnspecifiedPlatform current_platform: Platform +# NOTE: we don't use `torch.version.cuda` / `torch.version.hip` because +# they only indicate the build configuration, not the runtime environment. +# For example, people can install a cuda build of pytorch but run on tpu. + +is_tpu = False +try: + import torch_xla.core.xla_model as xm + xm.xla_device(devkind="TPU") + is_tpu = True +except Exception: + pass + +is_cuda = False + +try: + import pynvml + pynvml.nvmlInit() + try: + if pynvml.nvmlDeviceGetCount() > 0: + is_cuda = True + finally: + pynvml.nvmlShutdown() +except Exception: + pass + +is_rocm = False + try: - import libtpu -except ImportError: - libtpu = None + import amdsmi + amdsmi.amdsmi_init() + try: + if len(amdsmi.amdsmi_get_processor_handles()) > 0: + is_rocm = True + finally: + amdsmi.amdsmi_shut_down() +except Exception: + pass -if libtpu is not None: +if is_tpu: # people might install pytorch built with cuda but run on tpu # so we need to check tpu first from .tpu import TpuPlatform current_platform = TpuPlatform() -elif torch.version.cuda is not None: +elif is_cuda: from .cuda import CudaPlatform current_platform = CudaPlatform() -elif torch.version.hip is not None: +elif is_rocm: from .rocm import RocmPlatform current_platform = RocmPlatform() else: diff --git a/vllm/sequence.py b/vllm/sequence.py index f6c4a5a50ffc0..2fe8ae9d7b270 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -9,7 +9,6 @@ Tuple, Union, cast) import msgspec -import numpy import torch from vllm.inputs.parse import is_valid_encoder_decoder_llm_inputs @@ -1082,7 +1081,10 @@ class SamplerOutput( # On-device tensor containing the sampled token ids. sampled_token_ids: Optional[torch.Tensor] = None - sampled_token_ids_numpy: Optional[numpy.ndarray] = None + # CPU tensor containing the sampled token ids. Used during multi-step to + # return the sampled token ids from last rank to AsyncLLMEngine to be + # 'broadcasted' to all other PP ranks for next step. + sampled_token_ids_cpu: Optional[torch.Tensor] = None # Spec decode metrics populated by workers. spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None @@ -1308,9 +1310,7 @@ def is_last_step(self) -> bool: assert len(self.seq_group_metadata_list) > 0 first_seq_group = self.seq_group_metadata_list[0] assert first_seq_group.state is not None - num_steps = first_seq_group.state.num_steps - current_step = first_seq_group.state.current_step - return num_steps - current_step == 1 + return first_seq_group.state.remaining_steps == 1 @property def current_step(self) -> int: diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 46ac16b504bf4..90c39407d7266 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -14,7 +14,7 @@ from vllm.attention.backends.abstract import AttentionBackend from vllm.model_executor import SamplingMetadata -T = TypeVar('T', bound="ModelRunnerInputBase") +T = TypeVar('T', bound="BroadcastableModelInput") def _add_attn_metadata_broadcastable_dict( @@ -81,18 +81,26 @@ def _add_sampling_metadata_broadcastable_dict( sampling_metadata.selected_token_indices) -@dataclasses.dataclass(frozen=True) -class ModelRunnerInputBase(ABC): - """Local inputs to each worker's model runner. May contain - device-specific data. Different worker backends may have different methods - of converting from the global ExecuteModelRequest produced by the LLM - engine to the worker-local ModelRunnerInputBase objects. - - Model runners that support multi-GPU execution should define a - ModelRunnerInputBase subclass, add their required fields, and specify how to - serialize/deserialize a ModelInput for broadcast between workers. +def _init_frozen_model_input_from_tensor_dict( + frozen_model_input_cls: Type["ModelRunnerInputBase"], + tensor_dict: Dict[str, Any]) -> Dict[str, Any]: """ + Helper method to initialize a frozen ModelInput based on broadcastable + """ + valid_tensor_kwargs = {} + for field in dataclasses.fields(frozen_model_input_cls): + val = tensor_dict.pop(field.name, None) + if val is not None: + valid_tensor_kwargs[field.name] = val + + frozen_model_input = frozen_model_input_cls(**valid_tensor_kwargs) + tensor_dict["frozen_model_input"] = frozen_model_input + return tensor_dict + +class BroadcastableModelInput(ABC): + + @abstractmethod def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: """ Extract broadcastable fields. Override for fields that require some @@ -109,11 +117,25 @@ def from_broadcasted_tensor_dict( ) -> T: """ Pop fields from the given tensor_dict and populate a new instance of - ModelRunnerInputBase. + BroadcastableModelInput. """ raise NotImplementedError +@dataclasses.dataclass(frozen=True) +class ModelRunnerInputBase(BroadcastableModelInput): + """Local inputs to each worker's model runner. May contain + device-specific data. Different worker backends may have different methods + of converting from the global ExecuteModelRequest produced by the LLM + engine to the worker-local ModelRunnerInputBase objects. + + Model runners that support multi-GPU execution should define a + ModelRunnerInputBase subclass, add their required fields, and specify how to + serialize/deserialize a ModelInput for broadcast between workers. + """ + pass + + class ModelRunnerInputBuilderBase(ABC, Generic[T]): """A builder to create ModelRunnerInputBase objects. """ diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py new file mode 100644 index 0000000000000..521205eca05af --- /dev/null +++ b/vllm/worker/multi_step_model_runner.py @@ -0,0 +1,453 @@ +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +try: + from vllm.attention.backends.flash_attn import FlashAttentionMetadata +except ModuleNotFoundError: + # vllm_flash_attn is not installed, use the identical ROCm FA metadata + from vllm.attention.backends.rocm_flash_attn import ( + ROCmFlashAttentionMetadata as FlashAttentionMetadata) + +import torch + +from vllm import _custom_ops as ops +from vllm.distributed import get_pp_group +from vllm.logger import init_logger +from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, + Logprob, SamplerOutput, SequenceGroupMetadata, + SequenceOutput) +from vllm.worker.model_runner import (GPUModelRunnerBase, + ModelInputForGPUWithSamplingMetadata) +from vllm.worker.model_runner_base import ( + BroadcastableModelInput, _init_attn_metadata_from_tensor_dict, + _init_frozen_model_input_from_tensor_dict, + _init_sampling_metadata_from_tensor_dict) + +from ..model_executor.model_loader.tensorizer import TensorizerConfig + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + +logger = init_logger(__name__) + + +@dataclass +class ModelOutput: + """The output of a single model forward pass. + + The sampler_output_ready_event is set when the tensors in + sampler_output are ready (the model+sampler forward pass has + completed). We use the event to synchronize the GPU->CPU transfer, + which we want to only run when the data has been written to the + GPU tensors. Until the event is ready, the tensors in sampler_output + will have garbage data. + + There are two scenarios: + 1. The output tensors are ready and we can pythonize them immediately. + 2. The output tensors are not ready and we need to wait for the event to be + ready. + """ + sampler_output: SamplerOutput + sampler_output_ready_event: torch.cuda.Event + sampled_token_ids: Optional[torch.Tensor] = None + pythonized: bool = False + + def pythonize(self, input_metadata: "StatefulModelInput", + copy_stream: torch.cuda.Stream, + pinned_sampled_token_buffer: torch.Tensor) -> None: + """Pythonize the output. Blocking.""" + if not self.pythonized: + self._pythonize_sampler_output(input_metadata, copy_stream, + pinned_sampled_token_buffer, True) + self.pythonized = True + + def maybe_pythonize(self, input_metadata: "StatefulModelInput", + copy_stream: torch.cuda.Stream, + pinned_sampled_token_buffer: torch.Tensor) -> None: + """Pythonize the output if ready, else return None. Non-blocking.""" + if not self.pythonized: + self.pythonized = self._pythonize_sampler_output( + input_metadata, copy_stream, pinned_sampled_token_buffer, + False) + + def _pythonize_sampler_output(self, input_metadata: "StatefulModelInput", + copy_stream: torch.cuda.Stream, + pinned_sampled_token_buffer: torch.Tensor, + blocking: bool) -> bool: + """ + If blocking is set, will block until the forward pass for the output is + ready and pythonize the output. + """ + assert self.sampled_token_ids is not None + if not blocking and not self.sampler_output_ready_event.query(): + return False + + if blocking: + self.sampler_output_ready_event.synchronize() + with torch.cuda.stream(copy_stream): + _pythonize_sampler_output(input_metadata, self.sampler_output, + pinned_sampled_token_buffer, + self.sampled_token_ids) + return True + + +@dataclass(frozen=False) +class StatefulModelInput(BroadcastableModelInput): + # actual frozen model input dataclass passed to _base_model_runner + frozen_model_input: Optional[ModelInputForGPUWithSamplingMetadata] = None + + # list of model outputs for each step, may not be all pythonized + cached_outputs: List[ModelOutput] = field(default_factory=list) + + # used to pass sampled token ids from the last step to the current step for + # TP workers. Used to append to end of outputs and used by advance_step + last_sampled_token_ids: Optional[torch.Tensor] = None + current_step: int = 0 + is_multi_step: bool = True + is_last_step: bool = False + is_first_multi_step: bool = False + # ping-pong data structures for multi-step to wait on the previous step + step_cuda_events: List[torch.cuda.Event] = field( + default_factory=lambda: [torch.cuda.Event(blocking=True)] * 2) + num_seqs: int = -1 + num_queries: int = -1 + + def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: + assert self.frozen_model_input is not None + tensor_dict = self.frozen_model_input.as_broadcastable_tensor_dict() + new_tensor_dict = { + 'last_sampled_token_ids': self.last_sampled_token_ids, + 'current_step': self.current_step, + 'is_multi_step': self.is_multi_step, + 'is_last_step': self.is_last_step, + 'is_first_multi_step': self.is_first_multi_step, + 'num_seqs': self.num_seqs, + 'num_queries': self.num_queries, + } + tensor_dict.update(new_tensor_dict) + return tensor_dict + + @classmethod + def from_broadcasted_tensor_dict( + cls, + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> "StatefulModelInput": + tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict) + if attn_backend is not None: + tensor_dict = _init_attn_metadata_from_tensor_dict( + attn_backend, tensor_dict) + tensor_dict = _init_frozen_model_input_from_tensor_dict( + ModelInputForGPUWithSamplingMetadata, tensor_dict) + + return cls(**tensor_dict) + + def record_step_event(self, current_stream: torch.cuda.Stream): + # record the event for the current step so that the next step can sync + # on it. We modulo by 2 to keep the events in a circular buffer and + # support any attn backends that may be supported in the future. ie + # Flashinfer would want two DecodeWrappers to overlap the CPU and GPU. + self.step_cuda_events[self.current_step & 1] = \ + torch.cuda.Event(blocking=True) + self.step_cuda_events[self.current_step & 1].record(current_stream) + + def wait_previous_step(self): + # These cuda events are an explicit synchronization to ensure that + # advance_step() (for other attn backends that may be supported in the + # future) do not clobber any data structures that is also used by any + # enqueued forwards steps. For distributed case, only a single event is + # needed, but for single GPU case, since we can let the CPU run much + # further ahead, two events allow us to overlap the advance_step with + # the previous forward (ie using two DecodeWrappers for flashinfer + # backend) + self.step_cuda_events[(self.current_step + 1) & 1].wait() + + def add_sampler_output(self, + sampler_output: SamplerOutput, + sampled_token_ids: Optional[torch.Tensor] = None): + self.cached_outputs.append( + ModelOutput(sampler_output=sampler_output, + sampler_output_ready_event=None, + sampled_token_ids=sampled_token_ids, + pythonized=False)) + + +# MutableModelInputForGPUWithMultiStepMetadata is not subclass of +# ModelInputForGPU but it wraps the actual input dataclass and adds multi-step +# metadata +# mypy: disable-error-code=type-var +class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]): + # mypy: enable-error-code=type-var + + def __init__(self, base_model_runner: GPUModelRunnerBase, *args, **kwargs): + super().__init__(*args, **kwargs) + + # uses the base model runner to execute the model and wraps it with + # multi-step logic + self._base_model_runner: GPUModelRunnerBase = base_model_runner + + self.is_multi_step = self.scheduler_config.is_multi_step + # used to copy tensors from GPU to CPU asynchronously + self._copy_stream = torch.cuda.Stream() + self.pinned_sampled_token_ids: Optional[torch.Tensor] = None + + def make_model_input_from_broadcasted_tensor_dict( + self, tensor_dict: Dict[str, Any]) -> StatefulModelInput: + model_input = (StatefulModelInput.from_broadcasted_tensor_dict( + tensor_dict, + attn_backend=self.attn_backend, + )) + return model_input + + def prepare_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None + ) -> StatefulModelInput: + frozen_model_input = self._base_model_runner.prepare_model_input( + seq_group_metadata_list, virtual_engine, finished_requests_ids) + + model_input = StatefulModelInput( + frozen_model_input=frozen_model_input, + num_seqs=len(frozen_model_input.seq_lens), + num_queries=len(frozen_model_input.query_lens), + ) + return model_input + + @torch.inference_mode() + def execute_model( + self, + model_input: StatefulModelInput, + kv_caches: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: + """ + Execute the model for a single step and update multi-step + metadata + """ + assert num_steps == 1, "MultiStepModelRunner only supports num_steps=1" + frozen_model_input = model_input.frozen_model_input + assert frozen_model_input is not None + + # path for warm up runs + if not model_input.is_multi_step: + return self._base_model_runner.execute_model( + frozen_model_input, kv_caches, intermediate_tensors, num_steps) + + # make sure we skip the sampler on the lask rank and only pythonize + # if CPU is ahead. + if self.is_driver_worker and get_pp_group().is_last_rank: + if self.pinned_sampled_token_ids is None: + self.pinned_sampled_token_ids = torch.zeros( + (self.scheduler_config.max_num_seqs, 1), + dtype=torch.long, + device="cpu", + pin_memory=True) + + self._base_model_runner.model.sampler.include_gpu_probs_tensor = ( + True) + if frozen_model_input.sampling_metadata: + frozen_model_input.sampling_metadata.skip_sampler_cpu_output = ( + True) + + # some pre-execute model logic for multi-step: + # - if it's the first step, we need to reset the sampling tensors + # - if it's not the first step, we need to advance the step using the + # appended sampler output from last iteration + # - also maybe pythonize if CPU is ahead of GPU + + current_stream = torch.cuda.current_stream() + if not model_input.is_first_multi_step: + # Explicitly block on the previous step's forward to make sure we + # don't clobber any GPU tensors still in use. + # This is not needed for flashattn backend, but for other attn + # backends such as flashinfer that performs extra CPU operations on + # input metadata we may need to synchronize any CPU operations that + # might clobber enqueued forwards. (prevents CPU from running too + # far ahead if needed) + model_input.wait_previous_step() + model_input = self._advance_step( + model_input, model_input.cached_outputs[-1].sampler_output) + + # Execute the model + output = self._base_model_runner.execute_model(frozen_model_input, + kv_caches, + intermediate_tensors, + num_steps=1) + + # record the event for the current step so that the next step can sync + model_input.record_step_event(current_stream) + + if get_pp_group().is_last_rank and self.is_driver_worker: + assert len( + output + ) == 1, "MultiStepModelRunner requires single-step base_models" + + # event for the pythonization so that we only pythonize if the + # tensors are ready. May be able to be combined with the step event + output_ready_event = torch.cuda.Event() + output_ready_event.record(current_stream) + if self.parallel_config.pipeline_parallel_size > 1: + output[0].sampled_token_ids_cpu = output[ + 0].sampled_token_ids.cpu() + model_input.cached_outputs.append( + ModelOutput(output[0], output_ready_event, + output[0].sampled_token_ids, False)) + # make sure we dont try to serialize any GPU tensors + output[0].sampled_token_ids = None + output[0].sampled_token_probs = None + output[0].logprobs = None + # Pythonize the output if CPU is ahead and the previous step is + # ready. + for model_output in model_input.cached_outputs: + model_output.maybe_pythonize(model_input, self._copy_stream, + self.pinned_sampled_token_ids) + + model_input.current_step += 1 + + if not get_pp_group().is_last_rank: + # Should be IntermediateTensors + assert isinstance(output, IntermediateTensors) + return output + if not self.is_driver_worker: + return [] + + # Pythonize the output and block if needed since it is the last step + if model_input.is_last_step: + outputs = [] + for output in model_input.cached_outputs: + output.pythonize(model_input, self._copy_stream, + self.pinned_sampled_token_ids) + outputs.append(output.sampler_output) + return outputs + + # should be [SamplerOutput] + return output + + def _update_sampling_metadata(self, sampling_metadata, num_seqs, + num_queries): + + assert sampling_metadata.num_prompts == 0 + assert len(sampling_metadata.seq_groups) == num_queries + assert sampling_metadata.selected_token_indices.shape == ( + num_queries, ) + # assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501 + + # Verify that all sequences are decodes + for i in range(num_queries): + seq_group = sampling_metadata.seq_groups[i] + + assert seq_group.is_prompt is False # No prompt + assert seq_group.prompt_logprob_indices == [] # No prompt + assert seq_group.sample_indices == [i] # Simple + assert seq_group.seq_len is None # Decode + assert seq_group.query_len is None # Decode + + def _advance_step(self, model_input: StatefulModelInput, + out: SamplerOutput) -> StatefulModelInput: + frozen_model_input = model_input.frozen_model_input + assert frozen_model_input is not None + assert frozen_model_input.attn_metadata is not None + + num_seqs = model_input.num_seqs + num_queries = model_input.num_queries + assert num_seqs > 0 + assert num_queries > 0 + assert num_seqs >= num_queries + + attn_metadata = frozen_model_input.attn_metadata + assert isinstance(attn_metadata, FlashAttentionMetadata) + attn_metadata.advance_step(num_seqs, num_queries) + + # Update GPU tensors + ops.advance_step( + num_seqs=num_seqs, + num_queries=num_queries, + block_size=self.block_size, + input_tokens=frozen_model_input.input_tokens, + sampled_token_ids=model_input.cached_outputs[-1].sampled_token_ids, + input_positions=frozen_model_input.input_positions, + seq_lens=attn_metadata.seq_lens_tensor, + slot_mapping=attn_metadata.slot_mapping, + block_tables=attn_metadata.block_tables) + + if frozen_model_input.seq_lens is not None: + for i in range(num_queries): + frozen_model_input.seq_lens[i] = attn_metadata.seq_lens[i] + + return model_input + + def load_model(self) -> None: + return self._base_model_runner.load_model() + + def save_sharded_state( + self, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None, + ) -> None: + return self._base_model_runner.save_sharded_state( + path, pattern, max_size) + + def save_tensorized_model(self, + tensorizer_config: TensorizerConfig) -> None: + return self._base_model_runner.save_tensorized_model(tensorizer_config) + + def profile_run(self) -> None: + return self._base_model_runner.profile_run() + + def remove_all_loras(self): + return self._base_model_runner.remove_all_loras() + + def capture_model(self, kv_caches: List[List]) -> None: + return self._base_model_runner.capture_model(kv_caches) + + @property + def vocab_size(self) -> int: + return self._base_model_runner.vocab_size + + +def _pythonize_sampler_output(model_input: StatefulModelInput, + output: SamplerOutput, + pinned_sampled_token_buffer: torch.Tensor, + sampled_token_ids: torch.Tensor) -> None: + """ This function is only called when the output tensors are ready. + See ModelOutput + """ + + assert model_input.frozen_model_input is not None + + frozen_model_input = model_input.frozen_model_input + assert frozen_model_input.sampling_metadata is not None + # samples generation should have been skipped + assert not output.outputs + + pinned_buffer = pinned_sampled_token_buffer[:model_input.num_queries] + + # CPU GPU sync + pinned_buffer = pinned_buffer.copy_(sampled_token_ids, non_blocking=False) + + # this will not block as the tensors are already on CPU + samples_list = pinned_buffer.tolist() + + sampling_metadata = frozen_model_input.sampling_metadata + + for (seq_group, sample_result) in zip(sampling_metadata.seq_groups, + samples_list): + seq_ids = seq_group.seq_ids + next_token_ids = sample_result + parent_ids = [0] + seq_outputs: List[SequenceOutput] = [] + if seq_group.sampling_params.logits_processors: + assert len(seq_group.sampling_params.logits_processors) == 0, ( + "Logits Processors are not supported in multi-step decoding") + for parent_id, next_token_id in zip(parent_ids, next_token_ids): + # TODO(will): support logprobs + # Hard coded logprob + seq_outputs.append( + SequenceOutput(seq_ids[parent_id], next_token_id, + {next_token_id: Logprob(logprob=-1)})) + output.outputs.append(CompletionSequenceGroupOutput(seq_outputs, None)) + assert len(output.outputs) > 0 diff --git a/vllm/worker/multi_step_worker.py b/vllm/worker/multi_step_worker.py new file mode 100644 index 0000000000000..6a6caba9371eb --- /dev/null +++ b/vllm/worker/multi_step_worker.py @@ -0,0 +1,189 @@ +from dataclasses import dataclass +from typing import List, Optional, Tuple + +from vllm.distributed import broadcast_tensor_dict, get_pp_group +from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.worker.model_runner_base import BroadcastableModelInput +from vllm.worker.multi_step_model_runner import (MultiStepModelRunner, + StatefulModelInput) +from vllm.worker.worker import Worker, WorkerInput + + +@dataclass +class MultiStepState: + worker_input: WorkerInput + model_input: StatefulModelInput + + +class MultiStepWorker(Worker): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + base_model_runner = self.model_runner + # for multi-step model, wrap the model runner with MultiStepModelRunner + self.model_runner = MultiStepModelRunner( + base_model_runner, + base_model_runner.model_config, + base_model_runner.parallel_config, + base_model_runner.scheduler_config, + base_model_runner.device_config, + base_model_runner.cache_config, + load_config=base_model_runner.load_config, + lora_config=self.lora_config, + kv_cache_dtype=self.cache_config.cache_dtype, + is_driver_worker=base_model_runner.is_driver_worker, + prompt_adapter_config=base_model_runner.prompt_adapter_config, + observability_config=base_model_runner.observability_config, + ) + + pipeline_parallel_size = self.parallel_config.pipeline_parallel_size + self.multi_step_states: List[ + Optional[MultiStepState]] = [None] * pipeline_parallel_size + self.temp_output = None + + def _get_driver_input_and_broadcast( + self, execute_model_req: ExecuteModelRequest + ) -> Tuple[BroadcastableModelInput, WorkerInput]: + """ + Get the driver input and broadcast it to other workers. + """ + assert self.is_driver_worker + virtual_engine = execute_model_req.virtual_engine + is_first_multi_step = execute_model_req.is_first_multi_step + if is_first_multi_step: + # on first step we prepare the worker input and model input normally + worker_input: WorkerInput = self.prepare_worker_input( + execute_model_req=execute_model_req) + model_input: StatefulModelInput = ( + self.model_runner.prepare_model_input( + execute_model_req.seq_group_metadata_list, + execute_model_req.virtual_engine, + execute_model_req.finished_requests_ids)) + else: + # on subsequent steps we reuse the worker input and model input + multi_step_state = self.multi_step_states[virtual_engine] + worker_input = multi_step_state.worker_input + model_input = multi_step_state.model_input + frozen_model_input = model_input.frozen_model_input + assert frozen_model_input is not None + assert frozen_model_input.attn_metadata is not None + # clear the cached decode metadata so that it can be recomputed on + # the workers + frozen_model_input.attn_metadata._cached_decode_metadata = None + + model_input.is_first_multi_step = is_first_multi_step + model_input.is_last_step = execute_model_req.is_last_step + + if not is_first_multi_step: + # we broadcast the last sampled token ids to all TP workers so they + # can update their model input metadata in-place. + self._prepare_last_sampled_token_ids_for_tp_workers( + execute_model_req=execute_model_req, model_input=model_input) + + if self.do_metadata_broadcast: + broadcast_data = worker_input.as_broadcastable_tensor_dict() + broadcast_data.update(model_input.as_broadcastable_tensor_dict()) + broadcast_tensor_dict(broadcast_data, src=0) + + return model_input, worker_input + + def _prepare_last_sampled_token_ids_for_tp_workers( + self, + execute_model_req: ExecuteModelRequest, + model_input: StatefulModelInput, + ) -> None: + """ + Prepare the last sampled token ids for TP workers. If it's the last + PP rank, then the last sampled token ids are already in the model_input. + If it is NOT the last PP rank, then we need to get the last sampled + token that is cached in the execute_model_req. + """ + if get_pp_group().is_last_rank: + assert model_input.cached_outputs[ + -1].sampler_output.sampled_token_ids is None + assert model_input.cached_outputs[-1].sampled_token_ids is not None + model_input.last_sampled_token_ids = model_input.cached_outputs[ + -1].sampled_token_ids + # free sampled token ids from the previous step if it has been + # pythonized. Cannot free the last sampled token ids because + # we need it for GPU advance_step. + for output in model_input.cached_outputs[:-1]: + if output.pythonized: + output.sampled_token_ids = None + else: + # otherwise we need to get the cached sampled token ids from the + # execute_model_req + assert execute_model_req.last_sampled_token_ids is not None + model_input.last_sampled_token_ids = ( + execute_model_req.last_sampled_token_ids.cuda()) + model_input.add_sampler_output( + SamplerOutput(outputs=[], sampled_token_ids=None), + model_input.last_sampled_token_ids) + + # free sampled token ids from the previous step. + # TODO(will) we could reuse the sampled token ids tensor from + # the previous step instead. + for output in model_input.cached_outputs[:-1]: + output.sampled_token_ids = None + assert model_input.cached_outputs[-1].sampled_token_ids is not None + + def prepare_input( + self, + execute_model_req: Optional[ExecuteModelRequest] = None, + ) -> Optional[Tuple[StatefulModelInput, WorkerInput]]: + """ + Depending on the current state of the request and multi step worker, + this method may skip the normal _prepare_model_input and + _prepare_worker_input methods and instead used cached values. + """ + if self.is_driver_worker: + if execute_model_req is None: + if self.do_metadata_broadcast: + # This signals that there's no more requests to process for + # now. All workers are running infinite loop with + # broadcast_tensor_dict, and it stops the loop when the + # driver broadcasts an empty input. Send an empty input to + # notify all other workers to stop their execution loop. + broadcast_tensor_dict({}, src=0) + return None + + virtual_engine = execute_model_req.virtual_engine + model_input, worker_input = self._get_driver_input_and_broadcast( + execute_model_req) + assert isinstance(model_input, StatefulModelInput) + if execute_model_req.is_first_multi_step: + # cache the worker input and model input for the next steps + self.multi_step_states[virtual_engine] = MultiStepState( + worker_input=worker_input, model_input=model_input) + # if TP workers + else: + broadcast_data = self._get_worker_input_from_broadcast() + # if the driver has sent an empty input, we should stop the worker + # loop + if broadcast_data is None: + return None + model_input, worker_input = broadcast_data + assert isinstance(model_input, StatefulModelInput) + virtual_engine = worker_input.virtual_engine + if model_input.is_first_multi_step: + pass + # TODO(will) Can cache the worker input and model input for the + # next steps. See below for details + else: + # TODO(will) possible to also cache and reuse the cached worker + # input and model input. The idea is essentially the delta + # optimization for model_inputs. Where the TP workers can cache + # the model input states and we only broadcast the delta need + # for the next step (sampled_token_ids from the previous step) + + assert isinstance(model_input, StatefulModelInput) + # we need to update the last sampled token ids in the model + # input for the workers so that they can run inplace + # advance_step + model_input.add_sampler_output( + SamplerOutput(outputs=[], sampled_token_ids=None), + model_input.last_sampled_token_ids) + + assert model_input is not None + assert worker_input is not None + return model_input, worker_input diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 14f14e40b4c0b..01daa64b5a32f 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -516,27 +516,19 @@ def execute_model( raise ValueError( "TPUModelRunner does not support multi-step execution.") - def _execute_model(*args, clone: bool = False) -> torch.Tensor: + def _execute_model(*args): """Move input args from CPU to device and execute the model.""" - def _copy_to_device(x: torch.Tensor) -> torch.Tensor: - if clone: - # When x is a slice of a CPU tensor, XLA may copy the whole - # original tensor to TPU instead of only copying x. - # To avoid this, we copy x after cloning. - x = x.clone() - return x.to(self.device) - new_args = [] for arg in args: if isinstance(arg, torch.Tensor): - arg = _copy_to_device(arg) + arg = arg.to(self.device) elif isinstance(arg, AttentionMetadata): - arg.slot_mapping = _copy_to_device(arg.slot_mapping) + arg.slot_mapping = arg.slot_mapping.to(self.device) if getattr(arg, "block_tables", None) is not None: - arg.block_tables = _copy_to_device(arg.block_tables) + arg.block_tables = arg.block_tables.to(self.device) if getattr(arg, "context_lens", None) is not None: - arg.context_lens = _copy_to_device(arg.context_lens) + arg.context_lens = arg.context_lens.to(self.device) new_args.append(arg) return self.model(*new_args) @@ -563,13 +555,9 @@ def _copy_to_device(x: torch.Tensor) -> torch.Tensor: output_token_ids = _execute_model( model_input.token_ids[None, start_idx:end_idx], model_input.position_ids[None, start_idx:end_idx], - model_input.attn_metadata, - model_input.input_lens[i:i + 1], - model_input.t[i:i + 1], - model_input.p[i:i + 1], - model_input.num_samples, - kv_caches, - clone=True) + model_input.attn_metadata, model_input.input_lens[i:i + 1], + model_input.t[i:i + 1], model_input.p[i:i + 1], + model_input.num_samples, kv_caches) # Retrieve the outputs to CPU. next_token_ids += output_token_ids.cpu().tolist() start_idx = end_idx diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index c9d0375321d14..d447a3de29169 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -16,7 +16,9 @@ SamplerOutput) from vllm.utils import (enable_trace_function_call_for_thread, update_environment_variables) -from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase +from vllm.worker.model_runner_base import (BroadcastableModelInput, + ModelRunnerBase, + ModelRunnerInputBase) logger = init_logger(__name__) @@ -221,7 +223,7 @@ def execute_worker(self, worker_input: WorkerInput) -> None: def _get_worker_input_from_broadcast( self - ) -> Optional[Tuple[ModelRunnerInputBase, WorkerInput, Dict[ + ) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[ str, torch.Tensor]]]: """ Get the worker input from the broadcasted tensor dict. """ assert self.do_metadata_broadcast @@ -241,7 +243,7 @@ def _get_worker_input_from_broadcast( def _get_driver_input_and_broadcast( self, execute_model_req: ExecuteModelRequest - ) -> Tuple[ModelRunnerInputBase, WorkerInput, Dict[str, torch.Tensor]]: + ) -> Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]: """ Get the driver input and broadcast it to other workers. """ assert self.is_driver_worker @@ -266,7 +268,7 @@ def _get_driver_input_and_broadcast( def prepare_input( self, execute_model_req: Optional[ExecuteModelRequest] = None - ) -> Optional[Tuple[ModelRunnerInputBase, WorkerInput, Dict[ + ) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[ str, torch.Tensor]]]: """ Prepare the inputs to ModelRunner and workers. @@ -469,8 +471,8 @@ def extract_previous_hidden_states( # When called from non-driver worker, data is dict but when called from # driver worker, data is ExecuteModelRequest. if isinstance(data, dict): - if "previous_hidden_states" in data: - output["previous_hidden_states"] = data["previous_hidden_states"] + output["previous_hidden_states"] = data.get("previous_hidden_states", + None) elif data.previous_hidden_states is not None: output["previous_hidden_states"] = data.previous_hidden_states\ .hidden_states